aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md21
-rw-r--r--WORKSPACE20
-rw-r--r--configure.py261
-rw-r--r--tensorflow/BUILD96
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api_experimental.cc49
-rw-r--r--tensorflow/c/c_api_experimental.h8
-rw-r--r--tensorflow/c/c_api_experimental_test.cc46
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py4
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc11
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass_test.cc112
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc107
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc31
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc1
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h6
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc1
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py7
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py40
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py43
-rw-r--r--tensorflow/compiler/tests/lstm.py2
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc75
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc9
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc30
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h51
-rw-r--r--tensorflow/compiler/tf2xla/type_util.h8
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc12
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc8
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h24
-rw-r--r--tensorflow/compiler/xla/executable_run_options.cc10
-rw-r--r--tensorflow/compiler/xla/executable_run_options.h8
-rw-r--r--tensorflow/compiler/xla/literal.h16
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD17
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc278
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h37
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc23
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc283
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc45
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto8
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc29
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h11
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc2
-rw-r--r--tensorflow/compiler/xla/shape_util.cc7
-rw-r--r--tensorflow/contrib/BUILD53
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc1
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py8
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py2
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py24
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt1
-rw-r--r--tensorflow/contrib/compiler/BUILD2
-rw-r--r--tensorflow/contrib/compiler/xla.py13
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py4
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py9
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py6
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_test.py4
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py10
-rw-r--r--tensorflow/contrib/data/BUILD38
-rw-r--r--tensorflow/contrib/data/ops/indexed_dataset_ops.cc80
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD41
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py43
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py15
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py5
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD26
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py103
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py7
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py5
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py73
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD57
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py12
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py9
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py12
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py28
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py16
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py19
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py228
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py12
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py11
-rw-r--r--tensorflow/contrib/data/python/ops/sliding.py4
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py8
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py13
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py9
-rw-r--r--tensorflow/contrib/decision_trees/proto/BUILD1
-rw-r--r--tensorflow/contrib/distribute/README.md3
-rw-r--r--tensorflow/contrib/distribute/python/BUILD5
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py3
-rw-r--r--tensorflow/contrib/distribute/python/examples/simple_estimator_example.py21
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py177
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py8
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py28
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py175
-rw-r--r--tensorflow/contrib/distribute/python/values.py398
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py5
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/BUILD51
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py98
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py323
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py150
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py162
-rw-r--r--tensorflow/contrib/eager/python/examples/linear_regression/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD1
-rw-r--r--tensorflow/contrib/estimator/BUILD46
-rw-r--r--tensorflow/contrib/estimator/__init__.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py19
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping.py35
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py2
-rw-r--r--tensorflow/contrib/factorization/BUILD10
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py28
-rw-r--r--tensorflow/contrib/fused_conv/BUILD43
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc4
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py891
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py945
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py2
-rw-r--r--tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py4
-rw-r--r--tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py6
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py6
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops.py8
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py12
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors_test.py10
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py8
-rw-r--r--tensorflow/contrib/lite/BUILD19
-rw-r--r--tensorflow/contrib/lite/build_def.bzl42
-rw-r--r--tensorflow/contrib/lite/delegates/flex/BUILD (renamed from tensorflow/contrib/lite/delegates/eager/BUILD)0
-rw-r--r--tensorflow/contrib/lite/delegates/flex/buffer_map.cc (renamed from tensorflow/contrib/lite/delegates/eager/buffer_map.cc)8
-rw-r--r--tensorflow/contrib/lite/delegates/flex/buffer_map.h (renamed from tensorflow/contrib/lite/delegates/eager/buffer_map.h)12
-rw-r--r--tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate.cc (renamed from tensorflow/contrib/lite/delegates/eager/delegate.cc)34
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate.h (renamed from tensorflow/contrib/lite/delegates/eager/delegate.h)26
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate_data.cc (renamed from tensorflow/contrib/lite/delegates/eager/delegate_data.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate_data.h (renamed from tensorflow/contrib/lite/delegates/eager/delegate_data.h)16
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/delegate_test.cc)14
-rw-r--r--tensorflow/contrib/lite/delegates/flex/kernel.cc (renamed from tensorflow/contrib/lite/delegates/eager/kernel.cc)30
-rw-r--r--tensorflow/contrib/lite/delegates/flex/kernel.h (renamed from tensorflow/contrib/lite/delegates/eager/kernel.h)12
-rw-r--r--tensorflow/contrib/lite/delegates/flex/kernel_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/kernel_test.cc)16
-rw-r--r--tensorflow/contrib/lite/delegates/flex/test_util.cc (renamed from tensorflow/contrib/lite/delegates/eager/test_util.cc)47
-rw-r--r--tensorflow/contrib/lite/delegates/flex/test_util.h (renamed from tensorflow/contrib/lite/delegates/eager/test_util.h)20
-rw-r--r--tensorflow/contrib/lite/delegates/flex/util.cc (renamed from tensorflow/contrib/lite/delegates/eager/util.cc)6
-rw-r--r--tensorflow/contrib/lite/delegates/flex/util.h (renamed from tensorflow/contrib/lite/delegates/eager/util.h)10
-rw-r--r--tensorflow/contrib/lite/delegates/flex/util_test.cc (renamed from tensorflow/contrib/lite/delegates/eager/util_test.cc)6
-rw-r--r--tensorflow/contrib/lite/examples/android/BUILD1
-rw-r--r--tensorflow/contrib/lite/examples/android/app/README.md37
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.cc5
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.h2
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc2
-rw-r--r--tensorflow/contrib/lite/g3doc/_book.yaml1
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md186
-rw-r--r--tensorflow/contrib/lite/g3doc/performance_benchmarks.md174
-rw-r--r--tensorflow/contrib/lite/interpreter.cc9
-rw-r--r--tensorflow/contrib/lite/interpreter.h7
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl53
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md4
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java26
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java29
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java2
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java104
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java48
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc15
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h9
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java48
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java9
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java15
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc113
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc169
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc51
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc39
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc32
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h54
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h74
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h104
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h941
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h60
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h799
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h75
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h103
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h134
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h1067
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h906
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/softmax.h23
-rw-r--r--tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc28
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h24
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h33
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_test.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h14
-rw-r--r--tensorflow/contrib/lite/kernels/log_softmax_test.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/softmax_test.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose.cc23
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc9
-rw-r--r--tensorflow/contrib/lite/model.cc8
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD1
-rw-r--r--tensorflow/contrib/lite/python/BUILD2
-rw-r--r--tensorflow/contrib/lite/python/convert.py21
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py12
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py4
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc4
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h2
-rw-r--r--tensorflow/contrib/lite/python/lite.py94
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py177
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py12
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h147
-rw-r--r--tensorflow/contrib/lite/testing/BUILD31
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py12
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py249
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py130
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h4
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.h2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc8
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h4
-rw-r--r--tensorflow/contrib/lite/toco/args.h4
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc15
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc4
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h2
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc22
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h4
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc26
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h6
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc24
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto16
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc8
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD12
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc14
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h12
-rwxr-xr-xtensorflow/contrib/lite/tools/make/download_dependencies.sh2
-rw-r--r--tensorflow/contrib/lite/util.cc6
-rw-r--r--tensorflow/contrib/lite/util.h8
-rw-r--r--tensorflow/contrib/lite/util_test.cc16
-rw-r--r--tensorflow/contrib/makefile/Makefile3
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt3
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py19
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py3
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py22
-rw-r--r--tensorflow/contrib/opt/BUILD5
-rw-r--r--tensorflow/contrib/opt/python/training/addsign_test.py12
-rw-r--r--tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/external_optimizer_test.py22
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer_test.py3
-rw-r--r--tensorflow/contrib/opt/python/training/powersign_test.py12
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py40
-rw-r--r--tensorflow/contrib/predictor/BUILD3
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py3
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py115
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py37
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py4
-rw-r--r--tensorflow/contrib/saved_model/BUILD6
-rw-r--r--tensorflow/contrib/session_bundle/exporter_test.py6
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD2
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py13
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py14
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert_test.py2
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD7
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py2
-rw-r--r--tensorflow/contrib/tpu/BUILD33
-rw-r--r--tensorflow/contrib/tpu/__init__.py4
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc666
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc3
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto8
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto33
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py27
-rw-r--r--tensorflow/contrib/tpu/python/tpu/async_checkpoint.py202
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py668
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py55
-rw-r--r--tensorflow/contrib/tpu/python/tpu/session_support.py58
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py25
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py7
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py17
-rw-r--r--tensorflow/contrib/tpu/utils/BUILD30
-rw-r--r--tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc255
-rw-r--r--tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h90
-rw-r--r--tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc98
-rw-r--r--tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h38
-rw-r--r--tensorflow/contrib/training/BUILD1
-rw-r--r--tensorflow/contrib/training/python/training/device_setter_test.py8
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py4
-rw-r--r--tensorflow/core/BUILD58
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt21
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt58
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt25
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt28
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tile.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt6
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc34
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc82
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc57
-rw-r--r--tensorflow/core/common_runtime/direct_session.h23
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc28
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc16
-rw-r--r--tensorflow/core/common_runtime/executor.h6
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h1
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc4
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc75
-rw-r--r--tensorflow/core/common_runtime/ring_reducer_test.cc83
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc5
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc5
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc107
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/framework/function.cc8
-rw-r--r--tensorflow/core/framework/function.h5
-rw-r--r--tensorflow/core/framework/function_testlib.cc17
-rw-r--r--tensorflow/core/framework/node_def_util.h1
-rw-r--r--tensorflow/core/framework/op.h20
-rw-r--r--tensorflow/core/framework/op_def_builder.cc24
-rw-r--r--tensorflow/core/framework/op_def_builder.h14
-rw-r--r--tensorflow/core/framework/resource_mgr.cc9
-rw-r--r--tensorflow/core/framework/resource_mgr.h117
-rw-r--r--tensorflow/core/framework/run_handler.cc249
-rw-r--r--tensorflow/core/framework/run_handler.h95
-rw-r--r--tensorflow/core/framework/run_handler_util.cc57
-rw-r--r--tensorflow/core/framework/run_handler_util.h43
-rw-r--r--tensorflow/core/framework/run_handler_util_test.cc93
-rw-r--r--tensorflow/core/framework/tensor.cc2
-rw-r--r--tensorflow/core/framework/tensor.h2
-rw-r--r--tensorflow/core/framework/tensor_test.cc3
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc23
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc24
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc2
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc4
-rw-r--r--tensorflow/core/grappler/graph_view.cc29
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc29
-rw-r--r--tensorflow/core/grappler/grappler_item.cc1
-rw-r--r--tensorflow/core/grappler/grappler_item.h9
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc8
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.h2
-rw-r--r--tensorflow/core/grappler/grappler_item_builder_test.cc23
-rw-r--r--tensorflow/core/grappler/op_types.cc4
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc33
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD75
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc32
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.h36
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc24
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h29
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc289
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h55
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc84
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc41
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD3
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h23
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc451
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h35
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc205
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc50
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc156
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc33
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc32
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc12
-rw-r--r--tensorflow/core/grappler/utils.cc39
-rw-r--r--tensorflow/core/grappler/utils.h110
-rw-r--r--tensorflow/core/grappler/utils/functions.cc55
-rw-r--r--tensorflow/core/grappler/utils/functions.h5
-rw-r--r--tensorflow/core/grappler/utils_test.cc23
-rw-r--r--tensorflow/core/kernels/BUILD79
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc10
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc9
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD20
-rw-r--r--tensorflow/core/kernels/boosted_trees/boosted_trees.proto13
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc26
-rw-r--r--tensorflow/core/kernels/collective_ops.cc21
-rw-r--r--tensorflow/core/kernels/conv_ops.cc321
-rw-r--r--tensorflow/core/kernels/conv_ops.h44
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_xdivy.cc38
-rw-r--r--tensorflow/core/kernels/cwise_op_xlogy.cc41
-rw-r--r--tensorflow/core/kernels/cwise_ops.h45
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc4
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/experimental/BUILD (renamed from tensorflow/contrib/data/kernels/BUILD)90
-rw-r--r--tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/assert_next_dataset_op.cc)5
-rw-r--r--tensorflow/core/kernels/data/experimental/csv_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/csv_dataset_op.cc)3
-rw-r--r--tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc)5
-rw-r--r--tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc (renamed from tensorflow/contrib/data/kernels/identity_indexed_dataset.cc)7
-rw-r--r--tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc)6
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.cc (renamed from tensorflow/contrib/data/kernels/indexed_dataset.cc)14
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.h (renamed from tensorflow/contrib/data/kernels/indexed_dataset.h)6
-rw-r--r--tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/lmdb_dataset_op.cc)3
-rw-r--r--tensorflow/core/kernels/data/experimental/prefetching_kernels.cc (renamed from tensorflow/contrib/data/kernels/prefetching_kernels.cc)23
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/threadpool_dataset_op.cc)7
-rw-r--r--tensorflow/core/kernels/data/experimental/unique_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/unique_dataset_op.cc)7
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc111
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc34
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc42
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc24
-rw-r--r--tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc32
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/matmul_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_slice_op.cc358
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc12
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc64
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.h10
-rw-r--r--tensorflow/core/kernels/slice_op.cc199
-rw-r--r--tensorflow/core/kernels/string_length_op.cc23
-rw-r--r--tensorflow/core/kernels/string_util.cc63
-rw-r--r--tensorflow/core/kernels/string_util.h45
-rw-r--r--tensorflow/core/kernels/tensor_array.cc3
-rw-r--r--tensorflow/core/kernels/tensor_array.h3
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc45
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h37
-rw-r--r--tensorflow/core/kernels/training_ops.cc8
-rw-r--r--tensorflow/core/kernels/transpose_op.cc10
-rw-r--r--tensorflow/core/kernels/unicode_script_op.cc53
-rw-r--r--tensorflow/core/ops/array_ops.cc122
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt997
-rw-r--r--tensorflow/core/ops/dataset_ops.cc13
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc (renamed from tensorflow/contrib/data/ops/dataset_ops.cc)161
-rw-r--r--tensorflow/core/ops/math_grad.cc34
-rw-r--r--tensorflow/core/ops/math_grad_test.cc40
-rw-r--r--tensorflow/core/ops/math_ops.cc14
-rw-r--r--tensorflow/core/ops/nn_ops.cc18
-rw-r--r--tensorflow/core/ops/ops.pbtxt620
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc72
-rw-r--r--tensorflow/core/ops/string_ops.cc6
-rw-r--r--tensorflow/core/platform/default/build_config.bzl45
-rw-r--r--tensorflow/core/profiler/BUILD1
-rw-r--r--tensorflow/core/protobuf/config.proto5
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto6
-rw-r--r--tensorflow/core/util/mkl_util.h12
-rw-r--r--tensorflow/core/util/port.cc4
-rw-r--r--tensorflow/core/util/tensor_bundle/BUILD5
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc52
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc64
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README3
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001bin0 -> 1080 bytes
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.indexbin0 -> 211 bytes
-rw-r--r--tensorflow/examples/android/BUILD1
-rw-r--r--tensorflow/go/op/wrappers.go4698
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/spark-tensorflow-connector/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow-hadoop/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/python/BUILD11
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions.py9
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions_test.py16
-rw-r--r--tensorflow/python/autograph/converters/return_statements.py14
-rw-r--r--tensorflow/python/autograph/converters/return_statements_test.py12
-rw-r--r--tensorflow/python/autograph/pyct/templates.py2
-rw-r--r--tensorflow/python/autograph/pyct/templates_test.py12
-rw-r--r--tensorflow/python/client/session_ref.cc40
-rw-r--r--tensorflow/python/client/session_test.py18
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD326
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/cache_dataset_op_test.py5
-rw-r--r--tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/inputs_test.py149
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/multi_device_iterator_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py7
-rw-r--r--tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py124
-rw-r--r--tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shard_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py109
-rw-r--r--tensorflow/python/data/kernel_tests/window_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/zip_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py229
-rw-r--r--tensorflow/python/data/ops/multi_device_iterator_ops.py4
-rw-r--r--tensorflow/python/data/ops/readers.py12
-rw-r--r--tensorflow/python/debug/BUILD1
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py20
-rw-r--r--tensorflow/python/debug/cli/stepper_cli_test.py4
-rw-r--r--tensorflow/python/debug/lib/debug_utils_test.py4
-rw-r--r--tensorflow/python/debug/lib/dist_session_debug_grpc_test.py4
-rw-r--r--tensorflow/python/debug/lib/grpc_large_data_test.py12
-rw-r--r--tensorflow/python/debug/lib/session_debug_file_test.py4
-rw-r--r--tensorflow/python/debug/lib/session_debug_grpc_test.py46
-rw-r--r--tensorflow/python/debug/lib/session_debug_testlib.py90
-rw-r--r--tensorflow/python/debug/lib/stepper_test.py14
-rw-r--r--tensorflow/python/debug/wrappers/dumping_wrapper_test.py2
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py14
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py4
-rw-r--r--tensorflow/python/distribute/estimator_training.py2
-rw-r--r--tensorflow/python/eager/backprop.py2
-rw-r--r--tensorflow/python/eager/function.py21
-rw-r--r--tensorflow/python/eager/function_test.py37
-rw-r--r--tensorflow/python/estimator/BUILD3
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py127
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py535
-rw-r--r--tensorflow/python/estimator/canned/dnn.py183
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py7
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined_test.py268
-rw-r--r--tensorflow/python/estimator/canned/dnn_test.py161
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py116
-rw-r--r--tensorflow/python/estimator/canned/linear.py83
-rw-r--r--tensorflow/python/estimator/canned/linear_test.py138
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py184
-rw-r--r--tensorflow/python/estimator/estimator.py71
-rw-r--r--tensorflow/python/estimator/estimator_test.py150
-rw-r--r--tensorflow/python/estimator/keras.py39
-rw-r--r--tensorflow/python/estimator/keras_test.py28
-rw-r--r--tensorflow/python/estimator/util.py8
-rw-r--r--tensorflow/python/feature_column/feature_column.py35
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py12
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py632
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py1874
-rw-r--r--tensorflow/python/framework/function_test.py2
-rw-r--r--tensorflow/python/framework/graph_util_test.py8
-rw-r--r--tensorflow/python/framework/subscribe_test.py4
-rw-r--r--tensorflow/python/framework/test_util.py84
-rw-r--r--tensorflow/python/grappler/item_test.py2
-rw-r--r--tensorflow/python/grappler/memory_optimizer_test.py10
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py2
-rwxr-xr-xtensorflow/python/keras/BUILD2
-rw-r--r--tensorflow/python/keras/backend.py14
-rw-r--r--tensorflow/python/keras/callbacks_test.py40
-rw-r--r--tensorflow/python/keras/engine/base_layer.py161
-rw-r--r--tensorflow/python/keras/engine/training.py9
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py352
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py14
-rw-r--r--tensorflow/python/keras/engine/training_generator.py11
-rw-r--r--tensorflow/python/keras/engine/training_test.py12
-rw-r--r--tensorflow/python/keras/layers/core.py51
-rw-r--r--tensorflow/python/keras/layers/core_test.py45
-rw-r--r--tensorflow/python/keras/metrics.py16
-rw-r--r--tensorflow/python/keras/models.py9
-rw-r--r--tensorflow/python/keras/optimizers_test.py17
-rw-r--r--tensorflow/python/kernel_tests/BUILD41
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py134
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py236
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py12
-rw-r--r--tensorflow/python/kernel_tests/distributions/bernoulli_test.py12
-rw-r--r--tensorflow/python/kernel_tests/distributions/normal_test.py8
-rw-r--r--tensorflow/python/kernel_tests/identity_op_py_test.py2
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py73
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py26
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py7
-rw-r--r--tensorflow/python/kernel_tests/softsign_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/string_length_op_test.py27
-rw-r--r--tensorflow/python/kernel_tests/unicode_script_op_test.py57
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py4
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py43
-rw-r--r--tensorflow/python/layers/base.py16
-rw-r--r--tensorflow/python/layers/convolutional_test.py36
-rw-r--r--tensorflow/python/layers/core_test.py6
-rw-r--r--tensorflow/python/ops/control_flow_ops.py16
-rw-r--r--tensorflow/python/ops/distributions/distribution.py34
-rw-r--r--tensorflow/python/ops/embedding_ops.py8
-rw-r--r--tensorflow/python/ops/gradients_test.py2
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_test_util.py16
-rw-r--r--tensorflow/python/ops/math_grad.py34
-rw-r--r--tensorflow/python/ops/math_grad_test.py88
-rw-r--r--tensorflow/python/ops/math_ops_test.py71
-rw-r--r--tensorflow/python/ops/matmul_benchmark.py8
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py6
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py2
-rw-r--r--tensorflow/python/ops/string_ops.py13
-rw-r--r--tensorflow/python/ops/variable_scope.py126
-rw-r--r--tensorflow/python/ops/variables.py341
-rw-r--r--tensorflow/python/ops/while_v2.py4
-rw-r--r--tensorflow/python/saved_model/loader_test.py14
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py56
-rw-r--r--tensorflow/python/tools/BUILD9
-rw-r--r--tensorflow/python/tools/freeze_graph_test.py6
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py5
-rw-r--r--tensorflow/python/training/checkpointable/util.py2
-rw-r--r--tensorflow/python/training/evaluation.py68
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py4
-rw-r--r--tensorflow/python/training/monitored_session_test.py28
-rw-r--r--tensorflow/python/training/optimizer.py7
-rw-r--r--tensorflow/python/training/quantize_training_test.py3
-rw-r--r--tensorflow/python/training/queue_runner_test.py22
-rw-r--r--tensorflow/python/training/saver_test.py217
-rw-r--r--tensorflow/python/training/server_lib_same_variables_no_clear_test.py4
-rw-r--r--tensorflow/python/training/server_lib_test.py18
-rw-r--r--tensorflow/python/training/session_manager.py5
-rw-r--r--tensorflow/python/training/session_manager_test.py98
-rw-r--r--tensorflow/python/training/supervisor.py7
-rw-r--r--tensorflow/python/training/supervisor_test.py52
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer_test.py17
-rw-r--r--tensorflow/python/training/training_ops_test.py32
-rw-r--r--tensorflow/python/training/training_util_test.py4
-rw-r--r--tensorflow/python/util/function_utils.py23
-rw-r--r--tensorflow/python/util/function_utils_test.py95
-rw-r--r--tensorflow/python/util/nest.py16
-rw-r--r--tensorflow/python/util/nest_test.py34
-rw-r--r--tensorflow/python/util/util.cc356
-rw-r--r--tensorflow/python/util/util.h9
-rw-r--r--tensorflow/python/util/util.i12
-rw-r--r--tensorflow/requirements.txt2
-rw-r--r--tensorflow/tensorflow.bzl80
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt24
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt148
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt58
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt58
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt105
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt251
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt268
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt289
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt55
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt268
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt289
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt1
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt105
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt17
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt130
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt92
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt9
-rw-r--r--tensorflow/tools/api/tests/BUILD1
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py39
-rw-r--r--tensorflow/tools/benchmark/README.md2
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.042
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rocm97
-rwxr-xr-xtensorflow/tools/ci_build/builds/docker_test.sh9
-rwxr-xr-xtensorflow/tools/ci_build/builds/pip.sh4
-rwxr-xr-xtensorflow/tools/ci_build/builds/run_pip_tests.sh1
-rwxr-xr-xtensorflow/tools/ci_build/builds/with_the_same_user6
-rwxr-xr-xtensorflow/tools/ci_build/ci_build.sh11
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_cc_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py2_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_py3_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/libtensorflow.sh3
-rwxr-xr-xtensorflow/tools/ci_build/linux/libtensorflow_cpu.sh1
-rwxr-xr-xtensorflow/tools/ci_build/linux/libtensorflow_docker.sh6
-rwxr-xr-x[-rw-r--r--]tensorflow/tools/ci_build/linux/libtensorflow_rocm.sh (renamed from tensorflow/contrib/data/python/ops/contrib_op_loader.py)18
-rwxr-xr-xtensorflow/tools/ci_build/linux/rocm/run_cc_core.sh39
-rwxr-xr-xtensorflow/tools/ci_build/linux/rocm/run_py3_core.sh39
-rwxr-xr-xtensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh1
-rwxr-xr-xtensorflow/tools/ci_build/osx/libtensorflow_cpu.sh1
-rwxr-xr-xtensorflow/tools/ci_build/osx/libtensorflow_gpu.sh1
-rwxr-xr-xtensorflow/tools/ci_build/osx/libtensorflow_rocm.sh36
-rwxr-xr-xtensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh41
-rw-r--r--tensorflow/tools/dist_test/server/BUILD1
-rw-r--r--tensorflow/tools/docs/BUILD1
-rw-r--r--tensorflow/tools/lib_package/BUILD40
-rw-r--r--tensorflow/tools/pip_package/BUILD31
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py1
-rw-r--r--tensorflow/tools/pip_package/setup.py5
-rw-r--r--tensorflow/tools/quantization/BUILD78
-rw-r--r--tensorflow/tools/quantization/graph_to_dot.py68
-rw-r--r--tensorflow/tools/quantization/quantize_graph.py1302
-rw-r--r--tensorflow/tools/quantization/quantize_graph_test.py966
-rwxr-xr-xtensorflow/workspace.bzl400
-rw-r--r--third_party/flatbuffers/BUILD.bazel3
-rw-r--r--third_party/flatbuffers/workspace.bzl8
-rw-r--r--third_party/gpus/crosstool/BUILD.tpl14
-rw-r--r--third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl158
-rwxr-xr-xthird_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl241
-rw-r--r--third_party/gpus/cuda_configure.bzl35
-rw-r--r--third_party/gpus/rocm/BUILD0
-rw-r--r--third_party/gpus/rocm/BUILD.tpl99
-rw-r--r--third_party/gpus/rocm/build_defs.bzl.tpl45
-rw-r--r--third_party/gpus/rocm/rocm_config.h.tpl21
-rw-r--r--third_party/gpus/rocm_configure.bzl784
-rw-r--r--third_party/icu/BUILD1
-rw-r--r--third_party/icu/BUILD.bazel88
-rw-r--r--third_party/icu/workspace.bzl15
-rw-r--r--third_party/mkl/BUILD23
-rw-r--r--third_party/mkl/build_defs.bzl41
-rw-r--r--third_party/mkl_dnn/BUILD6
-rw-r--r--third_party/mkl_dnn/build_defs.bzl2
-rw-r--r--third_party/ngraph/ngraph.BUILD43
-rw-r--r--third_party/ngraph/ngraph_tf.BUILD51
-rw-r--r--third_party/py/python_configure.bzl4
-rw-r--r--third_party/toolchains/BUILD4
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD2
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD14
-rw-r--r--tools/bazel.rc17
847 files changed, 27798 insertions, 19877 deletions
diff --git a/README.md b/README.md
index e3092e551e..57efb876c9 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,7 @@ subscribing to
[announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce).
## Installation
-*See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.*
+*See [Installing TensorFlow](https://www.tensorflow.org/install) for instructions on how to install our release binaries or how to build from source.*
People who are a little more adventurous can also try our nightly binaries:
@@ -48,15 +48,12 @@ $ python
```
```python
>>> import tensorflow as tf
+>>> tf.enable_eager_execution()
+>>> tf.add(1, 2)
+3
>>> hello = tf.constant('Hello, TensorFlow!')
->>> sess = tf.Session()
->>> sess.run(hello)
+>>> hello.numpy()
'Hello, TensorFlow!'
->>> a = tf.constant(10)
->>> b = tf.constant(32)
->>> sess.run(a + b)
-42
->>> sess.close()
```
Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/).
@@ -106,13 +103,13 @@ The TensorFlow project strives to abide by generally accepted best practices in
## For more information
+* [TensorFlow Website](https://www.tensorflow.org)
+* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
+* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
+* [TensorFlow Twitter](https://twitter.com/tensorflow)
* [TensorFlow Blog](https://medium.com/tensorflow)
* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
-* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
-* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
-* [TensorFlow Twitter](https://twitter.com/tensorflow)
-* [TensorFlow Website](https://www.tensorflow.org)
* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
diff --git a/WORKSPACE b/WORKSPACE
index 11605871f3..17961829a6 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -9,26 +9,10 @@ http_archive(
"https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13
],
)
-load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
-closure_repositories()
-http_archive(
- name = "io_bazel_rules_python",
- strip_prefix = "rules_python-8b5d0683a7d878b28fffe464779c8a53659fc645",
- urls = [
- "https://github.com/bazelbuild/rules_python/archive/8b5d0683a7d878b28fffe464779c8a53659fc645.tar.gz",
- ],
-)
-load("@io_bazel_rules_python//python:pip.bzl", "pip_repositories")
-pip_repositories()
+load("@io_bazel_rules_closure//closure:defs.bzl", "closure_repositories")
-load("@io_bazel_rules_python//python:pip.bzl", "pip_import")
-pip_import(
- name = "pip_deps",
- requirements = "//tensorflow:requirements.txt",
-)
-load("@pip_deps//:requirements.bzl", "pip_install")
-pip_install()
+closure_repositories()
# We must check the bazel version before trying to parse any other BUILD
# files, in case the parsing of those build files depends on the bazel
diff --git a/configure.py b/configure.py
index f0b9fada5e..796c6231e8 100644
--- a/configure.py
+++ b/configure.py
@@ -41,7 +41,6 @@ _DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
-_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@@ -49,10 +48,14 @@ _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
-_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__))
_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
-_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME)
-_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE')
+_TF_WORKSPACE_ROOT = ''
+_TF_BAZELRC = ''
+
+if platform.machine() == 'ppc64le':
+ _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
+else:
+ _DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/%s-linux-gnu' % platform.machine()
class UserInputError(Exception):
@@ -153,14 +156,18 @@ def get_python_path(environ_cp, python_bin_path):
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
try:
- library_paths = run_shell(
- [python_bin_path, '-c',
- 'import site; print("\\n".join(site.getsitepackages()))']).split('\n')
+ library_paths = run_shell([
+ python_bin_path, '-c',
+ 'import site; print("\\n".join(site.getsitepackages()))'
+ ]).split('\n')
except subprocess.CalledProcessError:
- library_paths = [run_shell(
- [python_bin_path, '-c',
- 'from distutils.sysconfig import get_python_lib;'
- 'print(get_python_lib())'])]
+ library_paths = [
+ run_shell([
+ python_bin_path, '-c',
+ 'from distutils.sysconfig import get_python_lib;'
+ 'print(get_python_lib())'
+ ])
+ ]
all_paths = set(python_paths + library_paths)
@@ -187,8 +194,7 @@ def setup_python(environ_cp):
environ_cp, 'PYTHON_BIN_PATH', ask_python_bin_path,
default_python_bin_path)
# Check if the path is valid
- if os.path.isfile(python_bin_path) and os.access(
- python_bin_path, os.X_OK):
+ if os.path.isfile(python_bin_path) and os.access(python_bin_path, os.X_OK):
break
elif not os.path.exists(python_bin_path):
print('Invalid python path: %s cannot be found.' % python_bin_path)
@@ -230,15 +236,16 @@ def setup_python(environ_cp):
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
# Write tools/python_bin_path.sh
- with open(os.path.join(
- _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f:
+ with open(
+ os.path.join(_TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'),
+ 'w') as f:
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
-def reset_tf_configure_bazelrc(workspace_path):
+def reset_tf_configure_bazelrc():
"""Reset file that contains customized config settings."""
open(_TF_BAZELRC, 'w').close()
- bazelrc_path = os.path.join(workspace_path, '.bazelrc')
+ bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc')
data = []
if os.path.exists(bazelrc_path):
@@ -250,7 +257,7 @@ def reset_tf_configure_bazelrc(workspace_path):
continue
f.write('%s\n' % l)
if is_windows():
- tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/")
+ tf_bazelrc_path = _TF_BAZELRC.replace('\\', '/')
else:
tf_bazelrc_path = _TF_BAZELRC
f.write('import %s\n' % tf_bazelrc_path)
@@ -261,8 +268,8 @@ def cleanup_makefile():
These files could interfere with Bazel parsing.
"""
- makefile_download_dir = os.path.join(
- _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads')
+ makefile_download_dir = os.path.join(_TF_WORKSPACE_ROOT, 'tensorflow',
+ 'contrib', 'makefile', 'downloads')
if os.path.isdir(makefile_download_dir):
for root, _, filenames in os.walk(makefile_download_dir):
for f in filenames:
@@ -330,9 +337,8 @@ def get_var(environ_cp,
'Environment variable %s must be set as a boolean indicator.\n'
'The following are accepted as TRUE : %s.\n'
'The following are accepted as FALSE: %s.\n'
- 'Current value is %s.' % (
- var_name, ', '.join(true_strings), ', '.join(false_strings),
- var))
+ 'Current value is %s.' % (var_name, ', '.join(true_strings),
+ ', '.join(false_strings), var))
while var is None:
user_input_origin = get_input(question)
@@ -355,8 +361,12 @@ def get_var(environ_cp,
return var
-def set_build_var(environ_cp, var_name, query_item, option_name,
- enabled_by_default, bazel_config_name=None):
+def set_build_var(environ_cp,
+ var_name,
+ query_item,
+ option_name,
+ enabled_by_default,
+ bazel_config_name=None):
"""Set if query_item will be enabled for the build.
Ask user if query_item will be enabled. Default is used if no input is given.
@@ -379,8 +389,8 @@ def set_build_var(environ_cp, var_name, query_item, option_name,
elif bazel_config_name is not None:
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
# options and not to set build configs through environment variables.
- write_to_bazelrc('build:%s --define %s=true'
- % (bazel_config_name, option_name))
+ write_to_bazelrc(
+ 'build:%s --define %s=true' % (bazel_config_name, option_name))
def set_action_env_var(environ_cp,
@@ -447,7 +457,8 @@ def check_bazel_version(min_version):
if which('bazel') is None:
print('Cannot find bazel. Please install bazel.')
sys.exit(0)
- curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
+ curr_version = run_shell(
+ ['bazel', '--batch', '--bazelrc=/dev/null', 'version'])
for line in curr_version.split('\n'):
if 'Build label: ' in line:
@@ -499,6 +510,7 @@ def set_cc_opt_flags(environ_cp):
write_to_bazelrc('build:opt --host_copt=-march=native')
write_to_bazelrc('build:opt --define with_default_optimizations=true')
+
def set_tf_cuda_clang(environ_cp):
"""set TF_CUDA_CLANG action_env.
@@ -581,16 +593,14 @@ def set_clang_cuda_compiler_path(environ_cp):
clang_cuda_compiler_path)
-def prompt_loop_or_load_from_env(
- environ_cp,
- var_name,
- var_default,
- ask_for_var,
- check_success,
- error_msg,
- suppress_default_error=False,
- n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS
-):
+def prompt_loop_or_load_from_env(environ_cp,
+ var_name,
+ var_default,
+ ask_for_var,
+ check_success,
+ error_msg,
+ suppress_default_error=False,
+ n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS):
"""Loop over user prompts for an ENV param until receiving a valid response.
For the env param var_name, read from the environment or verify user input
@@ -629,9 +639,7 @@ def prompt_loop_or_load_from_env(
)
for _ in range(n_ask_attempts):
- val = get_from_env_or_user_or_default(environ_cp,
- var_name,
- full_query,
+ val = get_from_env_or_user_or_default(environ_cp, var_name, full_query,
default)
if check_success(val):
break
@@ -639,9 +647,9 @@ def prompt_loop_or_load_from_env(
print(error_msg % val)
environ_cp[var_name] = ''
else:
- raise UserInputError('Invalid %s setting was provided %d times in a row. '
- 'Assuming to be a scripting mistake.' %
- (var_name, n_ask_attempts))
+ raise UserInputError(
+ 'Invalid %s setting was provided %d times in a row. '
+ 'Assuming to be a scripting mistake.' % (var_name, n_ask_attempts))
environ_cp[var_name] = val
return val
@@ -650,8 +658,8 @@ def prompt_loop_or_load_from_env(
def create_android_ndk_rule(environ_cp):
"""Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule."""
if is_windows() or is_cygwin():
- default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' %
- environ_cp['APPDATA'])
+ default_ndk_path = cygpath(
+ '%s/Android/Sdk/ndk-bundle' % environ_cp['APPDATA'])
elif is_macos():
default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
else:
@@ -668,8 +676,7 @@ def create_android_ndk_rule(environ_cp):
ask_for_var='Please specify the home path of the Android NDK to use.',
check_success=valid_ndk_path,
error_msg=('The path %s or its child file "source.properties" '
- 'does not exist.')
- )
+ 'does not exist.'))
write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL',
check_ndk_level(android_ndk_home_path))
@@ -703,9 +710,9 @@ def create_android_sdk_rule(environ_cp):
api_levels = [x.replace('android-', '') for x in api_levels]
def valid_api_level(api_level):
- return os.path.exists(os.path.join(android_sdk_home_path,
- 'platforms',
- 'android-' + api_level))
+ return os.path.exists(
+ os.path.join(android_sdk_home_path, 'platforms',
+ 'android-' + api_level))
android_api_level = prompt_loop_or_load_from_env(
environ_cp,
@@ -720,9 +727,8 @@ def create_android_sdk_rule(environ_cp):
versions = sorted(os.listdir(build_tools))
def valid_build_tools(version):
- return os.path.exists(os.path.join(android_sdk_home_path,
- 'build-tools',
- version))
+ return os.path.exists(
+ os.path.join(android_sdk_home_path, 'build-tools', version))
android_build_tools_version = prompt_loop_or_load_from_env(
environ_cp,
@@ -736,10 +742,8 @@ def create_android_sdk_rule(environ_cp):
write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION',
android_build_tools_version)
- write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL',
- android_api_level)
- write_action_env_to_bazelrc('ANDROID_SDK_HOME',
- android_sdk_home_path)
+ write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', android_api_level)
+ write_action_env_to_bazelrc('ANDROID_SDK_HOME', android_sdk_home_path)
def check_ndk_level(android_ndk_home_path):
@@ -798,6 +802,7 @@ def reformat_version_sequence(version_str, sequence_count):
Args:
version_str: String, the version string.
sequence_count: int, an integer.
+
Returns:
string, reformatted version string.
"""
@@ -841,12 +846,19 @@ def set_tf_cuda_version(environ_cp):
if is_windows():
cuda_rt_lib_paths = ['lib/x64/cudart.lib']
elif is_linux():
- cuda_rt_lib_paths = ['%s/libcudart.so.%s' % (x, tf_cuda_version)
- for x in ['lib64', 'lib/x86_64-linux-gnu']]
+ cuda_rt_lib_paths = [
+ '%s/libcudart.so.%s' % (x, tf_cuda_version) for x in [
+ 'lib64',
+ 'lib/powerpc64le-linux-gnu',
+ 'lib/x86_64-linux-gnu',
+ ]
+ ]
elif is_macos():
cuda_rt_lib_paths = ['lib/libcudart.%s.dylib' % tf_cuda_version]
- cuda_toolkit_paths_full = [os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths]
+ cuda_toolkit_paths_full = [
+ os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
+ ]
if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
break
@@ -919,8 +931,8 @@ def set_tf_cudnn_version(environ_cp):
cudnn_path_from_ldconfig)
if cudnn_path_from_ldconfig:
cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1)
- if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig,
- tf_cudnn_version)):
+ if os.path.exists(
+ '%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)):
cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig)
break
@@ -1166,6 +1178,7 @@ def get_native_cuda_compute_capabilities(environ_cp):
Args:
environ_cp: copy of the os.environ.
+
Returns:
string of native cuda compute capabilities, separated by comma.
"""
@@ -1290,8 +1303,7 @@ def set_computecpp_toolkit_path(environ_cp):
else:
sycl_rt_lib_path = ''
- sycl_rt_lib_path_full = os.path.join(toolkit_path,
- sycl_rt_lib_path)
+ sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path)
exists = os.path.exists(sycl_rt_lib_path_full)
if not exists:
print('Invalid SYCL %s library path. %s cannot be found' %
@@ -1319,8 +1331,8 @@ def set_trisycl_include_dir(environ_cp):
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
- '[Default is %s]: '
- ) % (_DEFAULT_TRISYCL_INCLUDE_DIR)
+ '[Default is %s]: ') % (
+ _DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
@@ -1329,13 +1341,12 @@ def set_trisycl_include_dir(environ_cp):
if os.path.exists(trisycl_include_dir):
break
- print('Invalid triSYCL include directory, %s cannot be found'
- % (trisycl_include_dir))
+ print('Invalid triSYCL include directory, %s cannot be found' %
+ (trisycl_include_dir))
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
- write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR',
- trisycl_include_dir)
+ write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def set_mpi_home(environ_cp):
@@ -1345,8 +1356,9 @@ def set_mpi_home(environ_cp):
default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
def valid_mpi_path(mpi_home):
- exists = (os.path.exists(os.path.join(mpi_home, 'include')) and
- os.path.exists(os.path.join(mpi_home, 'lib')))
+ exists = (
+ os.path.exists(os.path.join(mpi_home, 'include')) and
+ os.path.exists(os.path.join(mpi_home, 'lib')))
if not exists:
print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
(os.path.join(mpi_home, 'include'),
@@ -1395,10 +1407,6 @@ def set_other_mpi_vars(environ_cp):
raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
-def set_grpc_build_flags():
- write_to_bazelrc('build --define grpc_no_ares=true')
-
-
def set_system_libs_flag(environ_cp):
syslibs = environ_cp.get('TF_SYSTEM_LIBS', '')
if syslibs and syslibs != '':
@@ -1431,14 +1439,20 @@ def set_windows_build_flags(environ_cp):
# TODO(pcloudy): Remove this flag when upgrading Bazel to 0.16.0
# Short object file path will be enabled by default.
write_to_bazelrc('build --experimental_shortened_obj_file_path=true')
+ # When building zip file for some py_binary and py_test targets, don't
+ # include its dependencies. This is for:
+ # 1. Running python tests against the system installed TF pip package.
+ # 2. Avoiding redundant files in
+ # //tensorflow/tools/pip_package:simple_console_windows,
+ # which is a py_binary used during creating TF pip package.
+ # See https://github.com/tensorflow/tensorflow/issues/22390
+ write_to_bazelrc('build --define=no_tensorflow_py_deps=true')
if get_var(
environ_cp, 'TF_OVERRIDE_EIGEN_STRONG_INLINE', 'Eigen strong inline',
- True,
- ('Would you like to override eigen strong inline for some C++ '
- 'compilation to reduce the compilation time?'),
- 'Eigen strong inline overridden.',
- 'Not overriding eigen strong inline, '
+ True, ('Would you like to override eigen strong inline for some C++ '
+ 'compilation to reduce the compilation time?'),
+ 'Eigen strong inline overridden.', 'Not overriding eigen strong inline, '
'some compilations could take more than 20 mins.'):
# Due to a known MSVC compiler issue
# https://github.com/tensorflow/tensorflow/issues/10521
@@ -1454,29 +1468,32 @@ def config_info_line(name, help_text):
def main():
+ global _TF_WORKSPACE_ROOT
+ global _TF_BAZELRC
+
parser = argparse.ArgumentParser()
- parser.add_argument("--workspace",
- type=str,
- default=_TF_WORKSPACE_ROOT,
- help="The absolute path to your active Bazel workspace.")
+ parser.add_argument(
+ '--workspace',
+ type=str,
+ default=os.path.abspath(os.path.dirname(__file__)),
+ help='The absolute path to your active Bazel workspace.')
args = parser.parse_args()
+ _TF_WORKSPACE_ROOT = args.workspace
+ _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME)
+
# Make a copy of os.environ to be clear when functions and getting and setting
# environment variables.
environ_cp = dict(os.environ)
check_bazel_version('0.15.0')
- reset_tf_configure_bazelrc(args.workspace)
+ reset_tf_configure_bazelrc()
cleanup_makefile()
setup_python(environ_cp)
if is_windows():
- environ_cp['TF_NEED_AWS'] = '0'
- environ_cp['TF_NEED_GCP'] = '0'
- environ_cp['TF_NEED_HDFS'] = '0'
environ_cp['TF_NEED_JEMALLOC'] = '0'
- environ_cp['TF_NEED_KAFKA'] = '0'
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
@@ -1486,40 +1503,24 @@ def main():
# Windows.
environ_cp['TF_DOWNLOAD_CLANG'] = '0'
environ_cp['TF_ENABLE_XLA'] = '0'
- environ_cp['TF_NEED_GDR'] = '0'
- environ_cp['TF_NEED_VERBS'] = '0'
environ_cp['TF_NEED_MPI'] = '0'
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
if is_macos():
environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'
+ environ_cp['TF_ENABLE_XLA'] = '0'
# The numpy package on ppc64le uses OpenBLAS which has multi-threading
# issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at
# runtime to allow the Tensorflow testcases which compare numpy
# results to Tensorflow results to succeed.
if is_ppc64le():
- write_action_env_to_bazelrc("OMP_NUM_THREADS", 1)
-
- set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
- 'with_jemalloc', True)
- set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform',
- 'with_gcp_support', True, 'gcp')
- set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
- 'with_hdfs_support', True, 'hdfs')
- set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform',
- 'with_aws_support', True, 'aws')
- set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
- 'with_kafka_support', True, 'kafka')
+ write_action_env_to_bazelrc('OMP_NUM_THREADS', 1)
+
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
- False, 'xla')
- set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
- False, 'gdr')
- set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support',
- False, 'verbs')
- set_build_var(environ_cp, 'TF_NEED_NGRAPH', 'nGraph',
- 'with_ngraph_support', False, 'ngraph')
+ True, 'xla')
+
set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
@@ -1531,6 +1532,13 @@ def main():
else:
set_trisycl_include_dir(environ_cp)
+ set_action_env_var(environ_cp, 'TF_NEED_ROCM', 'ROCm', False)
+ if (environ_cp.get('TF_NEED_ROCM') == '1' and
+ 'LD_LIBRARY_PATH' in environ_cp and
+ environ_cp.get('LD_LIBRARY_PATH') != '1'):
+ write_action_env_to_bazelrc('LD_LIBRARY_PATH',
+ environ_cp.get('LD_LIBRARY_PATH'))
+
set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False)
if (environ_cp.get('TF_NEED_CUDA') == '1' and
'TF_CUDA_CONFIG_REPO' not in environ_cp):
@@ -1571,12 +1579,24 @@ def main():
write_to_bazelrc('build --config=download_clang')
write_to_bazelrc('test --config=download_clang')
+ # SYCL / ROCm / CUDA are mutually exclusive.
+ # At most 1 GPU platform can be configured.
+ gpu_platform_count = 0
+ if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
+ gpu_platform_count += 1
+ if environ_cp.get('TF_NEED_ROCM') == '1':
+ gpu_platform_count += 1
+ if environ_cp.get('TF_NEED_CUDA') == '1':
+ gpu_platform_count += 1
+ if gpu_platform_count >= 2:
+ raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. '
+ 'At most 1 GPU platform can be configured.')
+
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
if environ_cp.get('TF_NEED_MPI') == '1':
set_mpi_home(environ_cp)
set_other_mpi_vars(environ_cp)
- set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
set_system_libs_flag(environ_cp)
if is_windows():
@@ -1585,13 +1605,10 @@ def main():
# Add a config option to build TensorFlow 2.0 API.
write_to_bazelrc('build:v2 --define=tf_api_version=2')
- if get_var(
- environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
- False,
- ('Would you like to interactively configure ./WORKSPACE for '
- 'Android builds?'),
- 'Searching for NDK and SDK installations.',
- 'Not configuring the WORKSPACE for Android builds.'):
+ if get_var(environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', False,
+ ('Would you like to interactively configure ./WORKSPACE for '
+ 'Android builds?'), 'Searching for NDK and SDK installations.',
+ 'Not configuring the WORKSPACE for Android builds.'):
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
@@ -1604,6 +1621,10 @@ def main():
'more details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
+ config_info_line('gdr', 'Build with GDR support.')
+ config_info_line('verbs', 'Build with libverbs support.')
+ config_info_line('ngraph', 'Build with Intel nGraph support.')
+
if __name__ == '__main__':
main()
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 3610eea42a..5f73da68a2 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -225,60 +225,6 @@ config_setting(
)
config_setting(
- name = "with_gcp_support",
- define_values = {"with_gcp_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support",
- define_values = {"with_hdfs_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support",
- define_values = {"with_aws_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_kafka_support",
- define_values = {"with_kafka_support": "true"},
- visibility = ["//visibility:public"],
-)
-
-# Crosses between platforms and file system libraries not supported on those
-# platforms due to limitations in nested select() statements.
-config_setting(
- name = "with_gcp_support_windows_override",
- define_values = {"with_gcp_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_windows_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_windows_override",
- define_values = {"with_aws_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_kafka_support_windows_override",
- define_values = {"with_kafka_support": "true"},
- values = {"cpu": "x64_windows"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
name = "with_cuda_support_windows_override",
define_values = {"using_cuda_nvcc": "true"},
values = {"cpu": "x64_windows"},
@@ -286,48 +232,6 @@ config_setting(
)
config_setting(
- name = "with_gcp_support_android_override",
- define_values = {"with_gcp_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_android_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_android_override",
- define_values = {"with_aws_support": "true"},
- values = {"crosstool_top": "//external:android/crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_gcp_support_ios_override",
- define_values = {"with_gcp_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_hdfs_support_ios_override",
- define_values = {"with_hdfs_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_aws_support_ios_override",
- define_values = {"with_aws_support": "true"},
- values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
name = "with_xla_support",
define_values = {"with_xla_support": "true"},
visibility = ["//visibility:public"],
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 43c279bd80..17e2e292eb 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -246,6 +246,7 @@ tf_cc_test(
":c_api_experimental",
":c_test_util",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 3bcc62cf2d..d4b78138e9 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using tensorflow::FunctionDef;
using tensorflow::Node;
@@ -8508,6 +8509,20 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
VLOG(1) << "Enqueuing is done.";
}
+TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
+ tensorflow::ServerDef server_def;
+ if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
+ &server_def)) {
+ status->status = tensorflow::errors::Internal(
+ "Invalid text proto for ServerDef: ", text_proto);
+ return nullptr;
+ }
+ status->status = tensorflow::Status();
+ TF_Buffer* ret = TF_NewBuffer();
+ TF_CHECK_OK(MessageToBuffer(server_def, ret));
+ return ret;
+}
+
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
TF_Status* status) {
auto* opts = TFE_NewContextOptions();
@@ -8723,35 +8738,7 @@ void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
TF_DeleteStatus(status);
}
-TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx) {
- // Intentionally LOG into INFO below for ease of debugging.
- VLOG(1) << "TFE_RunConstOp called";
-
- auto* status = TF_NewStatus();
- auto* op = TFE_NewOp(ctx, "Const", status);
- CheckOk(status);
- TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
-
- auto* tensor =
- TF_AllocateTensor(TF_FLOAT, /*shape.data()*/ nullptr, /*shape.size()*/ 0,
- TF_DataTypeSize(TF_FLOAT) * 1);
- auto* ptr = reinterpret_cast<char*>(TF_TensorData(tensor));
- *reinterpret_cast<float*>(ptr) = 17.0;
-
- TFE_OpSetAttrTensor(op, "value", tensor, status);
- CheckOk(status);
- TF_DeleteTensor(tensor);
- VLOG(1) << "New op created";
-
- TFE_TensorHandle* retval;
- int num_retvals = 1;
- TFE_Execute(op, &retval, &num_retvals, status);
- CheckOk(status);
- CHECK_EQ(num_retvals, 1);
- VLOG(1) << "Op executed";
-
- TFE_DeleteOp(op);
- TF_DeleteStatus(status);
-
- return retval;
+TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
+ const char* errMsg) {
+ status->status = tensorflow::errors::Internal(errMsg);
}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index a3ca847d96..d98d532e32 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -131,6 +131,8 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
int tensor_id,
TF_Tensor* tensor,
TF_Status* status);
+// Create a serialized tensorflow.ServerDef proto.
+TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status);
// TODO: remove this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
@@ -178,10 +180,8 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle);
-// Returns a const scalar tensor.
-// Caller owns both the input and the output tensor handles.
-// TODO: Remove this API with hard-coded tensor computation.
-TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx);
+TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
+ const char* errMsg);
#ifdef __cplusplus
} /* end extern "C" */
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index 30fcfd401d..c6effd3969 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
namespace tensorflow {
namespace {
@@ -116,5 +118,49 @@ TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) {
TF_DeleteStatus(s);
}
+TEST(CAPI_EXPERIMENTAL, GetServerDefTest) {
+ const string expected_text_proto(R"(cluster {
+ job {
+ name: "worker"
+ tasks {
+ key: 0
+ value: "tpuserver:0"
+ }
+ tasks {
+ key: 1
+ value: "localhost:1"
+ }
+ }
+}
+job_name: "worker"
+task_index: 1
+protocol: "grpc"
+)");
+
+ TF_Status* status = TF_NewStatus();
+ TF_Buffer* result = TFE_GetServerDef(expected_text_proto.c_str(), status);
+ EXPECT_EQ(TF_GetCode(status), TF_OK);
+
+ ServerDef actual;
+ ASSERT_TRUE(actual.ParseFromArray(result->data, result->length));
+ string actual_text_proto;
+ tensorflow::protobuf::TextFormat::PrintToString(actual, &actual_text_proto);
+ EXPECT_EQ(expected_text_proto, actual_text_proto);
+
+ const string malformed_text_proto(R"(cluster {
+ job {
+ name: "worker")");
+ TF_Buffer* null_result =
+ TFE_GetServerDef(malformed_text_proto.c_str(), status);
+ EXPECT_NE(TF_GetCode(status), TF_OK);
+ EXPECT_TRUE(tensorflow::str_util::StrContains(
+ TF_Message(status), "Invalid text proto for ServerDef"));
+ EXPECT_EQ(null_result, nullptr);
+
+ // Cleanup
+ TF_DeleteBuffer(result);
+ TF_DeleteStatus(status);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index de135d7a23..64b861a730 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -47,7 +47,7 @@ def tfadd(_):
def tfadd_with_ckpt(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = variables.Variable(constant_op.constant([0]), name='y_saved')
+ y = variables.VariableV1(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
init_op = variables.initialize_all_variables()
@@ -62,7 +62,7 @@ def tfadd_with_ckpt(out_dir):
def tfadd_with_ckpt_saver(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = variables.Variable(constant_op.constant([0]), name='y_saved')
+ y = variables.VariableV1(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
init_op = variables.initialize_all_variables()
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 4e184729ef..5bf4af1014 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -478,6 +478,7 @@ tf_cc_test(
name = "compilation_passes_test",
size = "small",
srcs = [
+ "build_xla_ops_pass_test.cc",
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",
"mark_for_compilation_pass_test.cc",
@@ -486,6 +487,7 @@ tf_cc_test(
deps = [
":common",
":compilation_passes",
+ ":node_matchers",
":xla_cluster_util",
":xla_gpu_device",
"//tensorflow/cc:cc_ops",
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 13a518d0e8..9e3fd93cda 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -112,16 +112,9 @@ static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
old_node->out_edges().end());
for (const Edge* edge : out_edges) {
- Node* dst = edge->dst();
- int src_output = edge->src_output();
- int dst_input = edge->dst_input();
+ // TODO(sanjoy): This does not update NodeDef inputs.
+ g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
g->RemoveEdge(edge);
-
- if (edge->IsControlEdge()) {
- g->AddControlEdge(new_node, dst);
- } else {
- g->AddEdge(new_node, src_output, dst, dst_input);
- }
}
}
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
new file mode 100644
index 0000000000..b7cb4506b9
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/node_matchers.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+using ::tensorflow::testing::FindNodeByName;
+using ::tensorflow::testing::matchers::CtrlDeps;
+using ::tensorflow::testing::matchers::NodeWith;
+using ::tensorflow::testing::matchers::Op;
+
+Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) {
+ auto graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
+
+ // Assign all nodes to the CPU device.
+ static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
+ for (Node* n : graph->nodes()) {
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
+ }
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = &graph;
+ BuildXlaOpsPass pass;
+ TF_RETURN_IF_ERROR(pass.Run(opt_options));
+ *result = std::move(graph);
+ return Status::OK();
+}
+
+Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
+ const string& node_name, Node** result) {
+ NodeDef call_node;
+ call_node.set_name(node_name);
+ call_node.set_op(callee_name);
+ AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node);
+ Status s;
+ *result = graph->AddNode(call_node, &s);
+ return s;
+}
+
+Node* MakeWrite(const Scope& scope, const string& id) {
+ Output var_handle =
+ ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
+ Output value_to_write =
+ ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f);
+ ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle,
+ value_to_write);
+ return assign_op.operation.node();
+}
+
+FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) {
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"},
+ /*attr_def*/
+ {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)},
+ /*ret_def=*/{{"out", "out:output:0"}});
+ *flib_def.add_function() = std::move(func);
+ return flib_def;
+}
+
+TEST(BuildXlaOps, ControlDepsPreserved) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
+ Node* write_op = MakeWrite(root, "write");
+ root.graph()->AddControlEdge(call, write_op);
+
+ std::unique_ptr<Graph> graph;
+ TF_ASSERT_OK(BuildXlaOps(root, &graph));
+
+ Node* write_op_new = FindNodeByName(graph.get(), write_op->name());
+ ASSERT_NE(write_op_new, nullptr);
+ EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 9128b48da3..25e2e9a7af 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "absl/algorithm/container.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -383,6 +384,8 @@ class PredicateFactory {
}
Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
+ Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
+ Predicate::Kind pred_kind);
// Predicate instances are interned, meaning that there is only a single
// instance of a Predicate object with a given content. This makes checking
@@ -429,11 +432,40 @@ class PredicateFactory {
interned_symbol_instances_;
};
+Predicate* PredicateFactory::MakeInternedAndOr(
+ std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
+ std::stable_sort(
+ simplified_ops.begin(), simplified_ops.end(),
+ [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+
+ auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
+ if (it != interned_and_or_instances_.end()) {
+ return it->second.get();
+ }
+
+ simplified_ops.shrink_to_fit();
+ // NB! Because we'll use a non-owning reference to simplified_ops in the
+ // key for interned_and_or_instances_ we need to be careful to std::move()
+ // it all the way through.
+ absl::Span<Predicate* const> operands_slice = simplified_ops;
+ std::unique_ptr<Predicate> new_pred =
+ pred_kind == Predicate::Kind::kAnd
+ ? Make<AndPredicate>(std::move(simplified_ops))
+ : Make<OrPredicate>(std::move(simplified_ops));
+
+ Predicate* new_pred_ptr = new_pred.get();
+ interned_and_or_instances_.emplace(
+ SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
+ return new_pred_ptr;
+}
+
// Common code to create AndPredicate or OrPredicate instances.
Predicate* PredicateFactory::MakeAndOrImpl(
absl::Span<Predicate* const> operands, bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
+ Predicate::Kind other_pred_kind =
+ is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
gtl::FlatSet<Predicate*> simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
@@ -472,30 +504,63 @@ Predicate* PredicateFactory::MakeAndOrImpl(
}
}
- std::stable_sort(
- simplified_ops.begin(), simplified_ops.end(),
- [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+ // If all ops contain the same subop, then factor it out thanks to the
+ // distributive property. Such as:
+ // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
+ // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
+ //
+ // First find any predicates contained in all subops.
+ std::vector<Predicate*> common_inner_operands;
+ gtl::FlatSet<Predicate*> common_inner_operands_set;
+ for (Predicate* op : simplified_ops) {
+ if (op->kind() != other_pred_kind) {
+ common_inner_operands.clear();
+ break;
+ }
- auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
- if (it == interned_and_or_instances_.end()) {
- simplified_ops.shrink_to_fit();
- // NB! Because we'll use a non-owning reference to simplified_ops in the
- // key for interned_and_or_instances_ we need to be careful to std::move()
- // it all the way through.
- absl::Span<Predicate* const> operands_slice = simplified_ops;
- std::unique_ptr<Predicate> new_pred =
- is_and ? Make<AndPredicate>(std::move(simplified_ops))
- : Make<OrPredicate>(std::move(simplified_ops));
+ if (common_inner_operands.empty()) {
+ common_inner_operands.insert(common_inner_operands.end(),
+ op->GetOperands().begin(),
+ op->GetOperands().end());
+ } else {
+ std::vector<Predicate*> sub_ops_intersection;
+ common_inner_operands.clear();
+ absl::c_copy_if(op->GetOperands(),
+ std::back_inserter(common_inner_operands),
+ [&](Predicate* sub_op) {
+ return common_inner_operands_set.count(sub_op) == 1;
+ });
+ }
+ if (common_inner_operands.empty()) break;
+ common_inner_operands_set.clear();
+ common_inner_operands_set.insert(common_inner_operands.begin(),
+ common_inner_operands.end());
+ }
- Predicate* new_pred_ptr = new_pred.get();
- CHECK(interned_and_or_instances_
- .emplace(SignatureForAndOr(pred_kind, operands_slice),
- std::move(new_pred))
- .second);
- return new_pred_ptr;
- } else {
- return it->second.get();
+ if (common_inner_operands.empty()) {
+ return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
}
+
+ // For all predicates that can be factored out, remove them and recreate the
+ // subops.
+ std::vector<Predicate*> factored_ops;
+ for (Predicate* op : simplified_ops) {
+ std::vector<Predicate*> new_sub_op_ops;
+ absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
+ [&](Predicate* sub_op) {
+ return std::find(common_inner_operands.begin(),
+ common_inner_operands.end(),
+ sub_op) == common_inner_operands.end();
+ });
+ factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
+ }
+
+ Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
+ std::vector<Predicate*> outer_ops;
+ outer_ops.push_back(new_inner_op);
+ outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
+ common_inner_operands.end());
+ return MakeAndOrImpl(outer_ops, !is_and);
}
class DeadnessAnalysisImpl : public DeadnessAnalysis {
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index 28a56044d5..617e31488c 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) {
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
}
-TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
- // This demonstrates one of the weaknesses in the current approach -- since we
- // only do some basic simplifications we can't see that "(A|B)&C" ==
- // "(A&C)|(B&C)".
+TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
+ // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "A");
+ ops::Switch sw_1 = CreateSwitch(root, "B");
+ Output add0 =
+ ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
+ Output add1 =
+ ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
+ ops::Merge or2(root.WithOpName("or2"), {add0, add1});
+ Output add3 =
+ ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
+ ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
+ EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
+}
+
+TEST(DeadnessAnalysisTest, AndOrDistributive) {
+ // (A|B)&C == (A&C)|(B&C)
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
@@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
- EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
}
TEST(DeadnessAnalysisTest, Ternary) {
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 3c160aefe5..b98c0cb028 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -34,6 +34,7 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
OptionalTensor& optional = variables[i];
optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) {
+ core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
optional.present = true;
optional.value = *variable->tensor();
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 2ccee79761..6967ad1f03 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -100,9 +100,15 @@ class XlaAssignVariableOp : public AsyncOpKernel {
Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \
ResourceHandleOp<Var>); \
REGISTER_KERNEL_BUILDER( \
+ Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \
+ ResourceHandlesOp<Var>); \
+ REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
ReadVariableOp); \
REGISTER_KERNEL_BUILDER( \
+ Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \
+ ReadVariablesOp); \
+ REGISTER_KERNEL_BUILDER( \
Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \
DestroyResourceOp); \
REGISTER_KERNEL_BUILDER(Name("Shape") \
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index f5c8bdd6ee..4f6fc4e068 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -49,6 +49,7 @@ std::map<int, OptionalTensor> SnapshotResourceVariables(
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
+ core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
tensor.present = true;
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index e219cf3d88..1b39d53dc0 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -1445,6 +1445,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([4, 0], dtype=np.int32),
expected=np.zeros([4, 0], dtype=dtype))
+ x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype)
+ self._testBinary(
+ array_ops.broadcast_to,
+ x,
+ np.array((3, 7, 8, 9), dtype=np.int32),
+ expected=np.tile(x, (1, 7, 8, 9)))
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 8c018cccb8..374942a0b3 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
+DATA_FORMATS = (
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+)
+
class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
@@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testInference(self, data_format):
channel = 3
x_shape = [2, 2, 6, channel]
@@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
self.assertAllClose(var_val, var_ref, atol=1e-3)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testLearning(self, data_format):
self._testLearning(False, data_format)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testLearningWithGradientChecker(self, data_format):
self._testLearning(True, data_format)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testGradientTraining(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
@@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
- @parameterized.named_parameters(
- ("_data_format_NHWC", "NHWC"),
- ("_data_format_NCHW", "NCHW"),
- ("_data_format_HWNC", "HWNC"),
- ("_data_format_HWCN", "HWCN"),
- )
+ @parameterized.named_parameters(*DATA_FORMATS)
def testGradientInference(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index bbe746e28f..68fdb5caf4 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -724,6 +724,49 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(num_valid, 2)
self.assertAllClose(indices_tf[:num_valid], [3, 0])
+ def testNMS3Then1WithScoreMaxThresh(self):
+ # Three boxes are selected based on IOU.
+ # One is filtered out by score threshold.
+ # One is filtered out by max_output_size.
+
+ with compat.forward_compatibility_horizon(2018, 8, 8):
+ boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ boxes_np = np.array(boxes_data, dtype=np.float32)
+
+ scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ scores_np = np.array(scores_data, dtype=np.float32)
+ max_output_size = 1
+ iou_threshold_np = np.array(0.5, dtype=np.float32)
+ score_threshold_np = np.array(0.4, dtype=np.float32)
+
+ with self.cached_session() as sess:
+ boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
+ scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
+ iou_threshold = array_ops.placeholder(iou_threshold_np.dtype,
+ iou_threshold_np.shape)
+ score_threshold = array_ops.placeholder(score_threshold_np.dtype,
+ score_threshold_np.shape)
+ with self.test_scope():
+ selected_indices = image_ops.non_max_suppression_padded(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True)
+ inputs_feed = {
+ boxes: boxes_np,
+ scores: scores_np,
+ iou_threshold: iou_threshold_np,
+ score_threshold: score_threshold_np
+ }
+ (indices_tf, num_valid) = sess.run(
+ selected_indices, feed_dict=inputs_feed)
+
+ self.assertEqual(indices_tf.size, max_output_size)
+ self.assertEqual(num_valid, 1)
+ self.assertAllClose(indices_tf[:num_valid], [3])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py
index 43c469d032..73b3638e80 100644
--- a/tensorflow/compiler/tests/lstm.py
+++ b/tensorflow/compiler/tests/lstm.py
@@ -117,7 +117,7 @@ def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq):
def RandomVar(shape, name=None):
"""Returns a variable of the given shape initialized to random values."""
- return variables.Variable(
+ return variables.VariableV1(
random_ops.random_uniform(shape), dtype=dtypes.float32, name=name)
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index f792c52032..36c6f5d316 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,11 +31,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -89,7 +91,44 @@ Status FunctionalizeControlFlowForFunction(
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
- const FunctionDef& fdef = body->fdef;
+
+ // Call graph optimizer. The most important optimization we need is constant
+ // folding, which will replace ops like Shape/BroadcastGradientArgs with
+ // constant shape input. Without this optimization, those ops might become
+ // dynamic input for then/else body function and XLA will complain that input
+ // is not compile time constant. We enable function inlining as well, because
+ // otherwise we won't be able to infer shape for any node depending on
+ // function call nodes.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_opt_", func_name),
+ *body->graph, fld);
+ }
+ // Optimizer accepts std::unique_ptr<Graph>* as input and might change
+ // underlying pointer, thus we create a new Graph and copy from body->graph.
+ std::unique_ptr<Graph> optimized_graph(new Graph(fld));
+ CopyGraph(*body->graph, optimized_graph.get());
+ OptimizerOptions opts;
+ opts.set_opt_level(OptimizerOptions::L0);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
+ GraphOptimizer optimizer(opts);
+ auto cf_consider_fn = [](const Node* n) {
+ // Skip SymbolicGradient op when doing constant folding.
+ // Enabling SymbolicGradient op in constant folding requires
+ // flr->device() to be non-null, and here we have not constructed
+ // proper Device object yet (it will be constructed in XlaCompiler).
+ return n->type_string() != FunctionLibraryDefinition::kGradientOp;
+ };
+ optimizer.Optimize(flr, flr->env(),
+ /*device=*/nullptr, &optimized_graph,
+ /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
+ cf_consider_fn);
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_opt_", func_name),
+ *optimized_graph, fld);
+ }
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
@@ -97,7 +136,7 @@ Status FunctionalizeControlFlowForFunction(
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
- for (auto* n : body->graph->nodes()) {
+ for (auto* n : optimized_graph->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@@ -108,7 +147,8 @@ Status FunctionalizeControlFlowForFunction(
auto associated_functions = iter.second;
for (auto& associated_function : associated_functions) {
string name = associated_function.func_name();
- string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
+ string canonicalized_name =
+ Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
if (iter != canonicalized_name_to_new_name->end()) {
@@ -116,9 +156,17 @@ Status FunctionalizeControlFlowForFunction(
// but still rewrite the node.
new_name = iter->second;
} else {
- new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ if (associated_function.type() ==
+ AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
+ // For SymbolicGradient, `name` is always "SymbolicGradient",
+ // which is not very informative. Use node name instead.
+ new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"));
+ } else {
+ new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ }
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
- name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
+ name, new_name, associated_function.attrs(), fld, flr,
+ canonicalized_name_to_new_name));
(*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
}
// Notice that if "n" is a function call, RewriteAssociatedFunction() will
@@ -126,7 +174,7 @@ Status FunctionalizeControlFlowForFunction(
// That's fine because in that case, associated_functions will only have
// one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- body->graph, n, fld, associated_function, new_name));
+ optimized_graph.get(), n, fld, associated_function, new_name));
}
}
@@ -134,22 +182,17 @@ Status FunctionalizeControlFlowForFunction(
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *body->graph, fld);
+ *optimized_graph, fld);
}
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *body->graph, fld);
+ *optimized_graph, fld);
}
FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(
- GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef));
-
- // Copy signature and ret from original FunctionDef.
- *functionalized_fdef.mutable_signature() = fdef.signature();
- *functionalized_fdef.mutable_ret() = fdef.ret();
- functionalized_fdef.mutable_signature()->set_name(new_func_name);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
+ &functionalized_fdef));
// Add rewritten FunctionDef into library.
if (func_name == new_func_name) {
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index b3ad0aea84..a267c0c72f 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -34,12 +34,6 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
- OP_REQUIRES(ctx,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW ||
- data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN),
- errors::InvalidArgument(
- "Unsupported data format ", ToString(data_format_),
- "; supported formats are NHWC, NCHW, HWNC and HWCN"));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -110,12 +104,6 @@ class FusedBatchNormGradOp : public XlaOpKernel {
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
- OP_REQUIRES(ctx,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW ||
- data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN),
- errors::InvalidArgument(
- "Unsupported data format ", ToString(data_format_),
- "; supported formats are NHWC, NCHW, HWNC and HWCN"));
}
void Compile(XlaOpKernelContext* ctx) override {
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index 66676452d0..a988d3c33e 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -103,6 +103,24 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
XLA_MAKE_BINARY(FloorDiv,
FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto is_zero = xla::Eq(x, zero);
+ return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
+}
+XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
+static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto is_zero = xla::Eq(x, zero);
+ return xla::Select(is_zero, zero, xla::Div(x, y));
+}
+XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorMod. Pseudo-code:
// T trunc_mod = std::fmod(x, y);
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 4bd7c74dca..696c1c39be 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -64,10 +64,9 @@ class BroadcastToOp : public XlaOpKernel {
output_shape.DebugString()));
broadcast_dims.push_back(broadcast_shape.size());
- if (output_dims[i] == input_dims[i] || input_dims[i] == 1) {
+ if (output_dims[i] == input_dims[i]) {
broadcast_shape.push_back(output_dims[i]);
- }
- if (output_dims[i] != input_dims[i]) {
+ } else if (output_dims[i] != input_dims[i]) {
// Add dimensions [I, O/I], which we will later flatten to just
// [O]. We must do this in two phases since XLA broadcasting does not
// support tiling.
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 33a73fe5fd..921b4340c0 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel {
OP_REQUIRES(
context, output_size >= 0,
errors::InvalidArgument("Need output_size >= 0, got ", output_size));
+ OP_REQUIRES(context, output_size <= kint32max,
+ errors::InvalidArgument("Need output_size <= kint32Max, got ",
+ output_size));
xla::XlaOp score_thresh = context->Input("score_threshold");
xla::XlaOp iou_thresh = context->Input("iou_threshold");
@@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel {
xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}),
xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes}));
- // num_valid is scalar.
- xla::XlaOp num_valid = xla::Reduce(
+ // num_valid is scalar. Value should be bound by output_size.
+ xla::XlaOp num_valid_total = xla::Reduce(
ones_included,
/*init_value=*/xla::ConstantR0<int>(builder, 0),
/*computation=*/CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
+ xla::XlaOp num_valid =
+ xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size));
xla::XlaOp output_tuple = TopK(scores_included, output_size);
xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1);
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index d6f42bac86..01dd3ba10f 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -336,9 +336,9 @@ bool HasAssociatedFunction(const NodeDef& node_def,
}
if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
- return false;
+ // Gradient op has "f" attr, which is set to the function we are getting
+ // gradient for. We need to functionalize the gradient function.
+ return true;
}
for (const auto& iter : node_def.attr()) {
@@ -357,17 +357,18 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
// This is a function call node.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
- results.emplace_back(AssociatedFunctionInfo(op, attrs));
+ results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
} else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
+ // This is a SymbolicGradient op.
+ AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+ results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
} else {
// Collect all function attrs for the node.
for (auto& iter : node.attrs()) {
if (iter.second.has_func()) {
VLOG(2) << "Found function attr for node " << node.name() << ": "
<< iter.first << " = " << iter.second.func().name();
- results.emplace_back(AssociatedFunctionInfo(
+ results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
iter.second.func().name(), iter.second.func().attr(), iter.first));
}
}
@@ -410,6 +411,21 @@ Status RewriteAssociatedFunction(
graph->RemoveNode(node);
break;
}
+ case AssociatedFunctionInfo::kSymbolicGradient: {
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
+ GradientDef gradient_def;
+ gradient_def.set_function_name(func.name());
+ gradient_def.set_gradient_func(rewritten_function_name);
+ string original_grad_func = fld->FindGradient(func.name());
+ if (original_grad_func.empty()) {
+ TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
+ } else if (original_grad_func != rewritten_function_name) {
+ TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
+ }
+ break;
+ }
case AssociatedFunctionInfo::kFunctionAttr: {
// Change function attr to rewritten functions.
NameAttrList func;
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 6065d0bb9a..53eab8b63e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -65,21 +65,33 @@ uint32 GetXLARandomSeed();
class AssociatedFunctionInfo {
public:
enum AssociatedFunctionType {
- kFunctionCallNode = 0,
- kFunctionAttr = 1,
+ kFunctionAttr = 0,
+ kFunctionCallNode = 1,
+ kSymbolicGradient = 2,
};
- // The node is a function call.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
- : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
-
// The function is an attr of the node.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
- const string& attr_name)
- : type_(kFunctionAttr),
- func_name_(func_name),
- attrs_(attrs),
- attr_name_(attr_name) {}
+ static AssociatedFunctionInfo FunctionAttr(const string& func_name,
+ const AttrValueMap& attrs,
+ const string& attr_name) {
+ return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
+ }
+
+ // The node is a function call.
+ static AssociatedFunctionInfo FunctionCall(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
+ /*attr_name=*/"");
+ }
+
+ // The node is a SymbolicGradient op.
+ static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
+ /*attr_name=*/"");
+ }
AssociatedFunctionType type() const { return type_; }
@@ -90,6 +102,13 @@ class AssociatedFunctionInfo {
const AttrValueMap& attrs() const { return attrs_; }
private:
+ AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
+ const AttrValueMap& attrs, const string& attr_name)
+ : type_(type),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
// Available for all instances.
AssociatedFunctionType type_;
string func_name_;
@@ -105,14 +124,18 @@ bool HasAssociatedFunction(const NodeDef& node_def,
// Gets functions associated with the node. Current cases:
// 1. For function call node, its function name;
-// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
+// and returned attrs will be this node's attributes;
+// 3. For nodes like XlaWhile/XlaIf, all their function attributes.
std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
const Node& node, FunctionLibraryRuntime* flr);
// Changes associated functions for the node. Current cases:
// 1. For function call node, creates a new node with the new function name and
// remove the old node;
-// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+// 2. For SymbolicGradient op, add or replace GradientDef in
+// FunctionLibraryDefinition;
+// 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
Status RewriteAssociatedFunction(
Graph* graph, Node* node, FunctionLibraryDefinition* fld,
const AssociatedFunctionInfo& associated_function,
diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h
index bda667eb1f..6354216eee 100644
--- a/tensorflow/compiler/tf2xla/type_util.h
+++ b/tensorflow/compiler/tf2xla/type_util.h
@@ -25,6 +25,14 @@ namespace tensorflow {
// Converts a Tensorflow DataType to an XLA PrimitiveType.
Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type);
+// N.B.: there is intentionally no function to convert an XLA PrimitiveType to
+// a TensorFlow DataType. The mapping from TF types to XLA types is not
+// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the
+// inverse would not be a well-defined function. If you find that you want the
+// inverse mapping, then most likely you should be preserving the original
+// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow
+// type.
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 25cc37edc4..ff0ec76a7f 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -97,13 +97,11 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
<< "Computation should have progran shape.";
auto program_shape = computation.proto().program_shape();
- // Create and run a program which produces a tuple with one element per
- // parameter, then return the tuple's constituent buffers.
- std::vector<Shape> param_shapes(program_shape.parameters().begin(),
- program_shape.parameters().end());
- auto fake_input_tuple =
- MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client);
- return client->DeconstructTuple(*fake_input_tuple).ValueOrDie();
+ std::vector<std::unique_ptr<GlobalData>> results;
+ for (const Shape& shape : program_shape.parameters()) {
+ results.push_back(MakeFakeDataOrDie(shape, client));
+ }
+ return results;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 95ff6432a5..5277de6a85 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1278,7 +1278,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
absl::Span<const XlaOp> operands,
- const Shape& shape) {
+ const Shape& shape, const string& opaque) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1289,6 +1289,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
}
*instr.mutable_shape() = shape;
instr.set_custom_call_target(call_target_name);
+ instr.set_custom_call_opaque(opaque);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@@ -2681,8 +2682,9 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
}
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape) {
- return builder->CustomCall(call_target_name, operands, shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque) {
+ return builder->CustomCall(call_target_name, operands, shape, opaque);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index d0c59fa6f2..1da6ddd318 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -577,11 +577,9 @@ class XlaBuilder {
absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
- // During code generation, a call instruction is emitted which targets a
- // symbol with the name |call_target_name|. The |operands| are passed to the
- // call instruction. |shape| is the resultant shape.
XlaOp CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -1195,7 +1193,8 @@ class XlaBuilder {
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
@@ -1717,12 +1716,17 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
-// Enqueues a custom call instruction onto the computation.
-// During code generation, a call instruction is emitted which targets a
-// symbol with the name |call_target_name|. The |operands| are passed to the
-// call instruction. |shape| is the resultant shape.
+// Enqueues a custom call instruction onto the computation. A custom call
+// invokes code external to XLA. The |operands| are passed to the external code,
+// and the external code is expected to produce a result of the given
+// |shape|. The exact mechanism is backend-specific. For example, in the CPU
+// backend, a call instruction is emitted which targets a symbol with the name
+// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
+// but |call_target_name| should be short as it may be used in labels. |opaque|
+// can encode arbitrarily large amounts of information.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque = "");
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc
index a472747bd1..0f9b591c70 100644
--- a/tensorflow/compiler/xla/executable_run_options.cc
+++ b/tensorflow/compiler/xla/executable_run_options.cc
@@ -45,6 +45,16 @@ stream_executor::Stream* ExecutableRunOptions::stream() const {
return stream_;
}
+ExecutableRunOptions& ExecutableRunOptions::set_host_to_device_stream(
+ stream_executor::Stream* stream) {
+ host_to_device_stream_ = stream;
+ return *this;
+}
+
+stream_executor::Stream* ExecutableRunOptions::host_to_device_stream() const {
+ return host_to_device_stream_;
+}
+
ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool(
const Eigen::ThreadPoolDevice* intra_op_thread_pool) {
intra_op_thread_pool_ = intra_op_thread_pool;
diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h
index 416131be00..ba3217f31b 100644
--- a/tensorflow/compiler/xla/executable_run_options.h
+++ b/tensorflow/compiler/xla/executable_run_options.h
@@ -65,6 +65,13 @@ class ExecutableRunOptions {
ExecutableRunOptions& set_stream(stream_executor::Stream* stream);
stream_executor::Stream* stream() const;
+ // If set, this is the stream to perform any pre-computation transfers on.
+ // The platform of the stream must match the platform the executable was
+ // built for. A value of nullptr indicates the option has not been set.
+ ExecutableRunOptions& set_host_to_device_stream(
+ stream_executor::Stream* stream);
+ stream_executor::Stream* host_to_device_stream() const;
+
// Sets the thread pool device on which to run Eigen subcomputations.
// Does not take ownership.
ExecutableRunOptions& set_intra_op_thread_pool(
@@ -90,6 +97,7 @@ class ExecutableRunOptions {
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr;
int rng_seed_ = 0;
+ stream_executor::Stream* host_to_device_stream_ = nullptr;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 1e0a2ad0dd..3cd3541fe1 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -203,6 +203,10 @@ class LiteralBase {
// Returns the count of the elements in the array at the given shape index in
// this literal.
int64 element_count(const ShapeIndex& index = {}) const {
+ if (index.empty()) {
+ // Common case, avoid GetSubshape().
+ return ShapeUtil::ElementsIn(shape());
+ }
return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
}
@@ -852,9 +856,9 @@ class BorrowingLiteral : public LiteralBase {
template <typename NativeT>
absl::Span<const NativeT> LiteralBase::Piece::data() const {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
+ DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ DCHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
<< "Attempting to access "
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
@@ -865,9 +869,9 @@ absl::Span<const NativeT> LiteralBase::Piece::data() const {
template <typename NativeT>
absl::Span<NativeT> LiteralBase::Piece::data() {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
+ DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ DCHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
<< "Attempting to access "
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 2775527e0c..51968d13d4 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -655,6 +655,7 @@ cc_library(
deps = [
":cudnn_convolution_algorithm_picker",
":cudnn_convolution_rewriter",
+ ":cudnn_fused_convolution_rewriter",
":fusion_merger",
":gpu_constants",
":gpu_copy_insertion",
@@ -967,3 +968,19 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
+
+cc_library(
+ name = "cudnn_fused_convolution_rewriter",
+ srcs = ["cudnn_fused_convolution_rewriter.cc"],
+ hdrs = ["cudnn_fused_convolution_rewriter.h"],
+ deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
new file mode 100644
index 0000000000..3761c19cfc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
@@ -0,0 +1,278 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Describes a matched pattern:
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+// Where side_input has the shape of output buffer, and bias is a 1D array with
+// the dimension of number of output features.
+struct ConvWithRelu {
+ HloInstruction* maximum;
+ HloCustomCallInstruction* conv;
+ HloInstruction* bias;
+ HloInstruction* side_input;
+ HloConstantInstruction* alpha_conv;
+ HloConstantInstruction* alpha_side_input;
+};
+
+absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::GetTupleElement;
+ using match::Maximum;
+ using match::MultiplyAnyOrder;
+ using match::Op;
+
+ // The pattern we want to match:
+ // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+ //
+ // With its variants involving commute/reassociation of adds, multiplies, and
+ // max, and omission of alpha1, side_input, alpha2, or bias.
+
+ HloInstruction* relu_input;
+
+ // Match max(0, relu_input).
+ auto zero_pattern = Broadcast(match::ConstantScalar(0));
+ if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
+ !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
+ return absl::nullopt;
+ }
+ HloInstruction* conv_instr = nullptr;
+ HloInstruction* alpha_conv_instr = nullptr;
+ HloInstruction* alpha_side_input_instr = nullptr;
+ HloInstruction* bias_broadcast_instr = nullptr;
+ HloInstruction* bias = nullptr;
+ HloInstruction* side_input = nullptr;
+
+ // These nodes will not be in the returned value, but we need to check them
+ // for single use.
+ HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
+ *mul1 = nullptr, *mul2 = nullptr;
+
+ const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
+ const auto conv_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
+ auto conv_pattern = GetTupleElement(
+ &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
+ }();
+ const auto side_input_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
+ // If bias is already matched, match arbitrary additional input as side
+ // input. Note this may force a cheap operation (e.g. broadcast) to be
+ // materialized into a large buffer, as large as the output buffer.
+ //
+ // TODO(timshen): If in practice there are significant false positives, we
+ // should fix it.
+ auto side_input_pattern = Op(&side_input);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
+ side_input_pattern);
+ }();
+
+ {
+ // Try to match any of the following form of add, in any association:
+ // addends[0]
+ // addends[0] + addends[1]
+ // addends[0] + addends[1] + addends[2]
+ //
+ // Then try to match each addend with one of the three patterns: bias, conv,
+ // or side_input. Notice that side_input matching must go last, as it
+ // also matches a conv or a bias.
+ HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
+ auto add3_pattern = [&] {
+ auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
+ return AnyOf<HloInstruction>(
+ AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
+ Op(&addends[0]));
+ }();
+ CHECK(Match(relu_input, add3_pattern));
+ for (auto addend : addends) {
+ if (addend) {
+ if (bias == nullptr && Match(addend, bias_pattern)) {
+ CHECK(bias);
+ } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
+ CHECK(conv_instr);
+ } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
+ CHECK(side_input);
+ } else {
+ return absl::nullopt;
+ }
+ }
+ }
+ }
+
+ if (conv_instr == nullptr) {
+ return absl::nullopt;
+ }
+
+ for (HloInstruction* instr :
+ {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
+ if (instr && instr->user_count() > 1) {
+ return absl::nullopt;
+ }
+ }
+
+ auto conv = Cast<HloCustomCallInstruction>(conv_instr);
+ auto bias_broadcast =
+ CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
+
+ if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
+ return absl::nullopt;
+ }
+
+ if (bias_broadcast) {
+ // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
+ if (bias_broadcast_instr->dimensions().size() != 1) {
+ return absl::nullopt;
+ }
+ if (bias_broadcast_instr->dimensions(0) !=
+ conv->convolution_dimension_numbers().output_feature_dimension()) {
+ return absl::nullopt;
+ }
+ }
+
+ return ConvWithRelu{
+ instr,
+ conv,
+ bias,
+ side_input,
+ CastOrNull<HloConstantInstruction>(alpha_conv_instr),
+ CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
+}
+
+StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
+ ConvWithRelu match) {
+ auto conv = match.conv;
+
+ HloComputation* computation = conv->parent();
+ PrimitiveType element_type = conv->operand(0)->shape().element_type();
+
+ const auto get_alpha_value =
+ [](HloConstantInstruction* instr) -> StatusOr<double> {
+ TF_ASSIGN_OR_RETURN(
+ auto alpha,
+ Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
+ return alpha.GetFirstElement<double>();
+ };
+
+ double alpha_conv = 1;
+ if (match.alpha_conv) {
+ TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
+ }
+
+ double alpha_side_input;
+ if (match.side_input) {
+ if (match.alpha_side_input) {
+ TF_ASSIGN_OR_RETURN(alpha_side_input,
+ get_alpha_value(match.alpha_side_input));
+ } else {
+ alpha_side_input = 1;
+ }
+ } else {
+ CHECK(match.alpha_side_input == nullptr);
+ alpha_side_input = 0;
+ }
+
+ auto bias = match.bias;
+ if (!bias) {
+ auto zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+
+ int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions(
+ conv->convolution_dimension_numbers().output_feature_dimension());
+ bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShapeWithDescendingLayout(element_type,
+ {num_output_feature}),
+ zero, {}));
+ }
+
+ CHECK(bias);
+ std::vector<HloInstruction*> args = {conv->mutable_operand(0),
+ conv->mutable_operand(1), bias};
+ if (match.side_input) {
+ args.push_back(match.side_input);
+ }
+ auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
+ conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
+ new_conv->set_window(conv->window());
+ new_conv->set_convolution_dimension_numbers(
+ conv->convolution_dimension_numbers());
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ config.set_activation_mode(
+ static_cast<int64>(se::dnn::ActivationMode::kRelu));
+ config.set_conv_result_scale(alpha_conv);
+ config.set_side_input_scale(alpha_side_input);
+ TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
+
+ VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name();
+ return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
+ new_conv, 0);
+}
+
+} // namespace
+
+StatusOr<bool> CudnnFusedConvolutionRewriter::Run(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ std::vector<ConvWithRelu> matches;
+ int num_forward_convs = 0;
+ for (auto instr : computation->instructions()) {
+ auto match = FindConvWithRelu(instr);
+ if (match.has_value()) {
+ matches.push_back(*match);
+ }
+ if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+ if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
+ num_forward_convs++;
+ }
+ }
+ }
+ VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
+ << " out of " << num_forward_convs << " forward convs.";
+ std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
+ replacements;
+ for (const ConvWithRelu& match : matches) {
+ TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
+ replacements.push_back({match.maximum, std::move(new_instr)});
+ changed = true;
+ }
+ for (auto& replacement : replacements) {
+ TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+ replacement.first, std::move(replacement.second)));
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
new file mode 100644
index 0000000000..bd12aadded
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+class CudnnFusedConvolutionRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "cudnn-fused-convolution-rewriter";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 01a18f4f8e..0b3b429710 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
@@ -208,6 +209,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
+ pipeline.AddPass<CudnnFusedConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index 2d270f630b..e3869b5c36 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -37,15 +37,32 @@ static constexpr int64 kDesiredNumFeaturesFactor = 8;
// there's additional room for speedups. Achieving those speedups without also
// slowing other things down will likely require a more sophisticated heuristic,
// possibly some form of auto-tuning.
-static constexpr double kMaxBytesTouchedIncrease = 1.2;
+//
+// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4"
+// special case inside PadShape won't fire.
+static constexpr double kMaxBytesTouchedIncrease = 1.35;
// Pads the given dimensions in the given shape up to a multiple of
// kDesiredNumFeaturesFactor.
static Shape PadShape(Shape s, absl::Span<const int64> dims) {
for (int64 dim : dims) {
int64 dim_to_pad_size = s.dimensions(dim);
- int64 new_dim_to_pad_size =
- RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+
+ // Round dim_to_pad_size up to the next multiple of
+ // kDesiredNumFeaturesFactor.
+ //
+ // Special case: dims of size 3 are rounded up to 4, not
+ // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia),
+ // this helps, but as of writing, it's not supported by anything in the
+ // cudnn docs.
+ int64 new_dim_to_pad_size;
+ if (dim_to_pad_size == 3) {
+ new_dim_to_pad_size = 4;
+ } else {
+ new_dim_to_pad_size =
+ RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+ }
+
s.set_dimensions(dim, new_dim_to_pad_size);
}
return s;
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 7e77dc9ac6..b42a19e3a2 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -30,7 +30,8 @@ namespace gpu {
namespace {
bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
- CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
+ CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
+ conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
return window_util::HasSymmetricPadding(conv.window()) &&
!window_util::HasNegativePadding(conv.window()) &&
!window_util::HasDilation(conv.window());
@@ -385,7 +386,8 @@ StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
}
for (HloInstruction* instruction : convs) {
const auto& target = instruction->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget) {
changed |= CanonicalizeForwardConvolution(instruction);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
changed |= CanonicalizeBackwardFilterConvolution(instruction);
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 5da6f232d5..a725533567 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -209,3 +209,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cudnn_fused_convolution_rewriter_test",
+ srcs = ["cudnn_fused_convolution_rewriter_test.cc"],
+ tags = tf_cuda_tests_tags(),
+ deps = [
+ ":gpu_codegen_test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
new file mode 100644
index 0000000000..5632cac186
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
@@ -0,0 +1,283 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/strings/str_replace.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CudnnFusedConvolutionRewriterTest : public HloTestBase {
+ protected:
+ string GetOptimizedHlo(absl::string_view hlo_string) {
+ return backend()
+ .compiler()
+ ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest())
+ .ConsumeValueOrDie(),
+ backend().default_stream_executor(),
+ backend().memory_allocator())
+ .ConsumeValueOrDie()
+ ->ToString();
+ }
+
+ void TestMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convForward"))
+ << optimized_hlo_string;
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo_string;
+ EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
+ << optimized_hlo_string;
+ }
+ }
+
+ void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ string optimized_hlo = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convForward"))
+ << optimized_hlo;
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo;
+ }
+ }
+};
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) {
+ // max(0, conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) {
+ // max(0, conv(x, w) + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) {
+ // max(0, conv(x, w) + side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) {
+ // max(0, conv(x, w) + side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) {
+ // max(0, 0.999994934 * conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+ scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) {
+ // max(0, conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest,
+ TestScaledConvAndScaledSideInputWithBias) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) {
+ // max(0.1, conv(x, w)) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ point_one = TYPE[] constant(0.1)
+ point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) {
+ // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input1 = TYPE[1,3,3,64] parameter(2)
+ side_input2 = TYPE[1,3,3,64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input2)
+ add2 = TYPE[1,3,3,64] add(add1, side_input1)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index a07eaaf997..2bd04259c0 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -827,33 +827,34 @@ class BufferIntervalTree {
// interval.
std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) {
std::vector<Chunk> result;
- if (node_count_ > 0) {
- ChunksOverlappingInTimeHelper(start, end, &node_storage_[0], &result);
+ if (node_count_ == 0) {
+ return result;
+ }
+ std::vector<BufferIntervalTreeNode*> visiting_stack;
+ visiting_stack.push_back(&node_storage_[0]);
+ while (!visiting_stack.empty()) {
+ BufferIntervalTreeNode* top = visiting_stack.back();
+ visiting_stack.pop_back();
+ if (start > top->subtree_end) {
+ continue;
+ }
+ if (top->left != nullptr) {
+ visiting_stack.push_back(top->left);
+ }
+ if (top->start <= end && top->end >= start) {
+ result.push_back(top->chunk);
+ }
+ if (end < top->start) {
+ continue;
+ }
+ if (top->right != nullptr) {
+ visiting_stack.push_back(top->right);
+ }
}
return result;
}
private:
- void ChunksOverlappingInTimeHelper(int64 start, int64 end,
- BufferIntervalTreeNode* visiting_node,
- std::vector<Chunk>* result) {
- if (start > visiting_node->subtree_end) {
- return;
- }
- if (visiting_node->left != nullptr) {
- ChunksOverlappingInTimeHelper(start, end, visiting_node->left, result);
- }
- if (visiting_node->start <= end && visiting_node->end >= start) {
- result->push_back(visiting_node->chunk);
- }
- if (end < visiting_node->start) {
- return;
- }
- if (visiting_node->right != nullptr) {
- ChunksOverlappingInTimeHelper(start, end, visiting_node->right, result);
- }
- }
-
int64 node_count_ = 0;
std::vector<BufferIntervalTreeNode> node_storage_;
};
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index b19ec12638..caaca16f71 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 53
+// Next ID: 54
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -124,9 +124,13 @@ message HloInstructionProto {
// The string representation of the infeed configuration.
bytes infeed_config = 27;
- // Name of a global symbol to call, only present for kCustomCall.
+ // Name of a external target (eg, global symbol) to call, only present for
+ // kCustomCall.
string custom_call_target = 28;
+ // Opaque string, only present for kCustomCall.
+ string custom_call_opaque = 53;
+
// Shape of outfeed request.
xla.Shape outfeed_shape = 29;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index e9e70b2c57..0e5920af7a 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -272,10 +272,11 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
<< "instruction " << instruction->name()
<< " has control successors and cannot be removed";
- TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
- auto inst_it = instruction_iterators_.at(instruction);
- (*inst_it)->set_parent(nullptr);
- instructions_.erase(inst_it);
+ auto inst_it = instruction_iterators_.find(instruction);
+ TF_RET_CHECK(inst_it != instruction_iterators_.end());
+ (*inst_it->second)->set_parent(nullptr);
+ instructions_.erase(inst_it->second);
+ instruction_iterators_.erase(inst_it);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index e7c98aae23..936a53bd7e 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -227,7 +227,7 @@ class HloComputation {
void UpdateReachabilityThroughInstruction(
const HloInstruction* instruction, HloReachabilityMap* reachability_map);
- int64 instruction_count() const { return instructions_.size(); }
+ int64 instruction_count() const { return instruction_iterators_.size(); }
// Creates and returns a list of the embedded computations called by this
// computation. This includes all embedded computations called directly or
@@ -439,7 +439,7 @@ class HloComputation {
// instruction pointer to location in the list for fast lookup.
using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
InstructionList instructions_;
- std::unordered_map<const HloInstruction*, InstructionList::iterator>
+ tensorflow::gtl::FlatMap<const HloInstruction*, InstructionList::iterator>
instruction_iterators_;
std::vector<HloInstruction*> param_instructions_;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index b91b2406e2..d7c39b2778 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -496,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
return Status::OK();
}
+Status HloEvaluator::HandleReal(HloInstruction* real) {
+ auto operand = real->operand(0);
+ switch (operand->shape().element_type()) {
+ case BF16: {
+ auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
+ real, [](bfloat16 elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case C64: {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ real, [](complex64 elem_operand) { return std::real(elem_operand); },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F16: {
+ auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
+ real, [](Eigen::half elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F32: {
+ auto result_or = ElementWiseUnaryOpImpl<float, float>(
+ real, [](float elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F64: {
+ auto result_or = ElementWiseUnaryOpImpl<double, double>(
+ real, [](double elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ default:
+ LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
+ << PrimitiveType_Name(operand->shape().element_type());
+ }
+
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleImag(HloInstruction* imag) {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
+ GetEvaluatedLiteralFor(imag->operand(0)));
+
+ TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
+ return Status::OK();
+}
+
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
HloOpcode opcode = compare->opcode();
auto lhs = compare->operand(0);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 21e676d671..6c2662ebae 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -184,6 +184,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSort(HloInstruction* sort) override;
+ Status HandleReal(HloInstruction* real) override;
+
+ Status HandleImag(HloInstruction* imag) override;
+
Status HandleReduce(HloInstruction* reduce) override;
// Returns the already-evaluated literal result for the instruction.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 04cdc6901c..b2d12c94b8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -89,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
// to this rule, notably:
// - HandleCompare and HandleIsFinite: where the resulting literal type is
// always boolean.
+// - HandleImag and HandleReal: where the resulting literal type is always float
+// and the operand is always complex, or real in the case of HandleReal.
// These operations are handled outside of the parent HloEvaluator handlers
// instead of from within TypedVisitor.
//
@@ -329,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleFloor<ReturnT>(floor);
}
- Status HandleImag(HloInstruction* imag) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag],
- ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) {
- return std::imag(elem_operand);
- }));
- return Status::OK();
- }
-
Status HandleLog(HloInstruction* log) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
@@ -684,14 +678,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
- Status HandleReal(HloInstruction* real) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[real],
- ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) {
- return std::real(elem_operand);
- }));
- return Status::OK();
- }
-
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleRemainder(HloInstruction* remainder) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ad58833e4d..23787dbc8a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -379,7 +379,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
case HloOpcode::kCustomCall:
instruction = CreateCustomCall(proto.shape(), all_operands(),
- proto.custom_call_target());
+ proto.custom_call_target(),
+ proto.custom_call_opaque());
if (proto.has_window()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_window(proto.window());
@@ -1108,9 +1109,9 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target) {
- return absl::make_unique<HloCustomCallInstruction>(shape, operands,
- custom_call_target);
+ absl::string_view custom_call_target, absl::string_view opaque) {
+ return absl::make_unique<HloCustomCallInstruction>(
+ shape, operands, custom_call_target, opaque);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
@@ -2423,7 +2424,7 @@ template <typename Visitor>
static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
const InternalCompareFunction* operand_order,
bool ignore_control_predecessors) {
- visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds());
+ visitor->ReserveVisitStates(root->GetModule()->instruction_count());
// dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
//
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index d615df0831..009bd3bab3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -718,10 +718,11 @@ class HloInstruction {
HloComputation* computation);
// Creates a custom call instruction that applies the given custom call target
- // to the given operands. "shape" is the resultant shape.
+ // to the given operands. "opaque" can be an arbitrary string with a
+ // backend-specific interpretation. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target);
+ absl::string_view custom_call_target, absl::string_view opaque = "");
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e92882c22a..cd71bc3323 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1830,9 +1830,10 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target)
+ absl::string_view custom_call_target, absl::string_view opaque)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ opaque_(opaque.begin(), opaque.end()),
feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
@@ -1849,6 +1850,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_custom_call_opaque(opaque_);
proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1872,6 +1874,11 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
// an HloComputation.
extra.push_back(
StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
+ // If the opaque string becomes enormous we may want to reconsider printing
+ // this inline and consider other options.
+ if (!opaque_.empty()) {
+ extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
+ }
return extra;
}
@@ -1897,7 +1904,8 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
if (feature_group_count_ != casted_other.feature_group_count_) {
return false;
}
- return custom_call_target_ == casted_other.custom_call_target_;
+ return custom_call_target_ == casted_other.custom_call_target_ &&
+ opaque_ == casted_other.opaque_;
}
std::unique_ptr<HloInstruction>
@@ -1905,7 +1913,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
auto cloned = absl::make_unique<HloCustomCallInstruction>(
- shape, new_operands, custom_call_target());
+ shape, new_operands, custom_call_target(), opaque());
if (window_ != nullptr) {
cloned->set_window(*window_);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 2d7bc83855..9c22f5db7e 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1070,7 +1070,8 @@ class HloCustomCallInstruction : public HloInstruction {
public:
explicit HloCustomCallInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target);
+ absl::string_view custom_call_target,
+ absl::string_view opaque);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1090,6 +1091,7 @@ class HloCustomCallInstruction : public HloInstruction {
convolution_dimension_numbers_ =
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
+ const string& opaque() const { return opaque_; }
const string& custom_call_target() const { return custom_call_target_; }
void set_feature_group_count(int64 feature_group_count) {
feature_group_count_ = feature_group_count;
@@ -1109,8 +1111,10 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // Name of a global symbol to call, only present for kCustomCall.
+ // Name of a global symbol to call.
string custom_call_target_;
+ // Opaque string interpreted by the backend.
+ string opaque_;
// Describes the window in a windowed operation such as convolution.
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index c7ec88d450..6a4e766788 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -400,7 +400,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
memory_by_computation) {
// These variables are a hack to prevent overflows.
int64 cumulative_total_size = 0;
- int64 total_hlos = computation.parent()->NumUniqueInstructionIds();
+ int64 total_hlos = computation.parent()->instruction_count();
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index f1dc08bafa..23d41d91d6 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -92,14 +92,18 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
}
bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
- // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
- // is live into the module.
+ // Entry parameter should always be defined before other instructions.
const HloModule* module = b.defining_instruction()->parent()->parent();
if (b.defining_instruction()->parent() == module->entry_computation() &&
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
return false;
}
+ if (a.defining_instruction()->parent() == module->entry_computation() &&
+ a.defining_instruction()->opcode() == HloOpcode::kParameter) {
+ return true;
+ }
+
// Phi values require special handling. Because XLA does not have a phi
// instruction, the definition instruction of the phis values are
// placeholders: either the subcomputation parameter (body or condition) or
@@ -316,7 +320,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const {
for (auto predecessor : all) {
if (predecessors_.at(computation)
->IsReachable(predecessor, instruction)) {
- pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
+ pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
}
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 00970bcda3..b045adc964 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -174,6 +174,26 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
+TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) {
+ // Entry parameter should always be defined before other instruction.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ module->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ DependencyHloOrdering ordering(module.get());
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param),
+ dataflow->GetValueDefinedAt(constant)));
+ EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(param)));
+}
+
TEST_F(HloOrderingTest, ValuesInWhileComputations) {
// Tests the ordering of values (defined by dataflow analysis) in the body and
// condition of a while instruction. HLO code:
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 37197b273b..25b70740e3 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1266,11 +1266,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kCustomCall: {
optional<string> custom_call_target;
+ optional<string> opaque;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
+ attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
@@ -1279,8 +1281,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
- shape, operands, *custom_call_target));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateCustomCall(shape, operands, *custom_call_target,
+ opaque.has_value() ? *opaque : ""));
if (window.has_value()) {
instruction->set_window(*window);
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index cca50fab54..96db96bdb9 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1004,6 +1004,18 @@ ENTRY CustomCall {
)"
},
+// CustomCall with opaque value.
+{
+"CustomCallWithOpaque",
+R"(HloModule custom_call
+
+ENTRY CustomCall {
+ constant = f32[1]{0} constant({12345})
+ ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque"
+}
+
+)"
+},
// Variables with non-default names
{
"NonDefaultNames",
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 3fdc2cee9a..e884122fcb 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -188,13 +188,20 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
bool InstructionFusion::CanFuseOnAllPaths(
HloInstruction* producer, HloInstruction* consumer,
- const HloInstructionSet& do_not_duplicate) {
+ const HloInstructionSet& do_not_fuse,
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>*
+ result_cache) {
if (consumer == producer) {
return true;
}
if (!consumer->IsFusible()) {
return false;
}
+ auto cache_it = result_cache->find(std::make_pair(producer, consumer));
+ if (cache_it != result_cache->end()) {
+ return cache_it->second;
+ }
+ bool result = true;
for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
auto* consumer_operand = consumer->mutable_operand(i);
// If the operand is not on a path to the producer, it doesn't matter
@@ -202,20 +209,23 @@ bool InstructionFusion::CanFuseOnAllPaths(
if (!reachability_->IsReachable(producer, consumer_operand)) {
continue;
}
- if (do_not_duplicate.count(consumer_operand) > 0 ||
- !ShouldFuse(consumer, i)) {
- return false;
+ if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
+ result = false;
+ break;
}
// The producer is reachable from consumer_operand which means we need
// to be able to fuse consumer_operand into consumer in order for
// producer to be fusible into consumer on all paths.
// Perform the recursive step: make sure producer can be fused into
// consumer_operand on all paths.
- if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) {
- return false;
+ if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
+ result_cache)) {
+ result = false;
+ break;
}
}
- return true;
+ result_cache->emplace(std::make_pair(producer, consumer), result);
+ return result;
}
InstructionFusion::HloInstructionSet
@@ -231,6 +241,8 @@ InstructionFusion::ComputeGloballyUnfusible(
// fusing operations that require duplication later depending on
// is_expensive_().
HloInstructionSet do_not_duplicate;
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>
+ can_fuse_on_all_paths_result_cache;
for (HloInstruction* consumer : post_order) {
for (HloInstruction* producer : consumer->operands()) {
if (do_not_duplicate.count(producer) > 0) {
@@ -286,7 +298,8 @@ InstructionFusion::ComputeGloballyUnfusible(
// A will be not allowed to be fused into B, as it cannot be fused via
// all paths.
if (producer->IsFusible() &&
- CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) {
+ CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
+ &can_fuse_on_all_paths_result_cache)) {
continue;
}
do_not_duplicate.insert(producer);
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 7e1196fb7f..c1ec3b18a1 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -151,8 +151,15 @@ class InstructionFusion : public HloModulePass {
// Whether or not we can fuse producer into consumer on all paths
// from the producer to the consumer where nodes are HLOs and edges are uses.
- bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer,
- const HloInstructionSet& do_not_fuse);
+ //
+ // A map from <producer, consumer> to a bool is required as the result cache
+ // to store and query the results of calls to this function, in order to avoid
+ // repeated computations.
+ bool CanFuseOnAllPaths(
+ HloInstruction* producer, HloInstruction* consumer,
+ const HloInstructionSet& do_not_fuse,
+ tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>,
+ bool>* result_cache);
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index eaa09591b7..ec52a24d78 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -54,7 +54,7 @@ Status LogicalBufferAnalysis::Analyze() {
// so reserve 10% more than the number of instructions to avoid frequent
// resizes.
logical_buffers_.clear();
- logical_buffers_.reserve((module_->NumUniqueInstructionIds() * 11) / 10);
+ logical_buffers_.reserve((module_->instruction_count() * 11) / 10);
// We filter out fusion computations, and get to them through fusion
// instructions. This is because it's possible to have orphaned (unreachable)
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 96c80fd577..020c167ee9 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -422,8 +422,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
- CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
- CHECK_EQ(shape.dimensions_size(), Rank(shape));
+ DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
+ DCHECK_EQ(shape.dimensions_size(), Rank(shape));
+ if (shape.dimensions().size() == 1) {
+ return shape.dimensions()[0];
+ }
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
std::multiplies<int64>());
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index ae5ca32bcf..98dff965a9 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -112,26 +112,14 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/estimator:estimator_py",
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
- "//tensorflow/contrib/kafka",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
- "//tensorflow/contrib/kinesis",
- ],
- "//conditions:default": [],
- }) + if_not_windows_cuda([
- "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
- ]) + if_not_windows([
- ]) + select({
"//tensorflow:linux_s390x": [],
"//tensorflow:windows": [],
"//conditions:default": [
"//tensorflow/contrib/bigtable",
"//tensorflow/contrib/cloud:cloud_py",
+ "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
+ "//tensorflow/contrib/kafka",
+ "//tensorflow/contrib/kinesis",
"//tensorflow/contrib/tensorrt:init_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
],
@@ -144,7 +132,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
"//tensorflow/contrib/coder:all_kernels",
- "//tensorflow/contrib/data/kernels:dataset_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/hadoop:dataset_kernels",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
@@ -159,20 +146,14 @@ cc_library(
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([
"//tensorflow/contrib/nccl:nccl_kernels",
]) + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
"//tensorflow/contrib/kafka:dataset_kernels",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_kernels",
+ "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
],
- "//conditions:default": [],
- }) + if_not_windows([
- "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
- ]),
+ }),
)
cc_library(
@@ -181,8 +162,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
@@ -198,18 +177,12 @@ cc_library(
"//tensorflow/contrib/text:all_ops",
"//tensorflow/contrib/tpu:all_ops",
] + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
],
- "//conditions:default": [],
- }) + if_not_windows([
- "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
- ]),
+ }),
)
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
index 11f530e82a..2c6317157d 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -28,6 +28,7 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
DatasetBase** output) override {
BigtableTableResource* table;
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
+ core::ScopedUnref scoped_unref(table);
std::vector<string> column_families;
std::vector<string> columns;
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
index 5cab729d9c..92a3658667 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -31,6 +31,7 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
*output = new Dataset(ctx, resource, std::move(prefix));
}
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
index 4dc4647bd2..bd8805a382 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -34,6 +34,7 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
*output =
new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
index 736775bdac..01608dc6bc 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -38,6 +38,7 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
errors::InvalidArgument(
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
index 208b7b3e08..9b60e0a667 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -28,6 +28,7 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
*output = new Dataset(ctx, resource);
}
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
index 9407855fe8..688289a4e2 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -67,6 +67,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
BigtableTableResource* resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ core::ScopedUnref scoped_unref(resource);
const uint64 num_outputs = columns.size() + 1;
std::vector<PartialTensorShape> output_shapes;
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index 3e1b622867..cf56822ff4 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -575,7 +575,7 @@ def _normalize_columns(columns, provided_kwargs):
return normalized
-class _BigtableKeyDataset(dataset_ops.Dataset):
+class _BigtableKeyDataset(dataset_ops.DatasetSource):
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
"""
@@ -645,7 +645,7 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset):
table=self._table._resource) # pylint: disable=protected-access
-class _BigtableLookupDataset(dataset_ops.Dataset):
+class _BigtableLookupDataset(dataset_ops.DatasetSource):
"""_BigtableLookupDataset represents a dataset that retrieves values for keys.
"""
@@ -678,7 +678,7 @@ class _BigtableLookupDataset(dataset_ops.Dataset):
columns=self._columns)
-class _BigtableScanDataset(dataset_ops.Dataset):
+class _BigtableScanDataset(dataset_ops.DatasetSource):
"""_BigtableScanDataset represents a dataset that retrieves keys and values.
"""
@@ -715,7 +715,7 @@ class _BigtableScanDataset(dataset_ops.Dataset):
probability=self._probability)
-class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset):
+class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
"""_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table.
"""
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index c050c2ed7f..a2f708081a 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -170,7 +170,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testObliviousFeatureSplitGeneration(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 1 | 1 |
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index c7eb2493a8..8531e97f90 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -402,13 +402,13 @@ class GradientBoostedDecisionTreeModel(object):
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
self._num_quantiles = num_quantiles
- self._max_tree_depth = variables.Variable(
+ self._max_tree_depth = variables.VariableV1(
initial_value=self._learner_config.constraints.max_tree_depth)
- self._attempted_trees = variables.Variable(
+ self._attempted_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="attempted_trees")
- self._finalized_trees = variables.Variable(
+ self._finalized_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="finalized_trees")
@@ -770,28 +770,28 @@ class GradientBoostedDecisionTreeModel(object):
fc_name_idx += 1
# Create ensemble stats variables.
- num_layer_examples = variables.Variable(
+ num_layer_examples = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_examples",
trainable=False)
- num_layer_steps = variables.Variable(
+ num_layer_steps = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_steps",
trainable=False)
- num_layers = variables.Variable(
+ num_layers = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layers",
trainable=False)
- active_tree = variables.Variable(
+ active_tree = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_tree",
trainable=False)
- active_layer = variables.Variable(
+ active_layer = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_layer",
trainable=False)
# Variable that becomes false once bias centering is done.
- continue_centering = variables.Variable(
+ continue_centering = variables.VariableV1(
initial_value=self._center_bias,
name="continue_centering",
trainable=False)
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 9d9941f696..6d20a2e7f4 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -239,7 +239,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -503,7 +503,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -607,7 +607,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -711,7 +711,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -783,7 +783,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -847,7 +847,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1090,7 +1090,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1194,7 +1194,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1299,7 +1299,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1405,7 +1405,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1524,7 +1524,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1656,7 +1656,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index ebcabb4223..c6d6f04168 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -353,7 +353,7 @@ endif()
# MKL Support
if (tensorflow_ENABLE_MKL_SUPPORT)
- add_definitions(-DINTEL_MKL -DEIGEN_USE_VML)
+ add_definitions(-DINTEL_MKL -DEIGEN_USE_VML -DENABLE_MKL)
include(mkl)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination)
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index c0763f4c0e..2975b167ec 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -132,7 +132,6 @@ tensorflow/contrib/cudnn_rnn/python
tensorflow/contrib/cudnn_rnn/python/layers
tensorflow/contrib/cudnn_rnn/python/ops
tensorflow/contrib/data
-tensorflow/contrib/data/kernels
tensorflow/contrib/data/python
tensorflow/contrib/data/python/kernel_tests
tensorflow/contrib/data/python/kernel_tests/serialization
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index f51bfc1b22..f83386b8a4 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -65,7 +65,7 @@ py_library(
"//tensorflow/python:summary_op_util",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
- "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:estimator_py",
],
)
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 1e30525159..873b03580d 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -293,7 +293,8 @@ def _compile_internal(computation, inputs=None):
saved_use_resource = vscope.use_resource
vscope.set_use_resource(True)
- outputs = computation(*computation_inputs)
+ with _disable_summary_context():
+ outputs = computation(*computation_inputs)
# Restore variable scope after computation.
vscope.set_use_resource(saved_use_resource)
@@ -371,13 +372,13 @@ def _disable_summary_context():
Yields:
None.
"""
- origional_skip_summary_func = summary_op_util.skip_summary
+ original_skip_summary_func = summary_op_util.skip_summary
summary_op_util.skip_summary = lambda: True
try:
yield
finally:
- summary_op_util.skip_summary = origional_skip_summary_func
+ summary_op_util.skip_summary = original_skip_summary_func
class _CapturedObject(object):
@@ -436,8 +437,7 @@ class _ModelFnWrapper(object):
if mode == model_fn_lib.ModeKeys.TRAIN:
train_step, captured_scaffold_fn = self._make_train_step(
features, labels, params)
- with _disable_summary_context():
- (loss,) = compile(train_step)
+ (loss,) = compile(train_step)
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=loss,
@@ -446,8 +446,7 @@ class _ModelFnWrapper(object):
elif mode == model_fn_lib.ModeKeys.EVAL:
eval_step, captured_eval_metric_fn, captured_scaffold_fn = (
self._make_eval_step(features, labels, params))
- with _disable_summary_context():
- outputs = compile(eval_step)
+ outputs = compile(eval_step)
loss = outputs[0]
# Calculate eval_metric_ops if eval_metric_fn is set and captured.
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
index d1af15f7e4..67f8ac2b93 100644
--- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
@@ -102,9 +102,9 @@ def _project_multipliers_wrt_euclidean_norm(multipliers, radius):
0.0,
(radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive)))
- multipliers += scale * inactive
+ multipliers = multipliers + (scale * inactive)
new_inactive = standard_ops.cast(multipliers > 0, multipliers.dtype)
- multipliers *= new_inactive
+ multipliers = multipliers * new_inactive
return (iteration, multipliers, new_inactive, inactive)
iteration = standard_ops.constant(0)
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
index 2c673d9347..a6cb1f62f0 100644
--- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
@@ -175,9 +175,9 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
scale = (1.0 - standard_ops.reduce_sum(
matrix, axis=0, keepdims=True)) / standard_ops.maximum(
1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True))
- matrix += scale * inactive
+ matrix = matrix + (scale * inactive)
new_inactive = standard_ops.cast(matrix > 0, matrix.dtype)
- matrix *= new_inactive
+ matrix = matrix * new_inactive
return (iteration, matrix, new_inactive, inactive)
iteration = standard_ops.constant(0)
@@ -210,8 +210,9 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
# For numerical reasons, make sure that the largest matrix element is zero
# before exponentiating.
- log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True)
- log_matrix -= standard_ops.log(
+ log_matrix = log_matrix - standard_ops.reduce_max(
+ log_matrix, axis=0, keepdims=True)
+ log_matrix = log_matrix - standard_ops.log(
standard_ops.reduce_sum(
standard_ops.exp(log_matrix), axis=0, keepdims=True))
return log_matrix
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index 6c9ab6aeb8..9c5871da34 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -31,7 +31,7 @@ from __future__ import division
from __future__ import print_function
from copy import deepcopy
-from tensorflow.python.ops.variables import Variable
+from tensorflow.python.ops.variables import VariableV1
from tensorflow.python.client.session import Session
from tensorflow.python.framework import ops
@@ -55,7 +55,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''):
TypeError: If `org_instance` is not a `Variable`.
"""
- if not isinstance(org_instance, Variable):
+ if not isinstance(org_instance, VariableV1):
raise TypeError(str(org_instance) + ' is not a Variable')
#The name of the new variable
@@ -88,7 +88,7 @@ def copy_variable_to_graph(org_instance, to_graph, scope=''):
#Initialize the new variable
with to_graph.as_default():
- new_var = Variable(
+ new_var = VariableV1(
init_value,
trainable,
name=new_name,
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py
index 05744bec4e..ba97c78456 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_test.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py
@@ -36,7 +36,7 @@ class CopyVariablesTest(test.TestCase):
with graph1.as_default():
#Define a Variable in graph1
- some_var = variables.Variable(2)
+ some_var = variables.VariableV1(2)
#Initialize session
sess1 = session_lib.Session()
#Initialize the Variable
@@ -72,7 +72,7 @@ class CopyOpsTest(test.TestCase):
with graph1.as_default():
#Initialize a basic expression y = ax + b
x = array_ops.placeholder("float")
- a = variables.Variable(3.0)
+ a = variables.VariableV1(3.0)
b = constant_op.constant(4.0)
ax = math_ops.multiply(x, a)
y = math_ops.add(ax, b)
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 2a91dcb63a..43bb43129b 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -56,7 +56,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
@@ -214,10 +213,11 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
log_norm)
return log_norm
- max_seq_len = array_ops.shape(inputs)[1]
- return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
- true_fn=_single_seq_fn,
- false_fn=_multi_seq_fn)
+ return utils.smart_cond(
+ pred=math_ops.equal(inputs.shape[1].value or
+ array_ops.shape(inputs)[1], 1),
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
def crf_log_likelihood(inputs,
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 9f710613dd..38f1c65a4d 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -4,17 +4,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load(
- "//tensorflow:tensorflow.bzl",
- "tf_custom_op_library",
- "tf_gen_op_libs",
- "if_not_windows",
-)
-load(
- "//tensorflow/core:platform/default/build_config_root.bzl",
- "if_static",
-)
-
py_library(
name = "data",
srcs = ["__init__.py"],
@@ -25,30 +14,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-cc_library(
- name = "lib_proto_parsing_for_dataset_ops",
- deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]),
-)
-
-tf_custom_op_library(
- name = "_dataset_ops.so",
- srcs = [
- "ops/dataset_ops.cc",
- "ops/indexed_dataset_ops.cc",
- ],
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/contrib/data/kernels:indexed_dataset",
- ] + if_static(
- extra_deps = [":lib_proto_parsing_for_dataset_ops"],
- otherwise = [],
- ),
-)
-
-tf_gen_op_libs(
- op_lib_names = [
- "dataset_ops",
- "indexed_dataset_ops",
- ],
-)
diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
deleted file mode 100644
index cd9b7c68a0..0000000000
--- a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("IdentityIndexedDataset")
- .Input("size: uint64")
- .Output("handle: variant")
- .SetIsStateful()
- .SetShapeFn(
- shape_inference::ScalarShape); // TODO(saeta): check input shapes.
-
-///////////////////////////////////////////////////////////////////////////////
-// IndexedDataset Internals
-///////////////////////////////////////////////////////////////////////////////
-
-// Creates the handle.
-REGISTER_OP("MaterializedIndexDatasetHandle")
- .Output("handle: resource")
- .Attr("container: string")
- .Attr("shared_name: string")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
-
-// Actually materialize the materialize handle.
-REGISTER_OP("IndexedDatasetMaterialize")
- .Input("dataset: variant")
- .Input("materialized: resource")
- .SetShapeFn(shape_inference::NoOutputs);
-
-namespace {
-
-Status GetShapeFn(shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
-}
-
-} // namespace
-
-REGISTER_OP("IndexedDatasetGet")
- .Input("materialized: resource")
- .Input("index: uint64")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(GetShapeFn)
- .Doc(R"doc(
-Gets the element at `index` from `materialized` IndexedDataset.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index c15e8d8861..33784afa3f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -31,6 +31,7 @@ py_test(
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -54,6 +55,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -77,6 +79,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -97,6 +100,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
],
@@ -112,6 +116,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:random_seed",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -130,6 +135,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -139,12 +145,12 @@ py_test(
name = "indexed_dataset_ops_test",
srcs = ["indexed_dataset_ops_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/data/python/ops:gen_dataset_ops",
"//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -170,6 +176,7 @@ py_test(
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six",
],
@@ -189,8 +196,8 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
],
)
@@ -216,6 +223,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//third_party/py/numpy",
],
)
@@ -241,6 +249,7 @@ py_test(
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -260,6 +269,7 @@ py_test(
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -284,6 +294,7 @@ py_test(
"//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
],
)
@@ -302,6 +313,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
@@ -317,6 +329,7 @@ cuda_py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
@@ -342,6 +355,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -367,6 +381,7 @@ py_library(
"//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers",
],
@@ -413,6 +428,7 @@ py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -435,6 +451,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -455,6 +472,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -472,6 +490,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -491,6 +510,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/kernel_tests:test_base",
"@org_sqlite//:python",
],
)
@@ -535,6 +555,7 @@ py_library(
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
],
)
@@ -551,6 +572,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:script_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -569,6 +591,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -589,6 +612,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -606,17 +630,8 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
],
)
-
-py_library(
- name = "test_utils",
- srcs = ["test_utils.py"],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/util:nest",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index e2508de9e9..fed7de5f2b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -40,12 +41,8 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
@@ -723,7 +720,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-class RestructuredDatasetTest(test.TestCase):
+class RestructuredDatasetTest(test_base.DatasetTestBase):
def test_assert_element_shape(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 48971f2ccc..ae401f786c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -35,7 +36,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class GroupByReducerTest(test.TestCase):
+class GroupByReducerTest(test_base.DatasetTestBase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
@@ -198,7 +199,7 @@ class GroupByReducerTest(test.TestCase):
self.assertEqual(y, 45)
-class GroupByWindowTest(test.TestCase):
+class GroupByWindowTest(test_base.DatasetTestBase):
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@@ -345,7 +346,7 @@ class GroupByWindowTest(test.TestCase):
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
-class BucketTest(test.TestCase):
+class BucketTest(test_base.DatasetTestBase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
@@ -570,7 +571,7 @@ def _get_record_shape(sparse):
return tensor_shape.TensorShape([None])
-class BucketBySequenceLength(test.TestCase):
+class BucketBySequenceLength(test_base.DatasetTestBase):
def testBucket(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index f8e74e4583..5b3c512b64 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -43,37 +44,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test.TestCase):
-
- def _get_next(self, dataset):
- # Returns a no argument function whose result is fed to self.evaluate to
- # yield the next element
- it = dataset.make_one_shot_iterator()
- if context.executing_eagerly():
- return it.get_next
- else:
- get_next = it.get_next()
- return lambda: get_next
-
- def _assert_datasets_equal(self, ds1, ds2):
- assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
- '%s') % (ds1.output_shapes,
- ds2.output_shapes)
- assert ds1.output_types == ds2.output_types
- assert ds1.output_classes == ds2.output_classes
- next1 = self._get_next(ds1)
- next2 = self._get_next(ds2)
- # Run through datasets and check that outputs match, or errors match.
- while True:
- try:
- op1 = self.evaluate(next1())
- except (errors.OutOfRangeError, ValueError) as e:
- # If op1 throws an exception, check that op2 throws same exception.
- with self.assertRaises(type(e)):
- self.evaluate(next2())
- break
- op2 = self.evaluate(next2())
- self.assertAllEqual(op1, op2)
+class CsvDatasetOpTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@@ -108,7 +79,7 @@ class CsvDatasetOpTest(test.TestCase):
"""Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
dataset_actual, dataset_expected = self._make_test_datasets(
inputs, **kwargs)
- self._assert_datasets_equal(dataset_actual, dataset_expected)
+ self.assertDatasetsEqual(dataset_actual, dataset_expected)
def _verify_output_or_err(self,
dataset,
@@ -116,7 +87,7 @@ class CsvDatasetOpTest(test.TestCase):
expected_err_re=None):
if expected_err_re is None:
# Verify that output is expected, without errors
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
expected_output = [[
v.encode('utf-8') if isinstance(v, str) else v for v in op
] for op in expected_output]
@@ -128,7 +99,7 @@ class CsvDatasetOpTest(test.TestCase):
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
while True:
try:
self.evaluate(nxt())
@@ -354,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,,3,4', '5,6,,8']]
ds_actual, ds_expected = self._make_test_datasets(
inputs, record_defaults=record_defaults)
- self._assert_datasets_equal(
+ self.assertDatasetsEqual(
ds_actual.repeat(5).prefetch(1),
ds_expected.repeat(5).prefetch(1))
@@ -377,7 +348,7 @@ class CsvDatasetOpTest(test.TestCase):
ds = readers.make_csv_dataset(
file_path, batch_size=1, shuffle=False, num_epochs=1)
- nxt = self._get_next(ds)
+ nxt = self.getNext(ds)
result = list(self.evaluate(nxt()).values())
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
index a2ab3de52e..722e87e555 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -25,7 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index eb110324d1..bc10c21472 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -20,13 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed
from tensorflow.python.platform import test
-class DirectedInterleaveDatasetTest(test.TestCase):
+class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
def testBasic(self):
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index f3968cdc15..cc22ea1df7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import get_single_element
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test.TestCase, parameterized.TestCase):
+class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Zero", 0, 1),
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
index 9c508d686d..d4d3d4adb2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -19,29 +19,30 @@ from __future__ import print_function
import unittest
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import indexed_dataset_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.platform import test
-class IndexedDatasetOpsTest(test.TestCase):
+class IndexedDatasetOpsTest(test_base.DatasetTestBase):
def testLowLevelIndexedDatasetOps(self):
- identity = gen_dataset_ops.identity_indexed_dataset(
+ identity = ged_ops.experimental_identity_indexed_dataset(
ops.convert_to_tensor(16, dtype=dtypes.uint64))
- handle = gen_dataset_ops.materialized_index_dataset_handle(
+ handle = ged_ops.experimental_materialized_index_dataset_handle(
container="",
shared_name="",
output_types=[dtypes.uint64],
output_shapes=[[]])
- materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle)
+ materialize = ged_ops.experimental_indexed_dataset_materialize(
+ identity, handle)
index = array_ops.placeholder(dtypes.uint64)
- get_op = gen_dataset_ops.indexed_dataset_get(
+ get_op = ged_ops.experimental_indexed_dataset_get(
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
with self.cached_session() as sess:
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index b9e74dfddb..28bd670ab5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -25,6 +25,7 @@ import time
from six.moves import zip_longest
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -36,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class ParallelInterleaveDatasetTest(test.TestCase):
+class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 704c0d1eb2..58a1d7c93b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
@@ -33,7 +34,7 @@ from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-class CheckpointInputPipelineHookTest(test.TestCase):
+class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
@staticmethod
def _model_fn(features, labels, mode, config):
@@ -42,7 +43,7 @@ class CheckpointInputPipelineHookTest(test.TestCase):
del config
global_step = training_util.get_or_create_global_step()
update_global_step_op = global_step.assign_add(1)
- latest_feature = variables.Variable(
+ latest_feature = variables.VariableV1(
0, name='latest_feature', dtype=dtypes.int64)
store_latest_feature_op = latest_feature.assign(features)
ops.add_to_collection('my_vars', global_step)
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index 1cc5ddc9a2..d2a72272db 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -22,6 +22,7 @@ import os
import shutil
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.util import compat
prefix_path = "tensorflow/core/lib"
-class LMDBDatasetTest(test.TestCase):
+class LMDBDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(LMDBDatasetTest, self).setUp()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index e8519381d6..385c4ef6ea 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -41,7 +42,7 @@ from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase):
def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 25aea0393f..751e6d5b30 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -21,6 +21,7 @@ import time
from tensorflow.contrib.data.python.ops import map_defun
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,8 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapDefunTest(test.TestCase):
+
+class MapDefunTest(test_base.DatasetTestBase):
def testMapDefunSimple(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index a2fc244ced..d7b5edcd9a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -15,11 +15,30 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_test(
+ name = "hoist_random_uniform_test",
+ size = "small",
+ srcs = ["hoist_random_uniform_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
name = "latency_all_edges_test",
size = "small",
srcs = ["latency_all_edges_test.py"],
@@ -40,7 +59,6 @@ py_test(
srcs = ["map_vectorization_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
@@ -50,6 +68,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -68,6 +87,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -85,6 +105,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -104,6 +125,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -120,6 +142,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -134,6 +157,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
index d10da80442..fe1b5280ba 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -18,12 +18,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class AssertNextDatasetTest(test.TestCase):
+class AssertNextDatasetTest(test_base.DatasetTestBase):
def testAssertNext(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
new file mode 100644
index 0000000000..b43efb5c7c
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -0,0 +1,103 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for HostState optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ plus_one = lambda x: x + 1
+
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=1,
+ maxval=10,
+ dtype=dtypes.float32,
+ seed=42)
+
+ def random_with_assert(x):
+ y = random(x)
+ assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
+ with ops.control_dependencies([assert_op]):
+ return y
+
+ twice_random = lambda x: (random(x) + random(x)) / 2.
+
+ tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
+ ("RandomWithAssert", random_with_assert, True),
+ ("TwiceRandom", twice_random, False)]
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testHoisting(self, function, will_optimize):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(
+ ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
+
+ dataset = dataset.apply(optimization.optimize(["hoist_random_uniform"]))
+ self._testDataset(dataset)
+
+ def testAdditionalInputs(self):
+ a = constant_op.constant(1, dtype=dtypes.float32)
+ b = constant_op.constant(0, dtype=dtypes.float32)
+ some_tensor = math_ops.mul(a, b)
+
+ def random_with_capture(_):
+ return some_tensor + random_ops.random_uniform(
+ [], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
+
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(
+ ["Zip[0]", "Map"])).map(random_with_capture).apply(
+ optimization.optimize(["hoist_random_uniform"]))
+ self._testDataset(dataset)
+
+ def _testDataset(self, dataset):
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ previous_result = 0
+ with self.cached_session() as sess:
+ for _ in range(5):
+ result = sess.run(get_next)
+ self.assertLessEqual(1, result)
+ self.assertLessEqual(result, 10)
+ # This checks if the result is somehow random by checking if we are not
+ # generating the same values.
+ self.assertNotEqual(previous_result, result)
+ previous_result = result
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
index db380c02a9..e4f18222fd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
@@ -34,8 +34,8 @@ class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
optimization.assert_next(
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
- optimization.optimize(["latency_all_edges"])).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
+ stats_ops.set_stats_aggregator(stats_aggregator)).apply(
+ optimization.optimize(["latency_all_edges"]))
iterator = dataset.make_initializable_iterator()
get_next = iterator.get_next()
summary_t = stats_aggregator.get_summary()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index e75edf6086..e9e3fc81e5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -28,7 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
+class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
index dd547db086..f7907eb890 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class MapParallelizationTest(test.TestCase, parameterized.TestCase):
+class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index 5b493f44c9..a5ea85f454 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -22,9 +22,9 @@ import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import test_utils
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
+class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def _get_test_datasets(self,
base_dataset,
@@ -85,7 +85,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
num_parallel_calls)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationBadMapFn(self):
# Test map functions that give an error
@@ -112,7 +112,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
# TODO(rachelim): when this optimization works, turn on expect_optimized
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(optimized, unoptimized)
+ self.assertDatasetsEqual(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):
@@ -124,7 +124,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
@@ -138,7 +138,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationIgnoreRaggedMap(self):
# Don't optimize when the output of the map fn shapes are unknown.
@@ -148,7 +148,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
index 3b62a7e468..33c250ab2a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -23,12 +23,13 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class ModelDatasetTest(test.TestCase):
+class ModelDatasetTest(test_base.DatasetTestBase):
def testModelMap(self):
k = 1024 * 1024
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
index 507feda3ad..b9e60cfa4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -26,7 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class NoopEliminationTest(test.TestCase):
+class NoopEliminationTest(test_base.DatasetTestBase):
def testNoopElimination(self):
a = constant_op.constant(1, dtype=dtypes.int64)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index a3fb824ce9..04f499f8c5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -28,7 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class OptimizeDatasetTest(test.TestCase):
+class OptimizeDatasetTest(test_base.DatasetTestBase):
def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
index c4623bca73..66ccaceea5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -72,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
i += 1
-class ParseExampleTest(test.TestCase):
+class ParseExampleTest(test_base.DatasetTestBase):
def _test(self,
input_tensor,
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 33a64ea767..7a6a7a709a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -22,6 +22,7 @@ import threading
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -35,7 +36,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class PrefetchingKernelsOpsTest(test.TestCase):
+class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
def setUp(self):
self._event = threading.Event()
@@ -244,7 +245,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
-class PrefetchToDeviceTest(test.TestCase):
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -445,7 +446,7 @@ class PrefetchToDeviceTest(test.TestCase):
sess.run(next_element)
-class CopyToDeviceTest(test.TestCase):
+class CopyToDeviceTest(test_base.DatasetTestBase):
def testCopyToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index db8fe6aa1b..2e901587f4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import counter
from tensorflow.contrib.data.python.ops import enumerate_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -27,7 +28,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def testEnumerateDataset(self):
components = (["a", "b"], [1, 2], [37.0, 38])
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index ed75b27a44..66ed547b6d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
@@ -242,7 +243,7 @@ class ReadBatchFeaturesTest(
self.assertEqual(32, shape[0])
-class MakeCsvDatasetTest(test.TestCase):
+class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
return readers.make_csv_dataset(
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
index 08b9f03816..f443b5501b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -25,6 +25,7 @@ import zlib
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import constant_op
@@ -32,11 +33,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class FixedLengthRecordDatasetTestBase(test.TestCase):
+class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing FixedLengthRecordDataset."""
def setUp(self):
@@ -63,7 +63,7 @@ class FixedLengthRecordDatasetTestBase(test.TestCase):
return filenames
-class ReadBatchFeaturesTestBase(test.TestCase):
+class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing `make_batched_feature_dataset`."""
def setUp(self):
@@ -273,7 +273,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
self.assertAllEqual(expected_batch[i], actual_batch[i])
-class TextLineDatasetTestBase(test.TestCase):
+class TextLineDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TextLineDataset."""
def _lineText(self, f, l):
@@ -313,7 +313,7 @@ class TextLineDatasetTestBase(test.TestCase):
return filenames
-class TFRecordDatasetTestBase(test.TestCase):
+class TFRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TFRecordDataset."""
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index 16b1441baa..32474bd411 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -24,6 +24,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.data.python.ops import resampling
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -57,7 +58,7 @@ def _time_resampling(
return end_time - start_time
-class ResampleTest(test.TestCase, parameterized.TestCase):
+class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index dde678bd54..bdf80eae4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
import numpy as np
from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -33,7 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ScanDatasetTest(test.TestCase):
+class ScanDatasetTest(test_base.DatasetTestBase):
def _counting_dataset(self, start, scan_fn):
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
index 14cd3e9c4a..a10f85263a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -90,6 +91,16 @@ class StatsDatasetSerializationTest(
lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
None, num_outputs)
+ def _build_dataset_stats_aggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ return dataset_ops.Dataset.range(10).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+
+ def test_set_stats_aggregator_not_support_checkpointing(self):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "does not support checkpointing"):
+ self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 440e48db30..c97002a255 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -20,13 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class ShuffleAndRepeatTest(test.TestCase):
+class ShuffleAndRepeatTest(test_base.DatasetTestBase):
def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 90d18dca2a..c5a7862322 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import sliding
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class SlideDatasetTest(test.TestCase, parameterized.TestCase):
+class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -197,11 +198,6 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
sliding.sliding_window_batch(
window_size=1, stride=1, window_shift=1, window_stride=1))
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSlideSparse(self):
def _sparse(i):
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
index 1f5c725a92..319a2ea263 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
@@ -24,12 +24,13 @@ import os
import sqlite3
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SqlDatasetTestBase(test.TestCase):
+class SqlDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing SqlDataset."""
def _createSqlDataset(self, output_types, num_repeats=1):
@@ -92,5 +93,3 @@ class SqlDatasetTestBase(test.TestCase):
9007199254740992.0)])
conn.commit()
conn.close()
-
-
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index b1b4c23510..80f2625927 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -19,10 +19,10 @@ from __future__ import print_function
from tensorflow.core.framework import summary_pb2
-from tensorflow.python.platform import test
+from tensorflow.python.data.kernel_tests import test_base
-class StatsDatasetTestBase(test.TestCase):
+class StatsDatasetTestBase(test_base.DatasetTestBase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
def _assertSummaryContains(self, summary_str, tag):
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
deleted file mode 100644
index 4c3353fe40..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test utilities for tf.data functionality."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class DatasetTestBase(test.TestCase):
- """Base class for dataset tests."""
-
- def _assert_datasets_equal(self, dataset1, dataset2):
- # TODO(rachelim): support sparse tensor outputs
- next1 = dataset1.make_one_shot_iterator().get_next()
- next2 = dataset2.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- while True:
- try:
- op1 = sess.run(next1)
- except errors.OutOfRangeError:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next2)
- break
- op2 = sess.run(next2)
-
- op1 = nest.flatten(op1)
- op2 = nest.flatten(op2)
- assert len(op1) == len(op2)
- for i in range(len(op1)):
- self.assertAllEqual(op1[i], op2[i])
-
- def _assert_datasets_raise_same_error(self,
- dataset1,
- dataset2,
- exception_class,
- replacements=None):
- # We are defining next1 and next2 in the same line so that we get identical
- # file:line_number in the error messages
- # pylint: disable=line-too-long
- next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
- # pylint: enable=line-too-long
- with self.cached_session() as sess:
- try:
- sess.run(next1)
- raise ValueError(
- "Expected dataset to raise an error of type %s, but it did not." %
- repr(exception_class))
- except exception_class as e:
- expected_message = e.message
- for old, new, count in replacements:
- expected_message = expected_message.replace(old, new, count)
- # Check that the first segment of the error messages are the same.
- with self.assertRaisesRegexp(exception_class,
- re.escape(expected_message)):
- sess.run(next2)
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 8d335e87d5..08de3a9143 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import threadpool
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,8 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
+class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index f994c8563f..8856ce5afb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class UniqueDatasetTest(test.TestCase):
+class UniqueDatasetTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 8b7b3ac0f7..79134c7bc6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _structuredDataset(self, structure, shape, dtype):
if structure is None:
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index 867ee2ba37..fca546a570 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.data.python.ops import writers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class TFRecordWriterTest(test.TestCase):
+class TFRecordWriterTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordWriterTest, self).setUp()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index a14781cd93..5cd1ed542b 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -78,7 +78,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":batching",
- ":gen_dataset_ops",
":interleave_ops",
":optimization",
":parsing_ops",
@@ -86,6 +85,7 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
@@ -148,8 +148,7 @@ py_library(
srcs = ["error_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -179,12 +178,11 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
":random_ops",
"//tensorflow/contrib/stateless",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
@@ -199,9 +197,8 @@ py_library(
srcs = ["optimization.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -304,8 +301,7 @@ py_library(
srcs = ["threadpool.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
@@ -321,9 +317,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -342,47 +337,11 @@ py_library(
],
)
-tf_gen_op_wrapper_py(
- name = "gen_dataset_ops",
- out = "gen_dataset_ops.py",
- deps = [
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- ],
-)
-
-tf_kernel_library(
- name = "dataset_ops_kernels",
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/core:framework",
- ],
- alwayslink = 1,
-)
-
-tf_custom_op_py_library(
- name = "contrib_op_loader",
- srcs = ["contrib_op_loader.py"],
- dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
- kernels = [
- ":dataset_ops_kernels",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":gen_dataset_ops",
- "//tensorflow/contrib/util:util_py",
- "//tensorflow/python:platform",
- ],
-)
-
py_library(
name = "indexed_dataset_ops",
srcs = ["indexed_dataset_ops.py"],
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
@@ -394,7 +353,7 @@ py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
- ":contrib_op_loader",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 367c159dc5..7a0f221284 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -345,12 +345,12 @@ def _padded_batch_sparse_window(dataset, padded_shape):
dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
-class _UnbatchDataset(dataset_ops.Dataset):
+class _UnbatchDataset(dataset_ops.UnaryDataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
- super(_UnbatchDataset, self).__init__()
+ super(_UnbatchDataset, self).__init__(input_dataset)
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
@@ -514,12 +514,12 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class _DenseToSparseBatchDataset(dataset_ops.Dataset):
+class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
def __init__(self, input_dataset, batch_size, row_shape):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
- super(_DenseToSparseBatchDataset, self).__init__()
+ super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." %
@@ -548,7 +548,7 @@ class _DenseToSparseBatchDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _RestructuredDataset(dataset_ops.Dataset):
+class _RestructuredDataset(dataset_ops.UnaryDataset):
"""An internal helper for changing the structure and shape of a dataset."""
def __init__(self,
@@ -583,7 +583,7 @@ class _RestructuredDataset(dataset_ops.Dataset):
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
- super(_RestructuredDataset, self).__init__()
+ super(_RestructuredDataset, self).__init__(dataset)
self._input_dataset = dataset
if not allow_unsafe_cast:
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index b4a7521e08..f962e623ee 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,9 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
def ignore_errors():
@@ -51,16 +50,16 @@ def ignore_errors():
return _apply_fn
-class _IgnoreErrorsDataset(dataset_ops.Dataset):
+class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that silently ignores errors when computing its input."""
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
- super(_IgnoreErrorsDataset, self).__init__()
+ super(_IgnoreErrorsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
- return gen_dataset_ops.ignore_errors_dataset(
+ return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 020167e4d1..7cae33beb3 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -282,12 +282,12 @@ def window_dataset(window_size):
return _apply_fn
-class _GroupByReducerDataset(dataset_ops.Dataset):
+class _GroupByReducerDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that groups its input and performs a reduction."""
def __init__(self, input_dataset, key_func, reducer):
"""See `group_by_reducer()` for details."""
- super(_GroupByReducerDataset, self).__init__()
+ super(_GroupByReducerDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
@@ -416,12 +416,12 @@ class _GroupByReducerDataset(dataset_ops.Dataset):
**dataset_ops.flat_structure(self))
-class _GroupByWindowDataset(dataset_ops.Dataset):
+class _GroupByWindowDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
- super(_GroupByWindowDataset, self).__init__()
+ super(_GroupByWindowDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
@@ -525,12 +525,12 @@ class Reducer(object):
return self._finalize_func
-class _MapXDataset(dataset_ops.Dataset):
+class _MapXDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func):
"""See `map_x_dataset()` for details."""
- super(_MapXDataset, self).__init__()
+ super(_MapXDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = dataset_ops.StructuredFunctionWrapper(
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
index a0932b4081..9c06474a2f 100644
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
@@ -19,14 +19,13 @@ from __future__ import print_function
import abc
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
class MaterializedIndexedDataset(object):
@@ -57,7 +56,7 @@ class MaterializedIndexedDataset(object):
A tensor containing the values corresponding to `index`.
"""
# TODO(saeta): nest.pack_sequence_as(...)
- return gen_dataset_ops.indexed_dataset_get(
+ return ged_ops.experimental_indexed_dataset_get(
self._materialized_resource,
index,
output_types=nest.flatten(
@@ -90,16 +89,18 @@ class IndexedDataset(dataset_ops.Dataset):
container = ""
if shared_name is None:
shared_name = ""
- materialized_resource = gen_dataset_ops.materialized_index_dataset_handle(
- container=container,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self.output_shapes, self.output_classes)))
+ materialized_resource = (
+ ged_ops.experimental_materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes,
+ self.output_classes))))
with ops.colocate_with(materialized_resource):
- materializer = gen_dataset_ops.indexed_dataset_materialize(
+ materializer = ged_ops.experimental_indexed_dataset_materialize(
self._as_variant_tensor(), materialized_resource)
return MaterializedIndexedDataset(materialized_resource, materializer,
self.output_classes, self.output_types,
@@ -170,4 +171,7 @@ class IdentityIndexedDataset(IndexedDataset):
return tensor_shape.scalar()
def _as_variant_tensor(self):
- return gen_dataset_ops.identity_indexed_dataset(self._size)
+ return ged_ops.experimental_identity_indexed_dataset(self._size)
+
+ def _inputs(self):
+ return []
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 92d4251a86..1ee9db1aa8 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -18,8 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import stateless
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import random_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@@ -28,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
@@ -167,12 +166,17 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
# pylint: disable=protected-access
- return gen_dataset_ops.directed_interleave_dataset(
- self._selector_input._as_variant_tensor(),
- [data_input._as_variant_tensor() for data_input in self._data_inputs],
- **dataset_ops.flat_structure(self))
+ return (
+ gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(), [
+ data_input._as_variant_tensor()
+ for data_input in self._data_inputs
+ ], **dataset_ops.flat_structure(self)))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._selector_input] + self._data_inputs
+
@property
def output_classes(self):
return self._data_inputs[0].output_classes
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 73840452df..30348ede36 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -17,12 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
# A constant that can be used to enable auto-tuning.
AUTOTUNE = -1
@@ -54,7 +53,7 @@ def model():
Returns:
A `Dataset` transformation function, which can be passed to
- @{tf.data.Dataset.apply}.
+ `tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
@@ -84,12 +83,12 @@ def optimize(optimizations=None):
return _apply_fn
-class _AssertNextDataset(dataset_ops.Dataset):
+class _AssertNextDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that asserts which transformations happen next."""
def __init__(self, input_dataset, transformations):
"""See `assert_next()` for details."""
- super(_AssertNextDataset, self).__init__()
+ super(_AssertNextDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if transformations is None:
raise ValueError("At least one transformation should be specified")
@@ -97,7 +96,7 @@ class _AssertNextDataset(dataset_ops.Dataset):
transformations, dtype=dtypes.string, name="transformations")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.assert_next_dataset(
+ return gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._transformations,
**dataset_ops.flat_structure(self))
@@ -115,12 +114,12 @@ class _AssertNextDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _ModelDataset(dataset_ops.Dataset):
+class _ModelDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and models performance."""
def __init__(self, input_dataset):
"""See `optimize()` for details."""
- super(_ModelDataset, self).__init__()
+ super(_ModelDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
@@ -141,12 +140,12 @@ class _ModelDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _OptimizeDataset(dataset_ops.Dataset):
+class _OptimizeDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
def __init__(self, input_dataset, optimizations):
"""See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__()
+ super(_OptimizeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index 2701605e64..cfbba701b0 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -26,11 +26,11 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
-class _ParseExampleDataset(dataset_ops.Dataset):
+class _ParseExampleDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
def __init__(self, input_dataset, features, num_parallel_calls):
- super(_ParseExampleDataset, self).__init__()
+ super(_ParseExampleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if not all(types == dtypes.string
for types in nest.flatten(input_dataset.output_types)):
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 5222011d04..46f82e453a 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -19,8 +19,6 @@ from __future__ import print_function
import warnings
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
@@ -31,9 +29,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import resource_variable_ops
@@ -65,7 +63,7 @@ def function_buffering_resource(string_arg,
"""
if shared_name is None:
shared_name = ""
- return gen_dataset_ops.function_buffering_resource(
+ return ged_ops.experimental_function_buffering_resource(
string_arg=string_arg,
target_device=target_device,
shared_name=shared_name,
@@ -79,14 +77,14 @@ def function_buffering_resource(string_arg,
def function_buffering_resource_get_next(function_buffer_resource,
output_types,
name=None):
- return gen_dataset_ops.function_buffering_resource_get_next(
+ return ged_ops.experimental_function_buffering_resource_get_next(
function_buffer_resource=function_buffer_resource,
output_types=output_types,
name=name)
def function_buffering_resource_reset(function_buffer_resource, name=None):
- return gen_dataset_ops.function_buffering_resource_reset(
+ return ged_ops.experimental_function_buffering_resource_reset(
function_buffer_resource=function_buffer_resource, name=name)
@@ -137,7 +135,7 @@ class _PrefetchToDeviceIterator(object):
ret = remote_iterator.get_next()
return nest.flatten(sparse.serialize_sparse_tensors(ret))
- iterator_device = gen_dataset_ops.iterator_get_device(
+ iterator_device = ged_ops.experimental_iterator_get_device(
self._input_iterator._iterator_resource)
with ops.device(device):
@@ -163,10 +161,11 @@ class _PrefetchToDeviceIterator(object):
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
self._buffering_resource,
- output_types=nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
ret = sparse.deserialize_sparse_tensors(
nest.pack_sequence_as(self.output_types, flat_ret),
@@ -220,7 +219,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
buffer_size):
with ops.device("/device:CPU:0"):
super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
- input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle(
+ input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
self._resource)
self._device = device
@@ -239,7 +238,8 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
self._buffering_resource = function_buffering_resource(
f=_prefetch_fn,
output_types=self._flat_output_types,
- target_device=gen_dataset_ops.iterator_get_device(self._resource),
+ target_device=ged_ops.experimental_iterator_get_device(
+ self._resource),
string_arg=input_iterator_handle,
buffer_size=buffer_size,
shared_name=iterator_ops._generate_shared_name(
@@ -253,7 +253,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
# TODO(b/77291417): Fix
with context.execution_mode(context.SYNC):
with ops.device(self._device):
- ret = gen_dataset_ops.function_buffering_resource_get_next(
+ ret = ged_ops.experimental_function_buffering_resource_get_next(
function_buffer_resource=self._buffering_resource,
output_types=self._flat_output_types)
return sparse.deserialize_sparse_tensors(
@@ -262,10 +262,11 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
# pylint: enable=protected-access
-class _PrefetchToDeviceDataset(dataset_ops.Dataset):
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` whose iterator prefetches elements to another device."""
def __init__(self, input_dataset, device, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._device = device
self._buffer_size = buffer_size if buffer_size is not None else 1
@@ -374,7 +375,7 @@ def copy_to_device(target_device, source_device="/cpu:0"):
# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
# all inputs to the Op are in host memory, thereby avoiding some unnecessary
# Sends and Recvs.
-class _CopyToDeviceDataset(dataset_ops.Dataset):
+class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that copies elements to another device."""
def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
@@ -385,6 +386,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
target_device: The name of the device to which elements would be copied.
source_device: Device where input_dataset would be placed.
"""
+ super(_CopyToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._target_device = target_device
spec = framework_device.DeviceSpec().from_string(self._target_device)
@@ -408,12 +410,12 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
"""
# pylint: disable=protected-access
ds_variant = self._input_dataset._as_variant_tensor()
- resource = core_gen_dataset_ops.anonymous_iterator(
+ resource = gen_dataset_ops.anonymous_iterator(
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
with ops.control_dependencies(
- [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
- return core_gen_dataset_ops.iterator_to_string_handle(resource)
+ [gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return gen_dataset_ops.iterator_to_string_handle(resource)
@function.Defun()
def _remote_init_func():
@@ -462,7 +464,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
Returns:
Tensor constant 0
"""
- iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
string_handle,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
@@ -503,7 +505,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
with ops.device(self._target_device):
- return core_gen_dataset_ops.generator_dataset(
+ return gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
@@ -524,187 +526,3 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
@property
def output_classes(self):
return self._input_dataset.output_classes
-
-
-class _PerDeviceGenerator(dataset_ops.Dataset):
- """A `dummy` generator dataset."""
-
- def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
- source_device, target_device, output_shapes, output_types,
- output_classes):
- self._target_device = target_device
- self._output_types = output_types
- self._output_shapes = output_shapes
- self._output_classes = output_classes
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._output_shapes, self._output_classes))
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._output_types, self._output_classes))
-
- multi_device_iterator_string_handle = (
- gen_dataset_ops.multi_device_iterator_to_string_handle(
- multi_device_iterator_resource))
-
- @function.Defun()
- def _init_func():
- return multi_device_iterator_string_handle
-
- @function.Defun()
- def _remote_init_func():
- return functional_ops.remote_call(
- target=source_device,
- args=_init_func.captured_inputs,
- Tout=[dtypes.string],
- f=_init_func)
-
- self._init_func = _remote_init_func
- self._init_captured_args = _remote_init_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _next_func(string_handle):
- multi_device_iterator = (
- gen_dataset_ops.multi_device_iterator_from_string_handle(
- string_handle=string_handle,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes))
- return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
- multi_device_iterator=multi_device_iterator,
- shard_num=shard_num,
- incarnation_id=incarnation_id,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- @function.Defun(dtypes.string)
- def _remote_next_func(string_handle):
- return functional_ops.remote_call(
- target=source_device,
- args=[string_handle] + _next_func.captured_inputs,
- Tout=self._flat_output_types,
- f=_next_func)
-
- self._next_func = _remote_next_func
- self._next_captured_args = _remote_next_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _finalize_func(unused_string_handle):
- return array_ops.constant(0, dtypes.int64)
-
- @function.Defun(dtypes.string)
- def _remote_finalize_func(string_handle):
- return functional_ops.remote_call(
- target=source_device,
- args=[string_handle] + _finalize_func.captured_inputs,
- Tout=[dtypes.int64],
- f=_finalize_func)
-
- self._finalize_func = _remote_finalize_func
- self._finalize_captured_args = _remote_finalize_func.captured_inputs
-
- def _as_variant_tensor(self):
- with ops.device(self._target_device):
- return core_gen_dataset_ops.generator_dataset(
- self._init_captured_args,
- self._next_captured_args,
- self._finalize_captured_args,
- init_func=self._init_func,
- next_func=self._next_func,
- finalize_func=self._finalize_func,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_classes(self):
- return self._output_classes
-
-
-class MultiDeviceIterator(object):
- """An iterator over multiple devices."""
-
- def __init__(self,
- dataset,
- devices,
- max_buffer_size=1,
- prefetch_buffer_size=1,
- source_device="/cpu:0"):
- """Constructs a MultiDeviceIterator.
-
- Args:
- dataset: The input dataset to be iterated over.
- devices: The list of devices to fetch data to.
- max_buffer_size: Maximum size of the host side per device buffer to keep.
- prefetch_buffer_size: if > 1, then we setup a buffer on each device
- to prefetch into.
- source_device: The host device to place the `dataset` on.
- """
- self._dataset = dataset
- self._devices = devices
- self._source_device = source_device
- self._source_device_tensor = ops.convert_to_tensor(source_device)
-
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._dataset.output_shapes,
- self._dataset.output_classes))
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._dataset.output_types,
- self._dataset.output_classes))
-
- # Create the MultiDeviceIterator.
- with ops.device(self._source_device):
- self._multi_device_iterator_resource = (
- gen_dataset_ops.multi_device_iterator(
- devices=self._devices,
- shared_name="",
- container="",
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes))
-
- # The incarnation ID is used to ensure consistency between the per-device
- # iterators and the multi-device iterator.
- self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
- self._dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._multi_device_iterator_resource,
- max_buffer_size=max_buffer_size)
-
- # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
- # initialize the device side of the pipeline. This would allow the
- # MultiDeviceIterator to choose, for example, to move some transformations
- # into the device side from its input. It might be useful in rewriting.
- # Create the per device iterators.
- self._device_iterators = []
- i = 0
- for device in self._devices:
- ds = _PerDeviceGenerator(
- i, self._multi_device_iterator_resource, self._incarnation_id,
- self._source_device_tensor, device, self._dataset.output_shapes,
- self._dataset.output_types, self._dataset.output_classes)
- if prefetch_buffer_size > 0:
- ds = ds.prefetch(prefetch_buffer_size)
- with ops.device(device):
- self._device_iterators.append(ds.make_initializable_iterator())
- i += 1
-
- device_iterator_initializers = [
- iterator.initializer for iterator in self._device_iterators
- ]
- self._initializer = control_flow_ops.group(*device_iterator_initializers)
-
- def get_next(self):
- result = []
- i = 0
- for device in self._devices:
- with ops.device(device):
- result.append(self._device_iterators[i].get_next())
- i += 1
- return result
-
- @property
- def initializer(self):
- return self._initializer
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index e670c4c835..344a0763c8 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
-class RandomDataset(dataset_ops.Dataset):
+class RandomDataset(dataset_ops.DatasetSource):
"""A `Dataset` of pseudorandom values."""
def __init__(self, seed=None):
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 785b395707..360971e200 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -23,7 +23,6 @@ import csv
import numpy as np
from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import parsing_ops
@@ -38,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
@@ -508,7 +508,7 @@ def make_csv_dataset(
_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
-class CsvDataset(dataset_ops.Dataset):
+class CsvDataset(dataset_ops.DatasetSource):
"""A Dataset comprising lines from one or more CSV files."""
def __init__(self,
@@ -629,7 +629,7 @@ class CsvDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
# Constructs graph node for the dataset op.
- return contrib_gen_dataset_ops.csv_dataset(
+ return gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
@@ -924,7 +924,7 @@ def _get_file_names(file_pattern, shuffle):
return file_names
-class SqlDataset(dataset_ops.Dataset):
+class SqlDataset(dataset_ops.DatasetSource):
"""A `Dataset` consisting of the results from a SQL query."""
def __init__(self, driver_name, data_source_name, query, output_types):
@@ -985,7 +985,7 @@ class SqlDataset(dataset_ops.Dataset):
return self._output_types
-class LMDBDataset(dataset_ops.Dataset):
+class LMDBDataset(dataset_ops.DatasetSource):
"""A LMDB Dataset that reads the lmdb file."""
def __init__(self, filenames):
@@ -1013,7 +1013,7 @@ class LMDBDataset(dataset_ops.Dataset):
filenames, dtype=dtypes.string, name="filenames")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.lmdb_dataset(
+ return gen_experimental_dataset_ops.experimental_lmdb_dataset(
self._filenames,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 6b002b4a53..c52582cd35 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -27,12 +27,12 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
-class _ScanDataset(dataset_ops.Dataset):
+class _ScanDataset(dataset_ops.UnaryDataset):
"""A dataset that scans a function across its input."""
def __init__(self, input_dataset, initial_state, scan_func):
"""See `scan()` for details."""
- super(_ScanDataset, self).__init__()
+ super(_ScanDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
with ops.name_scope("initial_state"):
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 4356721704..985d1d87d0 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -25,16 +25,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-class _ShuffleAndRepeatDataset(dataset_ops.Dataset):
+class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that fuses `shuffle` and `repeat`."""
- def __init__(self,
- input_dataset,
- buffer_size,
- count=None,
- seed=None):
- """See `Dataset.map()` for details."""
- super(_ShuffleAndRepeatDataset, self).__init__()
+ def __init__(self, input_dataset, buffer_size, count=None, seed=None):
+ super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py
index b0d6a16c20..bcc383587c 100644
--- a/tensorflow/contrib/data/python/ops/sliding.py
+++ b/tensorflow/contrib/data/python/ops/sliding.py
@@ -26,12 +26,12 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
-class _SlideDataset(dataset_ops.Dataset):
+class _SlideDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that passes a sliding window over its input."""
def __init__(self, input_dataset, window_size, window_shift, window_stride):
"""See `sliding_window_batch` for details."""
- super(_SlideDataset, self).__init__()
+ super(_SlideDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._window_size = ops.convert_to_tensor(
window_size, dtype=dtypes.int64, name="window_stride")
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 7410ee8e05..bc47c5989d 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -84,11 +84,11 @@ class StatsAggregator(object):
return gen_dataset_ops.stats_aggregator_summary(self._resource)
-class _SetStatsAggregatorDataset(dataset_ops.Dataset):
+class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
def __init__(self, input_dataset, stats_aggregator):
- super(_SetStatsAggregatorDataset, self).__init__()
+ super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = stats_aggregator
@@ -173,11 +173,11 @@ def latency_stats(tag):
return _apply_fn
-class _StatsDataset(dataset_ops.Dataset):
+class _StatsDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
def __init__(self, input_dataset, op_function, tag):
- super(_StatsDataset, self).__init__()
+ super(_StatsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._op_function = op_function
self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index dc67accdcf..f73c3fd9cb 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -19,10 +19,9 @@ from __future__ import print_function
import threading
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import resource_variable_ops
_uid_counter = 0
@@ -47,7 +46,7 @@ class PrivateThreadPool(object):
"""Creates a `PrivateThreadPool` with the given number of threads."""
if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool")
- self._resource = gen_dataset_ops.thread_pool_handle(
+ self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
@@ -55,22 +54,22 @@ class PrivateThreadPool(object):
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device=context.context().device_name)
else:
- self._resource = gen_dataset_ops.thread_pool_handle(
+ self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name)
-class _ThreadPoolDataset(dataset_ops.Dataset):
+class _ThreadPoolDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets a custom threadpool."""
def __init__(self, input_dataset, thread_pool):
- super(_ThreadPoolDataset, self).__init__()
+ super(_ThreadPoolDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._thread_pool = thread_pool
def _as_variant_tensor(self):
- return gen_dataset_ops.thread_pool_dataset(
+ return ged_ops.experimental_thread_pool_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index e0d606311c..ed363a7090 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -17,10 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_experimental_dataset_ops
def unique():
@@ -47,12 +46,12 @@ def unique():
return _apply_fn
-class _UniqueDataset(dataset_ops.Dataset):
+class _UniqueDataset(dataset_ops.UnaryDataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
- super(_UniqueDataset, self).__init__()
+ super(_UniqueDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):
@@ -61,7 +60,7 @@ class _UniqueDataset(dataset_ops.Dataset):
"`tf.int32`, `tf.int64`, or `tf.string` component.")
def _as_variant_tensor(self):
- return gen_dataset_ops.unique_dataset(
+ return gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD
index 3b50a48336..06940a90d5 100644
--- a/tensorflow/contrib/decision_trees/proto/BUILD
+++ b/tensorflow/contrib/decision_trees/proto/BUILD
@@ -17,7 +17,6 @@ tf_proto_library(
name = "generic_tree_model",
srcs = ["generic_tree_model.proto"],
cc_api_version = 2,
- java_api_version = 2,
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 91a27f97b7..2e025765e4 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -231,7 +231,8 @@ The same `input_fn` will be used for all workers if you use
important to shuffle your dataset in your `input_fn`.
`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
-`input_fn`. As a result, each worker gets a fraction of your input data.
+`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker
+gets a fraction of your input data.
### Performance Tips
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 48a7593ab4..cfb9d42a6f 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -28,6 +28,7 @@ py_library(
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
@@ -453,7 +454,7 @@ cuda_py_test(
cuda_py_test(
name = "estimator_training_test",
- size = "large",
+ size = "enormous",
srcs = ["estimator_training_test.py"],
additional_deps = [
":combinations",
@@ -651,8 +652,8 @@ py_library(
name = "prefetching_ops_v2",
srcs = ["prefetching_ops_v2.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/data/python/ops:prefetching_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index c900b41e14..9809204f8f 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -216,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Configures the object.
Args:
- session_config: a @{tf.ConfigProto}
+ session_config: a `tf.ConfigProto`
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
cluster configurations.
task_type: the current task type, such as "worker".
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 244d1fcec8..82ca041cc2 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -59,6 +59,7 @@ from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import rmsprop
from tensorflow.python.util import tf_inspect
@@ -354,6 +355,8 @@ gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
adagrad_optimizer_v1_fn = NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+rmsprop_optimizer_v1_fn = NamedObject(
+ "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
adagrad_optimizer_v1_fn]
diff --git a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
index 44a69ed23a..79a9803d75 100644
--- a/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
+++ b/tensorflow/contrib/distribute/python/examples/simple_estimator_example.py
@@ -22,6 +22,8 @@ from __future__ import print_function
import tensorflow as tf
+from tensorflow.python.keras import metrics as metrics_module
+
def build_model_fn_optimizer():
"""Simple model_fn with optimizer."""
@@ -45,7 +47,10 @@ def build_model_fn_optimizer():
return y * y
if mode == tf.estimator.ModeKeys.EVAL:
- return tf.estimator.EstimatorSpec(mode, loss=loss_fn())
+ acc_obj = metrics_module.BinaryAccuracy()
+ acc_obj.update_state(labels, labels)
+ return tf.estimator.EstimatorSpec(
+ mode, loss=loss_fn(), eval_metric_ops={"Accuracy": acc_obj})
assert mode == tf.estimator.ModeKeys.TRAIN
@@ -61,18 +66,26 @@ def main(_):
["/device:GPU:0", "/device:GPU:1"])
config = tf.estimator.RunConfig(train_distribute=distribution,
eval_distribute=distribution)
+ # Since there are 2 devices and 10 samples, we set steps=5.
+ steps = 5
- def input_fn():
+ def train_input_fn():
features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
labels = tf.data.Dataset.from_tensors([1.]).repeat(10)
return tf.data.Dataset.zip((features, labels))
estimator = tf.estimator.Estimator(
model_fn=build_model_fn_optimizer(), config=config)
- estimator.train(input_fn=input_fn, steps=10)
+ estimator.train(input_fn=train_input_fn, steps=steps)
+
+ def eval_input_fn():
+ features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
+ labels = tf.data.Dataset.from_tensors([1.]).repeat(10)
+ return tf.data.Dataset.zip((features, labels))
- eval_result = estimator.evaluate(input_fn=input_fn, steps=10)
+ eval_result = estimator.evaluate(input_fn=eval_input_fn, steps=steps)
print("Eval result: {}".format(eval_result))
+ assert eval_result["Accuracy"] == 1.0
def predict_input_fn():
predict_features = tf.data.Dataset.from_tensors([[1.]]).repeat(10)
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 2e6cd43fd4..3aab2c521f 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -173,13 +173,42 @@ def batch_wrapper(dataset, batch_size, distribution):
return dataset.batch(batch_size)
-def all_combinations():
+def get_model():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+ return model
+
+
+def get_dataset(distribution):
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = batch_wrapper(dataset, 10, distribution)
+ return dataset
+
+
+strategies = [combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus,
+ combinations.tpu_strategy_one_step]
+
+
+def strategy_combinations():
return combinations.combine(
- distribution=[combinations.default_strategy,
- combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus,
- combinations.tpu_strategy_one_step],
+ distribution=strategies,
+ mode=['graph'])
+
+
+def strategy_and_optimizer_combinations():
+ return combinations.combine(
+ distribution=strategies,
+ optimizer=[combinations.adagrad_optimizer_v1_fn,
+ combinations.adam_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.rmsprop_optimizer_v1_fn],
mode=['graph'])
@@ -205,6 +234,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
keras_model = simple_functional_model()
keras_model.compile(
loss='categorical_crossentropy',
+ metrics=[keras.metrics.CategoricalAccuracy()],
optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
@@ -229,6 +259,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
keras_model = simple_sequential_model()
keras_model.compile(
loss='categorical_crossentropy',
+ metrics=[keras.metrics.CategoricalAccuracy()],
optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01))
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
@@ -358,13 +389,11 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_calling_model_with_numpy_arrays(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae']
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
@@ -390,23 +419,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
# with batch_size
model.predict(inputs, batch_size=8)
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae']
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = batch_wrapper(dataset, 10, distribution)
+ dataset = get_dataset(distribution)
# Call fit with validation data
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -432,7 +455,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001)
loss = 'mse'
- metrics = ['mae']
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
'/device:CPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
@@ -459,23 +482,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae']
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = batch_wrapper(dataset, 10, distribution)
+ dataset = get_dataset(distribution)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
@@ -484,37 +501,23 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
- def test_raise_error_for_stateful_metrics(self):
-
- class ExampleStatefulMetric(keras.layers.Layer):
-
- def __init__(self, name='true_positives', **kwargs):
- super(ExampleStatefulMetric, self).__init__(name=name, **kwargs)
- self.stateful = True
-
- def __call__(self, y_true, y_pred):
- return y_pred - y_true
-
+ @combinations.generate(strategy_and_optimizer_combinations())
+ def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
- optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae', ExampleStatefulMetric()]
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
- '/device:GPU:0'])
- with self.assertRaisesRegexp(
- NotImplementedError, 'Stateful metrics are not supported with '
- 'DistributionStrategy.'):
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ model.compile(optimizer(), loss, distribute=distribution)
+
+ dataset = get_dataset(distribution)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
def test_unsupported_features(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -524,11 +527,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = get_dataset(strategy)
# Test with validation split
with self.assertRaisesRegexp(
@@ -565,9 +564,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_calling_with_unsupported_predefined_callbacks(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -576,11 +573,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = get_dataset(strategy)
def schedule(_):
return 0.001
@@ -604,9 +597,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_dataset_input_shape_validation(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
@@ -640,17 +631,13 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
mode=['graph']))
def test_dataset_input_shape_fully_defined(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
model.compile(optimizer, loss, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = get_dataset(distribution)
# Input shapes are not fully known. Batch dimension is unknown as we are
# not using the drop_remainder argument.
dataset = dataset.repeat(100).batch(10)
@@ -722,7 +709,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
class NormalizationLayerWithDistributionStrategyTest(
test.TestCase, parameterized.TestCase):
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_batchnorm_correctness(self, distribution):
with self.cached_session():
model = keras.models.Sequential()
@@ -750,7 +737,37 @@ class NormalizationLayerWithDistributionStrategyTest(
class CorrectnessWithDistributionStrategyTest(test.TestCase,
parameterized.TestCase):
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
+ def test_metric_correctness(self, distribution):
+ with self.cached_session():
+ keras.backend.set_image_data_format('channels_last')
+ num_samples = 10000
+
+ x_train = np.random.randint(0, 2, num_samples)
+ x_train = np.reshape(x_train, (num_samples, 1))
+ y_train = x_train
+ x_train = x_train.astype('float32')
+ y_train = y_train.astype('float32')
+
+ # Create identity model.
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Dense(1, input_shape=(1,), kernel_initializer='ones'))
+ model.compile(
+ loss=keras.losses.mean_squared_error,
+ optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+ metrics=[keras.metrics.BinaryAccuracy()],
+ distribute=distribution)
+
+ batch_size = 64
+ batch_size //= distribution.num_towers
+ train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
+ train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
+
+ history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10)
+ self.assertEqual(history.history['binary_accuracy'], [1.0])
+
+ @combinations.generate(strategy_combinations())
def test_correctness(self, distribution):
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 0c6805d682..4d7516063c 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
data to devices.
+ auto_shard_dataset: whether to auto-shard the dataset when there are
+ multiple workers.
"""
def __init__(self,
@@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus=None,
num_gpus_per_worker=None,
cross_tower_ops=None,
- prefetch_on_device=None):
+ prefetch_on_device=None,
+ auto_shard_dataset=False):
super(MirroredStrategy, self).__init__()
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
+ self._auto_shard_dataset = auto_shard_dataset
# Rememeber num GPUs which might be needed by `configure` method.
if num_gpus is not None and num_gpus_per_worker is not None:
raise ValueError(
@@ -477,7 +481,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if self._cluster_spec:
return values.MultiWorkerDataset(
partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
+ self._prefetch_on_device, self._auto_shard_dataset)
else:
return values.PerDeviceDataset(
self._call_dataset_fn(dataset_fn), self._devices,
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
index 1ff60c0762..8d949943b7 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -19,8 +19,6 @@ from __future__ import print_function
import warnings
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
@@ -30,6 +28,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util import nest
@@ -42,10 +41,9 @@ class _PrefetchToDeviceIterator(object):
one_shot: If true, we make a one shot iterator that's already initialized.
devices: Devices on which to prefetch.
buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server). Only used if one_shot
- is False.
+ shared_name: (Optional.) If non-empty, the returned iterator will be shared
+ under the given name across multiple sessions that share the same devices
+ (e.g. when using a remote server). Only used if one_shot is False.
Returns:
An Iterator type object.
@@ -82,7 +80,7 @@ class _PrefetchToDeviceIterator(object):
ret = remote_iterator.get_next()
return nest.flatten(sparse.serialize_sparse_tensors(ret))
- target_device = gen_dataset_ops.iterator_get_device(
+ target_device = ged_ops.experimental_iterator_get_device(
self._input_iterator._iterator_resource)
self._buffering_resources = []
for device in nest.flatten(self._devices):
@@ -102,7 +100,8 @@ class _PrefetchToDeviceIterator(object):
reset_ops = []
for buffer_resource in self._buffering_resources:
reset_ops.append(
- prefetching_ops.function_buffering_resource_reset(buffer_resource))
+ ged_ops.experimental_function_buffering_resource_reset(
+ buffer_resource))
with ops.control_dependencies(reset_ops):
self._initializer = self._input_iterator.make_initializer(
self._input_dataset)
@@ -118,10 +117,11 @@ class _PrefetchToDeviceIterator(object):
# batches) is not divisible by number of devices.
# How do we handle that more gracefully / let the user know?
for buffer_resource in self._buffering_resources:
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
buffer_resource,
- output_types=data_nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
ret = sparse.deserialize_sparse_tensors(
data_nest.pack_sequence_as(self.output_types, flat_ret),
@@ -152,13 +152,16 @@ class _PrefetchToDeviceIterator(object):
@property
def output_types(self):
return self._input_dataset.output_types
+
+
# pylint: enable=protected-access
-class _PrefetchToDeviceDataset(dataset_ops.Dataset):
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` whose iterator prefetches elements to other device(s)."""
def __init__(self, input_dataset, devices, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._devices = devices
self._buffer_size = buffer_size if buffer_size is not None else 1
@@ -222,6 +225,7 @@ def prefetch_to_devices(devices, buffer_size=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
+
def _apply_fn(dataset):
return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index a6762e5e87..1b555482d3 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -37,9 +38,13 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
+_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE"
+
+
def get_tpu_system_metadata(tpu_cluster_resolver):
"""Retrieves TPU system metadata given a TPUClusterResolver."""
master = tpu_cluster_resolver.master()
@@ -56,6 +61,58 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
return tpu_system_metadata
+# TODO(jhseu): Deduplicate with MirroredStrategy?
+def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args,
+ **kwargs): # pylint: disable=g-missing-docstring
+ # Figure out what collections this variable should be added to.
+ # We'll add the TPUMirroredVariable to those collections instead.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # TODO(jhseu): Should we have different behavior for different
+ # synchronization settings?
+
+ # Get aggregation value
+ # TODO(jhseu): Support aggregation in a tower context.
+ aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
+ if aggregation not in [
+ vs.VariableAggregation.NONE,
+ vs.VariableAggregation.SUM,
+ vs.VariableAggregation.MEAN,
+ vs.VariableAggregation.ONLY_FIRST_TOWER,
+ ]:
+ raise ValueError("Invalid variable aggregation mode: {} for variable: {}"
+ .format(aggregation, kwargs["name"]))
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
+ index = real_mirrored_creator(devices, *args, **kwargs)
+ result = values.TPUMirroredVariable(index, index[devices[0]], aggregation)
+
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ l.remove(v)
+ g.add_to_collections(collections, result)
+ return result
+
+
+# TODO(jhseu): Stop inheriting from OneDeviceStrategy.
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
@@ -82,6 +139,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# TODO(sourabhbajaj): Change this from num_cores to metadata_override
self._num_cores_override = num_cores
+ # TODO(jhseu): Switch to DeviceAssignment to support pods and model
+ # parallelism.
+ device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices)
+ if "device:TPU:" in d.name}
+ self._device_index = values.PerDevice(device_map)
+ self._tpu_devices = sorted(device_map.keys())
+ # Only create variables for the number of towers we're running.
+ self._tpu_devices = self._tpu_devices[:self.num_towers]
+
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
@@ -239,6 +305,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return ctx
def _call_for_each_tower(self, fn, *args, **kwargs):
+ # TODO(jhseu): Consider making it so call_for_each_tower implies that we're
+ # in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
kwargs.pop('run_concurrently', None)
with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access
return fn(*args, **kwargs)
@@ -248,7 +316,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# TODO(priyag): Add appopriate call here when eager is supported for TPUs.
raise NotImplementedError('Eager mode not supported in TPUStrategy.')
else:
- return [tpu.initialize_system()]
+ # TODO(jhseu): We need this hack because DistributionStrategies must be
+ # pickleable for copy.deepcopy(). Remove when initialize_system goes away.
+ graph = ops.get_default_graph()
+ tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION)
+ if tpu_init:
+ return tpu_init
+ graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION,
+ tpu.initialize_system())
+ return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION)
def finalize(self):
if context.executing_eagerly():
@@ -257,21 +333,53 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
else:
return [tpu.shutdown_system()]
+ def _get_devices_from(self, colocate_with=None):
+ # TODO(jhseu): Change this when we support model parallelism.
+ return self._tpu_devices
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
+ colocate_with = kwargs.pop("colocate_with", None)
+ devices = self._get_devices_from(colocate_with)
+
+ def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
+ index = {}
+ for i, d in enumerate(devices):
+ with ops.device(d):
+ if i > 0:
+ # Give replicas meaningful distinct names:
+ var0name = index[devices[0]].name.split(":")[0]
+ # We append a / to variable names created on towers with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
+ # Initialize replicas with the same value:
+ if context.executing_eagerly():
+ kwargs["initial_value"] = array_ops.identity(
+ index[devices[0]].value())
+ else:
+ def initial_value_fn(device=d):
+ with ops.device(device):
+ return array_ops.identity(index[devices[0]].initial_value)
+ kwargs["initial_value"] = initial_value_fn
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ v = next_creator(*args, **kwargs)
+ assert not isinstance(v, values.TPUMirroredVariable)
+ index[d] = v
+ return index
+
+ return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args,
+ **kwargs)
+
def _reduce(self, aggregation, value, destinations):
- graph = ops.get_default_graph()
- cf_context = graph._get_control_flow_context() # pylint: disable=protected-access
- # If we're inside the ReplicateContext, reduction should be done using
- # CrossReplicaSum while outside we can directly use an add_n op.
- while cf_context:
- if isinstance(cf_context, tpu.TPUReplicateContext):
- if aggregation == vs.VariableAggregation.MEAN:
- # TODO(jhseu): Revisit once we support model-parallelism.
- value *= (1. / self.num_towers)
- elif aggregation != vs.VariableAggregation.SUM:
- raise NotImplementedError(
- 'Currently only support sum & mean in TPUStrategy.')
- return tpu_ops.cross_replica_sum(value)
- cf_context = cf_context.outer_context
+ if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
+ if aggregation == vs.VariableAggregation.MEAN:
+ # TODO(jhseu): Revisit once we support model-parallelism.
+ value *= (1. / self.num_towers)
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise NotImplementedError(
+ "Currently only support sum & mean in TPUStrategy.")
+ return tpu_ops.cross_replica_sum(value)
# Validate that the destination is same as the host device
# Note we don't do this when in replicate context as the reduction is
@@ -290,6 +398,35 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return output * (1. / len(value))
return output
+ def _update(self, var, fn, *args, **kwargs):
+ # TODO(jhseu): Consider supporting grouped==False.
+ assert isinstance(var, values.TPUMirroredVariable)
+ if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
+ return fn(var, *args, **kwargs)
+
+ # Otherwise, we revert to MirroredStrategy behavior and update each variable
+ # directly.
+ updates = {}
+ for d, v in var._index.items(): # pylint: disable=protected-access
+ name = "update_%d" % self._device_index.get(d)
+ with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ # If args and kwargs are not mirrored, the value is returned as is.
+ updates[d] = fn(v,
+ *values.select_device_mirrored(d, args),
+ **values.select_device_mirrored(d, kwargs))
+
+ # Make a single control dependency to keep the variables mirrored. If one
+ # assignment is fetched, then run all assignments.
+ sorted_keys = sorted(updates.keys())
+ update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
+ for i, d in enumerate(sorted_keys):
+ updates[d] = update_tuple[i]
+ return values.regroup(updates, values.Mirrored)
+
+ def read_var(self, var):
+ assert isinstance(var, values.TPUMirroredVariable)
+ return var.read_value()
+
def _unwrap(self, value):
if isinstance(value, list):
return value
@@ -323,6 +460,14 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def should_save_summary(self):
return True
+ @property
+ def worker_devices(self):
+ return self._tpu_devices
+
+ @property
+ def parameter_devices(self):
+ return self._tpu_devices
+
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index fafa6384a1..c18faeb67d 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -22,17 +22,20 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
@@ -453,6 +456,384 @@ ops.register_tensor_conversion_function(MirroredVariable,
_tensor_conversion_mirrored)
+def _enclosing_tpu_context():
+ # pylint: disable=protected-access
+ tpu_context = ops.get_default_graph()._get_control_flow_context()
+ # pylint: enable=protected-access
+ while tpu_context is not None and not isinstance(
+ tpu_context, control_flow_ops.XLAControlFlowContext):
+ tpu_context = tpu_context.outer_context
+ return tpu_context
+
+
+# TODO(jhseu): Deduplicate code. We copy code because we don't want to
+# inherit from DistributedDelegate. DistributedDelegate will not work in a
+# tpu.replicate() because it assumes that you're in a device context where you
+# can operate on a single version of the variable, but a tpu.replicate()
+# operates on all variables and is replicated during a rewrite pass.
+class TPUMirroredVariable(checkpointable.CheckpointableBase):
+ """Holds a map from device to TPU variables whose values are kept in sync."""
+
+ def __init__(self, index, primary_var, aggregation):
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
+ self._index = {device_util.canonicalize(key): value
+ for key, value in six.iteritems(index)}
+ self._primary_var = primary_var
+ self._common_name = self._primary_var.name.split(":")[0]
+ self._aggregation = aggregation
+ # Needed for GradientTape
+ self._trainable = self._primary_var.trainable
+
+ def _get(self, device=None):
+ """Returns the value for the current device or raises a ValueError."""
+ if device is None:
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ device = tower_context.device
+ else:
+ device = distribute_lib.get_update_device()
+ if device is None:
+ return self._get_cross_tower()
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device]
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self.read_value() + o
+ def __radd__(self, o): return o + self.read_value()
+ def __sub__(self, o): return self.read_value() - o
+ def __rsub__(self, o): return o - self.read_value()
+ def __mul__(self, o): return self.read_value() * o
+ def __rmul__(self, o): return o * self.read_value()
+ def __truediv__(self, o): return self.read_value() / o
+ def __rtruediv__(self, o): return o / self.read_value()
+ def __floordiv__(self, o): return self.read_value() // o
+ def __rfloordiv__(self, o): return o // self.read_value()
+ def __mod__(self, o): return self.read_value() % o
+ def __rmod__(self, o): return o % self.read_value()
+ def __lt__(self, o): return self.read_value() < o
+ def __le__(self, o): return self.read_value() <= o
+ def __gt__(self, o): return self.read_value() > o
+ def __ge__(self, o): return self.read_value() >= o
+ def __and__(self, o): return self.read_value() & o
+ def __rand__(self, o): return o & self.read_value()
+ def __or__(self, o): return self.read_value() | o
+ def __ror__(self, o): return o | self.read_value()
+ def __xor__(self, o): return self.read_value() ^ o
+ def __rxor__(self, o): return o ^ self.read_value()
+ def __getitem__(self, o): return self.read_value()[o]
+ def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo)
+ def __rpow__(self, o): return pow(o, self.read_value())
+ def __invert__(self): return ~self.read_value()
+ def __neg__(self): return -self.read_value()
+ def __abs__(self): return abs(self.read_value())
+
+ def __div__(self, o):
+ try:
+ return self.read_value().__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self.read_value().__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self.read_value().__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self.read_value().__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ @property
+ def handle(self):
+ # If we're in a tpu.rewrite(), return the replicated handle.
+ tpu_context = _enclosing_tpu_context()
+ if tpu_context is not None:
+ return tpu_context.get_replicated_var_handle(
+ self._common_name, nest.flatten(self._index))
+
+ device = distribute_lib.get_update_device()
+ if device is None:
+ return self._primary_var.handle
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device].handle
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
+
+ # The arguments to update() are automatically unwrapped so the update()
+ # function would normally see regular variables, not MirroredVariables.
+ # However, the update function can still operate on wrapped MirroredVariables
+ # through object members, captured arguments, etc. This is more likely in an
+ # update_non_slot() function (like OptimizerV2._finish), which can
+ # update several non-slot variables in one call.
+ def _assign_func(self, *args, **kwargs):
+ if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy":
+ raise ValueError("You may only assign to a TPUMirroredVariable within a "
+ "TPUStrategy.")
+ f = kwargs.pop("f")
+ if distribution_strategy_context.get_cross_tower_context():
+ if _enclosing_tpu_context() is not None:
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+
+ update_device = distribute_lib.get_update_device()
+ # We are calling update on the mirrored variable in cross tower context.
+ if update_device is not None:
+ # We are calling an assign function on the mirrored variable in cross
+ # tower context.
+ v = self._get(device=update_device)
+ return f(v, *args, **kwargs)
+
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ _assert_tower_context()
+ # We are calling an assign function on the mirrored variable in tower
+ # context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function on each of the mirrored variables with the reduced
+ # value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "TPUMirroredVariable in Tower Context.")
+
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
+
+ @contextlib.contextmanager
+ def _handle_graph(self, handle):
+ # Note: might have an eager tensor but not be executing eagerly when
+ # building functions.
+ if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor)
+ or ops.has_default_graph()):
+ yield
+ else:
+ with handle.graph.as_default():
+ yield
+
+ @property
+ def trainable(self):
+ return self._trainable
+
+ def _read_variable_op(self, parent_op=None):
+ if self.trainable:
+ tape.variable_accessed(self)
+ if parent_op is not None:
+ with ops.control_dependencies([parent_op]):
+ return gen_resource_variable_ops.read_variable_op(
+ self.handle, self.dtype)
+
+ return gen_resource_variable_ops.read_variable_op(
+ self.handle, self.dtype)
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def assign_sub(self, *args, **kwargs):
+ def assign_sub_fn(var, delta, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_sub_variable_op(
+ var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ def assign_add_fn(var, delta, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_add_variable_op(
+ var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ def assign_fn(var, value, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_variable_op(
+ var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
+
+ @property
+ def constraint(self):
+ return None
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group(
+ [v.initializer for v in nest.flatten(self._index)])
+
+ @property
+ def graph(self):
+ return self._primary_var.graph
+
+ @property
+ def _shared_name(self):
+ return self._common_name
+
+ @property
+ def _unique_id(self):
+ return self._primary_var._unique_id # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ return self._primary_var.name
+
+ @property
+ def dtype(self):
+ return self._primary_var.dtype
+
+ @property
+ def shape(self):
+ return self._primary_var.shape
+
+ def get_shape(self):
+ return self._primary_var.get_shape()
+
+ def to_proto(self, export_scope=None):
+ return self._primary_var.to_proto(export_scope=export_scope)
+
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return self._index[device]
+ return self._primary_var
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribution_strategy_context.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self._read_variable_op()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Overrides CheckpointableBase method.
+
+ This allows both name-based and object-based save and restore of
+ MirroredVariables.
+
+ Returns:
+ A dictionary mapping attribute names to `SaveableObject` factories.
+ """
+ def _saveable_factory(name=self._common_name):
+ return _MirroredSaveable(self, self._primary_var, name)
+ return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+ # Needed to pass ResourceVariable checks.
+ @property
+ def op(self):
+ return self._primary_var.op
+
+ @property
+ def _in_graph_mode(self):
+ return self._primary_var._in_graph_mode # pylint: disable=protected-access
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ """Converts a variable to a tensor."""
+ # pylint: disable=protected-access
+ if _enclosing_tpu_context() is None:
+ return self._get()._dense_var_to_tensor(dtype, name, as_ref)
+ # pylint: enable=protected-access
+ if dtype is not None and dtype != self.dtype:
+ raise NotImplementedError
+ if as_ref:
+ return self.handle
+ else:
+ return self.read_value()
+
+ def is_initialized(self, name=None):
+ """Identifies if all the component variables are initialized.
+
+ Args:
+ name: Name of the final `logical_and` op.
+
+ Returns:
+ The op that evaluates to True or False depending on if all the
+ component variables are initialized.
+ """
+ # TODO(jhseu): Do we need TPU context implementation?
+
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = nest.flatten(self._index)
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For distributed variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False):
+ return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
+
+
+ops.register_tensor_conversion_function(TPUMirroredVariable,
+ _tensor_conversion_tpu_mirrored)
+ops.register_dense_tensor_like_type(TPUMirroredVariable)
+
+
class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a TowerLocalVariable."""
@@ -726,14 +1107,14 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
dataset_iterator = self._dataset.make_one_shot_iterator()
- return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
dataset_iterator = self._dataset.make_initializable_iterator()
- return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -793,7 +1174,8 @@ class MultiWorkerDataset(object):
eager mode.
"""
- def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+ def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None,
+ auto_shard=False):
"""Initialize the MultiWorkerDataset object.
Args:
@@ -801,6 +1183,7 @@ class MultiWorkerDataset(object):
worker_device_map: a dict mapping from each worker to a list of devices
that belong to this worker.
prefetch_on_device: whether to prefetch to devices.
+ auto_shard: whether to auto-shard the dataset.
"""
self._worker_device_map = worker_device_map
self._datasets = {}
@@ -810,8 +1193,9 @@ class MultiWorkerDataset(object):
six.iteritems(worker_device_map)):
with ops.device(worker):
worker_input = dataset_fn()
- worker_input = input_ops.auto_shard_dataset(
- worker_input, len(worker_device_map), i)
+ if auto_shard:
+ worker_input = input_ops.auto_shard_dataset(
+ worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 15a85a28f5..ae3e134333 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -375,8 +375,9 @@ class PerDeviceDatasetTest(test.TestCase):
combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
- combined_actual.extend(self.evaluate([
- values.select_device(d, next_element) for d in devices]))
+ combined_actual.extend(
+ self.evaluate(
+ [values.select_device(d, next_element) for d in devices]))
combined_expected.extend(expected_value)
self.assertEqual(set(combined_expected), set(combined_actual))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
index 3c988dad8a..be7c756bea 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/moving_stats_test.py
@@ -38,8 +38,8 @@ class MovingReduceMeanVarianceTest(test.TestCase):
true_stddev = np.array([[1.1, 0.5]])
with self.cached_session() as sess:
# Start "x" out with this mean.
- mean_var = variables.Variable(array_ops.zeros_like(true_mean))
- variance_var = variables.Variable(array_ops.ones_like(true_stddev))
+ mean_var = variables.VariableV1(array_ops.zeros_like(true_mean))
+ variance_var = variables.VariableV1(array_ops.ones_like(true_stddev))
x = random_ops.random_normal(shape, dtype=np.float64, seed=0)
x = true_stddev * x + true_mean
ema, emv = moving_stats.assign_moving_mean_variance(
@@ -115,7 +115,7 @@ class MovingLogExponentialMovingMeanExpTest(test.TestCase):
# Start "x" out with this mean.
x = random_ops.random_normal(shape, dtype=np.float64, seed=0)
x = true_stddev * x + true_mean
- log_mean_exp_var = variables.Variable(array_ops.zeros_like(true_mean))
+ log_mean_exp_var = variables.VariableV1(array_ops.zeros_like(true_mean))
variables.global_variables_initializer().run()
log_mean_exp = moving_stats.assign_log_moving_mean_exp(
log_mean_exp_var, x, decay=decay)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD
deleted file mode 100644
index 42ecea034d..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD
+++ /dev/null
@@ -1,51 +0,0 @@
-# Description:
-# Internal testing utilities, e.g., computing the correct answer to
-# put in a unit test.
-
-licenses(["notice"]) # Apache 2.0
-
-py_library(
- name = "correlation_matrix_volumes_py",
- srcs = [
- "correlation_matrix_volumes_lib.py",
- ],
- deps = [
- "//tensorflow/contrib/distributions:distributions_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:math_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "correlation_matrix_volumes",
- srcs = [
- "correlation_matrix_volumes.py",
- ],
- deps = [
- ":correlation_matrix_volumes_py",
- ],
-)
-
-py_test(
- name = "correlation_matrix_volumes_test",
- size = "medium",
- srcs = ["correlation_matrix_volumes_test.py"],
- tags = [
- "no_pip",
- "optonly",
- ],
- deps = [
- ":correlation_matrix_volumes_py",
- # For statistical testing
- "//tensorflow/contrib/distributions:distributions_py",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- ],
-)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py
deleted file mode 100644
index 2eab51cd30..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Executable to estimate the volume of various sets of correlation matrices.
-
-See correlation_matrix_volumes_lib.py for purpose and methodology.
-
-Invocation example:
-```
-python correlation_matrix_volumes.py --num_samples 1e7
-```
-
-This will compute 10,000,000-sample confidence intervals for the
-volumes of several sets of correlation matrices. Which sets, and the
-desired statistical significance, are hard-coded in this source file.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import pprint
-
-from absl import app
-from absl import flags
-
-from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr
-
-FLAGS = flags.FLAGS
-
-# Float to support giving the number of samples in scientific notation.
-# The production run used for the LKJ test used 1e7 samples.
-flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.')
-
-
-def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42):
- # This wrapper undoes the batching in compute_true_volumes, because
- # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM.
- bounds = {}
- for db in det_bounds:
- bounds[db] = corr.compute_true_volumes(
- [db], dim, num_samples, error_rate=error_rate, seed=seed)[db]
- return bounds
-
-
-# The particular bounds in all three of these functions were chosen by
-# a somewhat arbitrary walk through an empirical tradeoff, for the
-# purpose of testing the LKJ distribution. Setting the determinant
-# bound lower
-# - Covers more of the testee's sample space, and
-# - Increases the probability that the rejection sampler will hit, thus
-# - Decreases the relative error (at a fixed sample count) in the
-# rejection-based volume estimate;
-# but also
-# - Increases the variance of the estimator used in the LKJ test.
-# This latter variance is also affected by the dimension and the
-# tested concentration parameter, and can be compensated for with more
-# compute (expensive) or a looser discrepancy limit (unsatisfying).
-# The values here are the projection of the points in that test design
-# space that ended up getting chosen.
-def compute_3x3_volumes(num_samples):
- det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45]
- return ctv_debatched(
- det_bounds, 3, num_samples, error_rate=5e-7, seed=46)
-
-
-def compute_4x4_volumes(num_samples):
- det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45]
- return ctv_debatched(
- det_bounds, 4, num_samples, error_rate=5e-7, seed=47)
-
-
-def compute_5x5_volumes(num_samples):
- det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4]
- return ctv_debatched(
- det_bounds, 5, num_samples, error_rate=5e-7, seed=48)
-
-
-def main(_):
- full_bounds = {}
- full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples))
- full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples))
- full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples))
- pprint.pprint(full_bounds)
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py
deleted file mode 100644
index 455e71f00c..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py
+++ /dev/null
@@ -1,323 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Estimating the volume of the correlation matrices with bounded determinant.
-
-Why? Because lkj_test.py tests the sampler for the LKJ distribution
-by estimating the same volume another way.
-
-How? Rejection sampling. Or, more precisely, importance sampling,
-proposing from the uniform distribution on symmetric matrices with
-diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation
-matrix if and only if it is also positive semi-definite.
-
-The samples can then be converted into a confidence interval on the
-volume in question by the [Clopper-Pearson
-method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval),
-also implemented here.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import importlib
-import sys
-
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.distributions import uniform
-from tensorflow.python.ops.distributions import util
-from tensorflow.python.platform import tf_logging
-
-__all__ = [
- "correlation_matrix_volume_rejection_samples",
- "compute_true_volumes",
-]
-
-
-def try_import(name): # pylint: disable=invalid-name
- module = None
- try:
- module = importlib.import_module(name)
- except ImportError as e:
- tf_logging.warning("Could not import %s: %s" % (name, str(e)))
- return module
-
-optimize = try_import("scipy.optimize")
-stats = try_import("scipy.stats")
-
-
-def _psd_mask(x):
- """Computes whether each square matrix in the input is positive semi-definite.
-
- Args:
- x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`.
-
- Returns:
- mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each
- scalar is 1 if the corresponding matrix was PSD, otherwise 0.
- """
- # Allegedly
- # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite
- # it is more efficient to test for positive semi-definiteness by
- # trying to compute the Cholesky decomposition -- the matrix is PSD
- # if you succeed and not PSD if you fail. However, TensorFlow's
- # Cholesky raises an exception if _any_ of the input matrices are
- # not PSD, from which I don't know how to extract _which ones_, so I
- # proceed by explicitly computing all the eigenvalues and checking
- # whether they are all positive or not.
- #
- # Also, as was discussed in the answer, it is somewhat dangerous to
- # treat SPD-ness as binary in floating-point arithmetic. Cholesky
- # factorization can complete and 'look' like everything is fine
- # (e.g., O(1) entries and a diagonal of all ones) but the matrix can
- # have an exponential condition number.
- eigenvalues, _ = linalg_ops.self_adjoint_eig(x)
- return math_ops.cast(
- math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype)
-
-
-def _det_large_enough_mask(x, det_bounds):
- """Returns whether the input matches the given determinant limit.
-
- Args:
- x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`.
- det_bounds: A floating-point `Tensor` that must broadcast to shape
- `[B1, ..., Bn]`, giving the desired lower bound on the
- determinants in `x`.
-
- Returns:
- mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each
- scalar is 1 if the corresponding matrix had determinant above
- the corresponding bound, otherwise 0.
- """
- # For the curious: I wonder whether it is possible and desirable to
- # use a Cholesky decomposition-based algorithm for this, since the
- # only matrices whose determinant this code cares about will be PSD.
- # Didn't figure out how to code that in TensorFlow.
- #
- # Expert opinion is that it would be about twice as fast since
- # Cholesky is roughly half the cost of Gaussian Elimination with
- # Partial Pivoting. But this is less of an impact than the switch in
- # _psd_mask.
- return math_ops.cast(
- linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype)
-
-
-def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed):
- """Returns a uniformly random `Tensor` of "correlation-like" matrices.
-
- A "correlation-like" matrix is a symmetric square matrix with all entries
- between -1 and 1 (inclusive) and 1s on the main diagonal. Of these,
- the ones that are positive semi-definite are exactly the correlation
- matrices.
-
- Args:
- num_rows: Python `int` dimension of the correlation-like matrices.
- batch_shape: `Tensor` or Python `tuple` of `int` shape of the
- batch to return.
- dtype: `dtype` of the `Tensor` to return.
- seed: Random seed.
-
- Returns:
- matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]`
- and dtype `dtype`. Each entry is in [-1, 1], and each matrix
- along the bottom two dimensions is symmetric and has 1s on the
- main diagonal.
- """
- num_entries = num_rows * (num_rows + 1) / 2
- ones = array_ops.ones(shape=[num_entries], dtype=dtype)
- # It seems wasteful to generate random values for the diagonal since
- # I am going to throw them away, but `fill_triangular` fills the
- # diagonal, so I probably need them.
- # It's not impossible that it would be more efficient to just fill
- # the whole matrix with random values instead of messing with
- # `fill_triangular`. Then would need to filter almost half out with
- # `matrix_band_part`.
- unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed)
- tril = util.fill_triangular(unifs)
- symmetric = tril + array_ops.matrix_transpose(tril)
- diagonal_ones = array_ops.ones(
- shape=util.pad(batch_shape, axis=0, back=True, value=num_rows),
- dtype=dtype)
- return array_ops.matrix_set_diag(symmetric, diagonal_ones)
-
-
-def correlation_matrix_volume_rejection_samples(
- det_bounds, dim, sample_shape, dtype, seed):
- """Returns rejection samples from trying to get good correlation matrices.
-
- The proposal being rejected from is the uniform distribution on
- "correlation-like" matrices. We say a matrix is "correlation-like"
- if it is a symmetric square matrix with all entries between -1 and 1
- (inclusive) and 1s on the main diagonal. Of these, the ones that
- are positive semi-definite are exactly the correlation matrices.
-
- The rejection algorithm, then, is to sample a `Tensor` of
- `sample_shape` correlation-like matrices of dimensions `dim` by
- `dim`, and check each one for (i) being a correlation matrix (i.e.,
- PSD), and (ii) having determinant at least the corresponding entry
- of `det_bounds`.
-
- Args:
- det_bounds: A `Tensor` of lower bounds on the determinants of
- acceptable matrices. The shape must broadcast with `sample_shape`.
- dim: A Python `int` dimension of correlation matrices to sample.
- sample_shape: Python `tuple` of `int` shape of the samples to
- compute, excluding the two matrix dimensions.
- dtype: The `dtype` in which to do the computation.
- seed: Random seed.
-
- Returns:
- weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the
- corresponding matrix was not a correlation matrix, or had too
- small of a determinant. Otherwise, the entry is the
- multiplicative inverse of the density of proposing that matrix
- uniformly, i.e., the volume of the set of `dim` by `dim`
- correlation-like matrices.
- volume: The volume of the set of `dim` by `dim` correlation-like
- matrices.
- """
- with ops.name_scope("rejection_sampler"):
- rej_proposals = _uniform_correlation_like_matrix(
- dim, sample_shape, dtype, seed=seed)
- rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.)
- # The density of proposing any given point is 1 / rej_proposal_volume;
- # The weight of that point should be scaled by
- # 1 / density = rej_proposal_volume.
- rej_weights = rej_proposal_volume * _psd_mask(
- rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds)
- return rej_weights, rej_proposal_volume
-
-
-def _clopper_pearson_confidence_interval(samples, error_rate):
- """Computes a confidence interval for the mean of the given 1-D distribution.
-
- Assumes (and checks) that the given distribution is Bernoulli, i.e.,
- takes only two values. This licenses using the CDF of the binomial
- distribution for the confidence, which is tighter (for extreme
- probabilities) than the DKWM inequality. The method is known as the
- [Clopper-Pearson method]
- (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval).
-
- Assumes:
-
- - The given samples were drawn iid from the distribution of interest.
-
- - The given distribution is a Bernoulli, i.e., supported only on
- low and high.
-
- Guarantees:
-
- - The probability (over the randomness of drawing the given sample)
- that the true mean is outside the returned interval is no more
- than the given error_rate.
-
- Args:
- samples: `np.ndarray` of samples drawn iid from the distribution
- of interest.
- error_rate: Python `float` admissible rate of mistakes.
-
- Returns:
- low: Lower bound of confidence interval.
- high: Upper bound of confidence interval.
-
- Raises:
- ValueError: If `samples` has rank other than 1 (batch semantics
- are not implemented), or if `samples` contains values other than
- `low` or `high` (as that makes the distribution not Bernoulli).
- """
- # TODO(b/78025336) Migrate this confidence interval function
- # to statistical_testing.py. In order to do that
- # - Get the binomial CDF from the Binomial distribution
- # - Implement scalar root finding in TF. Batch bisection search
- # shouldn't be too hard, and is definitely good enough for this
- # problem. Batching the Brent algorithm (from scipy) that is used
- # here may be more involved, but may also not be necessary---it's
- # only used here because scipy made it convenient. In particular,
- # robustness is more important than speed here, which may make
- # bisection search actively better.
- # - The rest is just a matter of rewriting in the appropriate style.
- if optimize is None or stats is None:
- raise ValueError(
- "Scipy is required for computing Clopper-Pearson confidence intervals")
- if len(samples.shape) != 1:
- raise ValueError("Batch semantics not implemented")
- n = len(samples)
- low = np.amin(samples)
- high = np.amax(samples)
- successes = np.count_nonzero(samples - low)
- failures = np.count_nonzero(samples - high)
- if successes + failures != n:
- uniques = np.unique(samples)
- msg = ("Purportedly Bernoulli distribution had distinct samples"
- " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2]))
- raise ValueError(msg)
- def p_small_enough(p):
- prob = stats.binom.logcdf(successes, n, p)
- return prob - np.log(error_rate / 2.)
- def p_big_enough(p):
- prob = stats.binom.logsf(successes, n, p)
- return prob - np.log(error_rate / 2.)
- high_p = optimize.brentq(
- p_small_enough, float(successes) / n, 1., rtol=1e-9)
- low_p = optimize.brentq(
- p_big_enough, 0., float(successes) / n, rtol=1e-9)
- low_interval = low + (high - low) * low_p
- high_interval = low + (high - low) * high_p
- return (low_interval, high_interval)
-
-
-def compute_true_volumes(
- det_bounds, dim, num_samples, error_rate=1e-6, seed=42):
- """Returns confidence intervals for the desired correlation matrix volumes.
-
- The confidence intervals are computed by the [Clopper-Pearson method]
- (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval).
-
- Args:
- det_bounds: A rank-1 numpy array of lower bounds on the
- determinants of acceptable matrices. Entries must be unique.
- dim: A Python `int` dimension of correlation matrices to sample.
- num_samples: The number of samples to draw.
- error_rate: The statistical significance of the returned
- confidence intervals. The significance is broadcast: Each
- returned interval separately may be incorrect with probability
- (under the sample of correlation-like matrices drawn internally)
- at most `error_rate`.
- seed: Random seed.
-
- Returns:
- bounds: A Python `dict` mapping each determinant bound to the low, high
- tuple giving the confidence interval.
- """
- bounds = {}
- with session.Session() as sess:
- rej_weights, _ = correlation_matrix_volume_rejection_samples(
- det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed)
- rej_weights = sess.run(rej_weights)
- for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds):
- template = ("Estimating volume of {}x{} correlation "
- "matrices with determinant >= {}.")
- print(template.format(dim, dim, det))
- sys.stdout.flush()
- bounds[det] = _clopper_pearson_confidence_interval(
- rw, error_rate=error_rate)
- return bounds
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py
deleted file mode 100644
index 8f99300e63..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for correlation_matrix_volumes_lib.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr
-from tensorflow.contrib.distributions.python.ops import statistical_testing as st
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.platform import test
-
-
-# NxN correlation matrices are determined by the N*(N-1)/2
-# lower-triangular entries. In addition to being between -1 and 1,
-# they must also obey the constraint that the determinant of the
-# resulting symmetric matrix is non-negative. In 2x2, we can even
-# analytically compute the volume when the determinant is bounded to >
-# epsilon, as that boils down to the one lower-triangular entry being
-# less than 1 - epsilon in absolute value.
-def two_by_two_volume(det_bound):
- return 2 * np.sqrt(1.0 - det_bound)
-
-
-# The post
-# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/
-# derives (with elementary calculus) that the volume (with respect to
-# Lebesgue^3 measure) of the set of 3x3 correlation matrices is
-# pi^2/2. The same result is also obtained by [1].
-def three_by_three_volume():
- return np.pi**2 / 2.
-
-
-# The volume of the unconstrained set of correlation matrices is also
-# the normalization constant of the LKJ distribution from [2]. As
-# part of defining the distribution, that reference a derives general
-# formula for this volume for all dimensions. A TensorFlow
-# computation thereof gave the below result for 4x4:
-def four_by_four_volume():
- # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0]))
- return 11.6973076
-
-# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of
-# correlation matrices." The American Statistician, 48(4), 276-279.
-
-# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating
-# random correlation matrices based on vines and extended onion
-# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001.
-
-
-class CorrelationMatrixVolumesTest(test.TestCase):
-
- def testRejection2D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array(
- [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32)
- exact_volumes = two_by_two_volume(det_bounds)
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43)
- # shape of rej_weights: [num_samples, 9, 2, 2]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- # Correct the false fail rate due to different broadcasting
- false_fail_rate=1.1e-7, false_pass_rate=1e-6),
- 0.036)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testRejection3D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array([0.0], dtype=np.float32)
- exact_volumes = np.array([three_by_three_volume()], dtype=np.float32)
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44)
- # shape of rej_weights: [num_samples, 1, 3, 3]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- # Going for about a 3% relative error
- 0.15)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testRejection4D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array([0.0], dtype=np.float32)
- exact_volumes = [four_by_four_volume()]
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45)
- # shape of rej_weights: [num_samples, 1, 4, 4]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- # Going for about a 10% relative error
- 1.1)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testVolumeEstimation2D(self):
- # Test that the confidence intervals produced by
- # corr.compte_true_volumes are sound, in the sense of containing
- # the exact volume.
- num_samples = int(1e5) # Chosen by symmetry with testRejection2D
- det_bounds = np.array(
- [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32)
- volume_bounds = corr.compute_true_volumes(
- det_bounds, 2, num_samples, error_rate=1e-6, seed=47)
- exact_volumes = two_by_two_volume(det_bounds)
- for det, volume in zip(det_bounds, exact_volumes):
- computed_low, computed_high = volume_bounds[det]
- self.assertLess(computed_low, volume)
- self.assertGreater(computed_high, volume)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 6f02c90368..97c299a911 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -6,6 +6,7 @@ package(default_visibility = ["//tensorflow:internal"])
py_library(
name = "examples_pip",
deps = [
+ "//tensorflow/contrib/eager/python/examples/densenet",
"//tensorflow/contrib/eager/python/examples/gan:mnist",
"//tensorflow/contrib/eager/python/examples/l2hmc",
"//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets",
diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD
index c61ec2dbae..d64c8eb9ce 100644
--- a/tensorflow/contrib/eager/python/examples/gan/BUILD
+++ b/tensorflow/contrib/eager/python/examples/gan/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "mnist",
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
index c38a1597b8..1c925e455b 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
@@ -45,6 +45,17 @@ def step(dynamics, optimizer, samples):
return loss, samples
+# To be defunnable, the function cannot return an Operation, so the above
+# function is used for defun or eager, and this function is used in graph to be
+# able to run the gradient updates.
+def graph_step(dynamics, optimizer, samples):
+ loss, grads, samples, _ = l2hmc.loss_and_grads(
+ dynamics, samples, loss_fn=l2hmc.compute_loss)
+ train_op = optimizer.apply_gradients(zip(grads, dynamics.variables))
+
+ return train_op, loss, samples
+
+
def warmup(dynamics,
optimizer,
n_iters=1,
@@ -134,51 +145,48 @@ class L2hmcBenchmark(tf.test.Benchmark):
"""Benchmark Graph performance."""
hparams = get_default_hparams()
- tf.reset_default_graph()
- with tf.Graph().as_default():
- energy_fn, _, _ = l2hmc.get_scg_energy_fn()
- dynamics = l2hmc.Dynamics(
- x_dim=hparams.x_dim,
- minus_loglikelihood_fn=energy_fn,
- n_steps=hparams.n_steps,
- eps=hparams.eps)
- x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
- loss, x_out, _ = l2hmc.compute_loss(dynamics, x)
-
- global_step = tf.Variable(0., name="global_step", trainable=False)
- learning_rate = tf.train.exponential_decay(
- hparams.learning_rate, global_step, 1000, 0.96, staircase=True)
- optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
- train_op = optimizer.minimize(loss, global_step=global_step)
-
- # Single thread; fairer comparison against eager
- session_conf = tf.ConfigProto(
- intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
-
- with tf.Session(config=session_conf) as sess:
- sess.run(tf.global_variables_initializer())
-
- # Warmup to reduce initialization effect when timing
- samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
- for _ in range(hparams.n_warmup_iters):
- _, _, _, _ = sess.run(
- [x_out, loss, train_op, learning_rate], feed_dict={x: samples})
-
- # Training
- start_time = time.time()
- for i in range(hparams.n_iters):
- samples, loss_np, _, _ = sess.run(
- [x_out, loss, train_op, learning_rate], feed_dict={x: samples})
- print("Iteration %d: loss %.4f" % (i, loss_np))
- wall_time = time.time() - start_time
- examples_per_sec = hparams.n_samples / wall_time
-
- self.report_benchmark(
- name="graph_train_%s" % ("gpu"
- if tf.test.is_gpu_available() else "cpu"),
- iters=hparams.n_iters,
- extras={"examples_per_sec": examples_per_sec},
- wall_time=wall_time)
+ tf.enable_resource_variables()
+ for sample_size in [10, 25, 50, 100, 200]:
+ hparams.n_samples = sample_size
+ tf.reset_default_graph()
+ with tf.Graph().as_default():
+ energy_fn, _, _ = l2hmc.get_scg_energy_fn()
+ x = tf.random_normal([hparams.n_samples, hparams.x_dim],
+ dtype=tf.float32)
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ minus_loglikelihood_fn=energy_fn,
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ loss, _, _ = l2hmc.compute_loss(dynamics, x)
+
+ optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
+ train_op, loss, _ = graph_step(dynamics, optimizer, x)
+
+ # Single thread; fairer comparison against eager
+ session_conf = tf.ConfigProto(inter_op_parallelism_threads=1)
+
+ with tf.Session(config=session_conf) as sess:
+ sess.run(tf.global_variables_initializer())
+
+ # Warmup to reduce initialization effect when timing
+ for _ in range(hparams.n_warmup_iters):
+ _, _ = sess.run([train_op, loss])
+
+ # Training
+ start_time = time.time()
+ for i in range(hparams.n_iters):
+ _, loss_np = sess.run([train_op, loss])
+ print("Iteration %d: loss %.4f" % (i, loss_np))
+ wall_time = (time.time() - start_time) / hparams.n_iters
+ examples_per_sec = hparams.n_samples / wall_time
+
+ self.report_benchmark(
+ name="graph_train_%s_%d" %
+ ("gpu" if tf.test.is_gpu_available() else "cpu", sample_size),
+ iters=hparams.n_iters,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
def benchmark_eager(self):
self._benchmark_eager()
@@ -190,32 +198,44 @@ class L2hmcBenchmark(tf.test.Benchmark):
"""Benchmark Eager performance."""
hparams = get_default_hparams()
- energy_fn, _, _ = l2hmc.get_scg_energy_fn()
- dynamics = l2hmc.Dynamics(
- x_dim=hparams.x_dim,
- minus_loglikelihood_fn=energy_fn,
- n_steps=hparams.n_steps,
- eps=hparams.eps)
- optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
- step_fn = tfe.defun(step) if defun else step
-
- # Warmup to reduce initialization effect when timing
- warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, step_fn=step_fn)
-
- # Training
- samples = tf.random_normal(
- shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
- start_time = time.time()
- fit(dynamics, samples, optimizer, step_fn=step_fn, n_iters=hparams.n_iters)
- wall_time = time.time() - start_time
- examples_per_sec = hparams.n_samples / wall_time
-
- self.report_benchmark(
- name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else
- "cpu", "_defun" if defun else ""),
- iters=hparams.n_iters,
- extras={"examples_per_sec": examples_per_sec},
- wall_time=wall_time)
+ for sample_size in [10, 25, 50, 100, 200]:
+ hparams.n_samples = sample_size
+ energy_fn, _, _ = l2hmc.get_scg_energy_fn()
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ minus_loglikelihood_fn=energy_fn,
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
+ step_fn = tfe.defun(step) if defun else step
+
+ # Warmup to reduce initialization effect when timing
+ warmup(
+ dynamics,
+ optimizer,
+ n_iters=hparams.n_warmup_iters,
+ n_samples=hparams.n_samples,
+ step_fn=step_fn)
+
+ # Training
+ samples = tf.random_normal(
+ shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
+ start_time = time.time()
+ fit(dynamics,
+ samples,
+ optimizer,
+ step_fn=step_fn,
+ n_iters=hparams.n_iters)
+ wall_time = (time.time() - start_time) / hparams.n_iters
+ examples_per_sec = hparams.n_samples / wall_time
+
+ self.report_benchmark(
+ name="eager_train_%s%s_%d" %
+ ("gpu" if tf.test.is_gpu_available() else "cpu",
+ "_defun" if defun else "", sample_size),
+ iters=hparams.n_iters,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
del dynamics
diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
index 2f6cfdf31e..74ce9e84f0 100644
--- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
+++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "linear_regression",
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD
index f83eb5c476..d500b632eb 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "rnn_colorbot",
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD
index 4b4792cd49..2cc2fcbfeb 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "rnn_ptb",
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 6db311d52d..1ea00fb7f3 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -132,21 +132,11 @@ py_library(
srcs = ["python/estimator/dnn_with_layer_annotations.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:summary",
- "//tensorflow/python:variable_scope",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:optimizers",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:utils",
],
)
@@ -162,22 +152,13 @@ py_test(
],
deps = [
":dnn_with_layer_annotations",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:dnn",
"//tensorflow/python/estimator:dnn_testing_utils",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:pandas_io",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/feature_column",
"@six_archive//:six",
],
)
@@ -283,9 +264,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:exporter",
],
)
@@ -297,7 +276,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":exporter",
- "//tensorflow/python:platform",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:exporter",
],
@@ -502,7 +481,6 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
- "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
@@ -557,13 +535,10 @@ py_library(
srcs = ["python/estimator/saved_model_estimator.py"],
deps = [
":export",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export",
"//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/saved_model",
],
)
@@ -578,16 +553,7 @@ py_test(
deps = [
":export",
":saved_model_estimator",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:platform",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:export_output",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 78914ecaca..419609b1af 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -76,7 +76,7 @@ _allowed_symbols = [
'stop_if_no_decrease_hook',
'build_raw_supervised_input_receiver_fn',
'build_supervised_input_receiver_fn_from_input_fn',
- 'SavedModelEstimator'
+ 'SavedModelEstimator',
'DNNClassifierWithLayerAnnotations',
'DNNRegressorWithLayerAnnotations',
]
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 11f60c8238..a1f1c5f3d7 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -34,18 +34,19 @@ def _validate_input_fn_and_repeat_dataset(train_input_fn):
return _input_fn
-# pylint: disable=protected-access
def _is_classification_head(head):
"""Infers if the head is a classification head."""
# Check using all classification heads defined in canned/head.py. However, it
# is not a complete list - it does not check for other classification heads
# not defined in the head library.
+ # pylint: disable=protected-access
return isinstance(head,
(head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss,
head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss))
+ # pylint: enable=protected-access
-class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase):
+class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: disable=protected-access
"""An Estimator for Tensorflow Boosted Trees models."""
def __init__(self,
@@ -113,6 +114,7 @@ class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase):
are requested.
"""
# HParams for the model.
+ # pylint: disable=protected-access
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
index 3fd9f12c61..5faf0aacfe 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -75,7 +75,9 @@ def make_input_layer_with_layer_annotations(original_input_layer):
weight_collections=None,
trainable=True,
cols_to_vars=None,
- cols_to_output_tensors=None):
+ scope=None,
+ cols_to_output_tensors=None,
+ from_template=False):
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.
Generally a single example in training data is described with
@@ -111,9 +113,12 @@ def make_input_layer_with_layer_annotations(original_input_layer):
'some_variable:0' shape=(5, 10), <tf.Variable 'some_variable:1'
shape=(5, 10)]} If a column creates no variables, its value will be an
empty list.
+ scope: A name or variable scope to use
cols_to_output_tensors: If not `None`, must be a dictionary that will be
filled with a mapping from '_FeatureColumn' to the associated output
`Tensor`s.
+ from_template: True if the method is being instantiated from a
+ `make_template`.
Returns:
A `Tensor` which represents input layer of a model. Its shape
@@ -131,7 +136,9 @@ def make_input_layer_with_layer_annotations(original_input_layer):
weight_collections=weight_collections,
trainable=trainable,
cols_to_vars=cols_to_vars,
- cols_to_output_tensors=local_cols_to_output_tensors)
+ scope=scope,
+ cols_to_output_tensors=local_cols_to_output_tensors,
+ from_template=from_template)
if cols_to_output_tensors is not None:
cols_to_output_tensors = local_cols_to_output_tensors
@@ -296,9 +303,9 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
def _model_fn(features, labels, mode, config):
with _monkey_patch(
- feature_column_lib, 'input_layer',
+ feature_column_lib, '_internal_input_layer',
make_input_layer_with_layer_annotations(
- feature_column_lib.input_layer)):
+ feature_column_lib._internal_input_layer)): # pylint: disable=protected-access
return original.model_fn(features, labels, mode, config)
return estimator.Estimator(
@@ -417,9 +424,9 @@ def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name
def _model_fn(features, labels, mode, config):
with _monkey_patch(
- feature_column_lib, 'input_layer',
+ feature_column_lib, '_internal_input_layer',
make_input_layer_with_layer_annotations(
- feature_column_lib.input_layer)):
+ feature_column_lib._internal_input_layer)): # pylint: disable=protected-access
return original.model_fn(features, labels, mode, config)
return estimator.Estimator(
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
index e6e25e319f..cafe8279c7 100644
--- a/tensorflow/contrib/estimator/python/estimator/early_stopping.py
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
@@ -57,6 +57,13 @@ def make_early_stopping_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
should_stop_fn: `callable`, function that takes no arguments and returns a
@@ -109,6 +116,13 @@ def stop_if_higher_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -158,6 +172,13 @@ def stop_if_lower_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -207,6 +228,13 @@ def stop_if_no_increase_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -257,6 +285,13 @@ def stop_if_no_decrease_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index c6c6cad95a..62ffad56da 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -294,7 +294,7 @@ class InMemoryEvaluatorHookTest(test.TestCase):
def model_fn(features, labels, mode):
_, _ = features, labels
- w = variables.Variable(
+ w = variables.VariableV1(
initial_value=[0.],
trainable=False,
collections=[ops.GraphKeys.SAVEABLE_OBJECTS])
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 9e1f14f990..510f292508 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -64,7 +64,6 @@ tf_custom_op_py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column:feature_column_py",
"//third_party/py/numpy",
@@ -155,6 +154,8 @@ tf_py_test(
],
tags = [
"no_pip", # b/38283730
+ "noasan", # b/116875897
+ "nomsan",
"notsan", # Flaky: b/30756419
],
)
@@ -178,7 +179,11 @@ tf_py_test(
"//tensorflow/python:random_seed",
"//tensorflow/python:variables",
],
- tags = ["notsan"], # b/62863147
+ tags = [
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan", # b/62863147
+ ],
)
py_library(
@@ -277,6 +282,7 @@ tf_py_test(
"manual",
"noasan", # times out b/63678675
"nomsan",
+ "notsan", # b/116875897
],
)
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index f9b0efd1da..c223df5b6e 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -192,7 +192,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_dtype(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
0.0,
trainable=False,
dtype=dtypes.float32,
@@ -205,7 +205,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_shape(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
[0],
trainable=False,
dtype=dtypes.int32,
@@ -229,7 +229,7 @@ class GlobalStepTest(test.TestCase):
def test_get_global_step(self):
with ops.Graph().as_default() as g:
self.assertEquals(None, variables_lib2.get_global_step())
- variables_lib.Variable(
+ variables_lib.VariableV1(
0,
trainable=False,
dtype=dtypes.int32,
@@ -607,10 +607,10 @@ class ModelVariablesTest(test.TestCase):
with self.cached_session():
with variable_scope.variable_scope('A'):
variables_lib2.local_variable([5])
- a = variables_lib.Variable([5])
+ a = variables_lib.VariableV1([5])
with variable_scope.variable_scope('B'):
variables_lib2.local_variable([5])
- b = variables_lib.Variable([5])
+ b = variables_lib.VariableV1([5])
self.assertEquals([a], variables_lib2.get_trainable_variables('A'))
self.assertEquals([b], variables_lib2.get_trainable_variables('B'))
@@ -953,7 +953,7 @@ class AssignFromCheckpointTest(test.TestCase):
# Create a set of variables to save in the checkpoint.
for var_name in var_names_to_values:
var_value = var_names_to_values[var_name]
- var_list.append(variables_lib.Variable(var_value, name=var_name))
+ var_list.append(variables_lib.VariableV1(var_value, name=var_name))
saver = saver_lib.Saver(var_list)
init_op = variables_lib.variables_initializer(var_list)
sess.run(init_op)
@@ -1106,7 +1106,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
# Create a set of variables to save in the checkpoint.
for var_name in var_names_to_values:
var_value = var_names_to_values[var_name]
- var_list.append(variables_lib.Variable(var_value, name=var_name))
+ var_list.append(variables_lib.VariableV1(var_value, name=var_name))
saver = saver_lib.Saver(var_list)
init_op = variables_lib.variables_initializer(var_list)
sess.run(init_op)
@@ -1297,7 +1297,7 @@ class AssignFromCheckpointFnTest(test.TestCase):
class ZeroInitializerOpTest(test.TestCase):
def _testZeroInitializer(self, shape, initializer, use_init):
- var = variables_lib.Variable(initializer)
+ var = variables_lib.VariableV1(initializer)
var_zero = variables_lib2.zero_initializer(var)
with self.cached_session() as sess:
with self.assertRaisesOpError('Attempting to use uninitialized value'):
@@ -1350,12 +1350,12 @@ class FilterVariablesTest(test.TestCase):
g = ops.Graph()
with g.as_default():
var_list = []
- var_list.append(variables_lib.Variable(0, name='conv1/weights'))
- var_list.append(variables_lib.Variable(0, name='conv1/biases'))
- var_list.append(variables_lib.Variable(0, name='conv2/weights'))
- var_list.append(variables_lib.Variable(0, name='conv2/biases'))
- var_list.append(variables_lib.Variable(0, name='clfs/weights'))
- var_list.append(variables_lib.Variable(0, name='clfs/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='conv1/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='conv1/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='conv2/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='conv2/biases'))
+ var_list.append(variables_lib.VariableV1(0, name='clfs/weights'))
+ var_list.append(variables_lib.VariableV1(0, name='clfs/biases'))
self._var_list = var_list
def _test_filter_variables(self,
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index 0f0813c07f..490da9b33b 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -17,11 +17,14 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+ "tf_custom_op_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+)
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
-load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
-load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
tf_custom_op_py_library(
@@ -109,13 +112,13 @@ tf_gen_op_wrapper_py(
deps = [":fused_conv2d_bias_activation_op_op_lib"],
)
-cuda_py_test(
- name = "fused_conv2d_bias_activation_op_test",
- size = "large",
- srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
- additional_deps = [
+py_library(
+ name = "fused_conv2d_bias_activation_op_test_base",
+ testonly = 1,
+ srcs = ["python/ops/fused_conv2d_bias_activation_op_test_base.py"],
+ visibility = ["//tensorflow/compiler/tf2xla:internal"],
+ deps = [
":fused_conv_py",
- "//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
@@ -128,16 +131,27 @@ cuda_py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "fused_conv2d_bias_activation_op_test",
+ size = "large",
+ srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
+ additional_deps = [
+ ":fused_conv2d_bias_activation_op_test_base",
+ "//tensorflow/python:client_testlib",
],
tags = [
- "manual",
- "requires_cudnn6",
+ "no_pip",
+ "requires-gpu-sm70",
],
)
cuda_py_test(
name = "fused_conv2d_bias_activation_benchmark",
- size = "large",
srcs = ["python/ops/fused_conv2d_bias_activation_benchmark.py"],
additional_deps = [
":fused_conv_py",
@@ -155,7 +169,6 @@ cuda_py_test(
],
main = "python/ops/fused_conv2d_bias_activation_benchmark.py",
tags = [
- "manual",
- "requires_cudnn6",
+ "requires-gpu-sm70",
],
)
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index e9e6464d06..93b1aaa85e 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -111,8 +111,8 @@ class FusedConv2DBiasActivationOp : public OpKernel {
context,
(GetTensorDim(strides, data_format_, 'N') == 1 &&
GetTensorDim(strides, data_format_, 'C') == 1),
- errors::InvalidArgument("Convolutional strides are not supported in "
- "the batch or depth dimensions."));
+ errors::Unimplemented("Convolutional strides are not supported in "
+ "the batch and depth dimensions."));
// Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here.
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index 4894298694..e5c8a34fc1 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -12,896 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Functional tests for fused conv2d bias and activation operation."""
+
+"""Tests for fused convolutions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors_impl
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op_test_base as test_base
from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
-
-
-def GetShrunkInceptionShapes(shrink=10):
- """Iterator for smaller versions of convolution shapes in 2015 Inception.
-
- Relative to inception, each depth value is `depth // shrink`.
-
- Args:
- shrink: Factor to shrink each depth value by relative to Inception.
-
- Yields:
- Tuple (input_size, filter_size, out_size, stride, padding), the convolution
- parameters of Inception layers.
- """
- input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], [
- 4, 8, 8, 2048
- ], [4, 8, 8, 448], [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 2048], [
- 4, 8, 8, 1760
- ], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 17, 17, 192], [
- 4, 17, 17, 192
- ], [4, 17, 17, 1248], [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], [
- 4, 17, 17, 192
- ], [4, 17, 17, 192], [4, 17, 17, 1216], [4, 17, 17, 1216], [4, 17, 17, 224], [
- 4, 17, 17, 192
- ], [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], [4, 17, 17, 192], [
- 4, 17, 17, 160
- ], [4, 17, 17, 1152], [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024],
- [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], [
- 4, 17, 17, 768
- ], [4, 17, 17, 128], [4, 17, 17, 128], [4, 17, 17, 768],
- [4, 17, 17, 768], [4, 35, 35, 96], [4, 35, 35, 288], [
- 4, 35, 35, 64
- ], [4, 35, 35, 288], [4, 35, 35, 256], [4, 35, 35, 48], [
- 4, 35, 35, 256
- ], [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], [
- 4, 35, 35, 192
- ], [4, 73, 73, 64], [4, 73, 73, 64], [4, 147, 147, 24]]
- filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], [
- 1, 1, 2048, 192
- ], [3, 3, 448, 384], [1, 1, 2048, 320], [1, 1, 2048, 448], [1, 1, 2048, 384],
- [1, 1, 1760, 384], [1, 1, 1760, 192], [1, 1, 1760, 448], [
- 1, 1, 1760, 320
- ], [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], [
- 3, 3, 128, 320
- ], [1, 1, 1248, 128], [1, 3, 224, 224], [3, 1, 192, 256], [
- 1, 3, 192, 256
- ], [1, 1, 1216, 192], [1, 1, 1216, 96], [3, 1, 224, 224], [
- 3, 3, 192, 224
- ], [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], [
- 3, 1, 192, 192
- ], [3, 3, 160, 192], [1, 1, 1152, 160], [1, 1, 1024, 128], [
- 1, 3, 128, 192
- ], [1, 1, 1024, 160], [3, 1, 128, 192], [1, 1, 1024, 256], [
- 3, 1, 128, 128
- ], [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], [
- 1, 1, 768, 128
- ], [1, 1, 768, 320], [3, 3, 96, 96], [3, 3, 288, 384], [
- 3, 3, 64, 96
- ], [1, 1, 288, 64], [1, 1, 256, 64], [5, 5, 48, 64],
- [1, 1, 256, 48], [3, 3, 96, 96], [1, 1, 192, 32], [
- 1, 1, 192, 64
- ], [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64,
- 64], [1, 1, 24, 64]]
- out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], [4, 8, 8, 192], [
- 4, 8, 8, 384
- ], [4, 8, 8, 320], [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], [
- 4, 8, 8, 192
- ], [4, 8, 8, 448], [4, 8, 8, 320], [4, 8, 8, 192], [4, 17, 17, 192], [
- 4, 17, 17, 192
- ], [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], [4, 17, 17, 256], [
- 4, 17, 17, 256
- ], [4, 17, 17, 192], [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], [
- 4, 17, 17, 192
- ], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 192], [
- 4, 17, 17, 160
- ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 192], [
- 4, 17, 17, 256
- ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], [
- 4, 17, 17, 128
- ], [4, 17, 17, 320], [4, 17, 17, 96], [4, 17, 17, 384], [4, 35, 35, 96], [
- 4, 35, 35, 64
- ], [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], [4, 35, 35, 96],
- [4, 35, 35, 32], [4, 35, 35, 64], [4, 35, 35, 48],
- [4, 71, 71, 192], [4, 73, 73, 64], [4, 147, 147, 64]]
- strides = [
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1
- ]
- # Shrink sizes to make the test faster
- for i in input_sizes:
- i[3] //= shrink
- for f in filter_sizes:
- f[2] //= shrink
- f[3] //= shrink
- for o in out_sizes:
- o[3] //= shrink
- # pylint: disable=invalid-name
- VALID = "VALID"
- SAME = "SAME"
- # pylint: enable=invalid-name
- paddings = [
- SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
- VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
- SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
- SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME,
- SAME, SAME, SAME, SAME, VALID, VALID, VALID
- ]
- for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
- paddings):
- yield i, f, o, s, p
-
-
-def GetTestConfigs():
- """Get all the valid tests configs to run.
-
- Returns:
- all the valid test configs as tuples of data_format and use_gpu.
- """
- test_configs = [("NCHW", True), ("NHWC", True)]
- return test_configs
-
-
-class FusedConv2DBiasActivationTest(test.TestCase):
-
- def _DtypesToTest(self, use_gpu):
- return [dtypes.float32]
-
- def _FilterFormatsToTest(self, use_gpu):
- return ["HWIO", "OIHW"]
-
- def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias,
- strides, padding, activation_mode, data_format,
- filter_format, dtype):
- """Verifies the output values of the convolution function.
-
- Args:
- tensor_in_sizes: Input tensor dimensions in
- [batch, input_rows, input_cols, input_depth].
- filter_in_sizes: Filter tensor dimensions in
- [kernel_rows, kernel_cols, input_depth, output_depth].
- bias: 1-D bias tensor of length output_depth.
- strides: Stride: [col_stride, row_stride]
- padding: Padding type.
- activation_mode: Activation mode.
- data_format: Format of the data tensors.
- filter_format: Filter format to use for the fused convolution.
- dtype: Data type for inputs and outputs.
- Returns:
- Symbolic tensor value and reference value that can be used to
- execute the computation and verify the results.
- """
- input_size = np.prod(tensor_in_sizes)
- filter_size = np.prod(filter_in_sizes)
- bias_size = filter_in_sizes[-1] # equals to output depth
- # Initializes the input tensor with array containing incrementing
- # numbers from 1.
- x1 = [f * 1.0 for f in range(1, input_size + 1)]
- x2 = [f * 1.0 for f in range(1, filter_size + 1)]
- # This is to guarantee that there is always negative values after
- # bias add so that we can test whether relu works correctly.
- x3 = bias
- with self.test_session(use_gpu=True):
- t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
- t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
- fused_t2 = t2
- if filter_format == "OIHW":
- fused_t2 = HwioToOihw(t2)
- t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype)
- strides = [1] + strides + [1]
- if data_format == "NCHW":
- t1 = test_util.NHWCToNCHW(t1)
- strides = test_util.NHWCToNCHW(strides)
- output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- t1,
- fused_t2,
- t3,
- strides=strides,
- padding=padding,
- data_format=data_format,
- filter_format=filter_format,
- activation_mode=activation_mode)
- ref_conv_output = nn_ops.conv2d(
- t1, t2, strides=strides, padding=padding, data_format=data_format)
- ref_bias_output = nn_ops.bias_add(
- ref_conv_output, t3, data_format=data_format)
- ref_output = nn_ops.relu(ref_bias_output)
- if data_format == "NCHW":
- output = test_util.NCHWToNHWC(output)
- ref_output = test_util.NCHWToNHWC(ref_output)
-
- return output, ref_output
-
- def _CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides,
- padding):
- """Verifies that CPU and GPU produce the same values.
-
- Args:
- tensor_in_sizes: Input tensor dimensions in
- [batch, input_rows, input_cols, input_depth].
- filter_in_sizes: Filter tensor dimensions in
- [kernel_rows, kernel_cols, input_depth, output_depth].
- conv_strides: [row_stride, col_stride] for the convolution;
- padding: Padding type.
- """
- x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
- x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
- x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
-
- def _SetupVal(data_format, use_gpu):
- with self.test_session(use_gpu=use_gpu):
- t1 = constant_op.constant(x1, shape=tensor_in_sizes)
- t2 = constant_op.constant(x2, shape=filter_in_sizes)
- t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
- strides = [1] + conv_strides + [1]
- if data_format == "NCHW":
- t1 = test_util.NHWCToNCHW(t1)
- strides = test_util.NHWCToNCHW(strides)
- output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- t1,
- t2,
- t3,
- strides=strides,
- padding=padding,
- data_format=data_format,
- activation_mode="Relu")
-
- if data_format == "NCHW":
- output = test_util.NCHWToNHWC(output)
- return output
-
- tensors = []
- for (data_format, use_gpu) in GetTestConfigs():
- tensors.append(_SetupVal(data_format, use_gpu))
- with self.cached_session() as sess:
- values = sess.run(tensors)
- for i in range(1, len(values)):
- self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3)
-
- def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides,
- padding):
- tensors = []
- ref_tensors = []
- for (data_format, use_gpu) in GetTestConfigs():
- for dtype in self._DtypesToTest(use_gpu):
- for filter_format in self._FilterFormatsToTest(use_gpu):
- result, expected = self._SetupValuesForDevice(
- tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu",
- data_format, filter_format, dtype)
- tensors.append(result)
- ref_tensors.append(expected)
- with self.cached_session() as sess:
- values = sess.run(tensors)
- ref_values = sess.run(ref_tensors)
- for i in range(len(tensors)):
- conv = tensors[i]
- value = values[i]
- ref_value = ref_values[i]
- tf_logging.info("expected = ", ref_value)
- tf_logging.info("actual = ", value)
- tol = 1e-5
- if value.dtype == np.float16:
- tol = 1e-3
- self.assertAllClose(
- np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol)
- self.assertShapeEqual(value, conv)
-
- def testConv2D1x1Filter(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D1x1Filter test.")
- return
- # expected_output = [
- # 0.0, 0.0, 0.0, 21.0, 0.0, 0.0, 57.0, 0.0, 0.0, 93.0, 41.0, 0.0, 129.0,
- # 86.0, 43.0, 165.0, 131.0, 97.0
- # ]
- medians = [-45.0, -130.0, -215.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[1, 1, 3, 3],
- bias=medians,
- strides=[1, 1],
- padding="VALID")
-
- def testConv2DEmpty(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2DEmpty test.")
- return
- # expected_output = []
- self._VerifyValues(
- tensor_in_sizes=[0, 2, 3, 3],
- filter_in_sizes=[1, 1, 3, 3],
- bias=[0.0, 0.0, 0.0],
- strides=[1, 1],
- padding="VALID")
-
- def testConv2D2x2Filter(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D2x2Filter test.")
- return
- # expected_output = [0.0, 0.0, 0.0, 401.0, 533.0, 665.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[2, 2, 3, 3],
- bias=[-2500.0, -2500.0, -2500.0],
- strides=[1, 1],
- padding="VALID")
-
- def testConv2D1x2Filter(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D1x2Filter test.")
- return
- # expected_output = [
- # 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 265.0, 340.0, 343.0, 436.0, 529.0
- # ]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[1, 2, 3, 3],
- bias=[-500.0, -500.0, -500.0],
- strides=[1, 1],
- padding="VALID")
-
- def testConv2D2x2FilterStride2(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D2x2FilterStride2 test.")
- return
- # expected_output = [0.0, 67.0, 163.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[2, 2, 3, 3],
- bias=[-2300.0, -2300.0, -2300.0],
- strides=[2, 2],
- padding="VALID")
-
- def testConv2D2x2FilterStride2Same(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D2x2FilterStride2Same test.")
- return
- # expected_output = [0.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 3, 3],
- filter_in_sizes=[2, 2, 3, 3],
- bias=[-2300.0, -1000.0, -1000.0],
- strides=[2, 2],
- padding="SAME")
-
- def testConv2D2x2FilterStride1x2(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2D2x2FilterStride1x2 test.")
- return
- # expected_output = [0.0, 0.0, 8.0, 28.0, 48.0, 68.0]
- self._VerifyValues(
- tensor_in_sizes=[1, 3, 6, 1],
- filter_in_sizes=[2, 2, 1, 1],
- bias=[-90.0],
- strides=[1, 2],
- padding="VALID")
-
- def testConv2DKernelSmallerThanStrideValid(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2DKernelSmallerThanStrideValid test.")
- return
- # expected_output = [0, 0, 175, 205]
- self._VerifyValues(
- tensor_in_sizes=[1, 7, 7, 1],
- filter_in_sizes=[2, 2, 1, 1],
- bias=[-100.0],
- strides=[3, 3],
- padding="VALID")
-
- def testConv2DKernelSmallerThanStrideSame(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2DKernelSmallerThanStrideSame test.")
- return
- # expected = [0, 0, 2, 4]
- self._VerifyValues(
- tensor_in_sizes=[1, 3, 3, 1],
- filter_in_sizes=[1, 1, 1, 1],
- bias=[-5.0],
- strides=[2, 2],
- padding="SAME")
-
- # expected = [0, 0, 4, 6]
- self._VerifyValues(
- tensor_in_sizes=[1, 4, 4, 1],
- filter_in_sizes=[1, 1, 1, 1],
- bias=[-5.0],
- strides=[2, 2],
- padding="SAME")
-
- # expected = [4, 0, 1, 0]
- self._VerifyValues(
- tensor_in_sizes=[1, 4, 4, 1],
- filter_in_sizes=[2, 2, 1, 1],
- bias=[-40.0],
- strides=[3, 3],
- padding="SAME")
-
- def testConv2DKernelSizeMatchesInputSize(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping Conv2DKernelSizeMatchesInputSize test.")
- return
- # expected = [0, 5]
- self._VerifyValues(
- tensor_in_sizes=[1, 2, 2, 1],
- filter_in_sizes=[2, 2, 1, 2],
- bias=[-50.0, -55.0],
- strides=[1, 1],
- padding="VALID")
-
- # expected = [0, 2, 282, 322]
- self._VerifyValues(
- tensor_in_sizes=[1, 8, 8, 1],
- filter_in_sizes=[2, 2, 1, 1],
- bias=[-200.0],
- strides=[4, 4],
- padding="SAME")
-
- def testShapeFunctionEdgeCases(self):
- # All shapes unknown.
- c1 = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu")
- self.assertEqual([None, None, None, None], c1.get_shape().as_list())
-
- # Incorrect input shape.
- with self.assertRaises(ValueError):
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32, shape=[1, 3]),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu")
-
- # Incorrect filter shape.
- with self.assertRaises(ValueError):
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32, shape=[1, 3]),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu")
-
- # Depth mismatch.
- with self.assertRaises(ValueError):
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
- array_ops.placeholder(dtypes.float32, shape=[4, 4, 2, 2]),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu")
-
- def testOpEdgeCases(self, gpu_only=True):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping OpEdgeCases tests.")
- return
- with self.cached_session() as sess:
- # Illegal strides.
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "Convolutional strides are not supported in "
- "the batch or depth dimensions."):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[2, 1, 1, 1],
- padding="SAME",
- activation_mode="Relu"))
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "Convolutional strides are not supported in "
- "the batch or depth dimensions."):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 2],
- padding="SAME",
- activation_mode="Relu"))
-
- # Illegal activation mode.
- with self.assertRaisesRegexp(ValueError,
- "Op passed string 'Tanh' not in:"):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- array_ops.placeholder(dtypes.float32),
- strides=[1, 1, 1, 1],
- padding="SAME",
- activation_mode="Tanh"))
-
- # Filter larger than input.
- with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
- array_ops.placeholder(dtypes.float32, shape=[20, 21, 3, 2]),
- array_ops.placeholder(dtypes.float32, shape=[2]),
- strides=[1, 1, 1, 1],
- padding="VALID",
- activation_mode="Relu"))
- with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
- sess.run(
- fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
- array_ops.placeholder(dtypes.float32, shape=[21, 20, 3, 2]),
- array_ops.placeholder(dtypes.float32, shape=[2]),
- strides=[1, 1, 1, 1],
- padding="VALID",
- activation_mode="Relu"))
-
-
-def GetInceptionFwdTest(input_size, filter_size, stride, padding,
- gpu_only=True):
-
- def Test(self):
- if gpu_only and not test.is_gpu_available():
- tf_logging.info("Skipping InceptionFwd %s", (input_size, filter_size,
- stride, padding))
- return
- tf_logging.info("Testing InceptionFwd %s", (input_size, filter_size, stride,
- padding))
- self._CompareFwdValues(input_size, filter_size, [stride, stride], padding)
-
- return Test
-
-
-def CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type):
- """Calculates the size of an output dimension of a strided convolution.
-
- Given the sizes of the corresponding dimension of the input and filter shapes,
- and the stride and padding_types, calculates the size of the output dimension.
- This function can be called separately for each input dimension.
-
- Args:
- input_dim: An `int` specifying the size of the input dimension.
- filter_dim: An `int` specifying the size of the filter dimension.
- stride: An `int` specifying the step size of the convolution along the
- input dimension.
- padding_type: either 'VALID' or 'SAME'.
-
- Returns:
- The size of the output dimension.
- """
- if padding_type == "VALID":
- return (input_dim - filter_dim + stride) // stride
- else: # padding_type == 'SAME'
- return (input_dim + stride - 1) // stride
-
-
-def NchwVectCToNchw(in_tensor):
- # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W]
- t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3])
- n = in_tensor.shape.dims[0].value
- c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
- h = in_tensor.shape.dims[2].value
- w = in_tensor.shape.dims[3].value
- return array_ops.reshape(t, [n, c, h, w])
-
-
-def OihwVectIToHwio(in_tensor):
- # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W]
- t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0])
- o = in_tensor.shape.dims[0].value
- i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
- h = in_tensor.shape.dims[2].value
- w = in_tensor.shape.dims[3].value
- return array_ops.reshape(t, [h, w, i, o])
-
-
-def NchwToNchwVectC(in_tensor):
- n, c, h, w = in_tensor.shape.as_list()
- assert c % 4 == 0
- t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w])
- return array_ops.transpose(t, [0, 1, 3, 4, 2])
-
-
-def HwioToOihw(in_tensor):
- return array_ops.transpose(in_tensor, [3, 2, 0, 1])
-
-
-def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel,
- padding, strides, side_input_scale,
- side_input, biases, apply_relu):
- """Simulates the int8 fused 2-D convolution op using separate float ops.
-
- The arguments and return values have the same format, meanings and
- restrictions as the actual op.
- Args:
- conv_input_scale: A scalar 'float'.
- conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout.
- padding: A `string` from: `"SAME", "VALID"`.
- strides: A list of `ints`.
- side_input_scale: A scalar 'float'.
- side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- biases: A `Tensor` of type `float32` in NCHW layout.
- apply_relu: A boolean to specify whether to apply "Relu" activation function
- that clips outputs to the range [0, 127], or "None" activation that clips
- to the range [-128, 127].
- Returns:
- A `Tensor` of type `qint8` in NCHW_VECT_C layout.
- """
- conv_result = nn_ops.conv2d(
- NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)),
- OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)),
- strides=strides,
- padding=padding,
- data_format="NCHW") * conv_input_scale
-
- conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw(
- gen_array_ops.dequantize(side_input, -128, 127))
-
- output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW")
- if apply_relu:
- output = nn_ops.relu(output)
-
- result, _, _ = gen_array_ops.quantize_v2(
- NchwToNchwVectC(output), -128, 127, dtypes.qint8)
- return result
-
-
-class FusedConvInt8Tests(test.TestCase):
- _test_params = [
- {
- "batch_size": 1,
- "input_channels": 4,
- "output_channels": 4,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 6,
- "filter_width": 6,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 1,
- "input_channels": 4,
- "output_channels": 4,
- "input_height": 6,
- "input_width": 6,
- "filter_height": 6,
- "filter_width": 6,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "VALID"
- },
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.0,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 2,
- "input_channels": 8,
- "output_channels": 16,
- "input_height": 8,
- "input_width": 8,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 2,
- "horizontal_stride": 2,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "VALID"
- },
- {
- "batch_size": 2,
- "input_channels": 16,
- "output_channels": 16,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 3,
- "filter_width": 3,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.001,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 5,
- "filter_width": 5,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.001,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 7,
- "filter_width": 1,
- "vertical_stride": 2,
- "horizontal_stride": 1,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- {
- "batch_size": 3,
- "input_channels": 8,
- "output_channels": 8,
- "input_height": 9,
- "input_width": 9,
- "filter_height": 1,
- "filter_width": 7,
- "vertical_stride": 1,
- "horizontal_stride": 1,
- "conv_input_scale": 0.002,
- "side_input_scale": 0.5,
- "bias_scale": 1,
- "padding_type": "SAME"
- },
- ]
-
- def runTest(self, test_param, apply_relu):
- batch_size = test_param["batch_size"]
- input_channels = test_param["input_channels"]
- output_channels = test_param["output_channels"]
- input_height = test_param["input_height"]
- input_width = test_param["input_width"]
- filter_height = test_param["filter_height"]
- filter_width = test_param["filter_width"]
- vertical_stride = test_param["vertical_stride"]
- horizontal_stride = test_param["horizontal_stride"]
- conv_input_scale = test_param["conv_input_scale"]
- side_input_scale = test_param["side_input_scale"]
- bias_scale = test_param["bias_scale"]
- padding_type = test_param["padding_type"]
-
- conv_input, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [batch_size, input_channels // 4, input_height, input_width, 4],
- minval=-0.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- kernel, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [
- output_channels, input_channels // 4, filter_height,
- filter_width, 4
- ],
- minval=-1.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- output_height = CalculateConvolvedOutputDim(input_height, filter_height,
- vertical_stride, padding_type)
- output_width = CalculateConvolvedOutputDim(input_width, filter_width,
- horizontal_stride, padding_type)
- tf_logging.info("output_height=", output_height, ", output_width=",
- output_width)
-
- side_input, _, _ = gen_array_ops.quantize_v2(
- random_ops.random_uniform(
- [batch_size, output_channels // 4, output_height, output_width, 4],
- minval=0.0,
- maxval=1.0,
- dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
-
- biases = random_ops.random_uniform(
- [output_channels],
- minval=-10 * bias_scale,
- maxval=20 * bias_scale,
- dtype=dtypes.float32)
-
- strides = [1, 1, vertical_stride, horizontal_stride]
-
- actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
- conv_input,
- kernel,
- biases,
- strides=strides,
- padding=padding_type,
- conv_input_scale=conv_input_scale,
- side_input_scale=side_input_scale,
- side_input=side_input,
- activation_mode="Relu" if apply_relu else "None",
- data_format="NCHW_VECT_C",
- filter_format="OIHW_VECT_I")
- expected = SimulateFusedConv2dBiasActivationInt8(
- conv_input_scale, conv_input, kernel, padding_type, strides,
- side_input_scale, side_input, biases, apply_relu)
- with self.test_session(use_gpu=True) as sess:
- actual_y, expected_y = sess.run([actual, expected])
- self.assertAllClose(actual_y, expected_y, rtol=0, atol=1)
+# Instantiate the two test suites from test_base, mixing in test.TestCase as
+# the test framework.
+class FusedConv2DBiasActivationTest(test_base.FusedConv2DBiasActivationTest,
+ test.TestCase):
+ pass
- def testFusedConvInt8(self):
- if not test.is_gpu_available(
- cuda_only=True, min_cuda_compute_capability=(6, 1)):
- tf_logging.info("int8 test skipped because not run with --config=cuda or "
- "no GPUs with compute capability >= 6.1 are available.")
- return
- for apply_relu in [True, False]:
- for test_param in self._test_params:
- self.runTest(test_param, apply_relu)
+class FusedConvInt8Tests(test_base.FusedConvInt8Tests, test.TestCase):
+ pass
-if __name__ == "__main__":
- for index, (input_size_, filter_size_, output_size_, stride_,
- padding_) in enumerate(GetShrunkInceptionShapes()):
- setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_" + str(index),
- GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))
- # TODO(b/35359731)
- # Fwd, BckInput, and BackFilter to test that for certain input parameter
- # set, winograd nonfused algorithm will be excluded from conv autotune. If
- # in such case, winograd nonfused algorithm is added as one option of the
- # conv autotune, and cuDNN version is smaller than 7, the following tests
- # will fail.
- ishape = [1, 400, 400, 1]
- fshape = [1, 1, 1, 256]
- oshape = [1, 400, 400, 256]
- setattr(FusedConv2DBiasActivationTest,
- "testInceptionFwd_No_Winograd_Nonfused",
- GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py
new file mode 100644
index 0000000000..35fc65e4ba
--- /dev/null
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test_base.py
@@ -0,0 +1,945 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Provides test suites that can be run to test fused convolutions.
+
+Each of the two test suites in this module, FusedConv2DBiasActivationTest and
+FusedConvInt8Tests, should be "instantiated" by declaring a class which inherits
+from the FusedConv test and a class that provides the standard test.TestCase
+API.
+
+See e.g. fused_conv2d_bias_activation_op_test.py in this folder.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import numpy as np
+
+from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+def _GetShrunkInceptionShapes(shrink=10):
+ """Iterator for smaller versions of convolution shapes in 2015 Inception.
+
+ Relative to inception, each depth value is `depth // shrink`.
+
+ Args:
+ shrink: Factor to shrink each depth value by relative to Inception.
+
+ Yields:
+ Tuple (input_size, filter_size, out_size, stride, padding), the convolution
+ parameters of Inception layers.
+ """
+ input_sizes = [[4, 5, 5, 1248], [4, 8, 8, 384], [4, 8, 8, 384], [
+ 4, 8, 8, 2048
+ ], [4, 8, 8, 448], [4, 8, 8, 2048], [4, 8, 8, 2048], [4, 8, 8, 2048], [
+ 4, 8, 8, 1760
+ ], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 8, 8, 1760], [4, 17, 17, 192], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 1248], [4, 17, 17, 128], [4, 17, 17, 1248], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 1216], [4, 17, 17, 1216], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 1152], [4, 17, 17, 1152], [4, 17, 17, 192], [
+ 4, 17, 17, 160
+ ], [4, 17, 17, 1152], [4, 17, 17, 1024], [4, 17, 17, 128], [4, 17, 17, 1024],
+ [4, 17, 17, 128], [4, 17, 17, 1024], [4, 17, 17, 128], [
+ 4, 17, 17, 768
+ ], [4, 17, 17, 128], [4, 17, 17, 128], [4, 17, 17, 768],
+ [4, 17, 17, 768], [4, 35, 35, 96], [4, 35, 35, 288], [
+ 4, 35, 35, 64
+ ], [4, 35, 35, 288], [4, 35, 35, 256], [4, 35, 35, 48], [
+ 4, 35, 35, 256
+ ], [4, 35, 35, 96], [4, 35, 35, 192], [4, 35, 35, 192], [
+ 4, 35, 35, 192
+ ], [4, 73, 73, 64], [4, 73, 73, 64], [4, 147, 147, 24]]
+ filter_sizes = [[1, 1, 1248, 128], [1, 3, 384, 384], [3, 1, 384, 384], [
+ 1, 1, 2048, 192
+ ], [3, 3, 448, 384], [1, 1, 2048, 320], [1, 1, 2048, 448], [1, 1, 2048, 384],
+ [1, 1, 1760, 384], [1, 1, 1760, 192], [1, 1, 1760, 448], [
+ 1, 1, 1760, 320
+ ], [3, 3, 192, 192], [3, 3, 192, 192], [1, 1, 1248, 192], [
+ 3, 3, 128, 320
+ ], [1, 1, 1248, 128], [1, 3, 224, 224], [3, 1, 192, 256], [
+ 1, 3, 192, 256
+ ], [1, 1, 1216, 192], [1, 1, 1216, 96], [3, 1, 224, 224], [
+ 3, 3, 192, 224
+ ], [1, 3, 192, 192], [1, 1, 1152, 192], [1, 1, 1152, 128], [
+ 3, 1, 192, 192
+ ], [3, 3, 160, 192], [1, 1, 1152, 160], [1, 1, 1024, 128], [
+ 1, 3, 128, 192
+ ], [1, 1, 1024, 160], [3, 1, 128, 192], [1, 1, 1024, 256], [
+ 3, 1, 128, 128
+ ], [1, 1, 768, 192], [1, 3, 128, 128], [3, 3, 128, 128], [
+ 1, 1, 768, 128
+ ], [1, 1, 768, 320], [3, 3, 96, 96], [3, 3, 288, 384], [
+ 3, 3, 64, 96
+ ], [1, 1, 288, 64], [1, 1, 256, 64], [5, 5, 48, 64],
+ [1, 1, 256, 48], [3, 3, 96, 96], [1, 1, 192, 32], [
+ 1, 1, 192, 64
+ ], [1, 1, 192, 48], [3, 3, 64, 192], [1, 1, 64,
+ 64], [1, 1, 24, 64]]
+ out_sizes = [[4, 5, 5, 128], [4, 8, 8, 384], [4, 8, 8, 384], [4, 8, 8, 192], [
+ 4, 8, 8, 384
+ ], [4, 8, 8, 320], [4, 8, 8, 448], [4, 8, 8, 384], [4, 8, 8, 384], [
+ 4, 8, 8, 192
+ ], [4, 8, 8, 448], [4, 8, 8, 320], [4, 8, 8, 192], [4, 17, 17, 192], [
+ 4, 17, 17, 192
+ ], [4, 8, 8, 320], [4, 17, 17, 128], [4, 17, 17, 224], [4, 17, 17, 256], [
+ 4, 17, 17, 256
+ ], [4, 17, 17, 192], [4, 17, 17, 96], [4, 17, 17, 224], [4, 17, 17, 224], [
+ 4, 17, 17, 192
+ ], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 192], [
+ 4, 17, 17, 160
+ ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 160], [4, 17, 17, 192], [
+ 4, 17, 17, 256
+ ], [4, 17, 17, 128], [4, 17, 17, 192], [4, 17, 17, 128], [4, 17, 17, 128], [
+ 4, 17, 17, 128
+ ], [4, 17, 17, 320], [4, 17, 17, 96], [4, 17, 17, 384], [4, 35, 35, 96], [
+ 4, 35, 35, 64
+ ], [4, 35, 35, 64], [4, 35, 35, 64], [4, 35, 35, 48], [4, 35, 35, 96],
+ [4, 35, 35, 32], [4, 35, 35, 64], [4, 35, 35, 48],
+ [4, 71, 71, 192], [4, 73, 73, 64], [4, 147, 147, 64]]
+ strides = [
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1
+ ]
+ # Shrink sizes to make the test faster
+ for i in input_sizes:
+ i[3] //= shrink
+ for f in filter_sizes:
+ f[2] //= shrink
+ f[3] //= shrink
+ for o in out_sizes:
+ o[3] //= shrink
+ # pylint: disable=invalid-name
+ VALID = "VALID"
+ SAME = "SAME"
+ # pylint: enable=invalid-name
+ paddings = [
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ VALID, SAME, SAME, VALID, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, SAME, VALID, VALID, SAME, SAME, SAME, SAME, SAME,
+ SAME, SAME, SAME, SAME, VALID, VALID, VALID
+ ]
+ for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
+ paddings):
+ yield i, f, o, s, p
+
+
+def _GetTestConfigs():
+ """Get all the valid tests configs to run.
+
+ Returns:
+ all the valid test configs as tuples of data_format and use_gpu.
+ """
+ test_configs = [("NCHW", True), ("NHWC", True)]
+ return test_configs
+
+
+def _IotaNdF32Constant(dim_sizes):
+
+ def MakeList(dims):
+ if len(dims) == 1:
+ return [float(1 + f) for f in range(dims[0])]
+ return [MakeList(dims[1:]) for _ in range(dims[0])]
+
+ return constant_op.constant(MakeList(dim_sizes), dtype=dtypes.float32)
+
+
+def _GetInceptionFwdTest(input_size,
+ filter_size,
+ stride,
+ padding,
+ gpu_only=True):
+
+ def Test(self):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping InceptionFwd %s",
+ (input_size, filter_size, stride, padding))
+ return
+ tf_logging.info("Testing InceptionFwd %s",
+ (input_size, filter_size, stride, padding))
+ self.CompareFwdValues(input_size, filter_size, [stride, stride], padding)
+
+ return Test
+
+
+class FusedConv2DBiasActivationTest(object):
+
+ @contextlib.contextmanager
+ def test_scope(self): # pylint: disable=invalid-name
+ """Can be overridden in base classes to provide a test scope."""
+ yield
+
+ def _DtypesToTest(self, use_gpu):
+ return [dtypes.float32]
+
+ def _FilterFormatsToTest(self, use_gpu):
+ return ["HWIO", "OIHW"]
+
+ def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias,
+ strides, padding, activation_mode, data_format,
+ filter_format, dtype):
+ """Verifies the output values of the convolution function.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ bias: 1-D bias tensor of length output_depth.
+ strides: Stride: [col_stride, row_stride]
+ padding: Padding type.
+ activation_mode: Activation mode.
+ data_format: Format of the data tensors.
+ filter_format: Filter format to use for the fused convolution.
+ dtype: Data type for inputs and outputs.
+ Returns:
+ Symbolic tensor value and reference value that can be used to
+ execute the computation and verify the results.
+ """
+ input_size = np.prod(tensor_in_sizes)
+ filter_size = np.prod(filter_in_sizes)
+ bias_size = filter_in_sizes[-1] # equals to output depth
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, input_size + 1)]
+ x2 = [f * 1.0 for f in range(1, filter_size + 1)]
+ # This is to guarantee that there are always negative values after
+ # bias add so that we can test whether relu works correctly.
+ x3 = bias
+ with self.cached_session(use_gpu=True), self.test_scope():
+ t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
+ t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
+ fused_t2 = t2
+ if filter_format == "OIHW":
+ fused_t2 = _HwioToOihw(t2)
+ t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype)
+ strides = [1] + strides + [1]
+ if data_format == "NCHW":
+ t1 = test_util.NHWCToNCHW(t1)
+ strides = test_util.NHWCToNCHW(strides)
+ output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ t1,
+ fused_t2,
+ t3,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ filter_format=filter_format,
+ activation_mode=activation_mode)
+ ref_conv_output = nn_ops.conv2d(
+ t1, t2, strides=strides, padding=padding, data_format=data_format)
+ ref_bias_output = nn_ops.bias_add(
+ ref_conv_output, t3, data_format=data_format)
+ ref_output = nn_ops.relu(ref_bias_output)
+ if data_format == "NCHW":
+ output = test_util.NCHWToNHWC(output)
+ ref_output = test_util.NCHWToNHWC(ref_output)
+
+ return output, ref_output
+
+ def CompareFwdValues(self, tensor_in_sizes, filter_in_sizes, conv_strides,
+ padding):
+ """Verifies that CPU and GPU produce the same values.
+
+ Args:
+ tensor_in_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_in_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ conv_strides: [row_stride, col_stride] for the convolution;
+ padding: Padding type.
+ """
+ x1 = np.random.rand(*tensor_in_sizes).astype(np.float32)
+ x2 = np.random.rand(*filter_in_sizes).astype(np.float32)
+ x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
+
+ def _SetupVal(data_format, use_gpu):
+ with self.cached_session(use_gpu=use_gpu), self.test_scope():
+ t1 = constant_op.constant(x1, shape=tensor_in_sizes)
+ t2 = constant_op.constant(x2, shape=filter_in_sizes)
+ t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
+ strides = [1] + conv_strides + [1]
+ if data_format == "NCHW":
+ t1 = test_util.NHWCToNCHW(t1)
+ strides = test_util.NHWCToNCHW(strides)
+ output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ t1,
+ t2,
+ t3,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation_mode="Relu")
+
+ if data_format == "NCHW":
+ output = test_util.NCHWToNHWC(output)
+ return output
+
+ tensors = []
+ for (data_format, use_gpu) in _GetTestConfigs():
+ tensors.append(_SetupVal(data_format, use_gpu))
+ with self.cached_session() as sess, self.test_scope():
+ values = sess.run(tensors)
+ for i in range(1, len(values)):
+ self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3)
+
+ def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides,
+ padding):
+ tensors = []
+ ref_tensors = []
+ for (data_format, use_gpu) in _GetTestConfigs():
+ for dtype in self._DtypesToTest(use_gpu):
+ for filter_format in self._FilterFormatsToTest(use_gpu):
+ result, expected = self._SetupValuesForDevice(
+ tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu",
+ data_format, filter_format, dtype)
+ tensors.append(result)
+ ref_tensors.append(expected)
+ with self.cached_session() as sess, self.test_scope():
+ values = sess.run(tensors)
+ ref_values = sess.run(ref_tensors)
+ for i in range(len(tensors)):
+ conv = tensors[i]
+ value = values[i]
+ ref_value = ref_values[i]
+ tf_logging.info("expected = %s", ref_value)
+ tf_logging.info("actual = %s", value)
+ tol = 1e-5
+ if value.dtype == np.float16:
+ tol = 1e-3
+ self.assertAllClose(
+ np.ravel(ref_value), np.ravel(value), atol=tol, rtol=tol)
+ self.assertShapeEqual(value, conv)
+
+ def testConv2D1x1Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D1x1Filter test.")
+ return
+ # expected_output = [
+ # 0.0, 0.0, 0.0, 21.0, 0.0, 0.0, 57.0, 0.0, 0.0, 93.0, 41.0, 0.0, 129.0,
+ # 86.0, 43.0, 165.0, 131.0, 97.0
+ # ]
+ medians = [-45.0, -130.0, -215.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 1, 3, 3],
+ bias=medians,
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2DEmpty(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DEmpty test.")
+ return
+ # expected_output = []
+ self._VerifyValues(
+ tensor_in_sizes=[0, 2, 3, 3],
+ filter_in_sizes=[1, 1, 3, 3],
+ bias=[0.0, 0.0, 0.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D2x2Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2Filter test.")
+ return
+ # expected_output = [0.0, 0.0, 0.0, 401.0, 533.0, 665.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2500.0, -2500.0, -2500.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D1x2Filter(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D1x2Filter test.")
+ return
+ # expected_output = [
+ # 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 190.0, 265.0, 340.0, 343.0, 436.0, 529.0
+ # ]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 2, 3, 3],
+ bias=[-500.0, -500.0, -500.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ def testConv2D2x2FilterStride2(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride2 test.")
+ return
+ # expected_output = [0.0, 67.0, 163.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2300.0, -2300.0, -2300.0],
+ strides=[2, 2],
+ padding="VALID")
+
+ def testConv2D2x2FilterStride2Same(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride2Same test.")
+ return
+ # expected_output = [0.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ bias=[-2300.0, -1000.0, -1000.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ def testConv2D2x2FilterStride1x2(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2D2x2FilterStride1x2 test.")
+ return
+ # expected_output = [0.0, 0.0, 8.0, 28.0, 48.0, 68.0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 3, 6, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-90.0],
+ strides=[1, 2],
+ padding="VALID")
+
+ def testConv2DKernelSmallerThanStrideValid(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSmallerThanStrideValid test.")
+ return
+ # expected_output = [0, 0, 175, 205]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 7, 7, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-100.0],
+ strides=[3, 3],
+ padding="VALID")
+
+ def testConv2DKernelSmallerThanStrideSame(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSmallerThanStrideSame test.")
+ return
+ # expected = [0, 0, 2, 4]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 3, 3, 1],
+ filter_in_sizes=[1, 1, 1, 1],
+ bias=[-5.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ # expected = [0, 0, 4, 6]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[1, 1, 1, 1],
+ bias=[-5.0],
+ strides=[2, 2],
+ padding="SAME")
+
+ # expected = [4, 0, 1, 0]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-40.0],
+ strides=[3, 3],
+ padding="SAME")
+
+ def testConv2DKernelSizeMatchesInputSize(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping Conv2DKernelSizeMatchesInputSize test.")
+ return
+ # expected = [0, 5]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 2, 2, 1],
+ filter_in_sizes=[2, 2, 1, 2],
+ bias=[-50.0, -55.0],
+ strides=[1, 1],
+ padding="VALID")
+
+ # expected = [0, 2, 282, 322]
+ self._VerifyValues(
+ tensor_in_sizes=[1, 8, 8, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ bias=[-200.0],
+ strides=[4, 4],
+ padding="SAME")
+
+ def testShapeFunctionEdgeCases(self):
+ # All shapes unknown.
+ c1 = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+ self.assertEqual([None, None, None, None], c1.get_shape().as_list())
+
+ # Incorrect input shape.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32, shape=[1, 3]),
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ # Incorrect filter shape.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32),
+ array_ops.placeholder(dtypes.float32, shape=[1, 3]),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ # Depth mismatch.
+ with self.assertRaises(ValueError):
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[4, 4, 2, 2]),
+ array_ops.placeholder(dtypes.float32),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu")
+
+ def testOpEdgeCases(self, gpu_only=True):
+ if gpu_only and not test.is_gpu_available():
+ tf_logging.info("Skipping OpEdgeCases tests.")
+ return
+ with self.cached_session() as sess, self.test_scope():
+ # Illegal strides.
+ with self.assertRaisesRegexp(
+ errors_impl.UnimplementedError,
+ ".*strides.*in the batch and depth dimensions"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1]),
+ strides=[2, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Relu"))
+ with self.assertRaisesRegexp(
+ errors_impl.UnimplementedError,
+ ".*strides.*in the batch and depth dimensions"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1]),
+ strides=[1, 1, 1, 2],
+ padding="SAME",
+ activation_mode="Relu"))
+
+ # Illegal activation mode.
+ with self.assertRaisesRegexp(ValueError,
+ "Op passed string 'Tanh' not in:"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1, 1, 1, 1]),
+ _IotaNdF32Constant([1]),
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ activation_mode="Tanh"))
+
+ # Filter larger than input.
+ with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([32, 20, 20, 3]),
+ _IotaNdF32Constant([20, 21, 3, 2]),
+ _IotaNdF32Constant([2]),
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ activation_mode="Relu"))
+ with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
+ sess.run(
+ fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ _IotaNdF32Constant([32, 20, 20, 3]),
+ _IotaNdF32Constant([21, 20, 3, 2]),
+ _IotaNdF32Constant([2]),
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ activation_mode="Relu"))
+
+
+# Add InceptionFwd tests to FusedConv2DBiasActivationTest.
+for index, (input_size_, filter_size_, output_size_, stride_,
+ padding_) in enumerate(_GetShrunkInceptionShapes()):
+ setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_" + str(index),
+ _GetInceptionFwdTest(input_size_, filter_size_, stride_, padding_))
+
+# TODO(b/35359731)
+# Fwd, BckInput, and BackFilter to test that for certain input parameter
+# set, winograd nonfused algorithm will be excluded from conv autotune. If
+# in such case, winograd nonfused algorithm is added as one option of the
+# conv autotune, and cuDNN version is smaller than 7, the following tests
+# will fail.
+ishape = [1, 400, 400, 1]
+fshape = [1, 1, 1, 256]
+oshape = [1, 400, 400, 256]
+setattr(FusedConv2DBiasActivationTest, "testInceptionFwd_No_Winograd_Nonfused",
+ _GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True))
+
+
+def _CalculateConvolvedOutputDim(input_dim, filter_dim, stride, padding_type):
+ """Calculates the size of an output dimension of a strided convolution.
+
+ Given the sizes of the corresponding dimension of the input and filter shapes,
+ and the stride and padding_types, calculates the size of the output dimension.
+ This function can be called separately for each input dimension.
+
+ Args:
+ input_dim: An `int` specifying the size of the input dimension.
+ filter_dim: An `int` specifying the size of the filter dimension.
+ stride: An `int` specifying the step size of the convolution along the
+ input dimension.
+ padding_type: either 'VALID' or 'SAME'.
+
+ Returns:
+ The size of the output dimension.
+ """
+ if padding_type == "VALID":
+ return (input_dim - filter_dim + stride) // stride
+ else: # padding_type == 'SAME'
+ return (input_dim + stride - 1) // stride
+
+
+def _NchwVectCToNchw(in_tensor):
+ # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W]
+ t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3])
+ n = in_tensor.shape.dims[0].value
+ c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
+ h = in_tensor.shape.dims[2].value
+ w = in_tensor.shape.dims[3].value
+ return array_ops.reshape(t, [n, c, h, w])
+
+
+def _OihwVectIToHwio(in_tensor):
+ # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W]
+ t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0])
+ o = in_tensor.shape.dims[0].value
+ i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value
+ h = in_tensor.shape.dims[2].value
+ w = in_tensor.shape.dims[3].value
+ return array_ops.reshape(t, [h, w, i, o])
+
+
+def _NchwToNchwVectC(in_tensor):
+ n, c, h, w = in_tensor.shape.as_list()
+ assert c % 4 == 0
+ t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w])
+ return array_ops.transpose(t, [0, 1, 3, 4, 2])
+
+
+def _HwioToOihw(in_tensor):
+ return array_ops.transpose(in_tensor, [3, 2, 0, 1])
+
+
+def _SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel,
+ padding, strides, side_input_scale,
+ side_input, biases, apply_relu):
+ """Simulates the int8 fused 2-D convolution op using separate float ops.
+
+ The arguments and return values have the same format, meanings and
+ restrictions as the actual op.
+ Args:
+ conv_input_scale: A scalar 'float'.
+ conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout.
+ padding: A `string` from: `"SAME", "VALID"`.
+ strides: A list of `ints`.
+ side_input_scale: A scalar 'float'.
+ side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ biases: A `Tensor` of type `float32` in NCHW layout.
+ apply_relu: A boolean to specify whether to apply "Relu" activation function
+ that clips outputs to the range [0, 127], or "None" activation that clips
+ to the range [-128, 127].
+ Returns:
+ A `Tensor` of type `qint8` in NCHW_VECT_C layout.
+ """
+ conv_result = nn_ops.conv2d(
+ _NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)),
+ _OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)),
+ strides=strides,
+ padding=padding,
+ data_format="NCHW") * conv_input_scale
+
+ conv_and_side_inputs = conv_result + side_input_scale * _NchwVectCToNchw(
+ gen_array_ops.dequantize(side_input, -128, 127))
+
+ output = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW")
+ if apply_relu:
+ output = nn_ops.relu(output)
+
+ result, _, _ = gen_array_ops.quantize_v2(
+ _NchwToNchwVectC(output), -128, 127, dtypes.qint8)
+ return result
+
+
+# TODO(b/114580749): XLA:CPU/GPU don't support int8 at the moment, so this test
+# doesn't currently use XLA.
+class FusedConvInt8Tests(object):
+ _test_params = [
+ {
+ "batch_size": 1,
+ "input_channels": 4,
+ "output_channels": 4,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 6,
+ "filter_width": 6,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 1,
+ "input_channels": 4,
+ "output_channels": 4,
+ "input_height": 6,
+ "input_width": 6,
+ "filter_height": 6,
+ "filter_width": 6,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "VALID"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.0,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 8,
+ "output_channels": 16,
+ "input_height": 8,
+ "input_width": 8,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 2,
+ "horizontal_stride": 2,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "VALID"
+ },
+ {
+ "batch_size": 2,
+ "input_channels": 16,
+ "output_channels": 16,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 3,
+ "filter_width": 3,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.001,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 5,
+ "filter_width": 5,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.001,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 7,
+ "filter_width": 1,
+ "vertical_stride": 2,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ {
+ "batch_size": 3,
+ "input_channels": 8,
+ "output_channels": 8,
+ "input_height": 9,
+ "input_width": 9,
+ "filter_height": 1,
+ "filter_width": 7,
+ "vertical_stride": 1,
+ "horizontal_stride": 1,
+ "conv_input_scale": 0.002,
+ "side_input_scale": 0.5,
+ "bias_scale": 1,
+ "padding_type": "SAME"
+ },
+ ]
+
+ @contextlib.contextmanager
+ def test_scope(self): # pylint: disable=invalid-name
+ """Can be overridden in base classes to provide a test scope."""
+ yield
+
+ def runTest(self, test_param, apply_relu):
+ batch_size = test_param["batch_size"]
+ input_channels = test_param["input_channels"]
+ output_channels = test_param["output_channels"]
+ input_height = test_param["input_height"]
+ input_width = test_param["input_width"]
+ filter_height = test_param["filter_height"]
+ filter_width = test_param["filter_width"]
+ vertical_stride = test_param["vertical_stride"]
+ horizontal_stride = test_param["horizontal_stride"]
+ conv_input_scale = test_param["conv_input_scale"]
+ side_input_scale = test_param["side_input_scale"]
+ bias_scale = test_param["bias_scale"]
+ padding_type = test_param["padding_type"]
+
+ with self.cached_session(use_gpu=True) as sess, self.test_scope():
+ conv_input, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform(
+ [batch_size, input_channels // 4, input_height, input_width, 4],
+ minval=-0.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8)
+
+ kernel, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform([
+ output_channels, input_channels // 4, filter_height, filter_width,
+ 4
+ ],
+ minval=-1.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0,
+ dtypes.qint8)
+
+ output_height = _CalculateConvolvedOutputDim(
+ input_height, filter_height, vertical_stride, padding_type)
+ output_width = _CalculateConvolvedOutputDim(
+ input_width, filter_width, horizontal_stride, padding_type)
+ tf_logging.info("output_height=%s, output_width=%s", output_height,
+ output_width)
+
+ side_input, _, _ = gen_array_ops.quantize_v2(
+ random_ops.random_uniform([
+ batch_size, output_channels // 4, output_height, output_width, 4
+ ],
+ minval=0.0,
+ maxval=1.0,
+ dtype=dtypes.float32), -1.0, 1.0,
+ dtypes.qint8)
+
+ biases = random_ops.random_uniform([output_channels],
+ minval=-10 * bias_scale,
+ maxval=20 * bias_scale,
+ dtype=dtypes.float32)
+
+ strides = [1, 1, vertical_stride, horizontal_stride]
+
+ actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation(
+ conv_input,
+ kernel,
+ biases,
+ strides=strides,
+ padding=padding_type,
+ conv_input_scale=conv_input_scale,
+ side_input_scale=side_input_scale,
+ side_input=side_input,
+ activation_mode="Relu" if apply_relu else "None",
+ data_format="NCHW_VECT_C",
+ filter_format="OIHW_VECT_I")
+
+ expected = _SimulateFusedConv2dBiasActivationInt8(
+ conv_input_scale, conv_input, kernel, padding_type, strides,
+ side_input_scale, side_input, biases, apply_relu)
+
+ actual_y, expected_y = sess.run([actual, expected])
+ self.assertAllClose(actual_y, expected_y, rtol=0, atol=1)
+
+ def testFusedConvInt8(self):
+ if not test.is_gpu_available(
+ cuda_only=True, min_cuda_compute_capability=(6, 1)):
+ tf_logging.info("int8 test skipped because not run with --config=cuda or "
+ "no GPUs with compute capability >= 6.1 are available.")
+ return
+ for apply_relu in [True, False]:
+ for test_param in self._test_params:
+ self.runTest(test_param, apply_relu)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index 97f38c923f..0ebcdc2688 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -214,7 +214,7 @@ class TransformTest(test.TestCase):
def test_graph_replace_gradients(self):
ops.reset_default_graph()
- w = variables.Variable(0.0, name="w")
+ w = variables.VariableV1(0.0, name="w")
y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2")
g = gradients_impl.gradients(y, w, name="grad")[0]
diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
index 6e0e628655..bf398b838d 100644
--- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
+++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
@@ -19,14 +19,14 @@ from __future__ import print_function
from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops
from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class SequenceFileDataset(Dataset):
+class SequenceFileDataset(dataset_ops.DatasetSource):
"""A Sequence File Dataset that reads the sequence file."""
def __init__(self, filenames):
diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
index a1624614d1..7129f09e8b 100644
--- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
+++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
@@ -17,15 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
from tensorflow.contrib.kafka.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class KafkaDataset(Dataset):
+class KafkaDataset(dataset_ops.DatasetSource):
"""A Kafka Dataset that consumes the message.
"""
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
index ca2df95ba4..75806dbbeb 100644
--- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
@@ -17,15 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class KinesisDataset(Dataset):
+class KinesisDataset(dataset_ops.DatasetSource):
"""A Kinesis Dataset that consumes the message.
Kinesis is a managed service provided by AWS for data streaming.
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index 17ee8c0733..60e1d85ea9 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -112,11 +112,9 @@ def safe_embedding_lookup_sparse(embedding_weights,
dtype = sparse_weights.dtype if sparse_weights is not None else None
if isinstance(embedding_weights, variables.PartitionedVariable):
embedding_weights = list(embedding_weights)
- if not isinstance(embedding_weights[0],
- resource_variable_ops.ResourceVariable):
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
contrib_tensor_util.assert_same_float_dtype(embedding_weights +
[sparse_weights])
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 85af9de4e4..3b7ae72e9c 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -2360,7 +2360,7 @@ class BatchNormTest(test.TestCase):
batch_size * height * width, expected_var)
images = constant_op.constant(
image_values, shape=image_shape, dtype=dtypes.float32)
- is_training = variables_lib.Variable(True)
+ is_training = variables_lib.VariableV1(True)
output = _layers.batch_norm(
images,
decay=0.1,
@@ -2507,7 +2507,7 @@ class BatchNormTest(test.TestCase):
batch_size * height * width, expected_var)
images = constant_op.constant(
image_values, shape=image_shape, dtype=dtypes.float32)
- is_training = variables_lib.Variable(True)
+ is_training = variables_lib.VariableV1(True)
output = _layers.batch_norm(
images,
decay=0.1,
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index 33180b778a..a160cb54a3 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -162,9 +162,9 @@ class GraphActionsTest(test.TestCase):
Tuple of 3 `Tensor` objects, 2 input and 1 output.
"""
variables_lib.create_global_step()
- in0 = variables.Variable(1.0)
+ in0 = variables.VariableV1(1.0)
in1 = variables_lib.local_variable(2.0)
- fake_table = variables.Variable(
+ fake_table = variables.VariableV1(
3.0,
trainable=False,
collections=['fake_tables'],
@@ -312,8 +312,8 @@ class GraphActionsTest(test.TestCase):
def test_evaluate_ready_for_local_init(self):
with ops.Graph().as_default() as g, self.session(g):
variables_lib.create_global_step()
- v = variables.Variable(1.0)
- variables.Variable(
+ v = variables.VariableV1(1.0)
+ variables.VariableV1(
v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
ready_for_local_init_op = variables.report_uninitialized_variables(
variables.global_variables())
@@ -456,9 +456,9 @@ class GraphActionsTrainTest(test.TestCase):
Tuple of 3 `Tensor` objects, 2 input and 1 output.
"""
variables_lib.create_global_step()
- in0 = variables.Variable(1.0)
+ in0 = variables.VariableV1(1.0)
in1 = variables_lib.local_variable(2.0)
- fake_table = variables.Variable(
+ fake_table = variables.VariableV1(
3.0,
trainable=False,
collections=['fake_tables'],
diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py
index 83e48a36e7..d4a7169bb6 100644
--- a/tensorflow/contrib/learn/python/learn/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/monitors_test.py
@@ -247,7 +247,7 @@ class MonitorsTest(test.TestCase):
def test_logging_trainable(self):
with ops.Graph().as_default() as g, self.session(g):
- var = variables.Variable(constant_op.constant(42.0), name='foo')
+ var = variables.VariableV1(constant_op.constant(42.0), name='foo')
var.initializer.run()
cof = constant_op.constant(1.0)
loss = math_ops.subtract(
@@ -261,7 +261,7 @@ class MonitorsTest(test.TestCase):
with ops.Graph().as_default() as g, self.session(g):
log_dir = 'log/dir'
summary_writer = testing.FakeSummaryWriter(log_dir, g)
- var = variables.Variable(0.0)
+ var = variables.VariableV1(0.0)
var.initializer.run()
tensor = state_ops.assign_add(var, 1.0)
summary_op = summary.scalar('my_summary', tensor)
@@ -526,8 +526,8 @@ class MonitorsTest(test.TestCase):
monitor0 = learn.monitors.GraphDump()
monitor1 = learn.monitors.GraphDump()
with ops.Graph().as_default() as g, self.session(g):
- const_var = variables.Variable(42.0, name='my_const')
- counter_var = variables.Variable(0.0, name='my_counter')
+ const_var = variables.VariableV1(42.0, name='my_const')
+ counter_var = variables.VariableV1(0.0, name='my_counter')
assign_add = state_ops.assign_add(counter_var, 1.0, name='my_assign_add')
variables.global_variables_initializer().run()
@@ -569,7 +569,7 @@ class MonitorsTest(test.TestCase):
monitor = learn.monitors.CaptureVariable(
var_name='my_assign_add:0', every_n=8, first_n=2)
with ops.Graph().as_default() as g, self.session(g):
- var = variables.Variable(0.0, name='my_var')
+ var = variables.VariableV1(0.0, name='my_var')
var.initializer.run()
state_ops.assign_add(var, 1.0, name='my_assign_add')
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 9ecf023e03..8466dc36d1 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -125,7 +125,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
],
example_ids=[str(i) for i in range(num_examples)])
- weights = variables_lib.Variable(
+ weights = variables_lib.VariableV1(
array_ops.zeros([dim], dtype=dtypes.float32))
variables_dict = dict(
sparse_features_weights=[weights],
@@ -184,7 +184,7 @@ def make_dense_examples_and_variables_dicts(dense_features_values, weights,
dense_tensors.append(dense_tensor)
# Add variables of shape [feature_column_dimension].
dense_weights.append(
- variables_lib.Variable(
+ variables_lib.VariableV1(
array_ops.zeros(
[dense_tensor.get_shape().as_list()[1]], dtype=dtypes.float32)))
@@ -341,7 +341,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
examples = make_example_dict(example_protos, example_weights)
# Explicitly make age a [1]-shaped Variable (which cannot be
# partitioned), while making gender a PartitionedVariable.
- age_weights = variables_lib.Variable(
+ age_weights = variables_lib.VariableV1(
array_ops.zeros([1], dtype=dtypes.float32))
with variable_scope.variable_scope(
name_or_scope=('variables/shard_{}'.format(num_shards)
@@ -801,7 +801,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
labels=[1.0, 0.0])
# Replace with a variable of size 1 instead of 2.
variables['dense_features_weights'] = [
- variables_lib.Variable(array_ops.zeros(
+ variables_lib.VariableV1(array_ops.zeros(
[1], dtype=dtypes.float32))
]
options = dict(
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index f320b53d94..f3ebe3b245 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -26,6 +26,14 @@ config_setting(
},
)
+# Enables inclusion of TensorFlow kernels via the TF Lite Flex delegate.
+# WARNING: This build flag is experimental and subject to change.
+config_setting(
+ name = "with_tflite_flex",
+ define_values = {"with_tflite_flex": "true"},
+ visibility = ["//visibility:public"],
+)
+
cc_library(
name = "schema_fbs_version",
hdrs = ["version.h"],
@@ -157,6 +165,10 @@ cc_library(
"stderr_reporter.h",
],
copts = tflite_copts(),
+ defines = select({
+ ":with_tflite_flex": ["TFLITE_FLEX"],
+ "//conditions:default": [],
+ }),
linkopts = [
] + select({
"//tensorflow:android": [
@@ -180,7 +192,12 @@ cc_library(
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
"//tensorflow/contrib/lite/profiling:profiler",
"//tensorflow/contrib/lite/schema:schema_fbs",
- ],
+ ] + select({
+ ":with_tflite_flex": [
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ ],
+ "//conditions:default": [],
+ }),
)
cc_library(
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index fc4d9b4f17..7ef26de69f 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -301,7 +301,7 @@ def generated_test_conversion_modes():
"""Returns a list of conversion modes."""
# TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050.
- return ["toco-extended", ""]
+ return ["toco-flex", ""]
def generated_test_models_all():
"""Generates a list of all tests with the different converters.
@@ -335,7 +335,7 @@ def gen_zip_test(name, test_name, conversion_mode, **kwargs):
# TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050.
# if conversion_mode == "pb2lite":
# toco = "//tensorflow/contrib/lite/experimental/pb2lite:pb2lite"
- flags = "--ignore_toco_errors --run_with_extended"
+ flags = "--ignore_toco_errors --run_with_flex"
kwargs["tags"].append("skip_already_failing")
kwargs["tags"].append("no_oss")
kwargs["tags"].append("notap")
@@ -391,3 +391,41 @@ def gen_selected_ops(name, model):
(tool, model, out, tflite_path[2:]),
tools = [tool],
)
+
+def gen_full_model_test(conversion_modes, models, data, test_suite_tag):
+ """Generates Python test targets for testing TFLite models.
+
+ Args:
+ conversion_modes: List of conversion modes to test the models on.
+ models: List of models to test.
+ data: List of BUILD targets linking the data.
+ test_suite_tag: Tag identifying the model test suite.
+ """
+ options = [
+ (conversion_mode, model)
+ for model in models
+ for conversion_mode in conversion_modes
+ ]
+
+ for conversion_mode, model_name in options:
+ native.py_test(
+ name = "model_coverage_test_%s_%s" % (model_name, conversion_mode.lower()),
+ srcs = ["model_coverage_test.py"],
+ main = "model_coverage_test.py",
+ args = [
+ "--model_name=%s" % model_name,
+ "--converter_mode=%s" % conversion_mode,
+ ],
+ data = data,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_windows",
+ "notap",
+ ] + [test_suite_tag],
+ deps = [
+ "//tensorflow/contrib/lite/testing:model_coverage_lib",
+ "//tensorflow/contrib/lite/python:lite",
+ "//tensorflow/python:client_testlib",
+ ],
+ )
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD
index bf5d91899c..bf5d91899c 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/flex/BUILD
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/flex/buffer_map.cc
index e5a19c3997..63e39196d9 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
+++ b/tensorflow/contrib/lite/delegates/flex/buffer_map.cc
@@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h"
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
#include "tensorflow/core/framework/log_memory.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
// A tensor buffer that is allocated, deallocated and populated by TF Lite.
class TfLiteTensorBuffer : public tensorflow::TensorBuffer {
@@ -107,5 +107,5 @@ void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) {
id_to_tensor_[tensor_index] = std::move(tensor);
}
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/flex/buffer_map.h
index aaaa045840..4ce886568a 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h
+++ b/tensorflow/contrib/lite/delegates/flex/buffer_map.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
#include <map>
@@ -21,12 +21,12 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tflite {
-namespace eager {
+namespace flex {
// Maps a TF Lite tensor index into a TensorFlow tensor.
//
// The TF Lite interpreter assigns integer indices to each of its tensors, but
-// the Eager delegate deals in terms of TensorFlow tensors. This class maps
+// the Flex delegate deals in terms of TensorFlow tensors. This class maps
// from indices to tensors and allows the creation of new tensors to be
// associated with a given index.
class BufferMap {
@@ -55,7 +55,7 @@ class BufferMap {
std::map<int, tensorflow::Tensor> id_to_tensor_;
};
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc
index a046943e56..bb80e25e80 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
using ::testing::ElementsAre;
@@ -164,7 +164,7 @@ TEST(BufferMapTest, TensorFlowOverwritesTfLite) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc
index 45fc158157..ba065a8ff5 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc
@@ -12,19 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#include <vector>
#include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
-#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/flex/kernel.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
#include "tensorflow/contrib/lite/util.h"
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace delegate {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
@@ -32,7 +32,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TfLiteIntArray* plan;
TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
- // Add all custom ops starting with "Eager" to list of supported nodes.
+ // Add all custom ops starting with "Flex" to list of supported nodes.
std::vector<int> supported_nodes;
for (int node_index : TfLiteIntArrayView(plan)) {
TfLiteNode* node;
@@ -40,7 +40,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) {
TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
context, node_index, &node, &registration));
- if (IsEagerOp(registration->custom_name)) {
+ if (IsFlexOp(registration->custom_name)) {
supported_nodes.push_back(node_index);
}
}
@@ -81,28 +81,28 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
}
} // namespace delegate
-} // namespace eager
+} // namespace flex
-std::unique_ptr<EagerDelegate> EagerDelegate::Create() {
- std::unique_ptr<eager::DelegateData> delegate_data;
- if (!eager::DelegateData::Create(&delegate_data).ok()) {
+std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
+ std::unique_ptr<flex::DelegateData> delegate_data;
+ if (!flex::DelegateData::Create(&delegate_data).ok()) {
fprintf(stderr, "Unable to initialize TensorFlow context.\n");
return nullptr;
}
- return std::unique_ptr<EagerDelegate>(
- new EagerDelegate(std::move(delegate_data)));
+ return std::unique_ptr<FlexDelegate>(
+ new FlexDelegate(std::move(delegate_data)));
}
-EagerDelegate::EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data)
+FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data)
: TfLiteDelegate{
/*data_=*/delegate_data.get(),
- /*nullptr,*/ &eager::delegate::Prepare,
- /*CopyFromBufferHandle=*/&eager::delegate::CopyFromBufferHandle,
+ /*nullptr,*/ &flex::delegate::Prepare,
+ /*CopyFromBufferHandle=*/&flex::delegate::CopyFromBufferHandle,
/*CopyToBufferHandle=*/nullptr,
/*FreeBufferHandle=*/nullptr},
delegate_data_(std::move(delegate_data)) {}
-EagerDelegate::~EagerDelegate() {}
+FlexDelegate::~FlexDelegate() {}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/flex/delegate.h
index 70f3c15af4..1017780dc7 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/flex/delegate.h
@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_
#include "tensorflow/contrib/lite/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
namespace tflite {
@@ -24,12 +24,12 @@ namespace tflite {
// Delegate that can be used to extract parts of a graph that are designed to be
// executed by TensorFlow's runtime via Eager.
//
-// The interpreter must be constructed after the EagerDelegate and destructed
-// before the EagerDelegate. This delegate may be used with multiple
+// The interpreter must be constructed after the FlexDelegate and destructed
+// before the FlexDelegate. This delegate may be used with multiple
// interpreters, but it is *not* thread-safe.
//
// Usage:
-// auto delegate = EagerDelegate::Create();
+// auto delegate = FlexDelegate::Create();
// ... build interpreter ...
//
// if (delegate) {
@@ -39,21 +39,21 @@ namespace tflite {
// ... run inference ...
// ... destroy interpreter ...
// ... destroy delegate ...
-class EagerDelegate : public TfLiteDelegate {
+class FlexDelegate : public TfLiteDelegate {
public:
// Creates a delegate that supports TF ops.
//
- // If the underyling TF Eager context creation fails, returns null.
- static std::unique_ptr<EagerDelegate> Create();
+ // If the underyling TF Flex context creation fails, returns null.
+ static std::unique_ptr<FlexDelegate> Create();
- ~EagerDelegate();
+ ~FlexDelegate();
private:
- explicit EagerDelegate(std::unique_ptr<eager::DelegateData> delegate_data);
+ explicit FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data);
- std::unique_ptr<eager::DelegateData> delegate_data_;
+ std::unique_ptr<flex::DelegateData> delegate_data_;
};
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/flex/delegate_data.cc
index 0fd5c976f8..8f985f770c 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate_data.cc
@@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
-namespace eager {
+namespace flex {
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
std::vector<tensorflow::Device*> devices;
@@ -43,5 +43,5 @@ DelegateData::DelegateData(tensorflow::EagerContext* eager_context)
DelegateData::~DelegateData() {}
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/flex/delegate_data.h
index 772d26f44e..8d75f0b0ef 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data.h
+++ b/tensorflow/contrib/lite/delegates/flex/delegate_data.h
@@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
-#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h"
#include "tensorflow/core/common_runtime/eager/context.h"
namespace tflite {
-namespace eager {
+namespace flex {
-// Data kept by the Eager delegate for the lifetime of an Interpreter.
+// Data kept by the Flex delegate for the lifetime of an Interpreter.
class DelegateData {
public:
// Create a new DelegateData, initialized with a newly-created EagerContext.
@@ -29,7 +29,7 @@ class DelegateData {
~DelegateData();
- // The EagerContext that is required for execution of Eager Ops.
+ // The EagerContext that is required for execution of Flex Ops.
tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
// Map from TF Lite tensor index to TensorFlow tensor for a given context.
@@ -46,7 +46,7 @@ class DelegateData {
std::unordered_map<const TfLiteContext*, BufferMap> buffer_map_;
};
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc
index def063309f..30b10f435a 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
TEST(DelegateDataTest, Basic) {
@@ -39,7 +39,7 @@ TEST(DelegateDataTest, Basic) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/flex/delegate_test.cc
index 43ec5d53b8..1813952cef 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate_test.cc
@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
+#include "tensorflow/contrib/lite/delegates/flex/test_util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
using ::testing::ContainsRegex;
using ::testing::ElementsAre;
-class DelegateTest : public testing::EagerModelTest {
+class DelegateTest : public testing::FlexModelTest {
public:
DelegateTest() {
- delegate_ = EagerDelegate::Create();
+ delegate_ = FlexDelegate::Create();
interpreter_.reset(new Interpreter(&error_reporter_));
}
@@ -46,7 +46,7 @@ class DelegateTest : public testing::EagerModelTest {
}
private:
- std::unique_ptr<EagerDelegate> delegate_;
+ std::unique_ptr<FlexDelegate> delegate_;
};
TEST_F(DelegateTest, FullGraph) {
@@ -236,7 +236,7 @@ TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/flex/kernel.cc
index 48a2f56baf..e4f1aea990 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/flex/kernel.cc
@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+#include "tensorflow/contrib/lite/delegates/flex/kernel.h"
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/string.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@@ -28,10 +28,10 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
-// Note: this is part of TF Lite's Eager delegation code which is to be
+// Note: this is part of TF Lite's Flex delegation code which is to be
// completed soon.
-// This is the TF Lite op that is created by the eager delegate to handle
+// This is the TF Lite op that is created by the flex delegate to handle
// execution of a supported subgraph. The usual flow is that the delegate
// informs the interpreter of supported nodes in a graph, and each supported
// subgraph is replaced with one instance of this kernel.
@@ -46,7 +46,7 @@ limitations under the License.
// corresponding TensorFlow/Eager Op.
namespace tflite {
-namespace eager {
+namespace flex {
namespace kernel {
// Controls the lifetime of tensor handles in a vector.
@@ -72,11 +72,11 @@ class VectorOfHandles {
// Executes the TensorFlow op given by 'op_name', with the attributes specified
// in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
-tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context,
- BufferMap* buffer_map, const string& op_name,
- const tensorflow::NodeDef& nodedef,
- const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
+tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context,
+ BufferMap* buffer_map, const string& op_name,
+ const tensorflow::NodeDef& nodedef,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
const tensorflow::AttrTypeMap* attr_types;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types),
@@ -258,13 +258,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Execute the TensorFlow Ops sequentially.
for (const auto& node_data : op_data->nodes) {
if (node_data.nodedef.op().empty()) {
- context->ReportError(context, "Invalid NodeDef in Eager op '%s'",
+ context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
node_data.name.c_str());
return kTfLiteError;
}
auto status =
- ExecuteEagerOp(eager_context, buffer_map, node_data.name,
- node_data.nodedef, node_data.inputs, node_data.outputs);
+ ExecuteFlexOp(eager_context, buffer_map, node_data.name,
+ node_data.nodedef, node_data.inputs, node_data.outputs);
TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
}
@@ -295,5 +295,5 @@ TfLiteRegistration GetKernel() {
return registration;
}
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/flex/kernel.h
index 2478abccaa..ac9313a37b 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.h
+++ b/tensorflow/contrib/lite/delegates/flex/kernel.h
@@ -12,23 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_
#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
-namespace eager {
+namespace flex {
// Return the registration object used to initialize and execute ops that will
// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by
-// the eager delegate to handle execution of a supported subgraph. The usual
+// the flex delegate to handle execution of a supported subgraph. The usual
// flow is that the delegate informs the interpreter of supported nodes in a
// graph, and each supported subgraph is replaced with one instance of this
// kernel.
TfLiteRegistration GetKernel();
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc b/tensorflow/contrib/lite/delegates/flex/kernel_test.cc
index 66f2226626..94a6f8b61a 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/kernel_test.cc
@@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
+#include "tensorflow/contrib/lite/delegates/flex/kernel.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
-#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h"
+#include "tensorflow/contrib/lite/delegates/flex/test_util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
using ::testing::ContainsRegex;
@@ -31,12 +31,12 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate,
TfLiteIntArray* size_and_nodes =
ConvertVectorToTfLiteIntArray(supported_nodes);
TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels(
- context, eager::GetKernel(), size_and_nodes, delegate));
+ context, flex::GetKernel(), size_and_nodes, delegate));
TfLiteIntArrayFree(size_and_nodes);
return kTfLiteOk;
}
-class KernelTest : public testing::EagerModelTest {
+class KernelTest : public testing::FlexModelTest {
public:
KernelTest() {
CHECK(DelegateData::Create(&delegate_data_).ok());
@@ -167,7 +167,7 @@ TEST_F(KernelTest, WrongSetOfNodes) {
ASSERT_FALSE(Invoke());
ASSERT_THAT(error_reporter().error_messages(),
- ContainsRegex("Invalid NodeDef in Eager op"));
+ ContainsRegex("Invalid NodeDef in Flex op"));
}
TEST_F(KernelTest, MixedGraph) {
@@ -220,7 +220,7 @@ TEST_F(KernelTest, SplitGraph) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/flex/test_util.cc
index d47be761fb..69c336a01a 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.cc
+++ b/tensorflow/contrib/lite/delegates/flex/test_util.cc
@@ -13,25 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
+#include "tensorflow/contrib/lite/delegates/flex/test_util.h"
#include "absl/memory/memory.h"
#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace testing {
-bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
+bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
-void EagerModelTest::SetShape(int tensor_index,
- const std::vector<int>& values) {
+void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
}
-std::vector<int> EagerModelTest::GetShape(int tensor_index) {
+std::vector<int> FlexModelTest::GetShape(int tensor_index) {
std::vector<int> result;
auto* dims = interpreter_->tensor(tensor_index)->dims;
result.reserve(dims->size);
@@ -41,13 +40,13 @@ std::vector<int> EagerModelTest::GetShape(int tensor_index) {
return result;
}
-TfLiteType EagerModelTest::GetType(int tensor_index) {
+TfLiteType FlexModelTest::GetType(int tensor_index) {
return interpreter_->tensor(tensor_index)->type;
}
-void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
- const std::vector<int>& outputs,
- TfLiteType type, const std::vector<int>& dims) {
+void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
+ const std::vector<int>& outputs, TfLiteType type,
+ const std::vector<int>& dims) {
interpreter_->AddTensors(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteQuantizationParams quant;
@@ -66,8 +65,8 @@ void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
}
-void EagerModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
+void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
reg.builtin_code = BuiltinOperator_MUL;
reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
@@ -90,8 +89,8 @@ void EagerModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
kTfLiteOk);
}
-void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
+void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
auto attr = [](const string& key, const string& value) {
return " attr{ key: '" + key + "' value {" + value + "}}";
};
@@ -107,28 +106,28 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
if (op == kUnpack) {
string attributes =
type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
- AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs);
+ AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
} else if (op == kIdentity) {
string attributes = type_attribute;
- AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs);
+ AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
} else if (op == kAdd) {
string attributes = type_attribute;
- AddTfOp("EagerAdd", "Add", attributes, inputs, outputs);
+ AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
} else if (op == kMul) {
string attributes = type_attribute;
- AddTfOp("EagerMul", "Mul", attributes, inputs, outputs);
+ AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
} else if (op == kNonExistent) {
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
} else if (op == kIncompatibleNodeDef) {
// "Cast" op is created without attributes - making it incompatible.
- AddTfOp("EagerCast", "Cast", "", inputs, outputs);
+ AddTfOp("FlexCast", "Cast", "", inputs, outputs);
}
}
-void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
- const string& nodedef_str,
- const std::vector<int>& inputs,
- const std::vector<int>& outputs) {
+void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
+ const string& nodedef_str,
+ const std::vector<int>& inputs,
+ const std::vector<int>& outputs) {
static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
reg.builtin_code = BuiltinOperator_CUSTOM;
reg.custom_name = tflite_name;
@@ -154,5 +153,5 @@ void EagerModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
}
} // namespace testing
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/flex/test_util.h
index 816db41931..a8c81b90a3 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.h
+++ b/tensorflow/contrib/lite/delegates/flex/test_util.h
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace testing {
enum TfOpType {
@@ -35,12 +35,12 @@ enum TfOpType {
};
// This class creates models with TF and TFLite ops. In order to use this class
-// to test the Eager delegate, implement a function that calls
+// to test the Flex delegate, implement a function that calls
// interpreter->ModifyGraphWithDelegate.
-class EagerModelTest : public ::testing::Test {
+class FlexModelTest : public ::testing::Test {
public:
- EagerModelTest() {}
- ~EagerModelTest() {}
+ FlexModelTest() {}
+ ~FlexModelTest() {}
bool Invoke();
@@ -104,7 +104,7 @@ class EagerModelTest : public ::testing::Test {
private:
// Helper method to add a TensorFlow op. tflite_names needs to start with
- // "Eager" in order to work with the Eager delegate.
+ // "Flex" in order to work with the Flex delegate.
void AddTfOp(const char* tflite_name, const string& tf_name,
const string& nodedef_str, const std::vector<int>& inputs,
const std::vector<int>& outputs);
@@ -113,7 +113,7 @@ class EagerModelTest : public ::testing::Test {
};
} // namespace testing
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_TEST_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/flex/util.cc
index 051246bf86..829bc388bf 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/flex/util.cc
@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
namespace tflite {
-namespace eager {
+namespace flex {
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status) {
@@ -100,5 +100,5 @@ TfLiteType GetTensorFlowLiteType(TF_DataType type) {
}
}
-} // namespace eager
+} // namespace flex
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/flex/util.h
index 930cb99cb9..7f910e7316 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/flex/util.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
-#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
-namespace eager {
+namespace flex {
// Converts a tensorflow:Status into a TfLiteStatus. If the original status
// represented an error, reports it using the given 'context'.
@@ -41,7 +41,7 @@ TF_DataType GetTensorFlowDataType(TfLiteType type);
// Returns the TfLiteType that corresponds to the given TF C API Data type.
TfLiteType GetTensorFlowLiteType(TF_DataType);
-} // namespace eager
+} // namespace flex
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/flex/util_test.cc
index aebc91149c..5f049e7b0a 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/flex/util_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/contrib/lite/delegates/flex/util.h"
#include <cstdarg>
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
-namespace eager {
+namespace flex {
namespace {
using tensorflow::DT_FLOAT;
@@ -132,7 +132,7 @@ TEST(UtilTest, TypeConversionsFromTensorFlow) {
}
} // namespace
-} // namespace eager
+} // namespace flex
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD
index 4d2437e7d3..d180cb4785 100644
--- a/tensorflow/contrib/lite/examples/android/BUILD
+++ b/tensorflow/contrib/lite/examples/android/BUILD
@@ -28,6 +28,7 @@ android_binary(
srcs = glob([
"app/src/main/java/**/*.java",
]),
+ aapt_version = "aapt",
# Package assets from assets dir as well as all model targets.
# Remove undesired models (and corresponding Activities in source)
# to reduce APK size.
diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md
index dc31171672..7347147f99 100644
--- a/tensorflow/contrib/lite/examples/android/app/README.md
+++ b/tensorflow/contrib/lite/examples/android/app/README.md
@@ -1,8 +1,43 @@
# TF Lite Android App Example
+A simple Android example that demonstrates image classification and object
+detection using the camera, as well as speech recognition using the microphone.
+
+## Building in Android Studio with TensorFlow Lite AAR from JCenter.
+The build.gradle is configured to use TensorFlow Lite's nightly build.
+
+If you see a build error related to compatibility with Tensorflow Lite's Java
+API (example: method X is undefined for type Interpreter), there has likely been
+a backwards compatible change to the API. You will need to pull new app code
+that's compatible with the nightly build and may need to first wait a few days
+for our external and internal code to merge.
+
## Building from Source with Bazel
-1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/lite/demo_android).
+1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel):
+
+ 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites).
+ It's easiest with Android Studio.
+
+ - You'll need at least SDK version 23.
+ - Make sure to install the latest version of Bazel. Some distributions
+ ship with Bazel 0.5.4, which is too old.
+ - Bazel requires Android Build Tools `26.0.1` or higher.
+ - You also need to install the Android Support Repository, available
+ through Android Studio under `Android SDK Manager -> SDK Tools ->
+ Android Support Repository`.
+
+ 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace)
+ to add SDK and NDK targets.
+
+ NOTE: As long as you have the SDK and NDK installed, the `./configure`
+ script will create these rules for you. Answer "Yes" when the script asks
+ to automatically configure the `./WORKSPACE`.
+
+ - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
+ you have installed.
+ - By default, Android Studio will install the SDK to `~/Android/Sdk` and
+ the NDK to `~/Android/Sdk/ndk-bundle`.
2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device:
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
index 0f16595811..29f8701f53 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
@@ -21,9 +21,8 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
-TFL_Status TFL_InterpreterResetVariableTensorsToZero(
- TFL_Interpreter* interpreter) {
- return interpreter->impl->ResetVariableTensorsToZero();
+TFL_Status TFL_InterpreterResetVariableTensors(TFL_Interpreter* interpreter) {
+ return interpreter->impl->ResetVariableTensors();
}
void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
index b8de7b9964..fca5d92f77 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
@@ -25,7 +25,7 @@ extern "C" {
typedef TfLiteBuiltinOperator TFL_BuiltinOperator;
// Resets all variable tensors to zero.
-TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero(
+TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensors(
TFL_Interpreter* interpreter);
// Adds an op registration for a builtin operator.
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
index d86ad00d6d..1b1bedb754 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
@@ -44,7 +44,7 @@ TEST(CApiExperimentalSimple, Smoke) {
TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
ASSERT_NE(interpreter, nullptr);
ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk);
- EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk);
+ EXPECT_EQ(TFL_InterpreterResetVariableTensors(interpreter), kTfLiteOk);
EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk);
TFL_DeleteInterpreter(interpreter);
diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml
index beaa5c479a..de6914e536 100644
--- a/tensorflow/contrib/lite/g3doc/_book.yaml
+++ b/tensorflow/contrib/lite/g3doc/_book.yaml
@@ -57,6 +57,7 @@ upper_tabs:
path: /lite/tfmobile/optimizing
- name: API
+ skip_translation: true
contents:
- title: API
path: /api_docs/python/tf/contrib/lite
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 28cb6aba6e..0ae9400068 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -1,174 +1,38 @@
-# Performance
+# Performance best practices
-This document lists TensorFlow Lite performance benchmarks when running well
-known models on some Android and iOS devices.
+Mobile and embedded devices have limited computational resources and it is important to keep your application resource efficient. We have compiled a list of best practices and strategies you can use to optimize your model and application when using Tensorflow Lite.
-These performance benchmark numbers were generated with the
-[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark)
-and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+## Choose the most efficient model for the problem
+Some models may be too large to run on embedded devices. Instead of large models it is better to use a slightly less precise but smaller model for embedded devices. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices.
-# Android performance benchmarks
+You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for
+[image classification] (https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and
+ [object detection](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193).
-For Android benchmarks, the CPU affinity is set to use big cores on the device to
-reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
-It assumes that models were download and unzipped to the
-`/data/local/tmp/tflite_models` directory. The benchmark binary is built
-using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android)
-and assumed in the `/data/local/tmp` directory.
+## Profile your model
+Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](../tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
-To run the benchmark:
+## Profile and optimize operators in the graph
+If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator.
+ This scenario should be rare as Tensorflow Lite has optimized versions for most ops. However you may be able to write a faster version of a custom op, if you know the constraints in which the operator is executed. Check out our [custom operator documentation](custom_operators.md).
-```
-adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \
- --num_threads=1 \
- --graph=/data/local/tmp/tflite_models/${GRAPH} \
- --warmup_runs=1 \
- --num_runs=50 \
- --use_nnapi=false
-```
+## Quantize your model
+If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. Fully quantized models can be remarkably power efficient as well.
-Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity
-chosen according to the following table:
+## Tweak the number of threads
+Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](../interpreter.h) threads.
-Device | CPU_MASK |
--------| ----------
-Pixel 2 | f0 |
-Pixel xl | 0c |
+## Eliminate redundant copies
+Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to [mmap a model file](https://github.com/tensorflow/tensorflow/blob/9982fd6c8831cbd2f58954f79ea71f26660393bc/tensorflow/contrib/lite/model.h#L152) and avoid copies. If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151).
-<table>
- <thead>
- <tr>
- <th>Model Name</th>
- <th>Device </th>
- <th>Mean inference time (std dev)</th>
- </tr>
- </thead>
- <tr>
- <td rowspan = 2>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
- </td>
- <td>Pixel 2 </td>
- <td>166.5 ms (2.6 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>122.9 ms (1.8 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
- </td>
- <td>Pixel 2 </td>
- <td>69.5 ms (0.9 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>78.9 ms (2.2 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
- </td>
- <td>Pixel 2 </td>
- <td>273.8 ms (3.5 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>210.8 ms (4.2 ms)</td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
- </td>
- <td>Pixel 2 </td>
- <td>234.0 ms (2.1 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>158.0 ms (2.1 ms)</td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
- </td>
- <td>Pixel 2 </td>
- <td>2846.0 ms (15.0 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>1973.0 ms (15.0 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
- </td>
- <td>Pixel 2 </td>
- <td>3180.0 ms (11.7 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>2262.0 ms (21.0 ms) </td>
- </tr>
+## Profile your application with platform specific tools
+Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform.
- </table>
+## Use hardware accelerators available on the device
+Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/) on Android.
+You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable NNAPI call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance.
-# iOS benchmarks
-
-To run iOS benchmarks, the [benchmark
-app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
-was modified to include the appropriate model and `benchmark_params.json` was
-modified to set `num_threads` to 1.
-
-<table>
- <thead>
- <tr>
- <th>Model Name</th>
- <th>Device </th>
- <th>Mean inference time (std dev)</th>
- </tr>
- </thead>
- <tr>
- <td>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
- </td>
- <td>iPhone 8 </td>
- <td>32.2 ms (0.8 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
- </td>
- <td>iPhone 8 </td>
- <td>24.4 ms (0.8 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
- </td>
- <td>iPhone 8 </td>
- <td>60.3 ms (0.6 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
- </td>
- <td>iPhone 8 </td>
- <td>44.3 (0.7 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
- </td>
- <td>iPhone 8</td>
- <td>562.4 ms (18.2 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
- </td>
- <td>iPhone 8 </td>
- <td>661.0 ms (29.2 ms)</td>
- </tr>
- </table>
+## Need more help
+The Tensorflow team is happy to help diagnose and address specific performance issues you may be facing. Please file a bug on [github](https://github.com/tensorflow/tensorflow/issues) with details of the issue.
diff --git a/tensorflow/contrib/lite/g3doc/performance_benchmarks.md b/tensorflow/contrib/lite/g3doc/performance_benchmarks.md
new file mode 100644
index 0000000000..28cb6aba6e
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/performance_benchmarks.md
@@ -0,0 +1,174 @@
+
+# Performance
+
+This document lists TensorFlow Lite performance benchmarks when running well
+known models on some Android and iOS devices.
+
+These performance benchmark numbers were generated with the
+[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark)
+and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+
+# Android performance benchmarks
+
+For Android benchmarks, the CPU affinity is set to use big cores on the device to
+reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
+
+It assumes that models were download and unzipped to the
+`/data/local/tmp/tflite_models` directory. The benchmark binary is built
+using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android)
+and assumed in the `/data/local/tmp` directory.
+
+To run the benchmark:
+
+```
+adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \
+ --num_threads=1 \
+ --graph=/data/local/tmp/tflite_models/${GRAPH} \
+ --warmup_runs=1 \
+ --num_runs=50 \
+ --use_nnapi=false
+```
+
+Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity
+chosen according to the following table:
+
+Device | CPU_MASK |
+-------| ----------
+Pixel 2 | f0 |
+Pixel xl | 0c |
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>166.5 ms (2.6 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>122.9 ms (1.8 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>69.5 ms (0.9 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>78.9 ms (2.2 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>273.8 ms (3.5 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>210.8 ms (4.2 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>234.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>158.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>2846.0 ms (15.0 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>1973.0 ms (15.0 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>3180.0 ms (11.7 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>2262.0 ms (21.0 ms) </td>
+ </tr>
+
+ </table>
+
+# iOS benchmarks
+
+To run iOS benchmarks, the [benchmark
+app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
+was modified to include the appropriate model and `benchmark_params.json` was
+modified to set `num_threads` to 1.
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>32.2 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>24.4 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>60.3 ms (0.6 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>44.3 (0.7 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>iPhone 8</td>
+ <td>562.4 ms (18.2 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>661.0 ms (29.2 ms)</td>
+ </tr>
+ </table>
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 2657bcd42b..88e41ffc55 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -451,16 +451,15 @@ TfLiteStatus Interpreter::AllocateTensors() {
// Reset the variable tensors to zero after (re)allocating the tensors.
// Developers shouldn't rely on the side effect of this function to reset
- // variable tesnsors. They should call `ResetVariableTensorsToZero` directly
+ // variable tesnsors. They should call `ResetVariableTensors` directly
// instead.
- ResetVariableTensorsToZero();
+ ResetVariableTensors();
return kTfLiteOk;
}
-// TODO(ycling): Consider to provide other functions to initialize variable
-// tensors to non-zero values.
-TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
+// TODO(ycling): Support non-zero default values.
+TfLiteStatus Interpreter::ResetVariableTensors() {
for (auto& tensor : tensors_) {
if (!tensor.is_variable) {
continue;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index aa2bc4def6..7ef736d01b 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -421,9 +421,12 @@ class Interpreter {
allow_buffer_handle_output_ = allow_buffer_handle_output;
}
- // Reset all variable tensors to zero.
+ // Reset all variable tensors to the default value.
+ // If a variable tensor doesn't have a buffer, reset it to zero.
+ // TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
+ // to the value of the buffer.
// WARNING: This is an experimental API and subject to change.
- TfLiteStatus ResetVariableTensorsToZero();
+ TfLiteStatus ResetVariableTensors();
// Retrieve an operator's description of its work, for profiling purposes.
const char* OpProfilingString(const TfLiteRegistration& op_reg,
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index db837cf29e..9d2aead266 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -3,12 +3,12 @@
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
def aar_with_jni(name, android_library):
- # Generate dummy AndroidManifest.xml for dummy apk usage
- # (dummy apk is generated by <name>_dummy_app_for_so target below)
- native.genrule(
- name = name + "_binary_manifest_generator",
- outs = [name + "_generated_AndroidManifest.xml"],
- cmd = """
+ # Generate dummy AndroidManifest.xml for dummy apk usage
+ # (dummy apk is generated by <name>_dummy_app_for_so target below)
+ native.genrule(
+ name = name + "_binary_manifest_generator",
+ outs = [name + "_generated_AndroidManifest.xml"],
+ cmd = """
cat > $(OUTS) <<EOF
<manifest
xmlns:android="http://schemas.android.com/apk/res/android"
@@ -17,27 +17,28 @@ cat > $(OUTS) <<EOF
</manifest>
EOF
""",
- )
+ )
- # Generate dummy apk including .so files and later we extract out
- # .so files and throw away the apk.
- android_binary(
- name = name + "_dummy_app_for_so",
- manifest = name + "_generated_AndroidManifest.xml",
- custom_package = "dummy.package.for.so",
- deps = [android_library],
- # In some platforms we don't have an Android SDK/NDK and this target
- # can't be built. We need to prevent the build system from trying to
- # use the target in that case.
- tags = ["manual"],
- )
+ # Generate dummy apk including .so files and later we extract out
+ # .so files and throw away the apk.
+ android_binary(
+ name = name + "_dummy_app_for_so",
+ aapt_version = "aapt",
+ manifest = name + "_generated_AndroidManifest.xml",
+ custom_package = "dummy.package.for.so",
+ deps = [android_library],
+ # In some platforms we don't have an Android SDK/NDK and this target
+ # can't be built. We need to prevent the build system from trying to
+ # use the target in that case.
+ tags = ["manual"],
+ )
- native.genrule(
- name = name,
- srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
- outs = [name + ".aar"],
- tags = ["manual"],
- cmd = """
+ native.genrule(
+ name = name,
+ srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
+ outs = [name + ".aar"],
+ tags = ["manual"],
+ cmd = """
cp $(location {}.aar) $(location :{}.aar)
chmod +w $(location :{}.aar)
origdir=$$PWD
@@ -46,4 +47,4 @@ unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*"
cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
- )
+ )
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index 6a3f0651d0..c04b2a6194 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -1,4 +1,6 @@
-# TF Lite Android App
+# TF Lite Android Image Classifier App Example
+
+A simple Android example that demonstrates image classification using the camera.
## Building in Android Studio with TensorFlow Lite AAR from JCenter.
The build.gradle is configured to use TensorFlow Lite's nightly build.
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
index 220d6c2159..5ad738389e 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
@@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2.0
android_binary(
name = "TfLiteCameraDemo",
srcs = glob(["java/**/*.java"]),
+ aapt_version = "aapt",
assets = [
"//tensorflow/contrib/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
index 4f5662bc2d..3596e42011 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -58,9 +58,9 @@ import android.view.View;
import android.view.ViewGroup;
import android.widget.CompoundButton;
import android.widget.NumberPicker;
-import android.widget.ToggleButton;
import android.widget.TextView;
import android.widget.Toast;
+import android.widget.ToggleButton;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@@ -305,22 +305,24 @@ public class Camera2BasicFragment extends Fragment
textView = (TextView) view.findViewById(R.id.text);
toggle = (ToggleButton) view.findViewById(R.id.button);
- toggle.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
- public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
- classifier.setUseNNAPI(isChecked);
- }
- });
+ toggle.setOnCheckedChangeListener(
+ new CompoundButton.OnCheckedChangeListener() {
+ public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
+ backgroundHandler.post(() -> classifier.setUseNNAPI(isChecked));
+ }
+ });
np = (NumberPicker) view.findViewById(R.id.np);
np.setMinValue(1);
np.setMaxValue(10);
np.setWrapSelectorWheel(true);
- np.setOnValueChangedListener(new NumberPicker.OnValueChangeListener() {
- @Override
- public void onValueChange(NumberPicker picker, int oldVal, int newVal){
- classifier.setNumThreads(newVal);
- }
- });
+ np.setOnValueChangedListener(
+ new NumberPicker.OnValueChangeListener() {
+ @Override
+ public void onValueChange(NumberPicker picker, int oldVal, int newVal) {
+ backgroundHandler.post(() -> classifier.setNumThreads(newVal));
+ }
+ });
}
/** Load the model and labels. */
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
index 7bb6afd9d8..2d11a57434 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -59,9 +59,15 @@ public abstract class ImageClassifier {
private static final int DIM_PIXEL_SIZE = 3;
- /* Preallocated buffers for storing image data in. */
+ /** Preallocated buffers for storing image data in. */
private int[] intValues = new int[getImageSizeX() * getImageSizeY()];
+ /** Options for configuring the Interpreter. */
+ private final Interpreter.Options tfliteOptions = new Interpreter.Options();
+
+ /** The loaded TensorFlow Lite model. */
+ private MappedByteBuffer tfliteModel;
+
/** An instance of the driver class to run model inference with Tensorflow Lite. */
protected Interpreter tflite;
@@ -89,7 +95,8 @@ public abstract class ImageClassifier {
/** Initializes an {@code ImageClassifier}. */
ImageClassifier(Activity activity) throws IOException {
- tflite = new Interpreter(loadModelFile(activity));
+ tfliteModel = loadModelFile(activity);
+ tflite = new Interpreter(tfliteModel, tfliteOptions);
labelList = loadLabelList(activity);
imgData =
ByteBuffer.allocateDirect(
@@ -150,20 +157,28 @@ public abstract class ImageClassifier {
}
}
+ private void recreateInterpreter() {
+ if (tflite != null) {
+ tflite.close();
+ tflite = new Interpreter(tfliteModel, tfliteOptions);
+ }
+ }
+
public void setUseNNAPI(Boolean nnapi) {
- if (tflite != null)
- tflite.setUseNNAPI(nnapi);
+ tfliteOptions.setUseNNAPI(nnapi);
+ recreateInterpreter();
}
- public void setNumThreads(int num_threads) {
- if (tflite != null)
- tflite.setNumThreads(num_threads);
+ public void setNumThreads(int numThreads) {
+ tfliteOptions.setNumThreads(numThreads);
+ recreateInterpreter();
}
/** Closes tflite to release resources. */
public void close() {
tflite.close();
tflite = null;
+ tfliteModel = null;
}
/** Reads label list from Assets. */
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index b2e3a9bd7d..058240aada 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -8,6 +8,7 @@ android_binary(
srcs = [
"OvicBenchmarkerActivity.java",
],
+ aapt_version = "aapt",
assets = [
"//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
index 4cf51bb0fa..fd610b054f 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
@@ -74,7 +74,7 @@ public class OvicClassifier {
}
labelList = loadLabelList(labelInputStream);
// OVIC uses one thread for CPU inference.
- tflite = new Interpreter(model, 1);
+ tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1));
inputDims = TestHelper.getInputDims(tflite, 0);
if (inputDims.length != 4) {
throw new RuntimeException("The model's input dimensions must be 4 (BWHC).");
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index b84720ae8e..5cc6e754f3 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -56,16 +56,47 @@ import org.checkerframework.checker.nullness.qual.NonNull;
*/
public final class Interpreter implements AutoCloseable {
+ /** An options class for controlling runtime interpreter behavior. */
+ public static class Options {
+ public Options() {}
+
+ /**
+ * Sets the number of threads to be used for ops that support multi-threading. Defaults to a
+ * platform-dependent value.
+ */
+ public Options setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */
+ public Options setUseNNAPI(boolean useNNAPI) {
+ this.useNNAPI = useNNAPI;
+ return this;
+ }
+
+ /**
+ * Sets whether to allow float16 precision for FP32 calculation when possible. Defaults to false
+ * (disallow).
+ * WARNING: This is an experimental API and subject to change.
+ */
+ public Options setAllowFp16PrecisionForFp32(boolean allow) {
+ this.allowFp16PrecisionForFp32 = allow;
+ return this;
+ }
+
+ int numThreads = -1;
+ boolean useNNAPI = false;
+ boolean allowFp16PrecisionForFp32 = false;
+ }
+
/**
* Initializes a {@code Interpreter}
*
* @param modelFile: a File of a pre-trained TF Lite model.
*/
public Interpreter(@NonNull File modelFile) {
- if (modelFile == null) {
- return;
- }
- wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath());
+ this(modelFile, /*options = */ null);
}
/**
@@ -73,12 +104,22 @@ public final class Interpreter implements AutoCloseable {
*
* @param modelFile: a file of a pre-trained TF Lite model
* @param numThreads: number of threads to use for inference
+ * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will
+ * be removed in a future release.
*/
+ @Deprecated
public Interpreter(@NonNull File modelFile, int numThreads) {
- if (modelFile == null) {
- return;
- }
- wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads);
+ this(modelFile, new Options().setNumThreads(numThreads));
+ }
+
+ /**
+ * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
+ *
+ * @param modelFile: a file of a pre-trained TF Lite model
+ * @param options: a set of options for customizing interpreter behavior
+ */
+ public Interpreter(@NonNull File modelFile, Options options) {
+ wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
}
/**
@@ -89,7 +130,7 @@ public final class Interpreter implements AutoCloseable {
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
public Interpreter(@NonNull ByteBuffer byteBuffer) {
- wrapper = new NativeInterpreterWrapper(byteBuffer);
+ this(byteBuffer, /* options= */ null);
}
/**
@@ -99,9 +140,13 @@ public final class Interpreter implements AutoCloseable {
* <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
* {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
+ *
+ * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
+ * will be removed in a future release.
*/
+ @Deprecated
public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) {
- wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads);
+ this(byteBuffer, new Options().setNumThreads(numThreads));
}
/**
@@ -109,20 +154,25 @@ public final class Interpreter implements AutoCloseable {
*
* <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
* Interpreter}.
+ *
+ * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
+ * will be removed in a future release.
*/
+ @Deprecated
public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) {
- wrapper = new NativeInterpreterWrapper(mappedByteBuffer);
+ this(mappedByteBuffer, /* options= */ null);
}
/**
- * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and
- * specifies the number of threads used for inference.
+ * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
+ * {@link #Options}.
*
- * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
- * Interpreter}.
+ * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
+ * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
+ * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
- public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) {
- wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads);
+ public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
+ wrapper = new NativeInterpreterWrapper(byteBuffer, options);
}
/**
@@ -232,20 +282,34 @@ public final class Interpreter implements AutoCloseable {
/**
* Returns native inference timing.
- * <p>IllegalArgumentException will be thrown if the model is not initialized by the
- * {@link Interpreter}.
+ *
+ * <p>IllegalArgumentException will be thrown if the model is not initialized by the {@link
+ * Interpreter}.
*/
public Long getLastNativeInferenceDurationNanoseconds() {
checkNotClosed();
return wrapper.getLastNativeInferenceDurationNanoseconds();
}
- /** Turns on/off Android NNAPI for hardware acceleration when it is available. */
+ /**
+ * Turns on/off Android NNAPI for hardware acceleration when it is available.
+ *
+ * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API.
+ * This method will be removed in a future release.
+ */
+ @Deprecated
public void setUseNNAPI(boolean useNNAPI) {
checkNotClosed();
wrapper.setUseNNAPI(useNNAPI);
}
+ /**
+ * Sets the number of threads to be used for ops that support multi-threading.
+ *
+ * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread
+ * multi-threading. This method will be removed in a future release.
+ */
+ @Deprecated
public void setNumThreads(int numThreads) {
checkNotClosed();
wrapper.setNumThreads(numThreads);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index fa25082304..9bc44bf797 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -23,7 +23,7 @@ import java.util.HashMap;
import java.util.Map;
/**
- * A wrapper wraps native interpreter and controls model execution.
+ * An internal wrapper that wraps native interpreter and controls model execution.
*
* <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
* explicitly freed by invoking the {@link #close()} method when the {@code
@@ -32,36 +32,32 @@ import java.util.Map;
final class NativeInterpreterWrapper implements AutoCloseable {
NativeInterpreterWrapper(String modelPath) {
- this(modelPath, /* numThreads= */ -1);
+ this(modelPath, /* options= */ null);
}
- NativeInterpreterWrapper(String modelPath, int numThreads) {
+ NativeInterpreterWrapper(String modelPath, Interpreter.Options options) {
+ if (options == null) {
+ options = new Interpreter.Options();
+ }
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModel(modelPath, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
+ if (options.allowFp16PrecisionForFp32) {
+ setAllowFp16PrecisionForFp32(options.allowFp16PrecisionForFp32);
+ }
}
- /**
- * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer}. The ByteBuffer should
- * not be modified after the construction of a {@code NativeInterpreterWrapper}. The {@code
- * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a direct
- * {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
- */
NativeInterpreterWrapper(ByteBuffer byteBuffer) {
- this(byteBuffer, /* numThreads= */ -1);
+ this(byteBuffer, /* options= */ null);
}
- /**
- * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer} and specifies the
- * number of inference threads. The ByteBuffer should not be modified after the construction of a
- * {@code NativeInterpreterWrapper}. The {@code ByteBuffer} can be either a {@code
- * MappedByteBuffer} that memory-maps a model file, or a direct {@code ByteBuffer} of
- * nativeOrder() that contains the bytes content of a model.
- */
- NativeInterpreterWrapper(ByteBuffer buffer, int numThreads) {
+ NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {
+ if (options == null) {
+ options = new Interpreter.Options();
+ }
if (buffer == null
|| (!(buffer instanceof MappedByteBuffer)
&& (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) {
@@ -72,10 +68,16 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelByteBuffer = buffer;
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
+ if (options.useNNAPI) {
+ setUseNNAPI(options.useNNAPI);
+ }
+ if (options.allowFp16PrecisionForFp32) {
+ setAllowFp16PrecisionForFp32(options.allowFp16PrecisionForFp32);
+ }
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@@ -163,6 +165,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
useNNAPI(interpreterHandle, useNNAPI);
}
+ void setAllowFp16PrecisionForFp32(boolean allow) {
+ allowFp16PrecisionForFp32(interpreterHandle, allow);
+ }
+
void setNumThreads(int numThreads) {
numThreads(interpreterHandle, numThreads);
}
@@ -327,6 +333,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native void numThreads(long interpreterHandle, int numThreads);
+ private static native void allowFp16PrecisionForFp32(long interpreterHandle, boolean allow);
+
private static native long createErrorReporter(int size);
private static native long createModel(String modelPathOrBuffer, long errorHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index fdcf00a0a0..abb7320bc5 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -59,7 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
return outputs;
}
-
int getDataType(TfLiteType data_type) {
switch (data_type) {
case kTfLiteFloat32:
@@ -234,10 +233,18 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
}
JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
+ JNIEnv* env, jclass clazz, jlong handle, jboolean allow) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ interpreter->SetAllowFp16PrecisionForFp32(static_cast<bool>(allow));
+}
+
+JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads) {
+ jclass clazz,
+ jlong handle,
+ jint num_threads) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return;
interpreter->SetNumThreads(static_cast<int>(num_threads));
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 06b35d77c8..aa809dff8a 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -120,6 +120,15 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
+ * Signature: (JZ)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allowFp16PrecisionForFp32(
+ JNIEnv* env, jclass clazz, jlong handle, jboolean allow);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
* Signature: (JI)V
*/
JNIEXPORT void JNICALL
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 9070b788b6..a98fca0132 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -55,11 +55,23 @@ public final class InterpreterTest {
}
@Test
+ public void testInterpreterWithOptions() throws Exception {
+ Interpreter interpreter =
+ new Interpreter(MODEL_FILE, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
+ assertThat(interpreter).isNotNull();
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.close();
+ }
+
+ @Test
public void testRunWithMappedByteBufferModel() throws Exception {
Path path = MODEL_FILE.toPath();
FileChannel fileChannel =
(FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
- MappedByteBuffer mappedByteBuffer =
+ ByteBuffer mappedByteBuffer =
fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
Interpreter interpreter = new Interpreter(mappedByteBuffer);
float[] oneD = {1.23f, 6.54f, 7.81f};
@@ -106,7 +118,7 @@ public final class InterpreterTest {
byteBuffer.order(ByteOrder.nativeOrder());
fileChannel.read(byteBuffer);
try {
- Interpreter interpreter = new Interpreter(byteBuffer);
+ new Interpreter(byteBuffer);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
@@ -304,40 +316,16 @@ public final class InterpreterTest {
}
@Test
- public void testTurnOffNNAPI() throws Exception {
- Path path = MODEL_FILE.toPath();
- FileChannel fileChannel =
- (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
- MappedByteBuffer mappedByteBuffer =
- fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
- Interpreter interpreter = new Interpreter(mappedByteBuffer);
- interpreter.setUseNNAPI(true);
- float[] oneD = {1.23f, 6.54f, 7.81f};
- float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
- float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
- float[][][][] fourD = {threeD, threeD};
- float[][][][] parsedOutputs = new float[2][8][8][3];
- interpreter.run(fourD, parsedOutputs);
- float[] outputOneD = parsedOutputs[0][0][0];
- float[] expected = {3.69f, 19.62f, 23.43f};
- assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- interpreter.setUseNNAPI(false);
- interpreter.run(fourD, parsedOutputs);
- outputOneD = parsedOutputs[0][0][0];
- assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- interpreter.close();
- fileChannel.close();
- }
-
- @Test
public void testTurnOnNNAPI() throws Exception {
Path path = MODEL_FILE.toPath();
FileChannel fileChannel =
(FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
MappedByteBuffer mappedByteBuffer =
fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
- Interpreter interpreter = new Interpreter(mappedByteBuffer);
- interpreter.setUseNNAPI(true);
+ Interpreter interpreter =
+ new Interpreter(
+ mappedByteBuffer,
+ new Interpreter.Options().setUseNNAPI(true).setAllowFp16PrecisionForFp32(true));
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 9c4a5acd79..270bd6703a 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -63,6 +63,15 @@ public final class NativeInterpreterWrapperTest {
}
@Test
+ public void testConstructorWithOptions() {
+ NativeInterpreterWrapper wrapper =
+ new NativeInterpreterWrapper(
+ FLOAT_MODEL_PATH, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
+ assertThat(wrapper).isNotNull();
+ wrapper.close();
+ }
+
+ @Test
public void testConstructorWithInvalidModel() {
try {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
index 38b740021b..af20e3280b 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -19,21 +19,6 @@ package org.tensorflow.lite;
public class TestHelper {
/**
- * Turns on/off NNAPI of an {@code Interpreter}.
- *
- * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
- * IllegalArgumentException} will be thrown.
- * @param useNNAPI a boolean value indicating to turn on or off NNAPI.
- */
- public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) {
- if (interpreter != null && interpreter.wrapper != null) {
- interpreter.wrapper.setUseNNAPI(useNNAPI);
- } else {
- throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
- }
- }
-
- /**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
*
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index b2d9b84979..cf9441aee3 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -348,18 +348,22 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
case kTfLiteInt16: {
- optimized_ops::Tanh(GetTensorData<int16_t>(input), GetTensorShape(input),
- data->input_left_shift,
- GetTensorData<int16_t>(output),
- GetTensorShape(output));
+ TanhParams params;
+ params.input_left_shift = data->input_left_shift;
+ optimized_ops::Tanh(params, GetTensorShape(input),
+ GetTensorData<int16_t>(input), GetTensorShape(output),
+ GetTensorData<int16_t>(output));
return kTfLiteOk;
} break;
case kTfLiteUInt8: {
- optimized_ops::Tanh(GetTensorData<uint8_t>(input), GetTensorShape(input),
- input->params.zero_point, data->input_range_radius,
- data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ TanhParams params;
+ params.input_zero_point = input->params.zero_point;
+ params.input_range_radius = data->input_range_radius;
+ params.input_multiplier = data->input_multiplier;
+ params.input_left_shift = data->input_left_shift;
+ optimized_ops::Tanh(params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
return kTfLiteOk;
} break;
default:
@@ -385,17 +389,21 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
+ LogisticParams params;
optimized_ops::Logistic(
- GetTensorData<int16>(input), GetTensorShape(input),
- GetTensorData<int16_t>(output), GetTensorShape(output));
+ params, GetTensorShape(input), GetTensorData<int16_t>(input),
+ GetTensorShape(output), GetTensorData<int16_t>(output));
break;
}
case kTfLiteUInt8: {
+ LogisticParams params;
+ params.input_zero_point = input->params.zero_point;
+ params.input_range_radius = data->input_range_radius;
+ params.input_multiplier = data->input_multiplier;
+ params.input_left_shift = data->input_left_shift;
optimized_ops::Logistic(
- GetTensorData<uint8_t>(input), GetTensorShape(input),
- input->params.zero_point, data->input_range_radius,
- data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output), GetTensorShape(output));
+ params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
break;
}
default:
@@ -459,11 +467,13 @@ void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size = input->dims->data[0];
const int intermediate_size = input->dims->data[1];
const int input_size = input->dims->data[2];
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
optimized_ops::Softmax(
+ op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
GetTensorData<float>(input),
GetTensorShape({batch_size, intermediate_size, 1, input_size}),
- params->beta, GetTensorData<float>(output),
- GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+ GetTensorData<float>(output));
}
void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
@@ -473,10 +483,14 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
// tensor is 4D in a special way. We will convert a (Y) shape into a (1,
// 1, 1, Y) shape.
const int input_size = input->dims->data[0];
- optimized_ops::Softmax(
- GetTensorData<uint8_t>(input), GetTensorShape({1, 1, 1, input_size}),
- data->input_multiplier, data->input_left_shift, data->diff_min,
- GetTensorData<uint8_t>(output), GetTensorShape({1, 1, 1, input_size}));
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params, GetTensorShape({1, 1, 1, input_size}),
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({1, 1, 1, input_size}),
+ GetTensorData<uint8_t>(output));
}
void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
@@ -486,11 +500,15 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
// 1, 1, Y) shape.
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
- optimized_ops::Softmax(GetTensorData<uint8_t>(input),
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params,
+ GetTensorShape({batch_size, 1, 1, input_size}),
+ GetTensorData<uint8_t>(input),
GetTensorShape({batch_size, 1, 1, input_size}),
- data->input_multiplier, data->input_left_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape({batch_size, 1, 1, input_size}));
+ GetTensorData<uint8_t>(output));
}
void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
@@ -498,28 +516,36 @@ void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size = input->dims->data[0];
const int intermediate_size = input->dims->data[1];
const int input_size = input->dims->data[2];
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
optimized_ops::Softmax(
+ op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
GetTensorData<uint8_t>(input),
GetTensorShape({batch_size, intermediate_size, 1, input_size}),
- data->input_multiplier, data->input_left_shift, data->diff_min,
- GetTensorData<uint8_t>(output),
- GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+ GetTensorData<uint8_t>(output));
}
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
- optimized_ops::Softmax(GetTensorData<float>(input), GetTensorShape(input),
- params->beta, GetTensorData<float>(output),
- GetTensorShape(output));
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
+ optimized_ops::Softmax(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(output),
+ GetTensorData<float>(output));
}
void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
- optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorShape(input),
- data->input_multiplier, data->input_left_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
}
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
@@ -591,17 +617,20 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32:
+ SoftmaxParams op_params;
optimized_ops::LogSoftmax(
- GetTensorData<float>(input), GetTensorShape(input),
- GetTensorData<float>(output), GetTensorShape(output));
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
case kTfLiteUInt8:
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.reverse_scaling_divisor = data->reverse_scaling_divisor;
+ op_params.reverse_scaling_right_shift = data->reverse_scaling_right_shift;
+ op_params.diff_min = data->diff_min;
optimized_ops::LogSoftmax(
- GetTensorData<uint8_t>(input), GetTensorShape(input),
- data->input_multiplier, data->input_left_shift,
- data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
return kTfLiteOk;
default:
context->ReportError(context, "Only float32 supported currently., got %d",
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 541f320138..66b947771c 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -770,51 +770,29 @@ TfLiteStatus EvalFloat(
}
// Loop through the sequence.
- if (forward_sequence) {
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr,
- input_to_forget_weights->data.f, input_to_cell_weights->data.f,
- input_to_output_weights->data.f, aux_input_ptr,
- aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
- aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
- recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
- } else {
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr,
- input_to_forget_weights->data.f, input_to_cell_weights->data.f,
- input_to_output_weights->data.f, aux_input_ptr,
- aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
- aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
- recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * n_output;
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr_time = output->data.f + t_rel * output_step;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
+ input_to_cell_weights->data.f, input_to_output_weights->data.f,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
+ aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
}
return kTfLiteOk;
}
@@ -991,72 +969,41 @@ TfLiteStatus EvalHybrid(
aux_input_to_output_weights_scale =
aux_input_to_output_weights->params.scale;
}
- if (forward_sequence) {
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr,
- forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
- projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr);
- }
- } else {
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr = input->data.f + t * n_batch * n_input;
- float* output_ptr = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr,
- forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
- projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr);
- }
+
+ // Feed the sequence into the LSTM step-by-step.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * n_output;
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr = output->data.f + t_rel * output_step;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, aux_input_size, n_output, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 4cd96348a2..f765235e04 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -83,20 +83,24 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
&input2_multiplier, &input2_shift); \
\
+ ComparisonParams op_params; \
+ op_params.left_shift = left_shift; \
+ op_params.input1_offset = input1_offset; \
+ op_params.input1_multiplier = input1_multiplier; \
+ op_params.input1_shift = -input1_shift; \
+ op_params.input2_offset = input2_offset; \
+ op_params.input2_multiplier = input2_multiplier; \
+ op_params.input2_shift = -input2_shift; \
if (requires_broadcast) { \
- reference_ops::Broadcast##opname( \
- left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, input1_multiplier, input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
- input2_offset, input2_multiplier, input2_shift, \
- GetTensorData<bool>(output), GetTensorDims(output)); \
+ reference_ops::Broadcast4DSlow##opname##WithScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
+ GetTensorShape(input2), GetTensorData<uint8_t>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
} else { \
- reference_ops::opname( \
- left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, input1_multiplier, input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
- input2_offset, input2_multiplier, input2_shift, \
- GetTensorData<bool>(output), GetTensorDims(output)); \
+ reference_ops::opname##WithScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
+ GetTensorShape(input2), GetTensorData<uint8_t>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
} \
} \
}
@@ -108,16 +112,19 @@ TF_LITE_QUANTIZE_COMPARISON(Less);
TF_LITE_QUANTIZE_COMPARISON(LessEqual);
#undef TF_LITE_QUANTIZE_COMPARISON
-#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
- requires_broadcast \
- ? reference_ops::Broadcast##opname( \
- GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output)) \
- : reference_ops::opname( \
- GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output));
+#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
+ { \
+ ComparisonParams op_params; \
+ requires_broadcast \
+ ? reference_ops::Broadcast4DSlow##opname##NoScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
+ GetTensorShape(input2), GetTensorData<type>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)) \
+ : reference_ops::opname##NoScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
+ GetTensorShape(input2), GetTensorData<type>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
+ }
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 25ea556d5a..7ad3399ffd 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -100,20 +100,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have
// a model with a Concatenation with fused activation function.
-#define TF_LITE_CONCATENATION(type, scalar) \
- VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
- type::Concatenation<FusedActivationFunctionType::kNone, scalar>( \
- RemapDim(NumDimensions(output), axis), all_inputs.data(), \
- all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \
- GetTensorDims(output))
-
-#define TF_LITE_CONCATENATION_QUANTIZED(type) \
- VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
- type::Concatenation( \
- RemapDim(NumDimensions(output), axis), all_inputs.data(), \
- all_inputs.dims(), all_inputs.zero_point(), all_inputs.scale(), \
- node->inputs->size, GetTensorData<uint8>(output), GetTensorDims(output), \
- output->params.zero_point, output->params.scale)
+#define TF_LITE_CONCATENATION(type, scalar) \
+ { \
+ VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
+ tflite::ConcatenationParams op_params; \
+ op_params.axis = axis; \
+ op_params.inputs_count = node->inputs->size; \
+ type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
+ GetTensorShape(output), \
+ GetTensorData<scalar>(output)); \
+ }
+
+#define TF_LITE_CONCATENATION_QUANTIZED(type) \
+ { \
+ VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
+ tflite::ConcatenationParams op_params; \
+ op_params.axis = axis; \
+ op_params.input_zeropoint = all_inputs.zero_point(); \
+ op_params.input_scale = all_inputs.scale(); \
+ op_params.inputs_count = node->inputs->size; \
+ op_params.output_zeropoint = output->params.zero_point; \
+ op_params.output_scale = output->params.scale; \
+ type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
+ all_inputs.data(), GetTensorShape(output), \
+ GetTensorData<uint8>(output)); \
+ }
switch (output->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 101b4fc961..dbcadbee14 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -86,6 +86,18 @@ struct OpData {
bool run_multithreaded_kernel;
};
+inline PaddingType RuntimePaddingType(TfLitePadding padding) {
+ switch (padding) {
+ case TfLitePadding::kTfLitePaddingSame:
+ return PaddingType::kSame;
+ case TfLitePadding::kTfLitePaddingValid:
+ return PaddingType::kValid;
+ case TfLitePadding::kTfLitePaddingUnknown:
+ default:
+ return PaddingType::kNone;
+ }
+}
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to use as scratch space for im2col, and
@@ -487,18 +499,18 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
} else {
effective_kernel_type = kernel_type;
}
+ ConvParams op_params;
+ op_params.padding_type = RuntimePaddingType(params->padding);
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
switch (effective_kernel_type) {
case kReference: {
- ConvParams op_params;
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = data->padding.width;
- op_params.padding_values.height = data->padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
reference_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
@@ -508,16 +520,6 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
break;
}
case kGenericOptimized: {
- ConvParams op_params;
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = data->padding.width;
- op_params.padding_values.height = data->padding.height;
- op_params.stride_width = params->stride_width;
- op_params.stride_height = params->stride_height;
- op_params.dilation_width_factor = params->dilation_width_factor;
- op_params.dilation_height_factor = params->dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
optimized_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
GetTensorData<float>(filter), GetTensorShape(bias),
@@ -534,25 +536,21 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
filter_data = GetTensorData<float>(filter);
}
multithreaded_ops::Conv(
- *eigen_support::GetThreadPoolDevice(context),
- GetTensorData<float>(input), GetTensorDims(input), filter_data,
- GetTensorDims(filter), GetTensorData<float>(bias),
- GetTensorDims(bias), params->stride_width, params->stride_height,
- data->padding.width, data->padding.height, params->padding,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ *eigen_support::GetThreadPoolDevice(context), op_params,
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), filter_data, GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kCblasOptimized: {
- cblas_ops::Conv(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- data->padding.width, data->padding.height,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ cblas_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
}
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 3a08f48b00..59bf64e0af 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -77,13 +77,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
- auto zero_point = op_context.input->params.zero_point;
- auto scale = op_context.input->params.scale;
-
- optimized_ops::Dequantize(GetTensorData<uint8_t>(op_context.input),
- GetTensorDims(op_context.input), zero_point, scale,
- GetTensorData<float>(op_context.output),
- GetTensorDims(op_context.output));
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = op_context.input->params.zero_point;
+ op_params.scale = op_context.input->params.scale;
+ optimized_ops::Dequantize(op_params, GetTensorShape(op_context.input),
+ GetTensorData<uint8_t>(op_context.input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
if (IsConstantTensor(op_context.input)) {
op_data->float_dequantized_weights_initialized = true;
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index 7945c095b1..8d4bb51006 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -81,24 +81,27 @@ template <KernelType kernel_type>
void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_DIV(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_DIV(type, opname, data_type) \
+ tflite::ArithmeticParams op_params; \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv, int32_t);
+ TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, int32_t);
} else {
TF_LITE_DIV(reference_ops, Div, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv, int32_t);
+ TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, int32_t);
} else {
TF_LITE_DIV(optimized_ops, Div, int32_t);
}
@@ -106,13 +109,13 @@ void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv, float);
+ TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, float);
} else {
TF_LITE_DIV(reference_ops, Div, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv, float);
+ TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, float);
} else {
TF_LITE_DIV(optimized_ops, Div, float);
}
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index f9bc3747cb..b51af72fe6 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -68,11 +68,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
- reference_ops::FakeQuant(GetTensorData<float>(op_context.input),
- GetTensorDims(op_context.input), params->min,
- params->max, params->num_bits,
- GetTensorData<float>(op_context.output),
- GetTensorDims(op_context.output));
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = params->num_bits;
+ op_params.minmax.min = params->min;
+ op_params.minmax.max = params->max;
+ reference_ops::FakeQuant(op_params, GetTensorShape(op_context.input),
+ GetTensorData<float>(op_context.input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index badd2de11a..b5afeb1a7b 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -84,11 +84,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const int input_rank = NumDimensions(input);
-#define TF_LITE_GATHER(data_type, index_type) \
- optimized_ops::Gather( \
- GetTensorData<data_type>(input), GetTensorDims(input), input_rank, \
- GetTensorData<index_type>(positions), GetTensorDims(positions), \
- GetTensorData<data_type>(output), GetTensorDims(output));
+#define TF_LITE_GATHER(data_type, index_type) \
+ { \
+ tflite::GatherParams op_params; \
+ op_params.input_rank = input_rank; \
+ optimized_ops::Gather( \
+ op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorShape(positions), GetTensorData<index_type>(positions), \
+ GetTensorShape(output), GetTensorData<data_type>(output)); \
+ }
switch (input->type) {
case kTfLiteFloat32:
TF_LITE_GATHER(float, int32_t);
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
index 3624c20ae3..2252ca1bcc 100644
--- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -43,11 +43,15 @@ void RunLogSoftmaxFloatReference(const uint8* input_data,
// Reference data generated via Dequant of input into float, and then applying
// float LogSoftmax.
- reference_ops::Dequantize(
- input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
- reference_dequant_data.data(), ToRuntimeDims(shape_common));
- optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common,
- reference_output_float_data.data(), shape_common);
+ DequantizationParams dq_params;
+ dq_params.zero_point = input_offset;
+ dq_params.scale = input_scale;
+ reference_ops::Dequantize(dq_params, shape_common, input_data, shape_common,
+ reference_dequant_data.data());
+ SoftmaxParams sm_params;
+ optimized_ops::LogSoftmax(sm_params, shape_common,
+ reference_dequant_data.data(), shape_common,
+ reference_output_float_data.data());
// Work with quantized scaling for LogSoftmax, under which 255 represents 0,
// and -16 gets nudged up to 0.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -129,14 +133,16 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, reverse_scaling_divisor,
- reverse_scaling_right_shift, diff_min,
- optimized_logsoftmax_output.data(), shape_common);
- reference_ops::LogSoftmax(
- input_data, shape_common, input_beta_multiplier, input_beta_left_shift,
- reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
- reference_quant_logsoftmax_output.data(), shape_common);
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ optimized_ops::LogSoftmax(params, shape_common, input_data, shape_common,
+ optimized_logsoftmax_output.data());
+ reference_ops::LogSoftmax(params, shape_common, input_data, shape_common,
+ reference_quant_logsoftmax_output.data());
CheckOutputData(optimized_logsoftmax_output.data(),
reference_float_logsoftmax_output.data(), shape_common,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
index 40d42bbae9..2d96da65c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
@@ -31,20 +31,29 @@ limitations under the License.
namespace tflite {
namespace cblas_ops {
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
gemmlowp::ScopedProfilingLabel label("Conv/cblas");
const float* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
if (need_im2col) {
@@ -55,18 +64,17 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
op_params.padding_values.height = pad_height;
op_params.stride_width = stride_width;
op_params.stride_height = stride_height;
- op_params.dilation_width_factor = 1;
- op_params.dilation_height_factor = 1;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
optimized_ops::Im2col(op_params, filter_height, filter_width, 0,
- DimsToShape(input_dims), input_data,
- DimsToShape(im2col_dims), im2col_data);
+ input_shape, input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
// The following code computes matrix multiplication c = a * transponse(b)
@@ -78,10 +86,10 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
const float* a = gemm_input_data;
const float* b = filter_data;
float* c = output_data;
- int m = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] *
- gemm_input_dims->sizes[3];
- int n = output_dims.sizes[0];
- int k = gemm_input_dims->sizes[0];
+ const int gemm_input_dims = gemm_input_shape->DimensionsCount();
+ int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
+ int n = output_shape.Dims(3);
+ int k = gemm_input_shape->Dims(gemm_input_dims - 1);
// The stride of matrix a, b and c respectively.
int stride_a = k;
int stride_b = k;
@@ -91,8 +99,8 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
stride_a, b, stride_b, 0.0f, c, stride_c);
optimized_ops::AddBiasAndEvalActivationFunction(
- output_activation_min, output_activation_max, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data);
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
} // namespace cblas_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 114575a96a..d8dd7bba89 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -1092,80 +1092,6 @@ inline void DepthwiseConv(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- tflite::DepthwiseParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.depth_multiplier = depth_multiplier;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, 1, 1, pad_width,
- pad_height, depth_multiplier, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- float* output_data, const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, pad_width, pad_height,
- depth_multiplier, output_data, output_dims);
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index f892b8f661..803eff292a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -24,9 +24,6 @@ limitations under the License.
namespace tflite {
namespace optimized_ops {
-// TODO(b/80418076): Move to legacy ops file, along with invocations.
-static constexpr int kDepthwiseReverseShift = -1;
-
// Implementation of quantized DepthwiseConv
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
@@ -1701,6 +1698,8 @@ inline void DepthwiseConv(
const int output_shift = params.output_shift;
const int dilation_width_factor = params.dilation_width_factor;
const int dilation_height_factor = params.dilation_height_factor;
+ TFLITE_DCHECK_GE(dilation_width_factor, 1);
+ TFLITE_DCHECK_GE(dilation_height_factor, 1);
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
@@ -1994,105 +1993,6 @@ inline void DepthwiseConv(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- tflite::DepthwiseParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.depth_multiplier = depth_multiplier;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kDepthwiseReverseShift * output_shift;
-
- DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data);
-}
-
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims, stride,
- stride, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index b6151c40b3..4218be20a4 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -19,6 +19,8 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -28,9 +30,857 @@ namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
+using reference_ops::ArgMinMax;
+using reference_ops::Broadcast4DSlowGreater;
+using reference_ops::Broadcast4DSlowGreaterEqual;
+using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
+using reference_ops::Broadcast4DSlowGreaterWithScaling;
+using reference_ops::Broadcast4DSlowLess;
+using reference_ops::Broadcast4DSlowLessEqual;
+using reference_ops::Broadcast4DSlowLessEqualWithScaling;
+using reference_ops::Broadcast4DSlowLessWithScaling;
+using reference_ops::BroadcastAdd4DSlow;
+using reference_ops::BroadcastGreater;
+using reference_ops::BroadcastGreaterEqual;
+using reference_ops::BroadcastLess;
+using reference_ops::BroadcastLessEqual;
+using reference_ops::BroadcastMul4DSlow;
+using reference_ops::BroadcastSub4DSlow;
+using reference_ops::Concatenation;
+using reference_ops::ConcatenationWithScaling;
+using reference_ops::DepthConcatenation;
+using reference_ops::Dequantize;
+using reference_ops::Div;
+using reference_ops::FakeQuant;
+using reference_ops::Gather;
+using reference_ops::Greater;
+using reference_ops::GreaterEqual;
+using reference_ops::GreaterEqualWithScaling;
+using reference_ops::GreaterWithScaling;
+using reference_ops::Less;
+using reference_ops::LessEqual;
+using reference_ops::LessEqualWithScaling;
+using reference_ops::LessWithScaling;
+using reference_ops::Mean;
+using reference_ops::RankOneSelect;
using reference_ops::Relu1;
using reference_ops::Relu6;
+using reference_ops::ReluX;
+using reference_ops::Select;
using reference_ops::SpaceToBatchND;
+using reference_ops::Split;
+using reference_ops::StridedSlice;
+using reference_ops::TensorFlowSplit;
+using reference_ops::Transpose;
+
+static constexpr int kDepthwiseReverseShift = -1;
+
+template <typename Scalar, int N>
+VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
+ const int size = FlatSize(dims);
+ return VectorMap<Scalar>(data, size, 1);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
+ const Dims<N>& dims) {
+ const int cols = dims.sizes[N - 1];
+ int rows = 1;
+ for (int d = 0; d < N - 1; d++) {
+ rows *= dims.sizes[d];
+ }
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar, int N>
+ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
+ const Dims<N>& dims) {
+ const int rows = dims.sizes[0];
+ int cols = 1;
+ for (int d = 1; d < N; d++) {
+ cols *= dims.sizes[d];
+ }
+ return ArrayMap<Scalar>(data, rows, cols);
+}
+
+// TODO(b/62193649): this function is only needed as long
+// as we have the --variable_batch hack.
+template <typename Scalar, int N>
+MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
+ const Dims<N>& dims,
+ int rows) {
+ const int flatsize = FlatSize(dims);
+ TFLITE_DCHECK((flatsize % rows) == 0);
+ const int cols = flatsize / rows;
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
+ for (int i = 0; i < 4; i++) {
+ if (dims1.sizes[i] != dims2.sizes[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims,
+ float output_activation_min,
+ float output_activation_max) {
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(array_dims), array_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
+ output_activation_min,
+ output_activation_max);
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void FullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
+ const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
+ int32 output_multiplier, int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
+template <typename T>
+inline void ExtractPatchIntoBufferColumn(
+ const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int in_width, int in_height, int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
+ ExtractPatchIntoBufferColumn(
+ DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
+ stride_height, pad_width, pad_height, in_width, in_height, in_depth,
+ single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
+}
+
+template <typename T>
+void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+
+ DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, int kheight,
+ int kwidth, uint8 zero_byte, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+
+ Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
+ input_data, DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, zero_byte, output_data, output_dims);
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+ const int8_t* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* scaling_factors_ptr,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ int8_t* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
+ input_data, DimsToShape(filter_dims), filter_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float* output_data, const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, dilation_width_factor,
+ dilation_height_factor, pad_width, pad_height, output_activation_min,
+ output_activation_max, output_data, output_dims, im2col_data,
+ im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, 1, 1, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int kheight, int kwidth,
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
+ Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
+ kwidth, zero_byte, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
+
+ const auto input_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
+
+ AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
+ output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ const int input_rows = input_dims.sizes[0];
+ const int input_cols = FlatSizeSkipDim(input_dims, 0);
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+ const int output_rows = output_dims.sizes[0];
+ const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ TFLITE_DCHECK_EQ(output_rows, filter_rows);
+ TFLITE_DCHECK_EQ(output_cols, input_cols);
+ TFLITE_DCHECK_EQ(filter_cols, input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
+ filter_data, output_rows, filter_cols, filter_cols);
+ gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
+ input_data, filter_cols, output_cols, filter_cols);
+ gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
+ output_data, output_rows, output_cols, output_rows);
+ const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
+ bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ output_activation_min, output_activation_max);
+ gemmlowp::GemmWithOutputPipeline<uint8, uint8,
+ gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
+ gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
+ input_offset, output_pipeline);
+}
+
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <typename T>
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
@@ -574,6 +1424,14 @@ void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
filter_width, filter_height, output_data, output_dims);
}
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Softmax(const float* input_data, const Dims<4>& input_dims,
float beta, float* output_data,
const Dims<4>& output_dims) {
@@ -581,6 +1439,16 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
int32 input_beta_multiplier, int32 input_beta_left_shift,
int diff_min, uint8* output_data,
@@ -590,12 +1458,33 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
LogSoftmax(input_data, DimsToShape(input_dims), output_data,
DimsToShape(output_dims));
}
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
int32 input_multiplier, int32 input_left_shift,
int32 reverse_scaling_divisor,
@@ -607,6 +1496,18 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Logistic(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
@@ -622,6 +1523,20 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
int16* output_data, const Dims<4>& output_dims) {
Logistic(input_data, DimsToShape(input_dims), output_data,
@@ -634,6 +1549,18 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims,
output_data);
}
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
@@ -643,6 +1570,14 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
int input_left_shift, int16* output_data,
const Dims<4>& output_dims) {
@@ -777,7 +1712,6 @@ inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
DimsToShape(output_dims), output_data);
}
-// Legacy Dims<4>.
inline void LocalResponseNormalization(const float* input_data,
const Dims<4>& input_dims, int range,
float bias, float alpha, float beta,
@@ -793,7 +1727,6 @@ inline void LocalResponseNormalization(const float* input_data,
DimsToShape(output_dims), output_data);
}
-// Legacy Dims<4> version.
template <typename SrcT, typename DstT>
void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
const Dims<4>& output_dims) {
@@ -801,14 +1734,12 @@ void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
output_data);
}
-// Legacy Dims<4> version.
inline void Floor(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
output_data);
}
-// Legacy Dims<4>
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
const int32* output_size_data,
const Dims<4>& output_size_dims, float* output_data,
@@ -820,7 +1751,6 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
-// Legacy Dims<4>
inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
const int32* output_size_data,
const Dims<4>& output_size_dims, uint8* output_data,
@@ -850,7 +1780,6 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
output_data, output_dims, /*align_corners=*/false);
}
-// Legacy Dims<4>.
template <typename T>
inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
const int32* block_shape_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index b5d001cc9e..4139cf4eba 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -69,13 +69,13 @@ struct MatMulConvFunctor {
template <class T>
class EigenTensorConvFunctor {
private:
- Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) {
+ Eigen::PaddingType RuntimePadding2EigenPadding(PaddingType padding) {
switch (padding) {
- case kTfLitePaddingValid:
+ case PaddingType::kValid:
return Eigen::PADDING_VALID;
- case kTfLitePaddingSame:
+ case PaddingType::kSame:
return Eigen::PADDING_SAME;
- case kTfLitePaddingUnknown:
+ case PaddingType::kNone:
assert(false); // should never get here.
return Eigen::PADDING_VALID;
}
@@ -89,7 +89,7 @@ class EigenTensorConvFunctor {
int input_width, int input_depth, const T* filter_data,
int filter_height, int filter_width, int filter_count,
int stride_rows, int stride_cols, int pad_width,
- int pad_height, TfLitePadding padding, T* output_data,
+ int pad_height, PaddingType padding, T* output_data,
int output_height, int output_width) {
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
stride_rows == 1 && stride_cols == 1);
@@ -127,28 +127,38 @@ class EigenTensorConvFunctor {
input_depth, filter_count);
output.device(device) =
Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows,
- TfLitePadding2EigenPadding(padding));
+ RuntimePadding2EigenPadding(padding));
}
}
};
-inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
- const Dims<4>& input_dims, const float* filter_data,
- const Dims<4>& filter_dims, const float* bias_data,
- const Dims<4>& bias_dims, int stride_width, int stride_height,
- int pad_width, int pad_height, TfLitePadding padding,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void Conv(const Eigen::ThreadPoolDevice& device,
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const PaddingType padding = params.padding_type;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
EigenTensorConvFunctor<float> conv_functor;
conv_functor(device, input_data, im2col_data, batches, input_height,
input_width, input_depth, filter_data, filter_height,
@@ -157,8 +167,8 @@ inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
output_width);
optimized_ops::AddBiasAndEvalActivationFunction(
- output_activation_min, output_activation_max, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data);
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
} // namespace multithreaded_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 0999738396..77f84e0c1c 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -52,13 +52,10 @@ using reference_ops::Broadcast4DSlowLessEqual;
using reference_ops::Broadcast4DSlowLessEqualWithScaling;
using reference_ops::Broadcast4DSlowLessWithScaling;
using reference_ops::BroadcastAdd4DSlow;
-using reference_ops::BroadcastGreater;
-using reference_ops::BroadcastGreaterEqual;
-using reference_ops::BroadcastLess;
-using reference_ops::BroadcastLessEqual;
using reference_ops::BroadcastMul4DSlow;
using reference_ops::BroadcastSub4DSlow;
using reference_ops::Concatenation;
+using reference_ops::ConcatenationWithScaling;
using reference_ops::DepthConcatenation;
using reference_ops::Dequantize;
using reference_ops::Div;
@@ -81,7 +78,6 @@ using reference_ops::Select;
using reference_ops::SpaceToBatchND;
using reference_ops::Split;
using reference_ops::StridedSlice;
-using reference_ops::TensorFlowSplit;
using reference_ops::Transpose;
// TODO(b/80247582) Remove this constant.
@@ -111,12 +107,6 @@ VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
return VectorMap<Scalar>(data, size, 1);
}
-template <typename Scalar, int N>
-VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
- const int size = FlatSize(dims);
- return VectorMap<Scalar>(data, size, 1);
-}
-
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen matrix expression. The same explanation as for VectorMap
// above also applies here.
@@ -144,28 +134,6 @@ MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
return MatrixMap<Scalar>(data, rows, cols);
}
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
- const Dims<N>& dims) {
- const int rows = dims.sizes[0];
- int cols = 1;
- for (int d = 1; d < N; d++) {
- cols *= dims.sizes[d];
- }
- return MatrixMap<Scalar>(data, rows, cols);
-}
-
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithLastDimAsCols(Scalar* data,
- const Dims<N>& dims) {
- const int cols = dims.sizes[N - 1];
- int rows = 1;
- for (int d = 0; d < N - 1; d++) {
- rows *= dims.sizes[d];
- }
- return MatrixMap<Scalar>(data, rows, cols);
-}
-
template <typename Scalar>
using ArrayMap = typename std::conditional<
std::is_const<Scalar>::value,
@@ -173,17 +141,6 @@ using ArrayMap = typename std::conditional<
Eigen::Dynamic, Eigen::Dynamic>>,
Eigen::Map<Eigen::Array<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
-template <typename Scalar, int N>
-ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
- const Dims<N>& dims) {
- const int rows = dims.sizes[0];
- int cols = 1;
- for (int d = 1; d < N; d++) {
- cols *= dims.sizes[d];
- }
- return ArrayMap<Scalar>(data, rows, cols);
-}
-
template <typename Scalar>
ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
const RuntimeShape& shape) {
@@ -205,20 +162,6 @@ struct TTypes {
UnalignedConstMatrix;
};
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-// TODO(b/62193649): this function is only needed as long
-// as we have the --variable_batch hack.
-template <typename Scalar, int N>
-MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
- const Dims<N>& dims,
- int rows) {
- const int flatsize = FlatSize(dims);
- TFLITE_DCHECK((flatsize % rows) == 0);
- const int cols = flatsize / rows;
- return MatrixMap<Scalar>(data, rows, cols);
-}
-
// TODO(b/62193649): this function is only needed as long
// as we have the --variable_batch hack.
template <typename Scalar>
@@ -270,15 +213,6 @@ SaturatingRoundingMultiplyByPOTParam(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
- for (int i = 0; i < 4; i++) {
- if (dims1.sizes[i] != dims2.sizes[i]) {
- return false;
- }
- }
- return true;
-}
-
inline void AddBiasAndEvalActivationFunction(float output_activation_min,
float output_activation_max,
const RuntimeShape& bias_shape,
@@ -352,33 +286,6 @@ inline void AddBiasAndEvalActivationFunction(float output_activation_min,
#endif
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims,
- float output_activation_min,
- float output_activation_max) {
- AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
- DimsToShape(bias_dims), bias_data,
- DimsToShape(array_dims), array_data);
-}
-
-// Note: This to be converted to RuntimeShapes along with Conv.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims,
- output_activation_min,
- output_activation_max);
-}
-
template <typename Lhs, typename Rhs, typename Result>
void Gemm(const Eigen::MatrixBase<Lhs>& lhs, const Eigen::MatrixBase<Rhs>& rhs,
Eigen::MatrixBase<Result>* result) {
@@ -925,38 +832,6 @@ inline void FullyConnected(
output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- tflite::FullyConnectedParams op_params;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(weights_dims), weights_data,
- DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data, const Dims<4>& weights_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
- bias_dims, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
#ifdef USE_NEON
inline void FullyConnectedAsGEMV(
const RuntimeShape& input_shape, const uint8* input_data,
@@ -1203,33 +1078,6 @@ inline void FullyConnected(
input_offset, output_pipeline);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kReverseShift * output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data,
- gemm_context);
-}
-
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
@@ -1317,54 +1165,6 @@ inline void FullyConnected(
input_offset, output_pipeline);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
- int32 output_multiplier, int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kReverseShift * output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data_int32, DimsToShape(output_dims), output_data,
- gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims, gemm_context);
-}
-
// Internal function doing the actual arithmetic work for
// ShuffledFullyConnected.
// May be called either directly by it (single-threaded case) or may be used
@@ -1809,29 +1609,6 @@ inline void ShuffledFullyConnected(
gemm_context->workers_pool()->Execute(tasks);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void ShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kReverseShift * output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(weights_dims), shuffled_weights_data,
- DimsToShape(bias_dims), bias_data,
- DimsToShape(output_dims), output_data,
- shuffled_input_workspace_data, gemm_context);
-}
-
template <typename T>
inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
int h, int b, int kheight, int kwidth,
@@ -1922,20 +1699,6 @@ inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-inline void ExtractPatchIntoBufferColumn(
- const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int in_width, int in_height, int in_depth, int single_buffer_length,
- int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
- ExtractPatchIntoBufferColumn(
- DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
- stride_height, pad_width, pad_height, in_width, in_height, in_depth,
- single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
-}
-
template <typename T>
void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
const RuntimeShape& input_shape, const T* input_data,
@@ -2019,30 +1782,6 @@ void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 zero_byte,
- T* im2col_data) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
-
- DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), DimsToShape(output_dims),
- im2col_data);
-}
-
template <typename T>
void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
const RuntimeShape& input_shape, const T* input_data,
@@ -2078,36 +1817,6 @@ void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, int kheight,
- int kwidth, uint8 zero_byte, T* output_data,
- const Dims<4>& output_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = 1;
- op_params.dilation_height_factor = 1;
-
- Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
- input_data, DimsToShape(output_dims), output_data);
-}
-
-// legacy, for compatibility with old checked-in code
-template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int kheight, int kwidth,
- uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
- Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, zero_byte, output_data, output_dims);
-}
-
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
const float* filter_data, const RuntimeShape& bias_shape,
@@ -2171,33 +1880,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
- filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data);
-}
-
inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
const RuntimeShape& input_shape,
const int8_t* input_data,
@@ -2278,82 +1960,6 @@ inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
- const int8_t* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* scaling_factors_ptr,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- int8_t* im2col_data, const Dims<4>& im2col_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
- input_data, DimsToShape(filter_dims), filter_data,
- DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float* output_data, const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, dilation_width_factor,
- dilation_height_factor, pad_width, pad_height, output_activation_min,
- output_activation_max, output_data, output_dims, im2col_data,
- im2col_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, 1, 1, pad_width, pad_height,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
- output_dims, im2col_data, im2col_dims);
-}
-
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
@@ -2445,192 +2051,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
input_offset, output_pipeline);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- uint8* im2col_data, const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kReverseShift * output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
- filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, stride, pad_width,
- pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int kheight, int kwidth,
- uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
- Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, zero_byte, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("ConvAsGemm");
-
- const auto input_matrix_map =
- MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
- auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
-
- Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
-
- AddBiasAndEvalActivationFunction<Ac>(bias_data, bias_dims, output_data,
- output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit");
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- const int input_rows = input_dims.sizes[0];
- const int input_cols = FlatSizeSkipDim(input_dims, 0);
- const int filter_rows = filter_dims.sizes[3];
- const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
- const int output_rows = output_dims.sizes[0];
- const int output_cols = FlatSizeSkipDim(output_dims, 0);
- TFLITE_DCHECK_EQ(output_rows, filter_rows);
- TFLITE_DCHECK_EQ(output_cols, input_cols);
- TFLITE_DCHECK_EQ(filter_cols, input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
- gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
- filter_data, output_rows, filter_cols, filter_cols);
- gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
- input_data, filter_cols, output_cols, filter_cols);
- gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
- output_data, output_rows, output_cols, output_rows);
- const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
- bias_data, output_rows, output_offset, output_multiplier, -output_shift,
- output_activation_min, output_activation_max);
- gemmlowp::GemmWithOutputPipeline<uint8, uint8,
- gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
- gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset,
- input_offset, output_pipeline);
-}
-
template <typename T>
inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
const RuntimeShape& unextended_input_shape,
@@ -3547,21 +2967,6 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
// TODO(aselle): This is not actually optimized yet.
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
@@ -3755,31 +3160,6 @@ inline void LstmCell(
output_state_map.tanh();
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
- tflite::LstmCellParams op_params;
- // Float LSTM cell does not need parameters to be set: leave untouched.
-
- LstmCell(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(prev_activ_dims), prev_activ_data,
- DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(prev_state_dims), prev_state_data,
- DimsToShape(output_state_dims), output_state_data,
- DimsToShape(output_activ_dims), output_activ_data,
- DimsToShape(concat_temp_dims), concat_temp_data,
- DimsToShape(activ_temp_dims), activ_temp_data);
-}
-
// Quantized LSTM cell. Currently just a copy of the reference impl in
// reference_ops.h. See the big function comment there, not replicating it
// here.
@@ -4070,37 +3450,6 @@ inline void LstmCell(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
- tflite::LstmCellParams op_params;
- op_params.weights_zero_point = weights_zero_point;
- op_params.accum_multiplier = accum_multiplier;
- op_params.accum_shift = accum_shift;
-
- LstmCell<StateIntegerBits>(
- op_params, DimsToShape(input_dims), input_data_uint8,
- DimsToShape(prev_activ_dims), prev_activ_data_uint8,
- DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
- bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
- DimsToShape(output_state_dims), output_state_data_int16,
- DimsToShape(output_activ_dims), output_activ_data_uint8,
- DimsToShape(concat_temp_dims), concat_temp_data_uint8,
- DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
-}
-
inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
@@ -4560,16 +3909,6 @@ inline void Softmax(const SoftmaxParams& params,
out_mat.array().rowwise() *= scale;
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
- float beta, float* output_data,
- const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.beta = beta;
- Softmax(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Softmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, uint8* output_data) {
@@ -4781,19 +4120,6 @@ inline void Softmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
- const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.input_multiplier = input_beta_multiplier;
- params.input_left_shift = input_beta_left_shift;
- params.diff_min = diff_min;
- Softmax(params, input_shape, input_data, output_shape, output_data);
-}
-
// TODO(myenik): This is the same as the reference implementation, not actually
// optimized yet.
inline void LogSoftmax(const SoftmaxParams& params,
@@ -4831,15 +4157,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
- SoftmaxParams params;
- // No params currently used for float LogSoftmax.
- LogSoftmax(params, input_shape, input_data, output_shape, output_data);
-}
-
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(
@@ -5044,22 +4361,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- params.reverse_scaling_divisor = reverse_scaling_divisor;
- params.reverse_scaling_right_shift = reverse_scaling_right_shift;
- params.diff_min = diff_min;
- LogSoftmax(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
@@ -5218,20 +4519,6 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
- LogisticParams params;
- params.input_zero_point = input_zero_point;
- params.input_range_radius = input_range_radius;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Logistic(const LogisticParams& params,
const RuntimeShape& input_shape, const int16* input_data,
const RuntimeShape& output_shape, int16* output_data) {
@@ -5293,24 +4580,6 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy version.
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
- const RuntimeShape& output_shape, int16* output_data) {
- LogisticParams params;
- // No params currently needed by int16 Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy version.
-inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
- int16* output_data, const RuntimeShape& output_shape) {
- LogisticParams params;
- // No params currently needed by int16 Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
@@ -5478,20 +4747,6 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
- TanhParams params;
- params.input_zero_point = input_zero_point;
- params.input_range_radius = input_range_radius;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- Tanh(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
const int16* input_data, const RuntimeShape& output_shape,
int16* output_data) {
@@ -5593,16 +4848,6 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
- int input_left_shift, int16* output_data,
- const RuntimeShape& output_shape) {
- TanhParams params;
- params.input_left_shift = input_left_shift;
- Tanh(params, input_shape, input_data, output_shape, output_data);
-}
-
template <typename SrcT, typename DstT>
inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
const RuntimeShape& output_shape, DstT* output_data) {
@@ -6485,27 +5730,6 @@ void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 zero_byte,
- T* im2col_data) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
-
- TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), DimsToShape(output_dims),
- im2col_data);
-}
-
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
@@ -6529,27 +5753,6 @@ inline void TransposeConv(
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
-
- TransposeConv(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data);
-}
-
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index a8428528c9..11224270a4 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -94,81 +94,6 @@ inline void DepthwiseConv(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- tflite::DepthwiseParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.depth_multiplier = depth_multiplier;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, 1, 1, pad_width,
- pad_height, depth_multiplier, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- float* output_data, const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, pad_width, pad_height,
- depth_multiplier, output_data, output_dims);
-}
-
} // end namespace reference_ops
} // end namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index e8fc566502..eab28e6c84 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -25,9 +25,6 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-// TODO(b/80418076): Move to legacy ops file, along with invocations.
-static constexpr int kDepthwiseReverseShift = -1;
-
inline void DepthwiseConv(
const DepthwiseParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
@@ -109,106 +106,6 @@ inline void DepthwiseConv(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- tflite::DepthwiseParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.depth_multiplier = depth_multiplier;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kDepthwiseReverseShift * output_shift;
-
- DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width,
- stride_height, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy, for compatibility with old checked-in code.
-template <FusedActivationFunctionType Ac>
-void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims, stride,
- stride, pad_width, pad_height, depth_multiplier,
- output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
} // end namespace reference_ops
} // end namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h b/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
index 23325e8c4c..3c7fd29256 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
@@ -62,39 +62,6 @@ inline void FullyConnected(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- tflite::FullyConnectedParams op_params;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(weights_dims), weights_data,
- DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data, const Dims<4>& weights_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
- bias_dims, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
@@ -144,32 +111,6 @@ inline void FullyConnected(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, void* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kReverseShift * output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data,
- gemm_context);
-}
-
inline void FullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
@@ -224,32 +165,6 @@ inline void FullyConnected(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data,
- const Dims<4>& output_dims, void* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kReverseShift * output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- FullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(output_dims), output_data,
- gemm_context);
-}
-
inline void ShuffledFullyConnected(
const FullyConnectedParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& weights_shape,
@@ -405,55 +320,6 @@ inline void ShuffledFullyConnected(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void ShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, void* gemm_context) {
- tflite::FullyConnectedParams op_params;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kReverseShift * output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(weights_dims), shuffled_weights_data,
- DimsToShape(bias_dims), bias_data,
- DimsToShape(output_dims), output_data,
- shuffled_input_workspace_data, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, void* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims, gemm_context);
-}
-
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index 683ccdc74d..be99240b1f 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -19,6 +19,8 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -26,6 +28,1070 @@ namespace tflite {
namespace reference_ops {
+static constexpr int kDepthwiseReverseShift = -1;
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, pad_width, pad_height,
+ depth_multiplier, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ float* output_data, const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, pad_width, pad_height,
+ depth_multiplier, output_data, output_dims);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// Legacy, for compatibility with old checked-in code.
+template <FusedActivationFunctionType Ac>
+void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
+ filter_dims, filter_offset, bias_data, bias_dims, stride,
+ stride, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float* output_data, const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, dilation_width_factor,
+ dilation_height_factor, pad_width, pad_height, output_activation_min,
+ output_activation_max, output_data, output_dims, im2col_data,
+ im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
+ stride_width, stride_height, 1, 1, pad_width, pad_height,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
+ output_dims, im2col_data, im2col_dims);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width, stride_height,
+ pad_width, pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data, output_dims,
+ im2col_data, im2col_dims, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims, int stride,
+ int pad_width, int pad_height, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
+ const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
+ Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride, stride, pad_width,
+ pad_height, output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims, im2col_data, im2col_dims, gemm_context);
+}
+
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Div(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Div(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+inline void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.inputs_count = inputs_count;
+
+ Concatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Concatenation(int concat_dim, const uint8* const* input_data,
+ const Dims<4>* const* input_dims,
+ const int32* input_zeropoint,
+ const float* input_scale, int inputs_count,
+ uint8* output_data, const Dims<4>& output_dims,
+ const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void DepthConcatenation(const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.inputs_count = inputs_count;
+
+ DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int axis, int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ std::vector<RuntimeShape> output_shapes(outputs_count);
+ std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
+ for (int i = 0; i < outputs_count; ++i) {
+ ShapeFromDims(*output_dims[i], &output_shapes[i]);
+ output_shapes_indirect[i] = &output_shapes[i];
+ }
+ tflite::SplitParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Split(op_params, DimsToShape(input_dims), input_data,
+ output_shapes_indirect.data(), output_data);
+}
+
+template <FusedActivationFunctionType Ac, typename Scalar>
+void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
+ int outputs_count, Scalar* const* output_data,
+ const Dims<4>* const* output_dims) {
+ TFLITE_DCHECK_GE(outputs_count, 1);
+ for (int i = 0; i < outputs_count; i++) {
+ /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
+ /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
+ /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
+ }
+ // For now we don't have a model with a Split with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
+ output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = zero_point;
+ op_params.scale = scale;
+
+ Dequantize(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, int num_bits, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = num_bits;
+ op_params.minmax.min = rmin;
+ op_params.minmax.max = rmax;
+
+ FakeQuant(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::GatherParams op_params;
+ op_params.input_rank = input_rank;
+
+ Gather(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline uint32 LegacyReverseBits32(uint32 n) {
+ n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+ n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+ n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+ return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+ ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+ std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+ std::reverse(p->strides, p->strides + p->strides_count);
+
+ p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+ (32 - p->start_indices_count);
+ p->ellipsis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+ (32 - p->start_indices_count);
+ p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+ (32 - p->start_indices_count);
+ p->new_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+ (32 - p->start_indices_count);
+ p->shrink_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+ (32 - p->start_indices_count);
+}
+
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices,
+ const std::vector<int>& stop_indices,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_EQ(start_indices.size(), 4);
+ auto op_params = strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+ strides);
+ StridedSliceReverseIndices(&op_params);
+
+ StridedSlice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::MeanParams op_params;
+ op_params.axis_count = reduction_indices.size();
+ for (int i = 0; i < op_params.axis_count; ++i) {
+ op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
+ }
+
+ Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Transpose(const T* input, const Dims<4>& input_dims, T* output,
+ const Dims<4>& output_dims, const int* permuted_axes) {
+ TransposeParams params;
+ params.perm_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ params.perm[i] = 3 - permuted_axes[3 - i];
+ }
+ Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
+ output);
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void Comparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, bool* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ComparisonParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(const T* input1_data,
+ const Dims<4>& input1_dims,
+ const T* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
+ input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 input2_multiplier, int input2_shift,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ BroadcastComparison4DSlowWithScaling<T, F>(
+ op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+#define TFLITE_LEGACY_COMPARISON_OP(name) \
+ template <typename T> \
+ inline void name(const T* input1_data, const Dims<4>& input1_dims, \
+ const T* input2_data, const Dims<4>& input2_dims, \
+ bool* output_data, const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, input1_shift, \
+ input2_data, input2_dims, input2_offset, \
+ input2_multiplier, input2_shift, output_data, \
+ output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
+ const Dims<4>& input2_dims, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ }
+TFLITE_LEGACY_COMPARISON_OP(Equal);
+TFLITE_LEGACY_COMPARISON_OP(NotEqual);
+TFLITE_LEGACY_COMPARISON_OP(Greater);
+TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
+TFLITE_LEGACY_COMPARISON_OP(Less);
+TFLITE_LEGACY_COMPARISON_OP(LessEqual);
+#undef TFLITE_LEGACY_COMPARISON_OP
+
+template <typename D, typename T>
+inline void Select(const D* input_condition_data,
+ const Dims<4>& input_condition_dims, const T* input_x_data,
+ const Dims<4>& input_x_dims, const T* input_y_data,
+ const Dims<4>& input_y_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ Select(DimsToShape(input_condition_dims), input_condition_data,
+ DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
+ input_y_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename D, typename T>
+inline void RankOneSelect(const D* input_condition_data,
+ const Dims<4>& input_condition_dims,
+ const T* input_x_data, const Dims<4>& input_x_dims,
+ const T* input_y_data, const Dims<4>& input_y_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
+ DimsToShape(input_x_dims), input_x_data,
+ DimsToShape(input_y_dims), input_y_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, typename TI>
+inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
+ const T* values, T default_value, T* output_data,
+ const Dims<4>& output_dims, bool value_is_scalar) {
+ SparseToDense(indices, values, default_value, value_is_scalar,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::PackParams op_params;
+ op_params.axis = 3 - dim;
+ op_params.inputs_count = inputs_count;
+
+ Pack(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
+ int dimensions, int outputs_count, Scalar* const* output_datas,
+ const Dims<4>& output_dims) {
+ tflite::UnpackParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Unpack(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_datas);
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, const int32* input_zeropoint,
+ const float* input_scale, int inputs_count, Scalar* output_data,
+ const Dims<4>& output_dims, const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::PackParams op_params;
+ op_params.axis = 3 - dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
float* output_data, const RuntimeShape& output_shape) {
@@ -342,7 +1408,6 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
-// Legacy.
// Transitional version that will be moved shortly to legacy_reference_ops, as
// part of RuntimeShape revisions.
inline void BroadcastMul4DSlow(const uint8* input1_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 7a5535489a..59f17ae854 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -231,83 +231,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.float_activation_min = output_activation_min;
- op_params.float_activation_max = output_activation_max;
-
- Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
- filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float* output_data, const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, dilation_width_factor,
- dilation_height_factor, pad_width, pad_height, output_activation_min,
- output_activation_max, output_data, output_dims, im2col_data,
- im2col_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
- stride_width, stride_height, 1, 1, pad_width, pad_height,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
- bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
- output_dims, im2col_data, im2col_dims);
-}
-
inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
@@ -391,111 +314,6 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- uint8* im2col_data, const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- op_params.input_offset = input_offset;
- op_params.weights_offset = filter_offset;
- op_params.output_offset = output_offset;
- op_params.output_multiplier = output_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.output_shift = kReverseShift * output_shift;
- op_params.quantized_activation_min = output_activation_min;
- op_params.quantized_activation_max = output_activation_max;
-
- Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
- filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride_width, stride_height,
- pad_width, pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data, output_dims,
- im2col_data, im2col_dims, gemm_context);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int stride,
- int pad_width, int pad_height, int32 output_offset,
- int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
- const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) {
- Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, stride, pad_width,
- pad_height, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data,
- output_dims, im2col_data, im2col_dims, gemm_context);
-}
-
template <typename T>
inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
const RuntimeShape& unextended_input_shape,
@@ -1385,21 +1203,6 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void Div(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
@@ -1418,21 +1221,6 @@ inline void Div(const ArithmeticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename T>
-inline void Div(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- tflite::ArithmeticParams op_params;
- SetActivationParams(output_activation_min, output_activation_max, &op_params);
-
- Div(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
@@ -1772,35 +1560,10 @@ inline void Concatenation(const ConcatenationParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <FusedActivationFunctionType Ac, typename Scalar>
-inline void Concatenation(int concat_dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- // For now we don't have a model with a Concatenation with fused activation.
- TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
-
- std::vector<RuntimeShape> input_shapes(inputs_count);
- std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
- for (int i = 0; i < inputs_count; ++i) {
- ShapeFromDims(*input_dims[i], &input_shapes[i]);
- input_shapes_indirect[i] = &input_shapes[i];
- }
- tflite::ConcatenationParams op_params;
- op_params.axis = 3 - concat_dim;
- op_params.inputs_count = inputs_count;
-
- Concatenation(op_params, input_shapes_indirect.data(), input_data,
- DimsToShape(output_dims), output_data);
-}
-
// TODO(prabhumk): This is the same as the optimized implementation.
// TODO(prabhumk): The quantized implementation of concatentation isn't fully
// quantized as it takes scale as a floating point value. This should be fixed
// when optimizng this routine further.
-
-// template <>
inline void ConcatenationWithScaling(const ConcatenationParams& params,
const RuntimeShape* const* input_shapes,
const uint8* const* input_data,
@@ -1813,15 +1576,13 @@ inline void ConcatenationWithScaling(const ConcatenationParams& params,
const int32 output_zeropoint = params.output_zeropoint;
const float output_scale = params.output_scale;
- // The arguments input_zeropoint and input_scale are expected to be an array
- // that have the quantization parameters for all the inputs to the concat
- // operator.
- TFLITE_DCHECK_GT(inputs_count, 1);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int concat_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
- TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), 4);
- for (int j = 0; j < 4; j++) {
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
if (j != axis) {
MatchingDim(*input_shapes[i], j, output_shape, j);
}
@@ -1836,9 +1597,10 @@ inline void ConcatenationWithScaling(const ConcatenationParams& params,
// For all input arrays,
// FlatSize() = outer_size * Dims(axis) * base_inner_size;
int64_t base_inner_size = 1;
- for (int i = axis + 1; i < 4; ++i) {
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
base_inner_size *= output_shape.Dims(i);
}
+
const float inverse_output_scale = 1.f / output_scale;
uint8* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
@@ -1864,65 +1626,52 @@ inline void ConcatenationWithScaling(const ConcatenationParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-inline void Concatenation(int concat_dim, const uint8* const* input_data,
- const Dims<4>* const* input_dims,
- const int32* input_zeropoint,
- const float* input_scale, int inputs_count,
- uint8* output_data, const Dims<4>& output_dims,
- const int32 output_zeropoint,
- const float output_scale) {
- std::vector<RuntimeShape> input_shapes(inputs_count);
- std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
- for (int i = 0; i < inputs_count; ++i) {
- ShapeFromDims(*input_dims[i], &input_shapes[i]);
- input_shapes_indirect[i] = &input_shapes[i];
- }
- tflite::ConcatenationParams op_params;
- op_params.axis = 3 - concat_dim;
- op_params.input_zeropoint = input_zeropoint;
- op_params.input_scale = input_scale;
- op_params.inputs_count = inputs_count;
- op_params.output_zeropoint = output_zeropoint;
- op_params.output_scale = output_scale;
-
- ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename Scalar>
-void Pack(int dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data, const RuntimeShape& output_shape,
+ Scalar* output_data) {
+ const int dimensions = output_shape.DimensionsCount();
+ int axis = params.axis;
+ int inputs_count = params.inputs_count;
+
int outer_size = 1;
- for (int i = dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ for (int i = 0; i < axis; i++) {
+ outer_size *= output_shape.Dims(i);
}
- Scalar* output_ptr = output_data;
- const int copy_size = FlatSize(**input_dims) / outer_size;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < inputs_count; ++i) {
- memcpy(output_ptr, input_data[i] + k * copy_size,
- copy_size * sizeof(Scalar));
- output_ptr += copy_size;
+ int copy_size = 1;
+ for (int i = params.axis + 1; i < dimensions; i++) {
+ copy_size *= output_shape.Dims(i);
+ }
+ TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
+
+ for (int i = 0; i < inputs_count; ++i) {
+ for (int k = 0; k < outer_size; k++) {
+ const Scalar* input_ptr = input_data[i] + copy_size * k;
+ int loc = k * inputs_count * copy_size + i * copy_size;
+ memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
}
}
}
template <typename Scalar>
-void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
- int dimensions, int outputs_count, Scalar* const* output_datas,
- const Dims<4>& output_dims) {
+void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
+ const Scalar* input_data, const RuntimeShape& output_shape,
+ Scalar* const* output_datas) {
+ const int dimensions = input_shape.DimensionsCount();
+ const int outputs_count = params.num_split;
+
int outer_size = 1;
- for (int i = dimensions - axis; i < 4; i++) {
- outer_size *= input_dims.sizes[i];
+ for (int i = 0; i < params.axis; i++) {
+ outer_size *= input_shape.Dims(i);
+ }
+ int copy_size = 1;
+ for (int i = params.axis + 1; i < dimensions; i++) {
+ copy_size *= input_shape.Dims(i);
}
+ TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
- const int copy_size = FlatSize(input_dims) / outer_size / outputs_count;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < outputs_count; ++i) {
+ for (int i = 0; i < outputs_count; ++i) {
+ for (int k = 0; k < outer_size; k++) {
Scalar* output_ptr = output_datas[i] + copy_size * k;
int loc = k * outputs_count * copy_size + i * copy_size;
memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
@@ -1931,18 +1680,29 @@ void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
}
template <typename Scalar>
-void Pack(int dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, const int32* input_zeropoint,
- const float* input_scale, int inputs_count, Scalar* output_data,
- const Dims<4>& output_dims, const int32 output_zeropoint,
- const float output_scale) {
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+void PackWithScaling(const PackParams& params,
+ const RuntimeShape* const* input_shapes,
+ const uint8* const* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int dimensions = output_shape.DimensionsCount();
+ int axis = params.axis;
+ const int32* input_zeropoint = params.input_zeropoint;
+ const float* input_scale = params.input_scale;
+ int inputs_count = params.inputs_count;
+ const int32 output_zeropoint = params.output_zeropoint;
+ const float output_scale = params.output_scale;
+
int outer_size = 1;
- for (int i = dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ for (int i = 0; i < axis; i++) {
+ outer_size *= output_shape.Dims(i);
+ }
+ int copy_size = 1;
+ for (int i = axis + 1; i < dimensions; i++) {
+ copy_size *= output_shape.Dims(i);
}
+ TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
+
Scalar* output_ptr = output_data;
- const int copy_size = FlatSize(**input_dims) / outer_size;
const float inverse_output_scale = 1.f / output_scale;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
@@ -1968,12 +1728,15 @@ void Pack(int dim, const Scalar* const* input_data,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void DepthConcatenation(const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
- output_data, output_dims);
+template <typename Scalar>
+void DepthConcatenation(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data,
+ const RuntimeShape& output_shape, Scalar* output_data) {
+ auto params_copy = params;
+ params_copy.axis = 3;
+ Concatenation(params_copy, input_shapes, input_data, output_shape,
+ output_data);
}
inline void LstmCell(
@@ -2093,31 +1856,6 @@ inline void LstmCell(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
- tflite::LstmCellParams op_params;
- // Float LSTM cell does not need parameters to be set: leave untouched.
-
- LstmCell(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(prev_activ_dims), prev_activ_data,
- DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
- bias_data, DimsToShape(prev_state_dims), prev_state_data,
- DimsToShape(output_state_dims), output_state_data,
- DimsToShape(output_activ_dims), output_activ_data,
- DimsToShape(concat_temp_dims), concat_temp_data,
- DimsToShape(activ_temp_dims), activ_temp_data);
-}
-
// Quantized LSTM cell implementation.
// The quantization of the input, output arrays is as follows:
// - The input activations are quantized as uint8 on the interval
@@ -2392,37 +2130,6 @@ inline void LstmCell(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
- tflite::LstmCellParams op_params;
- op_params.weights_zero_point = weights_zero_point;
- op_params.accum_multiplier = accum_multiplier;
- op_params.accum_shift = accum_shift;
-
- LstmCell<StateIntegerBits>(
- op_params, DimsToShape(input_dims), input_data_uint8,
- DimsToShape(prev_activ_dims), prev_activ_data_uint8,
- DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
- bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
- DimsToShape(output_state_dims), output_state_data_int16,
- DimsToShape(output_activ_dims), output_activ_data_uint8,
- DimsToShape(concat_temp_dims), concat_temp_data_uint8,
- DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
-}
-
template <typename Scalar>
void Split(const SplitParams& params, const RuntimeShape& input_shape,
const Scalar* input_data, const RuntimeShape* const* output_shapes,
@@ -2465,45 +2172,6 @@ void Split(const SplitParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
- int axis, int outputs_count, Scalar* const* output_data,
- const Dims<4>* const* output_dims) {
- std::vector<RuntimeShape> output_shapes(outputs_count);
- std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
- for (int i = 0; i < outputs_count; ++i) {
- ShapeFromDims(*output_dims[i], &output_shapes[i]);
- output_shapes_indirect[i] = &output_shapes[i];
- }
- tflite::SplitParams op_params;
- op_params.axis = 3 - axis;
- op_params.num_split = outputs_count;
-
- Split(op_params, DimsToShape(input_dims), input_data,
- output_shapes_indirect.data(), output_data);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <FusedActivationFunctionType Ac, typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
- int outputs_count, Scalar* const* output_data,
- const Dims<4>* const* output_dims) {
- TFLITE_DCHECK_GE(outputs_count, 1);
- for (int i = 0; i < outputs_count; i++) {
- /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
- /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
- /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
- }
- // For now we don't have a model with a Split with fused activation.
- TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
-
- TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
- output_data, output_dims);
-}
-
inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
@@ -2834,15 +2502,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
- SoftmaxParams params;
- // No params currently used for float LogSoftmax.
- LogSoftmax(params, input_shape, input_data, output_shape, output_data);
-}
-
// Although currently the name of this function says that it cannot handle
// values less than 1, in practice it can handle as low as 1/x_max, where
// x_max is the largest representable input. In other words, the output range
@@ -3047,22 +2706,6 @@ inline void LogSoftmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- params.reverse_scaling_divisor = reverse_scaling_divisor;
- params.reverse_scaling_right_shift = reverse_scaling_right_shift;
- params.diff_min = diff_min;
- LogSoftmax(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3124,20 +2767,6 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
- LogisticParams params;
- params.input_zero_point = input_zero_point;
- params.input_range_radius = input_range_radius;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Logistic(const LogisticParams& params,
const RuntimeShape& input_shape, const int16* input_data,
const RuntimeShape& output_shape, int16* output_data) {
@@ -3157,15 +2786,6 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
- const RuntimeShape& output_shape, int16* output_data) {
- LogisticParams params;
- // No params currently needed by int16 Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3229,20 +2849,6 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
- TanhParams params;
- params.input_zero_point = input_zero_point;
- params.input_range_radius = input_range_radius;
- params.input_multiplier = input_multiplier;
- params.input_left_shift = input_left_shift;
- Tanh(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
const int16* input_data, const RuntimeShape& output_shape,
int16* output_data) {
@@ -3277,16 +2883,6 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
- int input_left_shift, int16* output_data,
- const RuntimeShape& output_shape) {
- TanhParams params;
- params.input_left_shift = input_left_shift;
- Tanh(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Dequantize(const tflite::DequantizationParams& op_params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, float* output_data) {
@@ -3301,19 +2897,6 @@ inline void Dequantize(const tflite::DequantizationParams& op_params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
- int32 zero_point, double scale, float* output_data,
- const Dims<4>& output_dims) {
- tflite::DequantizationParams op_params;
- op_params.zero_point = zero_point;
- op_params.scale = scale;
-
- Dequantize(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void FakeQuant(const tflite::FakeQuantParams& op_params,
const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
@@ -3337,20 +2920,6 @@ inline void FakeQuant(const tflite::FakeQuantParams& op_params,
output_data, flat_size);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
- float rmin, float rmax, int num_bits, float* output_data,
- const Dims<4>& output_dims) {
- tflite::FakeQuantParams op_params;
- op_params.num_bits = num_bits;
- op_params.minmax.min = rmin;
- op_params.minmax.max = rmax;
-
- FakeQuant(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename SrcT, typename DstT>
inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
const RuntimeShape& output_shape, DstT* output_data) {
@@ -3374,15 +2943,21 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data,
template <typename T>
inline void Gather(const tflite::GatherParams& op_params,
- const RuntimeShape& input_shape, const T* input_data,
- const RuntimeShape& coords_shape, const int32* coords_data,
- const RuntimeShape& output_shape, T* output_data) {
- // Enable these checks when moving legacy ops to legacy_reference_ops.
- //
- // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data, const RuntimeShape& coords_shape,
+ const int32* coords_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
const int input_rank = op_params.input_rank;
const int gather_dimensions = output_shape.DimensionsCount();
- TFLITE_DCHECK_LE(input_shape.DimensionsCount(), gather_dimensions);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), gather_dimensions);
const int axis = gather_dimensions - input_rank;
TFLITE_DCHECK_LT(axis, gather_dimensions);
TFLITE_DCHECK_GE(axis, 0);
@@ -3404,23 +2979,6 @@ inline void Gather(const tflite::GatherParams& op_params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4> version.
-// When moving legacy ops to legacy_reference_ops, replace content with looser
-// implementation.
-template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
- int input_rank, const int32* coords_data,
- const Dims<4>& coords_dims, T* output_data,
- const Dims<4>& output_dims) {
- tflite::GatherParams op_params;
- op_params.input_rank = input_rank;
-
- Gather(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
- output_data);
-}
-
template <typename T>
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
@@ -3750,58 +3308,6 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline uint32 LegacyReverseBits32(uint32 n) {
- n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
- n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
- n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
- return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
- ((n & 0xFF000000) >> 24));
-}
-
-inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
- TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
- TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
-
- std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
- std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
- std::reverse(p->strides, p->strides + p->strides_count);
-
- p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
- (32 - p->start_indices_count);
- p->ellipsis_mask =
- LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
- (32 - p->start_indices_count);
- p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
- (32 - p->start_indices_count);
- p->new_axis_mask =
- LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
- (32 - p->start_indices_count);
- p->shrink_axis_mask =
- LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
- (32 - p->start_indices_count);
-}
-
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask, int shrink_axis_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- auto op_params = strided_slice::BuildStridedSliceParams(
- begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
- strides);
- StridedSliceReverseIndices(&op_params);
-
- StridedSlice(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
@@ -4067,22 +3573,6 @@ inline void Mean(const tflite::MeanParams& op_params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy Dims<4>.
-template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& reduction_indices, T* output_data,
- const Dims<4>& output_dims) {
- tflite::MeanParams op_params;
- op_params.axis_count = reduction_indices.size();
- for (int i = 0; i < op_params.axis_count; ++i) {
- op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
- }
-
- Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
-}
-
// Computes the mean of elements across dimensions given in axis.
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis for quantized values.
@@ -4340,20 +3830,6 @@ void Transpose(const TransposeParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T>
-void Transpose(const T* input, const Dims<4>& input_dims, T* output,
- const Dims<4>& output_dims, const int* permuted_axes) {
- TransposeParams params;
- params.perm_count = 4;
- for (int i = 0; i < 4; ++i) {
- params.perm[i] = 3 - permuted_axes[3 - i];
- }
- Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
- output);
-}
-
inline void TransposeConv(
const ConvParams& params, const RuntimeShape& input_shape,
const float* input_data, const RuntimeShape& filter_shape,
@@ -4427,27 +3903,6 @@ inline void TransposeConv(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- tflite::ConvParams op_params;
- // Padding type is ignored, but still set.
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
-
- TransposeConv(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
- output_data, DimsToShape(im2col_dims), im2col_data);
-}
-
template <typename T>
inline bool EqualFn(T lhs, T rhs) {
return lhs == rhs;
@@ -4501,19 +3956,6 @@ inline void Comparison(const ComparisonParams& op_params,
input2_data, output_shape, output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, ComparisonFn<T> F>
-inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims) {
- ComparisonParams op_params;
- // No parameters needed.
- ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T, ComparisonFn<int32> F>
inline void ComparisonWithScaling(
const ComparisonParams& op_params, const RuntimeShape& input1_shape,
@@ -4544,32 +3986,6 @@ inline void ComparisonWithScaling(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, ComparisonFn<int32> F>
-inline void Comparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const T* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, bool* output_data,
- const Dims<4>& output_dims) {
- tflite::ComparisonParams op_params;
- op_params.left_shift = left_shift;
- op_params.input1_offset = input1_offset;
- op_params.input1_multiplier = input1_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.input1_shift = kReverseShift * input1_shift;
- op_params.input2_offset = input2_offset;
- op_params.input2_multiplier = input2_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.input2_shift = kReverseShift * input2_shift;
-
- ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T, ComparisonFn<T> F>
inline void BroadcastComparison4DSlowImpl(
const ComparisonParams& op_params,
@@ -4613,22 +4029,6 @@ inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
output_shape, output_data);
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, ComparisonFn<T> F>
-inline void BroadcastComparison(const T* input1_data,
- const Dims<4>& input1_dims,
- const T* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims) {
- ComparisonParams op_params;
- // No parameters needed.
- BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
- input1_data, DimsToShape(input2_dims),
- input2_data, DimsToShape(output_dims),
- output_data);
-}
-
template <typename T, ComparisonFn<int32> F>
inline void BroadcastComparison4DSlowWithScaling(
const ComparisonParams& op_params,
@@ -4679,80 +4079,7 @@ inline void BroadcastComparison4DSlowWithScaling(
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, ComparisonFn<int32> F>
-inline void BroadcastComparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const T* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 input2_multiplier, int input2_shift,
- bool* output_data, const Dims<4>& output_dims) {
- ComparisonParams op_params;
-
- op_params.left_shift = left_shift;
- op_params.input1_offset = input1_offset;
- op_params.input1_multiplier = input1_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.input1_shift = kReverseShift * input1_shift;
- op_params.input2_offset = input2_offset;
- op_params.input2_multiplier = input2_multiplier;
- // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
- op_params.input2_shift = kReverseShift * input2_shift;
-
- BroadcastComparison4DSlowWithScaling<T, F>(
- op_params, DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
- output_data);
-}
-
#define TFLITE_COMPARISON_OP(name) \
- template <typename T> \
- inline void name(const T* input1_data, const Dims<4>& input1_dims, \
- const T* input2_data, const Dims<4>& input2_dims, \
- bool* output_data, const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name); \
- Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
- Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, input1_shift, \
- input2_data, input2_dims, input2_offset, \
- input2_multiplier, input2_shift, output_data, \
- output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
- const Dims<4>& input2_dims, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
- BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
- BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, \
- input1_shift, input2_data, input2_dims, \
- input2_offset, input2_multiplier, \
- input2_shift, output_data, output_dims); \
- } \
inline void name(const ComparisonParams& op_params, \
const RuntimeShape& input1_shape, const float* input1_data, \
const RuntimeShape& input2_shape, const float* input2_data, \
@@ -4762,22 +4089,44 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
input2_data, output_shape, output_data); \
} \
template <typename T> \
+ inline void name##NoScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name "NoScaling"); \
+ ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, output_shape, \
+ output_data); \
+ } \
+ template <typename T> \
inline void name##WithScaling( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
- gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ gemmlowp::ScopedProfilingLabel label(#name "WithScaling/8bit"); \
ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, \
output_shape, output_data); \
} \
+ template <typename T> \
+ inline void Broadcast4DSlow##name##NoScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "NoScaling"); \
+ BroadcastComparison4DSlowImpl<T, name##Fn>( \
+ op_params, input1_shape, input1_data, input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
inline void Broadcast4DSlow##name( \
const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
const float* input1_data, const RuntimeShape& input2_shape, \
const float* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name); \
BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
input2_shape, input2_data, \
output_shape, output_data); \
@@ -4788,7 +4137,7 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
const T* input1_data, const RuntimeShape& input2_shape, \
const T* input2_data, const RuntimeShape& output_shape, \
bool* output_data) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ gemmlowp::ScopedProfilingLabel label("Broadcast4DSlow" #name "/8bit"); \
BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
op_params, input1_shape, input1_data, input2_shape, input2_data, \
output_shape, output_data); \
@@ -4815,19 +4164,6 @@ void Select(const RuntimeShape& input_condition_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename D, typename T>
-inline void Select(const D* input_condition_data,
- const Dims<4>& input_condition_dims, const T* input_x_data,
- const Dims<4>& input_x_dims, const T* input_y_data,
- const Dims<4>& input_y_dims, T* output_data,
- const Dims<4>& output_dims) {
- Select(DimsToShape(input_condition_dims), input_condition_data,
- DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
- input_y_data, DimsToShape(output_dims), output_data);
-}
-
template <typename D, typename T>
void RankOneSelect(const RuntimeShape& input_condition_shape,
const D* input_condition_data,
@@ -4849,20 +4185,6 @@ void RankOneSelect(const RuntimeShape& input_condition_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename D, typename T>
-inline void RankOneSelect(const D* input_condition_data,
- const Dims<4>& input_condition_dims,
- const T* input_x_data, const Dims<4>& input_x_dims,
- const T* input_y_data, const Dims<4>& input_y_dims,
- T* output_data, const Dims<4>& output_dims) {
- RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
- DimsToShape(input_x_dims), input_x_data,
- DimsToShape(input_y_dims), input_y_data,
- DimsToShape(output_dims), output_data);
-}
-
// For easy implementation, the indices is always a vector of size-4 vectors.
template <typename T, typename TI>
inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
@@ -4904,16 +4226,6 @@ inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-template <typename T, typename TI>
-inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
- const T* values, T default_value, T* output_data,
- const Dims<4>& output_dims, bool value_is_scalar) {
- SparseToDense(indices, values, default_value, value_is_scalar,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/softmax.h b/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
index 006174e8db..7d44296134 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
@@ -57,16 +57,6 @@ inline void Softmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
- float beta, float* output_data,
- const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.beta = beta;
- Softmax(params, input_shape, input_data, output_shape, output_data);
-}
-
inline void Softmax(const SoftmaxParams& params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, uint8* output_data) {
@@ -151,19 +141,6 @@ inline void Softmax(const SoftmaxParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
- const RuntimeShape& output_shape) {
- SoftmaxParams params;
- params.input_multiplier = input_beta_multiplier;
- params.input_left_shift = input_beta_left_shift;
- params.diff_min = diff_min;
- Softmax(params, input_shape, input_data, output_shape, output_data);
-}
-
// Performs softmax along the input of size (input_size * batch_size).
inline void Softmax(const float* in, const int input_size, const int batch_size,
const float beta, float* out) {
diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
index ca94e7740e..831fb3c243 100644
--- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
@@ -43,11 +43,15 @@ void RunSoftmaxFloatReference(const uint8* input_data,
// Reference data generated via Dequant of input into float, and then applying
// float Softmax.
- reference_ops::Dequantize(
- input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
- reference_dequant_data.data(), ToRuntimeDims(shape_common));
- optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta,
- reference_output_float_data.data(), shape_common);
+ DequantizationParams dq_params;
+ dq_params.zero_point = input_offset;
+ dq_params.scale = input_scale;
+ reference_ops::Dequantize(dq_params, shape_common, input_data, shape_common,
+ reference_dequant_data.data());
+ SoftmaxParams sm_params;
+ sm_params.beta = beta;
+ optimized_ops::Softmax(sm_params, shape_common, reference_dequant_data.data(),
+ shape_common, reference_output_float_data.data());
// Work with quantized scaling for Softmax, under which 256 represents 1, but
// we limit this to 255.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -116,12 +120,14 @@ void RunOneSoftmaxTest(const uint8* input_data,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, diff_min,
- optimized_softmax_output.data(), shape_common);
- reference_ops::Softmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, diff_min,
- reference_quant_softmax_output.data(), shape_common);
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ optimized_ops::Softmax(params, shape_common, input_data, shape_common,
+ optimized_softmax_output.data());
+ reference_ops::Softmax(params, shape_common, input_data, shape_common,
+ reference_quant_softmax_output.data());
CheckOutputData(optimized_softmax_output.data(),
reference_float_softmax_output.data(), shape_common,
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index 13106456df..689cea03e7 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -37,10 +37,6 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
: nullptr;
}
-inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
- return GetTensorDims(data.data(), data.size());
-}
-
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
@@ -56,20 +52,20 @@ class VectorOfTensors {
int num_tensors = tensor_list.size;
all_data_.reserve(num_tensors);
- all_dims_.reserve(num_tensors);
- all_dims_ptr_.reserve(num_tensors);
+ all_shape_.reserve(num_tensors);
+ all_shape_ptr_.reserve(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
all_data_.push_back(GetTensorData<T>(t));
- all_dims_.push_back(GetTensorDims(t));
+ all_shape_.push_back(GetTensorShape(t));
}
// Taking the pointer from inside a std::vector is only OK if the vector is
- // never modified, so we populate all_dims in the previous loop and then we
+ // never modified, so we populate all_shape in the previous loop and then we
// are free to grab iterators here.
for (int i = 0; i < num_tensors; ++i) {
- all_dims_ptr_.push_back(&all_dims_[i]);
+ all_shape_ptr_.push_back(&all_shape_[i]);
}
}
// Return a pointer to the data pointers of all tensors in the list. For
@@ -78,16 +74,16 @@ class VectorOfTensors {
// f[0][1] is the second element of the first tensor.
T* const* data() const { return all_data_.data(); }
- // Return a pointer the dim pointers of all tensors in the list. For
+ // Return a pointer the shape pointers of all tensors in the list. For
// example:
- // const Dims<4>* const* d = v.dims();
+ // const RuntimeShape* const* d = v.dims();
// dims[1] are the dimensions of the second tensor in the list.
- const Dims<4>* const* dims() const { return all_dims_ptr_.data(); }
+ const RuntimeShape* const* shapes() const { return all_shape_ptr_.data(); }
private:
std::vector<T*> all_data_;
- std::vector<Dims<4>> all_dims_;
- std::vector<Dims<4>*> all_dims_ptr_;
+ std::vector<RuntimeShape> all_shape_;
+ std::vector<RuntimeShape*> all_shape_ptr_;
};
// A list of quantized tensors in a format that can be used by kernels like
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
index 77e22a08b4..9f5b33d217 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -86,39 +86,6 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
-inline int RemapDim(int max_dimensions, int d) {
- return max_dimensions - d - 1;
-}
-
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
-}
-
inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
if (tensor == nullptr) {
return RuntimeShape();
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
index bf2068d320..2ed73ba82d 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
@@ -21,28 +21,32 @@ namespace {
using ::testing::ElementsAre;
-TEST(TensorTest, GetTensorDims4D) {
- Dims<4> d = GetTensorDims({2, 3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape4D) {
+ RuntimeShape d = GetTensorShape({2, 3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(2, 3, 4, 5));
}
-TEST(TensorTest, GetTensorDims3D) {
- Dims<4> d = GetTensorDims({3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape3D) {
+ RuntimeShape d = GetTensorShape({3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(3, 4, 5));
}
-TEST(TensorTest, GetTensorDims2D) {
- Dims<4> d = GetTensorDims({4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20));
+TEST(TensorTest, GetTensorShape2D) {
+ RuntimeShape d = GetTensorShape({4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(4, 5));
}
-TEST(TensorTest, GetTensorDims1D) {
- Dims<4> d = GetTensorDims({5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5));
+TEST(TensorTest, GetTensorShape1D) {
+ RuntimeShape d = GetTensorShape({5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(5));
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index a3a5994c9c..b39347758a 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -875,6 +875,15 @@ struct MeanParams {
int16 axis[4];
};
+struct PackParams {
+ int8 axis;
+ const int32* input_zeropoint;
+ const float* input_scale;
+ uint16 inputs_count;
+ int32 output_zeropoint;
+ float output_scale;
+};
+
struct PadParams {
int8 left_padding_count;
int32 left_padding[4];
@@ -975,6 +984,11 @@ struct TransposeParams {
int32 perm[4];
};
+struct UnpackParams {
+ uint16 num_split;
+ int16 axis;
+};
+
template <typename P>
inline void SetActivationParams(float min, float max, P* params) {
params->float_activation_min = min;
diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
index 9a8d35e82c..1acc966cdc 100644
--- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
@@ -91,8 +91,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::LogSoftmax(input_buffer, input_shape,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ tflite::reference_ops::LogSoftmax(params, input_shape, input_buffer,
+ input_shape, output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index aaa3ce966e..5b996d00bc 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -893,18 +893,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
activation_out->type == kTfLiteFloat32 &&
concat_temp->type == kTfLiteFloat32 &&
activation_temp->type == kTfLiteFloat32) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
optimized_ops::LstmCell(
+ op_params,
// Inputs.
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(prev_activation), GetTensorDims(prev_activation),
- GetTensorData<float>(weights), GetTensorDims(weights),
- GetTensorData<float>(bias), GetTensorDims(bias),
- GetTensorData<float>(prev_state), GetTensorDims(prev_state),
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(prev_activation), GetTensorData<float>(prev_activation),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(prev_state), GetTensorData<float>(prev_state),
// Outputs.
- GetTensorData<float>(state_out), GetTensorDims(state_out),
- GetTensorData<float>(activation_out), GetTensorDims(activation_out),
- GetTensorData<float>(concat_temp), GetTensorDims(concat_temp),
- GetTensorData<float>(activation_temp), GetTensorDims(activation_temp));
+ GetTensorShape(state_out), GetTensorData<float>(state_out),
+ GetTensorShape(activation_out), GetTensorData<float>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
+ GetTensorShape(activation_temp), GetTensorData<float>(activation_temp));
} else if (input->type == kTfLiteUInt8 &&
prev_activation->type == kTfLiteUInt8 &&
weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
@@ -934,20 +937,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
int accum_shift;
tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
&accum_shift);
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights->params.zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
optimized_ops::LstmCell<4>(
+ op_params,
// Inputs.
- GetTensorData<uint8_t>(input), GetTensorDims(input),
- GetTensorData<uint8_t>(prev_activation), GetTensorDims(prev_activation),
- GetTensorData<uint8_t>(weights), GetTensorDims(weights),
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- GetTensorData<int16_t>(prev_state), GetTensorDims(prev_state),
+ GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(prev_activation),
+ GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights),
+ GetTensorData<uint8_t>(weights), GetTensorShape(bias),
+ GetTensorData<int32_t>(bias), GetTensorShape(prev_state),
+ GetTensorData<int16_t>(prev_state),
// Outputs.
- GetTensorData<int16_t>(state_out), GetTensorDims(state_out),
- GetTensorData<uint8_t>(activation_out), GetTensorDims(activation_out),
- GetTensorData<uint8_t>(concat_temp), GetTensorDims(concat_temp),
- GetTensorData<int16_t>(activation_temp), GetTensorDims(activation_temp),
- weights->params.zero_point, accum_multiplier, accum_shift,
- gemm_context);
+ GetTensorShape(state_out), GetTensorData<int16_t>(state_out),
+ GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
+ GetTensorShape(activation_temp),
+ GetTensorData<int16_t>(activation_temp), gemm_context);
} else {
context->ReportError(context,
"Unsupported combination of data types for LstmCell");
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index 4cb98fdd19..c368582ef7 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -85,9 +85,12 @@ template <typename T>
void PackImpl(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* output,
int values_count, int axis) {
VectorOfTensors<T> all_inputs(*context, *node->inputs);
- reference_ops::Pack<T>(RemapDim(NumDimensions(output), axis),
- all_inputs.data(), all_inputs.dims(), values_count,
- GetTensorData<T>(output), GetTensorDims(output));
+ tflite::PackParams op_params;
+ op_params.axis = axis;
+ op_params.inputs_count = values_count;
+
+ reference_ops::Pack<T>(op_params, all_inputs.shapes(), all_inputs.data(),
+ GetTensorShape(output), GetTensorData<T>(output));
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 2f4b663a28..9402105fa7 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -125,7 +125,7 @@ TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
context,
"Regular TensorFlow ops are not supported by this interpreter. Make sure "
- "you invoke the Eager delegate before inference.");
+ "you invoke the Flex delegate before inference.");
return kTfLiteError;
}
@@ -136,13 +136,13 @@ const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op,
const TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op,
int version) const {
- // Return the NULL Op for all ops whose name start with "Eager", allowing
+ // Return the NULL Op for all ops whose name start with "Flex", allowing
// the interpreter to delegate their execution.
- if (IsEagerOp(op)) {
+ if (IsFlexOp(op)) {
static TfLiteRegistration null_op{
nullptr, nullptr, &UnsupportedTensorFlowOp,
nullptr, nullptr, BuiltinOperator_CUSTOM,
- "Eager", 1};
+ "Flex", 1};
return &null_op;
}
return MutableOpResolver::FindOp(op, version);
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 3959502d91..4780a86ee5 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -70,12 +70,12 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
bool is_rank_one = !HaveSameShapes(input_condition, input_x);
-#define TF_LITE_SELECT(type, op) \
- reference_ops::op(GetTensorData<bool>(input_condition), \
- GetTensorDims(input_condition), \
- GetTensorData<type>(input_x), GetTensorDims(input_x), \
- GetTensorData<type>(input_y), GetTensorDims(input_y), \
- GetTensorData<type>(output), GetTensorDims(output));
+#define TF_LITE_SELECT(type, op) \
+ reference_ops::op(GetTensorShape(input_condition), \
+ GetTensorData<bool>(input_condition), \
+ GetTensorShape(input_x), GetTensorData<type>(input_x), \
+ GetTensorShape(input_y), GetTensorData<type>(input_y), \
+ GetTensorShape(output), GetTensorData<type>(output));
#define TF_LITE_SWITCH(type, op) \
switch (type) { \
diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc
index 727822f6be..bd66980226 100644
--- a/tensorflow/contrib/lite/kernels/softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/softmax_test.cc
@@ -93,8 +93,10 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ params.beta = beta;
+ tflite::reference_ops::Softmax(params, input_shape, input_buffer, input_shape,
+ output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
@@ -120,8 +122,10 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ params.beta = beta;
+ tflite::reference_ops::Softmax(params, input_shape, input_buffer, input_shape,
+ output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index 178568e07c..349fa0bd28 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -210,8 +210,9 @@ TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
&indices_vector));
reference_ops::SparseToDense(indices_vector, GetTensorData<T>(values),
*GetTensorData<T>(default_value),
- GetTensorData<T>(output), GetTensorDims(output),
- value_is_scalar);
+ value_is_scalar, GetTensorShape(output),
+ GetTensorData<T>(output));
+
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index 719e2dc606..dab887bf9c 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -109,25 +109,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (axis_value < 0) {
axis_value += NumDimensions(op_context.input);
}
- axis_value = RemapDim(NumDimensions(op_context.input), axis_value);
// TODO(ahentz): Our usage of VectorOfTensors could be optimized by
// calculating it in Prepare, unless we defer shape calculation.
// TODO(ahentz): We can improve the optimized_ops version to handle other
// cases too.
-#define TF_LITE_SPLIT(scalar) \
- VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
- if (axis_value == NumDimensions(op_context.input)) { \
- optimized_ops::TensorFlowSplit<FusedActivationFunctionType::kNone, \
- scalar>( \
- GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), NumOutputs(node), all_outputs.data(), \
- all_outputs.dims()); \
- } else { \
- reference_ops::TensorFlowSplit<scalar>( \
- GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), axis_value, NumOutputs(node), \
- all_outputs.data(), all_outputs.dims()); \
+#define TF_LITE_SPLIT(scalar) \
+ VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
+ tflite::SplitParams op_params; \
+ op_params.num_split = NumOutputs(node); \
+ op_params.axis = axis_value; \
+ if (axis_value == 0) { \
+ optimized_ops::Split(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ all_outputs.shapes(), all_outputs.data()); \
+ } else { \
+ reference_ops::Split(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ all_outputs.shapes(), all_outputs.data()); \
}
switch (op_context.input->type) {
case kTfLiteFloat32: {
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index 87ffcc4110..06b36dd196 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -57,17 +57,6 @@ struct StridedSliceContext {
int dims;
};
-// Reverse order of bits in the mask to match the expected order in kernel
-inline int ReverseMaskBits(int mask, int num_dimensions) {
- int out = 0;
- for (int dim = 0; dim < num_dimensions; dim++) {
- out <<= 1;
- out += (mask & 1);
- mask >>= 1;
- }
- return out;
-}
-
// This Op only supports 1-4D cases and since we use the reference 4D
// implementation, the 1-3D tensors are mapped to 4D.
const int kMaxDim = 4;
@@ -198,30 +187,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
std::vector<int32_t> stops;
std::vector<int32_t> strides;
- for (int idx = op_context.dims - 1; idx >= 0; --idx) {
- starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
- stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
- strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
- }
-
for (int i = op_context.dims; i < kMaxDim; i++) {
starts.emplace_back(0);
stops.emplace_back(1);
strides.emplace_back(1);
}
- int begin_mask =
- ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
- int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims);
- int shrink_axis_mask =
- ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
-
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
- kernel_type::StridedSlice( \
- GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \
- starts, stops, strides, GetTensorData<data_type>(op_context.output), \
- GetTensorDims(op_context.output))
+ for (int idx = 0; idx < op_context.dims; ++idx) {
+ starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
+ stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
+ strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
+ }
+
+ int begin_mask = op_context.params->begin_mask << (4 - op_context.dims);
+ int end_mask = op_context.params->end_mask << (4 - op_context.dims);
+ int shrink_axis_mask = op_context.params->shrink_axis_mask
+ << (4 - op_context.dims);
+ TF_LITE_ENSURE_EQ(context, starts.size(), 4);
+ auto op_params = ::tflite::strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, starts, stops, strides);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorShape(op_context.output), \
+ GetTensorData<data_type>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 0fdb0a3935..05a7c23ba1 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -122,7 +122,7 @@ void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
- interpreter_->ResetVariableTensorsToZero();
+ interpreter_->ResetVariableTensors();
}
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index 95359962e0..e42a30420b 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -92,26 +92,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
- // Reverse the permuted axes and convert to 4D due to the way Dims are
- // constructed in GetTensorDims.
const int* perm_data = GetTensorData<int32_t>(op_context.perm);
const int size = op_context.perm->dims->data[0];
- const int kOutputDimensionNum = 4;
- int reversed_perm[kOutputDimensionNum];
-
- for (int output_k = 0, input_k = size - 1; output_k < size;
- ++output_k, --input_k) {
- reversed_perm[output_k] = size - perm_data[input_k] - 1;
- }
- for (int k = size; k < kOutputDimensionNum; ++k) {
- reversed_perm[k] = k;
+ TransposeParams params;
+ params.perm_count = size;
+ for (int i = 0; i < size; ++i) {
+ params.perm[i] = perm_data[i];
}
#define TF_LITE_TRANSPOSE(type, scalar) \
- type::Transpose(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), reversed_perm)
+ type::Transpose(params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/contrib/lite/kernels/transpose_test.cc
index 337bc144b9..79ef0a7c56 100644
--- a/tensorflow/contrib/lite/kernels/transpose_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_test.cc
@@ -51,21 +51,21 @@ void RunTestPermutation(const std::vector<int>& shape,
reversed_perms[k] = k;
}
- // Make input and output dims (i.e. reversed shape and dest_shape).
- Dims<4> input_dims = GetTensorDims(shape);
- Dims<4> output_dims;
- for (int i = 0; i < 4; i++) {
- output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]];
+ // Make input and output shapes.
+ const RuntimeShape input_shape = GetTensorShape(shape);
+ RuntimeShape output_shape(perms.size());
+ for (int i = 0; i < perms.size(); i++) {
+ output_shape.SetDim(i, input_shape.Dims(perms[i]));
}
- output_dims.strides[0] = 1;
- for (int k = 1; k < 4; k++) {
- output_dims.strides[k] =
- output_dims.strides[k - 1] * output_dims.sizes[k - 1];
+
+ TransposeParams params;
+ params.perm_count = perms.size();
+ for (int i = 0; i < perms.size(); ++i) {
+ params.perm[i] = perms[i];
}
- reference_ops::Transpose<float>(input.data(), input_dims,
- input_transposed->data(), output_dims,
- reversed_perms);
+ reference_ops::Transpose<float>(params, input_shape, input.data(),
+ output_shape, input_transposed->data());
}
TEST(TransposeTest, TestRefOps1D) {
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
index 9ff06f8331..a7d3a9bc76 100644
--- a/tensorflow/contrib/lite/kernels/unpack.cc
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -88,10 +88,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <typename T>
void UnpackImpl(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input, int output_count, int axis) {
+ tflite::UnpackParams op_params;
+ op_params.axis = axis;
+ op_params.num_split = output_count;
VectorOfTensors<T> all_outputs(*context, *node->outputs);
- reference_ops::Unpack<T>(axis, GetTensorData<T>(input), GetTensorDims(input),
- NumDimensions(input), output_count,
- all_outputs.data(), **all_outputs.dims());
+ reference_ops::Unpack<T>(op_params, GetTensorShape(input),
+ GetTensorData<T>(input), **all_outputs.shapes(),
+ all_outputs.data());
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index ea2817beec..d50c345194 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -27,8 +27,8 @@ limitations under the License.
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#endif
-#if defined(TFLITE_EXTENDED)
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#if defined(TFLITE_FLEX)
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#endif
#include "tensorflow/contrib/lite/version.h"
@@ -450,8 +450,8 @@ TfLiteStatus InterpreterBuilder::operator()(
}
(**interpreter).SetVariables(std::move(variables));
-#if defined(TFLITE_EXTENDED)
- if (auto delegate = EagerDelegate::Create()) {
+#if defined(TFLITE_FLEX)
+ if (auto delegate = FlexDelegate::Create()) {
(**interpreter)
.ModifyGraphWithDelegate(std::move(delegate),
/*allow_dynamic_tensors=*/true);
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
index f18a2ca07a..2e5033dab1 100644
--- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
@@ -20,6 +20,7 @@ filegroup(
android_binary(
name = "SmartReplyDemo",
srcs = glob(["java/**/*.java"]),
+ aapt_version = "aapt",
assets = [":assets"],
assets_dir = "",
custom_package = "com.example.android.smartreply",
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 57e1290e07..916788f215 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -144,7 +144,7 @@ py_library(
name = "convert_saved_model",
srcs = ["convert_saved_model.py"],
srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
+ visibility = ["//tensorflow/contrib/lite:__subpackages__"],
deps = [
":convert",
"//tensorflow/contrib/saved_model:saved_model_py",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 1f48a826d4..613a1530f7 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -67,12 +67,12 @@ class ConverterMode(enum.Enum):
# Convert model using TOCO such that only unsupported operations are
# represented as TensorFlow ops.
# WARNING: Experimental interface, subject to change.
- TOCO_EXTENDED = "TOCO_EXTENDED"
+ TOCO_FLEX = "TOCO_FLEX"
# Convert model using TOCO such that all operations are represented as
# TensorFlow ops.
# WARNING: Experimental interface, subject to change.
- TOCO_EXTENDED_ALL = "TOCO_EXTENDED_ALL"
+ TOCO_FLEX_ALL = "TOCO_FLEX_ALL"
def __str__(self):
return self.value
@@ -240,11 +240,11 @@ def build_toco_convert_protos(input_tensors,
if dump_graphviz_dir:
toco.dump_graphviz_dir = dump_graphviz_dir
toco.dump_graphviz_include_video = dump_graphviz_video
- if converter_mode == ConverterMode.TOCO_EXTENDED:
- toco.allow_eager_ops = True
- elif converter_mode == ConverterMode.TOCO_EXTENDED_ALL:
- toco.allow_eager_ops = True
- toco.force_eager_ops = True
+ if converter_mode == ConverterMode.TOCO_FLEX:
+ toco.allow_flex_ops = True
+ elif converter_mode == ConverterMode.TOCO_FLEX_ALL:
+ toco.allow_flex_ops = True
+ toco.force_flex_ops = True
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
@@ -343,13 +343,14 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
return data
-@deprecation.deprecated(None, "Use `lite.TocoConverter` instead.")
+@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.")
def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
- """"Convert a model using TOCO.
+ """Convert a model using TOCO.
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
Conversion can be customized by providing arguments that are forwarded to
- `build_toco_convert_protos` (see documentation for details).
+ `build_toco_convert_protos` (see documentation for details). This function has
+ been deprecated. Please use `lite.TFLiteConverter` instead.
Args:
input_data: Input data (i.e. often `sess.graph_def`),
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 1553464b9f..d18b60d0ea 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -44,7 +44,7 @@ def _log_tensor_details(tensor_info):
dtype)
-def _get_meta_graph_def(saved_model_dir, tag_set):
+def get_meta_graph_def(saved_model_dir, tag_set):
"""Validate saved_model and extract MetaGraphDef.
Args:
@@ -61,7 +61,7 @@ def _get_meta_graph_def(saved_model_dir, tag_set):
return loader.load(sess, tag_set, saved_model_dir)
-def _get_signature_def(meta_graph, signature_key):
+def get_signature_def(meta_graph, signature_key):
"""Get the signature def from meta_graph with given signature_key.
Args:
@@ -86,7 +86,7 @@ def _get_signature_def(meta_graph, signature_key):
return signature_def_map[signature_key]
-def _get_inputs_outputs(signature_def):
+def get_inputs_outputs(signature_def):
"""Get inputs and outputs from SignatureDef.
Args:
@@ -236,9 +236,9 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
input_arrays or output_arrays are not valid.
"""
# Read SignatureDef.
- meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
- signature_def = _get_signature_def(meta_graph, signature_key)
- inputs, outputs = _get_inputs_outputs(signature_def)
+ meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
+ signature_def = get_signature_def(meta_graph, signature_key)
+ inputs, outputs = get_inputs_outputs(signature_def)
# Check SavedModel for assets directory.
collection_def = meta_graph.collection_def
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index 1be61fe053..5700bf7892 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -253,5 +253,5 @@ class Interpreter(object):
self._ensure_safe()
self._interpreter.Invoke()
- def reset_all_variables_to_zero(self):
- return self._interpreter.ResetVariableTensorsToZero()
+ def reset_all_variables(self):
+ return self._interpreter.ResetVariableTensors()
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 9ab05f3068..418f19a179 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -466,9 +466,9 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
error_msg);
}
-PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {
+PyObject* InterpreterWrapper::ResetVariableTensors() {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
- TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero());
+ TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
Py_RETURN_NONE;
}
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 641dd93db5..f5ca81e62a 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -65,7 +65,7 @@ class InterpreterWrapper {
PyObject* TensorQuantization(int i) const;
PyObject* SetTensor(int i, PyObject* value);
PyObject* GetTensor(int i) const;
- PyObject* ResetVariableTensorsToZero();
+ PyObject* ResetVariableTensors();
// Returns a reference to tensor index i as a numpy array. The base_object
// should be the interpreter object providing the memory.
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 2be24455d8..09365f101f 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -17,6 +17,7 @@
EXPERIMENTAL: APIs here are unstable and likely to change without notice.
@@TocoConverter
+@@TFLiteConverter
@@toco_convert
@@toco_convert_protos
@@Interpreter
@@ -62,9 +63,10 @@ from tensorflow.python.framework.importer import import_graph_def as _import_gra
from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
+from tensorflow.python.util import deprecation as _deprecation
-class TocoConverter(object):
+class TFLiteConverter(object):
"""Convert a TensorFlow model into `output_format` using TOCO.
This is used to convert from a TensorFlow GraphDef or SavedModel into either a
@@ -121,22 +123,22 @@ class TocoConverter(object):
```python
# Converting a GraphDef from session.
- converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
+ converter = lite.TFLiteConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a GraphDef from file.
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a SavedModel.
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
# Converting a tf.keras model.
- converter = lite.TocoConverter.from_keras_model_file(keras_model)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_model)
tflite_model = converter.convert()
```
"""
@@ -147,10 +149,9 @@ class TocoConverter(object):
output_tensors,
input_arrays_with_shape=None,
output_arrays=None):
- """Constructor for TocoConverter.
+ """Constructor for TFLiteConverter.
Args:
-
graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
@@ -158,8 +159,8 @@ class TocoConverter(object):
input_arrays_with_shape: Tuple of strings representing input tensor names
and list of integers representing input shapes
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
- into TensorFlow and when `input_tensors` and `output_tensors` are None.
- (default None)
+ into TensorFlow and when `input_tensors` and `output_tensors` are
+ None. (default None)
output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `input_tensors` and
`output_tensors` are None. (default None)
@@ -195,7 +196,7 @@ class TocoConverter(object):
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
- """Creates a TocoConverter class from a TensorFlow Session.
+ """Creates a TFLiteConverter class from a TensorFlow Session.
Args:
sess: TensorFlow Session.
@@ -204,7 +205,7 @@ class TocoConverter(object):
output_tensors: List of output tensors (only .name is used from this).
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
graph_def = _freeze_graph(sess, output_tensors)
return cls(graph_def, input_tensors, output_tensors)
@@ -215,7 +216,7 @@ class TocoConverter(object):
input_arrays,
output_arrays,
input_shapes=None):
- """Creates a TocoConverter class from a file containing a frozen GraphDef.
+ """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
Args:
graph_def_file: Full filepath of file containing frozen GraphDef.
@@ -224,10 +225,10 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
Raises:
IOError:
@@ -310,7 +311,7 @@ class TocoConverter(object):
output_arrays=None,
tag_set=None,
signature_key=None):
- """Creates a TocoConverter class from a SavedModel.
+ """Creates a TFLiteConverter class from a SavedModel.
Args:
saved_model_dir: SavedModel directory to convert.
@@ -319,7 +320,7 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
@@ -328,7 +329,7 @@ class TocoConverter(object):
(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
if tag_set is None:
tag_set = set([_tag_constants.SERVING])
@@ -346,7 +347,7 @@ class TocoConverter(object):
input_arrays=None,
input_shapes=None,
output_arrays=None):
- """Creates a TocoConverter class from a tf.keras model file.
+ """Creates a TFLiteConverter class from a tf.keras model file.
Args:
model_file: Full filepath of HDF5 file containing the tf.keras model.
@@ -355,12 +356,12 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
_keras.backend.clear_session()
_keras.backend.set_learning_phase(False)
@@ -502,6 +503,59 @@ class TocoConverter(object):
tensor.set_shape(shape)
+class TocoConverter(object):
+ """Convert a TensorFlow model into `output_format` using TOCO.
+
+ This class has been deprecated. Please use `lite.TFLiteConverter` instead.
+ """
+
+ @classmethod
+ @_deprecation.deprecated(None,
+ "Use `lite.TFLiteConverter.from_session` instead.")
+ def from_session(cls, sess, input_tensors, output_tensors):
+ """Creates a TocoConverter class from a TensorFlow Session."""
+ return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
+ def from_frozen_graph(cls,
+ graph_def_file,
+ input_arrays,
+ output_arrays,
+ input_shapes=None):
+ """Creates a TocoConverter class from a file containing a frozen graph."""
+ return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
+ output_arrays, input_shapes)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
+ def from_saved_model(cls,
+ saved_model_dir,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=None):
+ """Creates a TocoConverter class from a SavedModel."""
+ return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
+ input_shapes, output_arrays,
+ tag_set, signature_key)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
+ def from_keras_model_file(cls,
+ model_file,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None):
+ """Creates a TocoConverter class from a tf.keras model file."""
+ return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
+ input_shapes, output_arrays)
+
+
def _is_frozen_graph(sess):
"""Determines if the graph is frozen.
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index f112ed5cdd..d243a494f6 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -50,18 +50,18 @@ class FromConstructor(test_util.TensorFlowTestCase):
# `output_arrays` is not defined.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter(
+ lite.TFLiteConverter(
None, None, [], input_arrays_with_shape=[('input', [3, 9])])
self.assertEqual(message, str(error.exception))
# `input_arrays_with_shape` is not defined.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter(None, [], None, output_arrays=['output'])
+ lite.TFLiteConverter(None, [], None, output_arrays=['output'])
self.assertEqual(message, str(error.exception))
# Tests valid constructors using a dummy value for the GraphDef.
def testValidConstructor(self):
- converter = lite.TocoConverter(
+ converter = lite.TFLiteConverter(
None,
None,
None,
@@ -76,7 +76,7 @@ class FromConstructor(test_util.TensorFlowTestCase):
'The batch size cannot be set for this model. Please use '
'input_shapes parameter.', str(error.exception))
- converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor'])
+ converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
self.assertTrue(converter._has_valid_tensors())
@@ -89,7 +89,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -121,7 +122,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
+ converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {
@@ -166,7 +167,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
+ converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
@@ -182,7 +183,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Test invalid shape. None after 1st dimension.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
@@ -195,7 +197,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Test invalid shape. None after 1st dimension.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual(
@@ -210,7 +213,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -242,7 +246,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess.run(_global_variables_initializer())
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -272,7 +277,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.output_format = lite_constants.GRAPHVIZ_DOT
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
@@ -285,7 +291,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
tflite_model = converter.convert()
@@ -299,7 +306,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(num_items_graphviz)
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
converter.dump_graphviz_video = True
@@ -317,7 +325,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
@@ -347,7 +356,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
converter.default_ranges_stats = (0, 6) # min, max
@@ -387,13 +397,13 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert float model.
- float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1],
- [out_tensor])
+ float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
+ [out_tensor])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
# Convert quantized weights model.
- quantized_converter = lite.TocoConverter.from_session(
+ quantized_converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1], [out_tensor])
quantized_converter.post_training_quantize = True
quantized_tflite = quantized_converter.convert()
@@ -402,15 +412,16 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Ensure that the quantized weights tflite model is smaller.
self.assertTrue(len(quantized_tflite) < len(float_tflite))
- def testExtendedMode(self):
+ def testFlexMode(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
- converter.converter_mode = lite.ConverterMode.TOCO_EXTENDED_ALL
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
+ converter.converter_mode = lite.ConverterMode.TOCO_FLEX_ALL
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -421,9 +432,25 @@ class FromSessionTest(test_util.TensorFlowTestCase):
interpreter.allocate_tensors()
self.assertIn(
'Regular TensorFlow ops are not supported by this interpreter. Make '
- 'sure you invoke the Eager delegate before inference.',
+ 'sure you invoke the Flex delegate before inference.',
str(error.exception))
+ def testFloatTocoConverter(self):
+ """Tests deprecated test TocoConverter."""
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the interpreter is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@@ -439,8 +466,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
- ['Placeholder'], ['add'])
+ converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -474,7 +501,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
graph_def_file, ['Placeholder'], ['add'],
input_shapes={'Placeholder': [1, 16, 16, 3]})
tflite_model = converter.convert()
@@ -503,8 +530,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Ensure the graph with variables cannot be converted.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
self.assertEqual('Please freeze the graph using freeze_graph.py.',
str(error.exception))
@@ -520,8 +547,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
- ['Placeholder'], ['add'])
+ converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -545,8 +572,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
def testInvalidFileNotFound(self):
with self.assertRaises(IOError) as error:
- lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
+ ['add'])
self.assertEqual('File \'invalid_file\' does not exist.',
str(error.exception))
@@ -558,8 +585,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Attempts to convert the invalid model.
with self.assertRaises(IOError) as error:
- lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
self.assertEqual(
'Unable to parse input file \'{}\'.'.format(graph_def_file),
str(error.exception))
@@ -580,7 +607,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Tests the object detection model that cannot be loaded in TensorFlow.
self._initObjectDetectionArgs()
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file, self._input_arrays, self._output_arrays,
self._input_shapes)
converter.allow_custom_ops = True
@@ -621,7 +648,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Missing `input_shapes`.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(
+ lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file, self._input_arrays, self._output_arrays)
self.assertEqual('input_shapes must be defined for this model.',
str(error.exception))
@@ -632,7 +659,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# `input_shapes` does not contain the names in `input_arrays`.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(
+ lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file,
self._input_arrays,
self._output_arrays,
@@ -641,6 +668,27 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
'input_shapes must contain a value for each item in input_array.',
str(error.exception))
+ def testFloatTocoConverter(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
@@ -663,7 +711,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -693,7 +741,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
"""Test a SavedModel, with None in input tensor's shape."""
saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -724,7 +772,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
"""Test a SavedModel ordering of input arrays."""
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir, input_arrays=['inputB', 'inputA'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -757,7 +805,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
# Check case where input shape is given.
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir,
input_arrays=['inputA'],
input_shapes={'inputA': [1, 16, 16, 3]})
@@ -766,12 +814,25 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
# Check case where input shape is None.
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
tflite_model = converter.convert()
self.assertTrue(tflite_model)
+ def testSimpleModelTocoConverter(self):
+ """Test a SavedModel with deprecated TocoConverter."""
+ saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromKerasFile(test_util.TensorFlowTestCase):
@@ -805,7 +866,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
"""Test a Sequential tf.keras model with default inputs."""
keras_file = self._getSequentialModel()
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -845,13 +906,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Invalid input array raises error.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_keras_model_file(
+ lite.TFLiteConverter.from_keras_model_file(
keras_file, input_arrays=['invalid-input'])
self.assertEqual("Invalid tensors 'invalid-input' were found.",
str(error.exception))
# Valid input array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_arrays=['dense_input'])
tflite_model = converter.convert()
os.remove(keras_file)
@@ -863,13 +924,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Passing in shape of invalid input array has no impact as long as all input
# arrays have a shape.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_shapes={'invalid-input': [2, 3]})
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Passing in shape of valid input array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_shapes={'dense_input': [2, 3]})
tflite_model = converter.convert()
os.remove(keras_file)
@@ -890,13 +951,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Invalid output array raises error.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_keras_model_file(
+ lite.TFLiteConverter.from_keras_model_file(
keras_file, output_arrays=['invalid-output'])
self.assertEqual("Invalid tensors 'invalid-output' were found.",
str(error.exception))
# Valid output array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, output_arrays=['time_distributed/Reshape_1'])
tflite_model = converter.convert()
os.remove(keras_file)
@@ -926,7 +987,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -991,7 +1052,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -1052,7 +1113,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -1086,6 +1147,18 @@ class FromKerasFile(test_util.TensorFlowTestCase):
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
+ def testSequentialModelTocoConverter(self):
+ """Test a Sequential tf.keras model with deprecated TocoConverter."""
+ keras_file = self._getSequentialModel()
+
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index c0ff7f37f9..d6d9052a4e 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -40,13 +40,13 @@ def _parse_set(values):
def _get_toco_converter(flags):
- """Makes a TocoConverter object based on the flags provided.
+ """Makes a TFLiteConverter object based on the flags provided.
Args:
flags: argparse.Namespace object containing TFLite flags.
Returns:
- TocoConverter object.
+ TFLiteConverter object.
Raises:
ValueError: Invalid flags.
@@ -68,17 +68,17 @@ def _get_toco_converter(flags):
"output_arrays": output_arrays
}
- # Create TocoConverter.
+ # Create TFLiteConverter.
if flags.graph_def_file:
- converter_fn = lite.TocoConverter.from_frozen_graph
+ converter_fn = lite.TFLiteConverter.from_frozen_graph
converter_kwargs["graph_def_file"] = flags.graph_def_file
elif flags.saved_model_dir:
- converter_fn = lite.TocoConverter.from_saved_model
+ converter_fn = lite.TFLiteConverter.from_saved_model
converter_kwargs["saved_model_dir"] = flags.saved_model_dir
converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
converter_kwargs["signature_key"] = flags.saved_model_signature_key
elif flags.keras_model_file:
- converter_fn = lite.TocoConverter.from_keras_model_file
+ converter_fn = lite.TFLiteConverter.from_keras_model_file
converter_kwargs["model_file"] = flags.keras_model_file
else:
raise ValueError("--graph_def_file, --saved_model_dir, or "
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index c7a59cabc5..23ac8484de 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -264,8 +264,8 @@ enum TensorType {
TensorType_MAX = TensorType_COMPLEX64
};
-inline TensorType (&EnumValuesTensorType())[9] {
- static TensorType values[] = {
+inline const TensorType (&EnumValuesTensorType())[9] {
+ static const TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
TensorType_INT32,
@@ -279,8 +279,8 @@ inline TensorType (&EnumValuesTensorType())[9] {
return values;
}
-inline const char **EnumNamesTensorType() {
- static const char *names[] = {
+inline const char * const *EnumNamesTensorType() {
+ static const char * const names[] = {
"FLOAT32",
"FLOAT16",
"INT32",
@@ -399,8 +399,8 @@ enum BuiltinOperator {
BuiltinOperator_MAX = BuiltinOperator_FILL
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[94] {
- static BuiltinOperator values[] = {
+inline const BuiltinOperator (&EnumValuesBuiltinOperator())[94] {
+ static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
BuiltinOperator_CONCATENATION,
@@ -499,8 +499,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[94] {
return values;
}
-inline const char **EnumNamesBuiltinOperator() {
- static const char *names[] = {
+inline const char * const *EnumNamesBuiltinOperator() {
+ static const char * const names[] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@@ -680,8 +680,8 @@ enum BuiltinOptions {
BuiltinOptions_MAX = BuiltinOptions_FillOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
- static BuiltinOptions values[] = {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
+ static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
BuiltinOptions_DepthwiseConv2DOptions,
@@ -755,8 +755,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
return values;
}
-inline const char **EnumNamesBuiltinOptions() {
- static const char *names[] = {
+inline const char * const *EnumNamesBuiltinOptions() {
+ static const char * const names[] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
@@ -1699,16 +1699,16 @@ enum Padding {
Padding_MAX = Padding_VALID
};
-inline Padding (&EnumValuesPadding())[2] {
- static Padding values[] = {
+inline const Padding (&EnumValuesPadding())[2] {
+ static const Padding values[] = {
Padding_SAME,
Padding_VALID
};
return values;
}
-inline const char **EnumNamesPadding() {
- static const char *names[] = {
+inline const char * const *EnumNamesPadding() {
+ static const char * const names[] = {
"SAME",
"VALID",
nullptr
@@ -1732,8 +1732,8 @@ enum ActivationFunctionType {
ActivationFunctionType_MAX = ActivationFunctionType_SIGN_BIT
};
-inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
- static ActivationFunctionType values[] = {
+inline const ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
+ static const ActivationFunctionType values[] = {
ActivationFunctionType_NONE,
ActivationFunctionType_RELU,
ActivationFunctionType_RELU_N1_TO_1,
@@ -1744,8 +1744,8 @@ inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
return values;
}
-inline const char **EnumNamesActivationFunctionType() {
- static const char *names[] = {
+inline const char * const *EnumNamesActivationFunctionType() {
+ static const char * const names[] = {
"NONE",
"RELU",
"RELU_N1_TO_1",
@@ -1770,8 +1770,8 @@ enum LSHProjectionType {
LSHProjectionType_MAX = LSHProjectionType_DENSE
};
-inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
- static LSHProjectionType values[] = {
+inline const LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
+ static const LSHProjectionType values[] = {
LSHProjectionType_UNKNOWN,
LSHProjectionType_SPARSE,
LSHProjectionType_DENSE
@@ -1779,8 +1779,8 @@ inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
return values;
}
-inline const char **EnumNamesLSHProjectionType() {
- static const char *names[] = {
+inline const char * const *EnumNamesLSHProjectionType() {
+ static const char * const names[] = {
"UNKNOWN",
"SPARSE",
"DENSE",
@@ -1801,16 +1801,16 @@ enum FullyConnectedOptionsWeightsFormat {
FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
};
-inline FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] {
- static FullyConnectedOptionsWeightsFormat values[] = {
+inline const FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] {
+ static const FullyConnectedOptionsWeightsFormat values[] = {
FullyConnectedOptionsWeightsFormat_DEFAULT,
FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
};
return values;
}
-inline const char **EnumNamesFullyConnectedOptionsWeightsFormat() {
- static const char *names[] = {
+inline const char * const *EnumNamesFullyConnectedOptionsWeightsFormat() {
+ static const char * const names[] = {
"DEFAULT",
"SHUFFLED4x16INT8",
nullptr
@@ -1830,16 +1830,16 @@ enum LSTMKernelType {
LSTMKernelType_MAX = LSTMKernelType_BASIC
};
-inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
- static LSTMKernelType values[] = {
+inline const LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
+ static const LSTMKernelType values[] = {
LSTMKernelType_FULL,
LSTMKernelType_BASIC
};
return values;
}
-inline const char **EnumNamesLSTMKernelType() {
- static const char *names[] = {
+inline const char * const *EnumNamesLSTMKernelType() {
+ static const char * const names[] = {
"FULL",
"BASIC",
nullptr
@@ -1860,8 +1860,8 @@ enum CombinerType {
CombinerType_MAX = CombinerType_SQRTN
};
-inline CombinerType (&EnumValuesCombinerType())[3] {
- static CombinerType values[] = {
+inline const CombinerType (&EnumValuesCombinerType())[3] {
+ static const CombinerType values[] = {
CombinerType_SUM,
CombinerType_MEAN,
CombinerType_SQRTN
@@ -1869,8 +1869,8 @@ inline CombinerType (&EnumValuesCombinerType())[3] {
return values;
}
-inline const char **EnumNamesCombinerType() {
- static const char *names[] = {
+inline const char * const *EnumNamesCombinerType() {
+ static const char * const names[] = {
"SUM",
"MEAN",
"SQRTN",
@@ -1890,15 +1890,15 @@ enum CustomOptionsFormat {
CustomOptionsFormat_MAX = CustomOptionsFormat_FLEXBUFFERS
};
-inline CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] {
- static CustomOptionsFormat values[] = {
+inline const CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] {
+ static const CustomOptionsFormat values[] = {
CustomOptionsFormat_FLEXBUFFERS
};
return values;
}
-inline const char **EnumNamesCustomOptionsFormat() {
- static const char *names[] = {
+inline const char * const *EnumNamesCustomOptionsFormat() {
+ static const char * const names[] = {
"FLEXBUFFERS",
nullptr
};
@@ -1943,13 +1943,13 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_MIN) &&
- verifier.Verify(min()) &&
+ verifier.VerifyVector(min()) &&
VerifyOffset(verifier, VT_MAX) &&
- verifier.Verify(max()) &&
+ verifier.VerifyVector(max()) &&
VerifyOffset(verifier, VT_SCALE) &&
- verifier.Verify(scale()) &&
+ verifier.VerifyVector(scale()) &&
VerifyOffset(verifier, VT_ZERO_POINT) &&
- verifier.Verify(zero_point()) &&
+ verifier.VerifyVector(zero_point()) &&
verifier.EndTable();
}
QuantizationParametersT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2060,11 +2060,11 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SHAPE) &&
- verifier.Verify(shape()) &&
+ verifier.VerifyVector(shape()) &&
VerifyField<int8_t>(verifier, VT_TYPE) &&
VerifyField<uint32_t>(verifier, VT_BUFFER) &&
VerifyOffset(verifier, VT_NAME) &&
- verifier.Verify(name()) &&
+ verifier.VerifyString(name()) &&
VerifyOffset(verifier, VT_QUANTIZATION) &&
verifier.VerifyTable(quantization()) &&
VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
@@ -2530,9 +2530,9 @@ struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Ta
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_NUM_CHANNELS) &&
VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) &&
- verifier.Verify(num_columns_per_channel()) &&
+ verifier.VerifyVector(num_columns_per_channel()) &&
VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) &&
- verifier.Verify(embedding_dim_per_channel()) &&
+ verifier.VerifyVector(embedding_dim_per_channel()) &&
verifier.EndTable();
}
ConcatEmbeddingsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -3630,7 +3630,7 @@ struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NEW_SHAPE) &&
- verifier.Verify(new_shape()) &&
+ verifier.VerifyVector(new_shape()) &&
verifier.EndTable();
}
ReshapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -4294,7 +4294,7 @@ struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SQUEEZE_DIMS) &&
- verifier.Verify(squeeze_dims()) &&
+ verifier.VerifyVector(squeeze_dims()) &&
verifier.EndTable();
}
SqueezeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6041,7 +6041,7 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_BUILTIN_CODE) &&
VerifyOffset(verifier, VT_CUSTOM_CODE) &&
- verifier.Verify(custom_code()) &&
+ verifier.VerifyString(custom_code()) &&
VerifyField<int32_t>(verifier, VT_VERSION) &&
verifier.EndTable();
}
@@ -6360,17 +6360,17 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_OPCODE_INDEX) &&
VerifyOffset(verifier, VT_INPUTS) &&
- verifier.Verify(inputs()) &&
+ verifier.VerifyVector(inputs()) &&
VerifyOffset(verifier, VT_OUTPUTS) &&
- verifier.Verify(outputs()) &&
+ verifier.VerifyVector(outputs()) &&
VerifyField<uint8_t>(verifier, VT_BUILTIN_OPTIONS_TYPE) &&
VerifyOffset(verifier, VT_BUILTIN_OPTIONS) &&
VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) &&
VerifyOffset(verifier, VT_CUSTOM_OPTIONS) &&
- verifier.Verify(custom_options()) &&
+ verifier.VerifyVector(custom_options()) &&
VerifyField<int8_t>(verifier, VT_CUSTOM_OPTIONS_FORMAT) &&
VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) &&
- verifier.Verify(mutating_variable_inputs()) &&
+ verifier.VerifyVector(mutating_variable_inputs()) &&
verifier.EndTable();
}
OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6773,17 +6773,17 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_TENSORS) &&
- verifier.Verify(tensors()) &&
+ verifier.VerifyVector(tensors()) &&
verifier.VerifyVectorOfTables(tensors()) &&
VerifyOffset(verifier, VT_INPUTS) &&
- verifier.Verify(inputs()) &&
+ verifier.VerifyVector(inputs()) &&
VerifyOffset(verifier, VT_OUTPUTS) &&
- verifier.Verify(outputs()) &&
+ verifier.VerifyVector(outputs()) &&
VerifyOffset(verifier, VT_OPERATORS) &&
- verifier.Verify(operators()) &&
+ verifier.VerifyVector(operators()) &&
verifier.VerifyVectorOfTables(operators()) &&
VerifyOffset(verifier, VT_NAME) &&
- verifier.Verify(name()) &&
+ verifier.VerifyString(name()) &&
verifier.EndTable();
}
SubGraphT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6873,7 +6873,7 @@ struct Buffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_DATA) &&
- verifier.Verify(data()) &&
+ verifier.VerifyVector(data()) &&
verifier.EndTable();
}
BufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6962,18 +6962,18 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_VERSION) &&
VerifyOffset(verifier, VT_OPERATOR_CODES) &&
- verifier.Verify(operator_codes()) &&
+ verifier.VerifyVector(operator_codes()) &&
verifier.VerifyVectorOfTables(operator_codes()) &&
VerifyOffset(verifier, VT_SUBGRAPHS) &&
- verifier.Verify(subgraphs()) &&
+ verifier.VerifyVector(subgraphs()) &&
verifier.VerifyVectorOfTables(subgraphs()) &&
VerifyOffset(verifier, VT_DESCRIPTION) &&
- verifier.Verify(description()) &&
+ verifier.VerifyString(description()) &&
VerifyOffset(verifier, VT_BUFFERS) &&
- verifier.Verify(buffers()) &&
+ verifier.VerifyVector(buffers()) &&
verifier.VerifyVectorOfTables(buffers()) &&
VerifyOffset(verifier, VT_METADATA_BUFFER) &&
- verifier.Verify(metadata_buffer()) &&
+ verifier.VerifyVector(metadata_buffer()) &&
verifier.EndTable();
}
ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -10628,6 +10628,10 @@ inline const tflite::Model *GetModel(const void *buf) {
return flatbuffers::GetRoot<tflite::Model>(buf);
}
+inline const tflite::Model *GetSizePrefixedModel(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<tflite::Model>(buf);
+}
+
inline const char *ModelIdentifier() {
return "TFL3";
}
@@ -10642,6 +10646,11 @@ inline bool VerifyModelBuffer(
return verifier.VerifyBuffer<tflite::Model>(ModelIdentifier());
}
+inline bool VerifySizePrefixedModelBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<tflite::Model>(ModelIdentifier());
+}
+
inline const char *ModelExtension() {
return "tflite";
}
@@ -10652,6 +10661,12 @@ inline void FinishModelBuffer(
fbb.Finish(root, ModelIdentifier());
}
+inline void FinishSizePrefixedModelBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<tflite::Model> root) {
+ fbb.FinishSizePrefixed(root, ModelIdentifier());
+}
+
inline std::unique_ptr<ModelT> UnPackModel(
const void *buf,
const flatbuffers::resolver_function_t *res = nullptr) {
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index a4736bfee9..f0bfec2338 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -13,6 +13,7 @@ load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite"
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
+ "py_test",
)
[gen_zip_test(
@@ -163,7 +164,7 @@ cc_library(
":test_runner",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
- "//tensorflow/contrib/lite/delegates/eager:delegate",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
)
@@ -362,4 +363,32 @@ cc_binary(
],
)
+py_binary(
+ name = "model_coverage_lib",
+ srcs = ["//tensorflow/contrib/lite/testing:model_coverage/model_coverage_lib.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ visibility = ["//tensorflow/contrib/lite:__subpackages__"],
+ deps = [
+ "//tensorflow/contrib/lite/python:lite",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_test(
+ name = "model_coverage_lib_test",
+ srcs = ["//tensorflow/contrib/lite/testing:model_coverage/model_coverage_lib_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "notap",
+ ],
+ deps = [
+ ":model_coverage_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 014c80b5ef..18036fac6f 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -81,9 +81,9 @@ parser.add_argument(
action="store_true",
help="Include intermediate graphdefs in the output zip files.")
parser.add_argument(
- "--run_with_extended",
+ "--run_with_flex",
action="store_true",
- help="Whether the TFLite Extended converter is being used.")
+ help="Whether the TFLite Flex converter is being used.")
RANDOM_SEED = 342
TEST_INPUT_DEPTH = 3
@@ -339,11 +339,11 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
graphdef_file.flush()
# TODO(aselle): Switch this to subprocess at some point.
- if "pb2lite" in bin_path and FLAGS.run_with_extended:
+ if "pb2lite" in bin_path and FLAGS.run_with_flex:
opts = ("--input_arrays={0} --output_arrays={1}".format(
",".join(input_arrays), ",".join(output_tensors)))
- elif FLAGS.run_with_extended:
- opts += " --allow_eager_ops --force_eager_ops"
+ elif FLAGS.run_with_flex:
+ opts += " --allow_flex_ops --force_flex_ops"
cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
(bin_path, graphdef_file.name, output_file.name, opts,
stdout_file.name))
@@ -3333,7 +3333,7 @@ def main(unused_args):
# list of valid conversion modes is defined in
# generated_test_conversion_modes() in build_def.bzl.
test_function = ("make_%s_tests" % (out.replace(".zip", "").replace(
- "pb2lite", "").replace("toco-extended", "").rstrip("_")))
+ "pb2lite", "").replace("toco-flex", "").rstrip("_")))
if test_function not in globals():
raise RuntimeError("Can't find a test function to create %r. Tried %r" %
(out, test_function))
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
new file mode 100644
index 0000000000..5ca57d083d
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
@@ -0,0 +1,249 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions to test TFLite models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.lite.python import convert_saved_model as _convert_saved_model
+from tensorflow.contrib.lite.python import lite as _lite
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python import keras as _keras
+from tensorflow.python.client import session as _session
+from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
+from tensorflow.python.lib.io import file_io as _file_io
+from tensorflow.python.saved_model import signature_constants as _signature_constants
+from tensorflow.python.saved_model import tag_constants as _tag_constants
+
+
+def _convert(converter, **kwargs):
+ """Converts the model.
+
+ Args:
+ converter: TocoConverter object.
+ **kwargs: Additional arguments to be passed into the converter. Supported
+ flags are {"converter_mode", "post_training_quant"}.
+
+ Returns:
+ The converted TFLite model in serialized format.
+ """
+ if "converter_mode" in kwargs:
+ converter.converter_mode = kwargs["converter_mode"]
+ if "post_training_quantize" in kwargs:
+ converter.post_training_quantize = kwargs["post_training_quantize"]
+ return converter.convert()
+
+
+def _generate_random_input_data(tflite_model, seed=None):
+ """Generates input data based on the input tensors in the TFLite model.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ seed: Integer seed for the random generator. (default None)
+
+ Returns:
+ List of np.ndarray.
+ """
+ interpreter = _lite.Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+
+ if seed:
+ np.random.seed(seed=seed)
+ return [
+ np.array(
+ np.random.random_sample(input_tensor["shape"]),
+ dtype=input_tensor["dtype"]) for input_tensor in input_details
+ ]
+
+
+def _evaluate_tflite_model(tflite_model, input_data):
+ """Returns evaluation of input data on TFLite model.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ input_data: List of np.ndarray.
+
+ Returns:
+ List of np.ndarray.
+ """
+ interpreter = _lite.Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ for input_tensor, tensor_data in zip(input_details, input_data):
+ interpreter.set_tensor(input_tensor["index"], tensor_data)
+
+ interpreter.invoke()
+ output_data = [
+ interpreter.get_tensor(output_tensor["index"])
+ for output_tensor in output_details
+ ]
+ return output_data
+
+
+def evaluate_frozen_graph(filename, input_arrays, output_arrays):
+ """Returns a function that evaluates the frozen graph on input data.
+
+ Args:
+ filename: Full filepath of file containing frozen GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ with _session.Session().as_default() as sess:
+ with _file_io.FileIO(filename, "rb") as f:
+ file_content = f.read()
+
+ graph_def = _graph_pb2.GraphDef()
+ graph_def.ParseFromString(file_content)
+ _import_graph_def(graph_def, name="")
+
+ inputs = _convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, input_arrays)
+ outputs = _convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, output_arrays)
+
+ return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
+
+
+def evaluate_saved_model(directory, tag_set, signature_key):
+ """Returns a function that evaluates the SavedModel on input data.
+
+ Args:
+ directory: SavedModel directory to convert.
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present.
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ with _session.Session().as_default() as sess:
+ if tag_set is None:
+ tag_set = set([_tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+
+ meta_graph = _convert_saved_model.get_meta_graph_def(directory, tag_set)
+ signature_def = _convert_saved_model.get_signature_def(
+ meta_graph, signature_key)
+ inputs, outputs = _convert_saved_model.get_inputs_outputs(signature_def)
+
+ return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
+
+
+def evaluate_keras_model(filename):
+ """Returns a function that evaluates the tf.keras model on input data.
+
+ Args:
+ filename: Full filepath of HDF5 file containing the tf.keras model.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ keras_model = _keras.models.load_model(filename)
+ return lambda input_data: [keras_model.predict(input_data)]
+
+
+# TODO(nupurgarg): Make this function a parameter to test_frozen_graph (and
+# related functions) in order to make it easy to use different data generators.
+def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
+ """Compares TensorFlow and TFLite models with random data.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ tf_eval_func: Lambda function that takes in input data and outputs the
+ results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]).
+ tolerance: Decimal place to check accuracy to.
+ """
+ input_data = _generate_random_input_data(tflite_model)
+ tf_results = tf_eval_func(input_data)
+ tflite_results = _evaluate_tflite_model(tflite_model, input_data)
+ for tf_result, tflite_result in zip(tf_results, tflite_results):
+ np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
+
+
+def test_frozen_graph(filename,
+ input_arrays,
+ output_arrays,
+ input_shapes=None,
+ **kwargs):
+ """Validates the TensorFlow frozen graph converts to a TFLite model.
+
+ Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the
+ model on random data.
+
+ Args:
+ filename: Full filepath of file containing frozen GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays,
+ output_arrays, input_shapes)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)
+ compare_models_random_data(tflite_model, tf_eval_func)
+
+
+def test_saved_model(directory, tag_set=None, signature_key=None, **kwargs):
+ """Validates the TensorFlow SavedModel converts to a TFLite model.
+
+ Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the
+ model on random data.
+
+ Args:
+ directory: SavedModel directory to convert.
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present.
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_saved_model(directory, tag_set,
+ signature_key)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key)
+ compare_models_random_data(tflite_model, tf_eval_func)
+
+
+def test_keras_model(filename, **kwargs):
+ """Validates the tf.keras model converts to a TFLite model.
+
+ Converts the tf.keras model to TFLite and checks the accuracy of the model on
+ random data.
+
+ Args:
+ filename: Full filepath of HDF5 file containing the tf.keras model.
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_keras_model_file(filename)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_keras_model(filename)
+ compare_models_random_data(tflite_model, tf_eval_func)
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
new file mode 100644
index 0000000000..1498f86c6f
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
@@ -0,0 +1,130 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for model_coverage_lib.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.testing.model_coverage import model_coverage_lib as model_coverage
+from tensorflow.python import keras
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import saved_model
+from tensorflow.python.training.training_util import write_graph
+
+
+class EvaluateFrozenGraph(test.TestCase):
+
+ def _saveFrozenGraph(self, sess):
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+ return graph_def_file
+
+ def testFloat(self):
+ with session.Session().as_default() as sess:
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ filename = self._saveFrozenGraph(sess)
+
+ model_coverage.test_frozen_graph(filename, ['Placeholder'], ['add'])
+
+ def testMultipleOutputs(self):
+ with session.Session().as_default() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16], dtype=dtypes.float32, name='inputB')
+
+ weight = constant_op.constant(-1.0, shape=[16, 16])
+ bias = constant_op.constant(-1.0, shape=[16])
+ layer = math_ops.matmul(in_tensor_1, weight) + bias
+ _ = math_ops.reduce_mean(math_ops.square(layer - in_tensor_2))
+ filename = self._saveFrozenGraph(sess)
+
+ model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
+ ['add', 'Mean'])
+
+
+class EvaluateSavedModel(test.TestCase):
+
+ def testFloat(self):
+ saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
+ with session.Session().as_default() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ out_tensor = in_tensor_1 + in_tensor_2
+
+ inputs = {'x': in_tensor_1, 'y': in_tensor_2}
+ outputs = {'z': out_tensor}
+ saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+ model_coverage.test_saved_model(saved_model_dir)
+
+
+class EvaluateKerasModel(test.TestCase):
+
+ def _getSingleInputKerasModel(self):
+ """Returns single input Sequential tf.keras model."""
+ keras.backend.clear_session()
+
+ xs = [-1, 0, 1, 2, 3, 4]
+ ys = [-3, -1, 1, 3, 5, 7]
+
+ model = keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
+ model.compile(optimizer='sgd', loss='mean_squared_error')
+ model.train_on_batch(xs, ys)
+ return model
+
+ def _saveKerasModel(self, model):
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
+
+ def testFloat(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(keras_file)
+
+ def testPostTrainingQuantize(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(keras_file, post_training_quantize=True)
+
+ def testConverterMode(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(
+ keras_file, converter_mode=lite.ConverterMode.TOCO_FLEX)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 3874bc31d7..ad889a2f19 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -57,7 +57,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
"[optional] Number of full runs in each pass."),
tensorflow::Flag("delegate", &values.delegate,
"[optional] Delegate to use for executing ops. Must be "
- "`{\"\", EAGER}`"),
+ "`{\"\", FLEX}`"),
};
bool no_inputs = *argc == 1;
@@ -70,7 +70,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
values.input_layer_shape.empty() || values.output_layer.empty()) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
return {};
- } else if (!(values.delegate == "" || values.delegate == "EAGER")) {
+ } else if (!(values.delegate == "" || values.delegate == "FLEX")) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
return {};
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h
index f67992139f..28b14bd143 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h
@@ -45,7 +45,7 @@ struct DiffOptions {
// second pass does multiple inferences back to back.
int num_runs_per_pass;
// Path to the delegate library to be loaded in order to execute ops. Must be
- // `{"", EAGER}`.
+ // `{"", FLEX}`.
string delegate;
};
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 1836eb53b9..ef49e6f8bc 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <iostream>
#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#include "tensorflow/contrib/lite/testing/split.h"
namespace tflite {
@@ -138,8 +138,8 @@ class TfLiteDriver::Expectation {
TfLiteDriver::TfLiteDriver(bool use_nnapi, const string& delegate_name)
: use_nnapi_(use_nnapi) {
- if (delegate_name == "EAGER") {
- delegate_ = EagerDelegate::Create();
+ if (delegate_name == "FLEX") {
+ delegate_ = FlexDelegate::Create();
}
}
@@ -301,7 +301,7 @@ bool TfLiteDriver::CheckResults() {
}
void TfLiteDriver::ResetLSTMStateTensors() {
- interpreter_->ResetVariableTensorsToZero();
+ interpreter_->ResetVariableTensors();
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index aed35f877d..dc2a4e5877 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <map>
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
@@ -53,7 +53,7 @@ class TfLiteDriver : public TestRunner {
class Expectation;
- std::unique_ptr<EagerDelegate> delegate_;
+ std::unique_ptr<FlexDelegate> delegate_;
bool use_nnapi_ = false;
std::unique_ptr<FlatBufferModel> model_;
std::unique_ptr<Interpreter> interpreter_;
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index f14dbc258b..2699ac76e1 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -248,9 +248,9 @@ struct ParsedTocoFlags {
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
// WARNING: Experimental interface, subject to change
- Arg<bool> allow_eager_ops = Arg<bool>(false);
+ Arg<bool> allow_flex_ops = Arg<bool>(false);
// WARNING: Experimental interface, subject to change
- Arg<bool> force_eager_ops = Arg<bool>(false);
+ Arg<bool> force_flex_ops = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 910fa4c8de..8c31c3dca8 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -39,13 +39,18 @@ The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9
is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is
`tf.contrib.lite.Interpreter`.
+**NOTE**: As of TensorFlow 1.12, the API for converting TensorFlow models to
+TFLite will be renamed to `TFLiteConverter`. `TFLiteConverter` is semantically
+identically to `TocoConverter`. The API is available at
+`tf.contrib.lite.TFLiteConverter` as of the Sept 26 `tf-nightly`.
+
`TocoConverter` provides class methods based on the original format of the
model. `TocoConverter.from_session()` is available for GraphDefs.
`TocoConverter.from_saved_model()` is available for SavedModels.
`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files.
-Example usages for simple float-point models are shown in [Basic
-Examples](#basic). Examples usages for more complex models is shown in [Complex
-Examples](#complex).
+Example usages for simple float-point models are shown in
+[Basic Examples](#basic). Examples usages for more complex models is shown in
+[Complex Examples](#complex).
**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python
interpreter when the conversion fails. This will be remedied as soon as
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index f943da6d85..d056a8add7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -659,11 +659,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
}
}
auto& output_array = model->GetArray(op->outputs[0]);
- // Use 0 input as basis for output dimensions.
- const auto& first_input_array = model->GetArray(op->inputs[0]);
- output_array.copy_shape(first_input_array.shape());
- // Negative axis means the count starts at the back of the dims().
- if (op->axis < 0) op->axis += first_input_array.shape().dims().size();
+ // Use first non-empty input as basis for output dimensions.
+ for (const auto& input_name : op->inputs) {
+ const auto& input_array = model->GetArray(input_name);
+ if (input_array.shape().dimensions_count() > 0) {
+ output_array.copy_shape(input_array.shape());
+ // Negative axis means the count starts at the back of the dims().
+ if (op->axis < 0) op->axis += input_array.shape().dims().size();
+ break;
+ }
+ }
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index e02d000e7e..5eaf6e27fc 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -2123,9 +2123,9 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
Model* model = new Model;
internal::ConverterMapType converter_map;
- // This is used for the TFLite "Full Eager Mode" conversion. All the ops are
+ // This is used for the TFLite "Full Flex Mode" conversion. All the ops are
// imported as `TensorFlowUnsupportedOperator`, and later all these ops are
- // converted to TFLite Eager ops.
+ // converted to TFLite Flex ops.
if (!tf_import_flags.import_all_ops_as_unsupported) {
converter_map = internal::GetTensorFlowNodeConverterMap();
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
index 7db23f2d44..c5ff96956a 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -30,7 +30,7 @@ struct TensorFlowImportFlags {
// Do not recognize any op and import all ops as
// `TensorFlowUnsupportedOperator`. This is used to populated with the
- // `force_eager_ops` flag.
+ // `force_flex_ops` flag.
bool import_all_ops_as_unsupported = false;
};
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
index 33c5b16462..cf97ba7084 100644
--- a/tensorflow/contrib/lite/toco/python/BUILD
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -4,6 +4,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load("//tensorflow:tensorflow.bzl", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
cc_library(
name = "toco_python_api",
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index fee10b1dff..0c9fac249c 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -50,16 +50,16 @@ namespace {
details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_eager_ops) {
+ bool allow_flex_ops) {
string custom_code;
if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
- // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way
+ // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
// to populate a regular custom op. We need to find a way to fix this.
- if (allow_eager_ops) {
- custom_code = string(::tflite::kEagerCustomCodePrefix) +
+ if (allow_flex_ops) {
+ custom_code = string(::tflite::kFlexCustomCodePrefix) +
unsupported_op.tensorflow_op;
} else {
custom_code = unsupported_op.tensorflow_op;
@@ -101,11 +101,11 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_eager_ops) {
+ bool allow_flex_ops) {
// First find a list of unique operator types.
std::set<OperatorKey> keys;
for (const auto& op : model.operators) {
- keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops));
+ keys.insert(GetOperatorKey(*op, ops_by_type, allow_flex_ops));
}
// Now assign indices to them and fill in the map.
int index = 0;
@@ -216,7 +216,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
for (const auto& op : model.operators) {
const details::OperatorKey operator_key =
- GetOperatorKey(*op, ops_by_type, params.allow_eager_ops);
+ GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
int op_index = operators_map.at(operator_key);
int op_version = operator_key.version;
@@ -281,7 +281,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
}
int op_index = operators_map.at(
- GetOperatorKey(*op, ops_by_type, params.allow_eager_ops));
+ GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -334,7 +334,7 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
void Export(const Model& model, string* output_file_contents,
const ExportParams& params) {
- const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops);
+ const auto ops_by_type = BuildOperatorByTypeMap(params.allow_flex_ops);
Export(model, output_file_contents, params, ops_by_type);
}
@@ -349,7 +349,7 @@ void Export(
details::OperatorsMap operators_map;
details::LoadOperatorsMap(model, &operators_map, ops_by_type,
- params.allow_eager_ops);
+ params.allow_flex_ops);
std::vector<const Array*> buffers_to_write;
Array empty_array;
@@ -388,7 +388,7 @@ void Export(
"the standard TensorFlow Lite runtime. If you have a custom "
"implementation for them you can disable this error with "
"--allow_custom_ops, or by setting allow_custom_ops=True "
- "when calling tf.contrib.lite.TocoConverter(). Here is a list "
+ "when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
"of operators for which you will need custom implementations: "
<< absl::StrJoin(error_summary_final, ", ") << ".";
}
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index b070a38768..29d6de4049 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -26,7 +26,7 @@ namespace tflite {
// The parameters for exporting a TFLite model.
struct ExportParams {
bool allow_custom_ops = false;
- bool allow_eager_ops = false;
+ bool allow_flex_ops = false;
bool quantize_weights = false;
};
@@ -121,7 +121,7 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_eager_ops);
+ bool allow_flex_ops);
} // namespace details
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 8d4d197c46..93882a91a7 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -105,7 +105,7 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- // TODO(ycling): Add a test for allow_eager_ops.
+ // TODO(ycling): Add a test for allow_flex_ops.
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index ca2a6a19b3..9addbb81e7 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1160,8 +1160,8 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
class TensorFlowUnsupported : public BaseOperator {
public:
TensorFlowUnsupported(const string& name, OperatorType type,
- bool allow_eager_ops)
- : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {}
+ bool allow_flex_ops)
+ : BaseOperator(name, type), allow_flex_ops_(allow_flex_ops) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
@@ -1177,9 +1177,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<Operator> Deserialize(
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const override {
- // Deserializing Eager ops doesn't work now.
+ // Deserializing Flex ops doesn't work now.
// TODO(ycling): Revisit and decide if we should fix the flow for importing
- // TFLite models with Eager ops.
+ // TFLite models with Flex ops.
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
if (custom_options) {
auto flexbuffer_map =
@@ -1200,13 +1200,13 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
- if (allow_eager_ops_) {
+ if (allow_flex_ops_) {
fbb->Vector([&]() {
fbb->String(node_def.op());
fbb->String(op.tensorflow_node_def);
});
fbb->Finish();
- LOG(INFO) << "Writing eager op: " << node_def.op();
+ LOG(INFO) << "Writing flex op: " << node_def.op();
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
@@ -1316,13 +1316,13 @@ class TensorFlowUnsupported : public BaseOperator {
}
private:
- const bool allow_eager_ops_;
+ const bool allow_flex_ops_;
};
namespace {
// Build a vector containing all the known operators.
std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
- bool allow_eager_ops = false) {
+ bool allow_flex_ops = false) {
std::vector<std::unique_ptr<BaseOperator>> ops;
using tensorflow::MakeUnique;
// Builtin Operators.
@@ -1434,7 +1434,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
ops.push_back(MakeUnique<TensorFlowUnsupported>(
- "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops));
+ "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_flex_ops));
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
@@ -1512,11 +1512,11 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
} // namespace
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
- bool allow_eager_ops) {
+ bool allow_flex_ops) {
std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops =
- BuildOperatorList(allow_eager_ops);
+ BuildOperatorList(allow_flex_ops);
for (auto& op : ops) {
result[op->type()] = std::move(op);
}
@@ -1525,11 +1525,11 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
}
std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
- bool allow_eager_ops) {
+ bool allow_flex_ops) {
std::map<string, std::unique_ptr<BaseOperator>> result;
std::vector<std::unique_ptr<BaseOperator>> ops =
- BuildOperatorList(allow_eager_ops);
+ BuildOperatorList(allow_flex_ops);
for (auto& op : ops) {
result[op->name()] = std::move(op);
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 702fb28ea6..13d9f6c49a 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -26,15 +26,15 @@ namespace tflite {
class BaseOperator;
// Return a map contained all know TF Lite Operators, keyed by their names.
-// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops)
+// TODO(ycling): The pattern to propagate parameters (e.g. allow_flex_ops)
// is ugly here. Consider refactoring.
std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
- bool allow_eager_ops = false);
+ bool allow_flex_ops = false);
// Return a map contained all know TF Lite Operators, keyed by the type of
// their tf.mini counterparts.
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
- bool allow_eager_ops = false);
+ bool allow_flex_ops = false);
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index b6aebc0470..cff79776bc 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -167,11 +167,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
"converted float model. Model size will be reduced and there will "
"be latency improvements (at the cost of accuracy)."),
// WARNING: Experimental interface, subject to change
- Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(),
- parsed_flags.allow_eager_ops.default_value(), ""),
+ Flag("allow_flex_ops", parsed_flags.allow_flex_ops.bind(),
+ parsed_flags.allow_flex_ops.default_value(), ""),
// WARNING: Experimental interface, subject to change
- Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(),
- parsed_flags.force_eager_ops.default_value(), "")};
+ Flag("force_flex_ops", parsed_flags.force_flex_ops.bind(),
+ parsed_flags.force_flex_ops.default_value(), "")};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
if (asked_for_help) {
@@ -266,15 +266,15 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
- READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone);
- READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_flex_ops, FlagRequirement::kNone);
+ READ_TOCO_FLAG(force_flex_ops, FlagRequirement::kNone);
- if (parsed_toco_flags.force_eager_ops.value() &&
- !parsed_toco_flags.allow_eager_ops.value()) {
- // TODO(ycling): Consider to enforce `allow_eager_ops` when
- // `force_eager_ops` is true.
- LOG(WARNING) << "--force_eager_ops should always be used with "
- "--allow_eager_ops.";
+ if (parsed_toco_flags.force_flex_ops.value() &&
+ !parsed_toco_flags.allow_flex_ops.value()) {
+ // TODO(ycling): Consider to enforce `allow_flex_ops` when
+ // `force_flex_ops` is true.
+ LOG(WARNING) << "--force_flex_ops should always be used with "
+ "--allow_flex_ops.";
}
// Deprecated flag handling.
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 53d60fed05..ca3e64485e 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -190,16 +190,16 @@ message TocoFlags {
// (at the cost of accuracy).
optional bool post_training_quantize = 26 [default = false];
- // When enabled, unsupported ops will be converted to TFLite Eager ops.
+ // When enabled, unsupported ops will be converted to TFLite Flex ops.
// TODO(ycling): Consider to rename the following 2 flags and don't call it
- // "Eager".
- // `allow_eager_ops` should always be used with `allow_custom_ops`.
+ // "Flex".
+ // `allow_flex_ops` should always be used with `allow_custom_ops`.
// WARNING: Experimental interface, subject to change
- optional bool allow_eager_ops = 27 [default = false];
+ optional bool allow_flex_ops = 27 [default = false];
- // When enabled, all TensorFlow ops will be converted to TFLite Eager
- // ops directly. This will force `allow_eager_ops` to true.
- // `force_eager_ops` should always be used with `allow_eager_ops`.
+ // When enabled, all TensorFlow ops will be converted to TFLite Flex
+ // ops directly. This will force `allow_flex_ops` to true.
+ // `force_flex_ops` should always be used with `allow_flex_ops`.
// WARNING: Experimental interface, subject to change
- optional bool force_eager_ops = 28 [default = false];
+ optional bool force_flex_ops = 28 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index a08b02485f..106494f354 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -198,7 +198,7 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
tf_import_flags.import_all_ops_as_unsupported =
- toco_flags.force_eager_ops();
+ toco_flags.force_flex_ops();
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
input_file_contents);
@@ -409,9 +409,9 @@ void Export(const TocoFlags& toco_flags, const Model& model,
case TFLITE: {
toco::tflite::ExportParams params;
- // Always allow custom ops when eager ops are allowed.
- if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) {
- params.allow_eager_ops = true;
+ // Always allow custom ops when flex ops are allowed.
+ if (toco_flags.force_flex_ops() || toco_flags.allow_flex_ops()) {
+ params.allow_flex_ops = true;
params.allow_custom_ops = true;
} else if (allow_custom_ops) {
params.allow_custom_ops = true;
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index dc97d22401..502e181139 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -36,11 +36,11 @@ cc_binary(
)
cc_binary(
- name = "benchmark_model_plus_eager",
+ name = "benchmark_model_plus_flex",
srcs = [
"benchmark_main.cc",
],
- copts = common_copts + ["-DTFLITE_EXTENDED"],
+ copts = common_copts + ["-DTFLITE_FLEX"],
linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
@@ -49,7 +49,7 @@ cc_binary(
"//conditions:default": [],
}),
deps = [
- ":benchmark_tflite_model_plus_eager_lib",
+ ":benchmark_tflite_model_plus_flex_lib",
":logging",
],
)
@@ -111,19 +111,19 @@ cc_library(
)
cc_library(
- name = "benchmark_tflite_model_plus_eager_lib",
+ name = "benchmark_tflite_model_plus_flex_lib",
srcs = [
"benchmark_tflite_model.cc",
"logging.h",
],
hdrs = ["benchmark_tflite_model.h"],
- copts = common_copts + ["-DTFLITE_EXTENDED"],
+ copts = common_copts + ["-DTFLITE_FLEX"],
deps = [
":benchmark_model_lib",
":logging",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite/delegates/eager:delegate",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/profiling:profile_summarizer",
],
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index ef4f0fa80d..463d5993f4 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -23,9 +23,9 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#ifdef TFLITE_EXTENDED
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
-#endif // TFLITE_EXTENDED
+#ifdef TFLITE_FLEX
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
+#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
@@ -305,14 +305,14 @@ void BenchmarkTfLiteModel::Init() {
interpreter->UseNNAPI(use_nnapi);
-#ifdef TFLITE_EXTENDED
- TFLITE_LOG(INFO) << "Instantiating Eager Delegate";
- delegate_ = EagerDelegate::Create();
+#ifdef TFLITE_FLEX
+ TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
+ delegate_ = FlexDelegate::Create();
if (delegate_) {
interpreter->ModifyGraphWithDelegate(delegate_.get(),
/*allow_dynamic_tensors=*/true);
}
-#endif // TFLITE_EXTENDED
+#endif // TFLITE_FLEX
auto interpreter_inputs = interpreter->inputs();
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index 8541512bc8..b091e18a29 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -20,9 +20,9 @@ limitations under the License.
#include <string>
#include <vector>
-#ifdef TFLITE_EXTENDED
-#include "tensorflow/contrib/lite/delegates/eager/delegate.h"
-#endif // TFLITE_EXTENDED
+#ifdef TFLITE_FLEX
+#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
+#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
@@ -73,9 +73,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
void PrepareInputsAndOutputs() override;
private:
-#ifdef TFLITE_EXTENDED
- std::unique_ptr<EagerDelegate> delegate_;
-#endif // TFLITE_EXTENDED
+#ifdef TFLITE_FLEX
+ std::unique_ptr<FlexDelegate> delegate_;
+#endif // TFLITE_FLEX
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<InputLayerInfo> inputs;
diff --git a/tensorflow/contrib/lite/tools/make/download_dependencies.sh b/tensorflow/contrib/lite/tools/make/download_dependencies.sh
index 29afa45133..3570f9a38d 100755
--- a/tensorflow/contrib/lite/tools/make/download_dependencies.sh
+++ b/tensorflow/contrib/lite/tools/make/download_dependencies.sh
@@ -35,7 +35,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip"
FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz"
-FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/v1.8.0.zip"
+FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz"
FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc
index 7950653da9..6aa35b5227 100644
--- a/tensorflow/contrib/lite/util.cc
+++ b/tensorflow/contrib/lite/util.cc
@@ -18,9 +18,9 @@ limitations under the License.
namespace tflite {
-bool IsEagerOp(const char* custom_name) {
- return custom_name && strncmp(custom_name, kEagerCustomCodePrefix,
- strlen(kEagerCustomCodePrefix)) == 0;
+bool IsFlexOp(const char* custom_name) {
+ return custom_name && strncmp(custom_name, kFlexCustomCodePrefix,
+ strlen(kFlexCustomCodePrefix)) == 0;
}
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input) {
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index 6d81f844f8..31292a6f81 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -26,15 +26,15 @@ limitations under the License.
namespace tflite {
-// The prefix of Eager op custom code.
+// The prefix of Flex op custom code.
// This will be matched agains the `custom_code` field in `OperatorCode`
// Flatbuffer Table.
// WARNING: This is an experimental API and subject to change.
-constexpr char kEagerCustomCodePrefix[] = "Eager";
+constexpr char kFlexCustomCodePrefix[] = "Flex";
// Checks whether the prefix of the custom name indicates the operation is an
-// Eager operation.
-bool IsEagerOp(const char* custom_name);
+// Flex operation.
+bool IsFlexOp(const char* custom_name);
// Converts a `std::vector` to a `TfLiteIntArray`. The caller takes ownership
// of the returned pointer.
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
index c5c1709f1d..25f3aded71 100644
--- a/tensorflow/contrib/lite/util_test.cc
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -41,14 +41,14 @@ TEST(ConvertVectorToTfLiteIntArray, TestWithEmptyVector) {
TfLiteIntArrayFree(output);
}
-TEST(UtilTest, IsEagerOp) {
- EXPECT_TRUE(IsEagerOp("Eager"));
- EXPECT_TRUE(IsEagerOp("EagerOp"));
- EXPECT_FALSE(IsEagerOp("eager"));
- EXPECT_FALSE(IsEagerOp("Eage"));
- EXPECT_FALSE(IsEagerOp("OpEager"));
- EXPECT_FALSE(IsEagerOp(nullptr));
- EXPECT_FALSE(IsEagerOp(""));
+TEST(UtilTest, IsFlexOp) {
+ EXPECT_TRUE(IsFlexOp("Flex"));
+ EXPECT_TRUE(IsFlexOp("FlexOp"));
+ EXPECT_FALSE(IsFlexOp("flex"));
+ EXPECT_FALSE(IsFlexOp("Fle"));
+ EXPECT_FALSE(IsFlexOp("OpFlex"));
+ EXPECT_FALSE(IsFlexOp(nullptr));
+ EXPECT_FALSE(IsFlexOp(""));
}
} // namespace
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index d962a5e12d..36125c198e 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -133,7 +133,8 @@ $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*benchmark*.cc) \
-tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc
+tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc \
+tensorflow/contrib/makefile/downloads/absl/absl/hash/internal/print_hash_of.cc
ABSL_CC_SRCS := $(filter-out $(ABSL_CC_EXCLUDE_SRCS), $(ABSL_CC_ALL_SRCS))
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 08de54b8e1..91af933cff 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -91,6 +91,8 @@ tensorflow/core/kernels/cwise_op_square.cc
tensorflow/core/kernels/cwise_op_squared_difference.cc
tensorflow/core/kernels/cwise_op_sub.cc
tensorflow/core/kernels/cwise_op_tanh.cc
+tensorflow/core/kernels/cwise_op_xdivy.cc
+tensorflow/core/kernels/cwise_op_xlogy.cc
tensorflow/core/kernels/cwise_ops_common.cc
tensorflow/core/kernels/data_format_ops.cc
tensorflow/core/kernels/decode_bmp_op.cc
@@ -253,6 +255,7 @@ tensorflow/core/kernels/strided_slice_op_inst_5.cc
tensorflow/core/kernels/strided_slice_op_inst_6.cc
tensorflow/core/kernels/strided_slice_op_inst_7.cc
tensorflow/core/kernels/string_join_op.cc
+tensorflow/core/kernels/string_util.cc
tensorflow/core/kernels/tensor_array.cc
tensorflow/core/kernels/tensor_array_ops.cc
tensorflow/core/kernels/tile_functor_cpu.cc
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 955b83b44d..fc64f343ab 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2069,11 +2069,11 @@ class StreamingDynamicAUCTest(test.TestCase):
num_batches = 100
labels = np.array([])
predictions = np.array([])
- tf_labels = variables.Variable(
+ tf_labels = variables.VariableV1(
array_ops.ones(batch_size, dtypes_lib.int32),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.int32)
- tf_predictions = variables.Variable(
+ tf_predictions = variables.VariableV1(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
@@ -2133,15 +2133,15 @@ class StreamingDynamicAUCTest(test.TestCase):
labels = np.array([])
predictions = np.array([])
weights = np.array([])
- tf_labels = variables.Variable(
+ tf_labels = variables.VariableV1(
array_ops.ones(batch_size, dtypes_lib.int32),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.int32)
- tf_predictions = variables.Variable(
+ tf_predictions = variables.VariableV1(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
- tf_weights = variables.Variable(
+ tf_weights = variables.VariableV1(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
@@ -2311,10 +2311,11 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
num_batches = 100
labels = np.array([])
predictions = np.array([])
- tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- dtype=dtypes_lib.int32)
- tf_predictions = variables.Variable(
+ tf_labels = variables.VariableV1(
+ array_ops.ones(batch_size, dtypes_lib.int32),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.int32)
+ tf_predictions = variables.VariableV1(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index a81abac2fa..67e58ff15d 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -247,7 +247,8 @@ class Pruning(object):
# Stores the tensorflow sparsity variable.
# Built using self._setup_sparsity() or provided externally
- self._sparsity = sparsity if sparsity else self._setup_sparsity()
+ self._sparsity = (sparsity
+ if sparsity is not None else self._setup_sparsity())
# List of tensorflow assignments ops for new masks and thresholds
self._assign_ops = []
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index cd3d8e76bb..1b6da5ce2b 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -45,7 +45,7 @@ class PruningHParamsTest(test.TestCase):
# Add global step variable to the graph
self.global_step = training_util.get_or_create_global_step()
# Add sparsity
- self.sparsity = variables.Variable(0.5, name="sparsity")
+ self.sparsity = variables.VariableV1(0.5, name="sparsity")
# Parse hparams
self.pruning_hparams = pruning.get_pruning_hparams().parse(
self.TEST_HPARAMS)
@@ -88,7 +88,7 @@ class PruningTest(test.TestCase):
width = 10
height = 20
with self.cached_session():
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_normal([width, height], stddev=1), name="weights")
masked_weights = pruning.apply_mask(weights,
variable_scope.get_variable_scope())
@@ -99,10 +99,10 @@ class PruningTest(test.TestCase):
def testUpdateSingleMask(self):
with self.cached_session() as session:
- weights = variables.Variable(
+ weights = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
- sparsity = variables.Variable(0.5, name="sparsity")
+ sparsity = variables.VariableV1(0.5, name="sparsity")
p = pruning.Pruning(sparsity=sparsity)
p._spec.threshold_decay = 0.0
mask_update_op = p.mask_update_op()
@@ -115,8 +115,8 @@ class PruningTest(test.TestCase):
def _blockMasking(self, hparams, weights, expected_mask):
- threshold = variables.Variable(0.0, name="threshold")
- sparsity = variables.Variable(0.5, name="sparsity")
+ threshold = variables.VariableV1(0.0, name="threshold")
+ sparsity = variables.VariableV1(0.5, name="sparsity")
test_spec = ",".join(hparams)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
@@ -169,7 +169,7 @@ class PruningTest(test.TestCase):
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
with self.cached_session() as session:
with variable_scope.variable_scope("", partitioner=partitioner):
- sparsity = variables.Variable(0.5, name="Sparsity")
+ sparsity = variables.VariableV1(0.5, name="Sparsity")
weights = variable_scope.get_variable(
"weights", initializer=math_ops.linspace(1.0, 100.0, 100))
masked_weights = pruning.apply_mask(
@@ -190,10 +190,10 @@ class PruningTest(test.TestCase):
]
test_spec = ",".join(param_list)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
- weights = variables.Variable(
+ weights = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
- sparsity = variables.Variable(0.00, name="sparsity")
+ sparsity = variables.VariableV1(0.00, name="sparsity")
# Set up pruning
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
p._spec.threshold_decay = 0.0
@@ -222,11 +222,11 @@ class PruningTest(test.TestCase):
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
with variable_scope.variable_scope("layer1"):
- w1 = variables.Variable(
+ w1 = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
_ = pruning.apply_mask(w1)
with variable_scope.variable_scope("layer2"):
- w2 = variables.Variable(
+ w2 = variables.VariableV1(
math_ops.linspace(1.0, 100.0, 100), name="weights")
_ = pruning.apply_mask(w2)
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index f4ac70eb1a..6a67c6295d 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -377,6 +377,11 @@ py_test(
size = "large",
srcs = ["python/training/shampoo_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan",
+ ],
deps = [
":opt_py",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py
index 628a735e72..6150fa117f 100644
--- a/tensorflow/contrib/opt/python/training/addsign_test.py
+++ b/tensorflow/contrib/opt/python/training/addsign_test.py
@@ -80,9 +80,9 @@ class AddSignTest(test.TestCase):
global_step = resource_variable_ops.ResourceVariable(
0, trainable=False)
else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- global_step = variables.Variable(
+ var0 = variables.VariableV1(var0_np)
+ var1 = variables.VariableV1(var1_np)
+ global_step = variables.VariableV1(
0, trainable=False)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
@@ -183,9 +183,9 @@ class AddSignTest(test.TestCase):
global_step = resource_variable_ops.ResourceVariable(
0, trainable=False)
else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- global_step = variables.Variable(
+ var0 = variables.VariableV1(var0_np)
+ var1 = variables.VariableV1(var1_np)
+ global_step = variables.VariableV1(
0, trainable=False)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py
index 53232082e1..0a69096768 100644
--- a/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/drop_stale_gradient_optimizer_test.py
@@ -61,8 +61,8 @@ def _get_workers(num_workers, staleness):
graph = ops.Graph()
with graph.as_default():
global_step = training_util.create_global_step()
- var_0 = variables.Variable(0.0, name='v0')
- var_1 = variables.Variable(1.0, name='v1')
+ var_0 = variables.VariableV1(0.0, name='v0')
+ var_1 = variables.VariableV1(1.0, name='v1')
compute_gradients_queue = data_flow_ops.FIFOQueue(
-1, global_step.dtype.base_dtype, shapes=(),
name='compute_gradients_queue', shared_name='compute_gradients_queue')
diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
index 9997103016..70c5f8ff19 100644
--- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
@@ -69,9 +69,9 @@ class TestCase(test.TestCase):
class ExternalOptimizerInterfaceTest(TestCase):
def test_optimize(self):
- scalar = variables.Variable(random_ops.random_normal([]), 'scalar')
- vector = variables.Variable(random_ops.random_normal([2]), 'vector')
- matrix = variables.Variable(random_ops.random_normal([2, 3]), 'matrix')
+ scalar = variables.VariableV1(random_ops.random_normal([]), 'scalar')
+ vector = variables.VariableV1(random_ops.random_normal([2]), 'vector')
+ matrix = variables.VariableV1(random_ops.random_normal([2, 3]), 'matrix')
minimum_location = constant_op.constant(np.arange(9), dtype=dtypes.float32)
@@ -96,7 +96,7 @@ class ExternalOptimizerInterfaceTest(TestCase):
def test_callbacks(self):
vector_val = np.array([7., -2.], dtype=np.float32)
- vector = variables.Variable(vector_val, 'vector')
+ vector = variables.VariableV1(vector_val, 'vector')
minimum_location_val = np.arange(2)
minimum_location = constant_op.constant(
@@ -160,7 +160,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
rtol=1e-5,
atol=1e-5,
dimension=5):
- x = variables.Variable(array_ops.zeros(dimension))
+ x = variables.VariableV1(array_ops.zeros(dimension))
optimizer = external_optimizer.ScipyOptimizerInterface(
self._objective(x), method=method, options=options)
@@ -173,7 +173,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_unconstrained(self):
dimension = 5
- x = variables.Variable(array_ops.zeros(dimension))
+ x = variables.VariableV1(array_ops.zeros(dimension))
optimizer = external_optimizer.ScipyOptimizerInterface(self._objective(x))
with self.cached_session() as sess:
@@ -230,7 +230,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_nonlinear_programming(self):
vector_initial_value = [7., 7.]
- vector = variables.Variable(vector_initial_value, 'vector')
+ vector = variables.VariableV1(vector_initial_value, 'vector')
# Make norm as small as possible.
loss = math_ops.reduce_sum(math_ops.square(vector))
@@ -249,7 +249,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_scalar_bounds(self):
vector_initial_value = [7., 7.]
- vector = variables.Variable(vector_initial_value, 'vector')
+ vector = variables.VariableV1(vector_initial_value, 'vector')
# Make norm as small as possible.
loss = math_ops.reduce_sum(math_ops.square(vector))
@@ -267,7 +267,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_vector_bounds(self):
vector_initial_value = [7., 7.]
- vector = variables.Variable(vector_initial_value, 'vector')
+ vector = variables.VariableV1(vector_initial_value, 'vector')
# Make norm as small as possible.
loss = math_ops.reduce_sum(math_ops.square(vector))
@@ -287,7 +287,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
# after running optimizer.minimize().
# Bug reference: b/64065260
vector_initial_value = [7., 7.]
- vector = variables.Variable(vector_initial_value, 'vector')
+ vector = variables.VariableV1(vector_initial_value, 'vector')
loss = math_ops.reduce_sum(math_ops.square(vector))
optimizer = external_optimizer.ScipyOptimizerInterface(
@@ -301,7 +301,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
def test_callbacks(self):
vector_val = np.array([7., -2.], dtype=np.float32)
- vector = variables.Variable(vector_val, 'vector')
+ vector = variables.VariableV1(vector_val, 'vector')
minimum_location_val = np.arange(2)
minimum_location = constant_op.constant(
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
index b1fc50a21f..a25455e95d 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -110,10 +110,11 @@ def _get_workers(num_workers, steps, workers):
class ModelAverageOptimizerTest(test.TestCase):
+
def _run(self, train_op, sess):
sess.run(train_op)
- def test1Workers2Period(self):
+ def disabled_test1Workers2Period(self):
num_workers = 2
steps = 2
num_ps = 1
diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py
index 0bcf5d230a..1cf9901dc0 100644
--- a/tensorflow/contrib/opt/python/training/powersign_test.py
+++ b/tensorflow/contrib/opt/python/training/powersign_test.py
@@ -81,9 +81,9 @@ class PowerSignTest(test.TestCase):
global_step = resource_variable_ops.ResourceVariable(
0, trainable=False)
else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- global_step = variables.Variable(
+ var0 = variables.VariableV1(var0_np)
+ var1 = variables.VariableV1(var1_np)
+ global_step = variables.VariableV1(
0, trainable=False)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
@@ -188,9 +188,9 @@ class PowerSignTest(test.TestCase):
global_step = resource_variable_ops.ResourceVariable(
0, trainable=False)
else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- global_step = variables.Variable(
+ var0 = variables.VariableV1(var0_np)
+ var1 = variables.VariableV1(var1_np)
+ global_step = variables.VariableV1(
0, trainable=False)
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index 05bcf2cfa3..a2fd8fbd87 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -54,9 +54,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -105,9 +105,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -164,9 +164,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1], size[2])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -254,9 +254,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -310,9 +310,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -383,9 +383,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(sample_size_2, size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = ops.IndexedSlices(
constant_op.constant(grad_np, dtype=dtypes.float32),
@@ -463,9 +463,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(sample_size, size[1], size[2])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = ops.IndexedSlices(
constant_op.constant(grad_np, dtype=dtypes.float32),
@@ -533,9 +533,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
gbar_weight = 0.1
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -628,9 +628,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3 = np.zeros_like(mat_g3_a)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = array_ops.placeholder(dtypes.float32, shape=size)
@@ -705,9 +705,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3 = np.zeros_like(mat_g3_a)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = array_ops.placeholder(dtypes.float32, shape=size)
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD
index 72ea777ca7..d50b52b8ff 100644
--- a/tensorflow/contrib/predictor/BUILD
+++ b/tensorflow/contrib/predictor/BUILD
@@ -27,7 +27,7 @@ py_library(
":contrib_estimator_predictor",
":core_estimator_predictor",
":saved_model_predictor",
- "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:estimator_py",
],
)
@@ -89,7 +89,6 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/saved_model:signature_constants",
],
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index e5790a6e13..7575b1b6cd 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -418,10 +418,11 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor,
transpose_b=layer_op.get_attr('transpose_b'),
name=new_layer_name)
elif layer_op.type == 'DepthwiseConv2dNative':
+ # We don't copy dilation rate because we reuse the input SpaceToBatch
+ # and create our own BatchToSpace operation below.
conv = nn.depthwise_conv2d(
input_tensor,
weight_tensor,
- rate=layer_op.get_attr('dilations'),
strides=layer_op.get_attr('strides'),
padding=layer_op.get_attr('padding'),
name=new_layer_name)
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 5e63d33db8..afb9de8370 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -461,8 +461,8 @@ class _LayerMatch(object):
return self._bias_add_op
-def _FollowedByFakeQuant(tensor):
- """Returns True if the tensor is followed by a FakeQuant."""
+def _GetFollowingFakeQuantOp(tensor):
+ """Returns the following FakeQuant op if it exists else None."""
fake_quant_ops = set([
'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
'FakeQuantWithMinMaxVarsPerChannel'
@@ -472,11 +472,11 @@ def _FollowedByFakeQuant(tensor):
while consumers:
c = consumers.pop()
if c.type in fake_quant_ops:
- return True
+ return c
elif c.type in pass_through_ops:
for output in c.outputs:
consumers.extend(output.consumers())
- return False
+ return None
def _InsertQuantOp(context,
@@ -559,44 +559,77 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- if _FollowedByFakeQuant(inputs):
- return
-
- if moving_avg:
- quant = (
- quant_ops.MovingAvgQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- ema_decay=ema_decay,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
+ fake_quant_op = _GetFollowingFakeQuantOp(inputs)
+
+ # If we find that we are attempting to insert a fake quant op following
+ # a fake quant, we skip inserting a fake quant op
+
+ if fake_quant_op is None:
+ if moving_avg:
+ quant = (
+ quant_ops.MovingAvgQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ ema_decay=ema_decay,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+ else:
+ quant = (
+ quant_ops.LastValueQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+
+ if quant_delay and quant_delay > 0:
+ activate_quant = math_ops.greater_equal(
+ common.CreateOrGetQuantizationStep(),
+ quant_delay,
+ name=name_prefix + '/activate_quant')
+ quant = control_flow_ops.cond(
+ activate_quant,
+ lambda: quant,
+ lambda: inputs,
+ name=name_prefix + '/delayed_quant')
else:
- quant = (
- quant_ops.LastValueQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
-
- if quant_delay and quant_delay > 0:
- activate_quant = math_ops.greater_equal(
- common.CreateOrGetQuantizationStep(),
- quant_delay,
- name=name_prefix + '/activate_quant')
- quant = control_flow_ops.cond(
- activate_quant,
- lambda: quant,
- lambda: inputs,
- name=name_prefix + '/delayed_quant')
-
+ # If a fake quant op is present already, make sure that
+ # any downstream use of the tensor reroutes to the appropriate quantized
+ # tensor. If there is no quant_delay, this is simply the output of the
+ # fake quant op. If there is a quant delay, we reroute to the output
+ # of the delayed quant operation, which inserts quantization only after
+ # a specified quant_delay
+
+ quant = fake_quant_op.outputs[0]
+ if quant_delay and quant_delay > 0:
+ name_prefix = '/'.join(quant.name.split('/')[:-1])
+ quant = quant.graph.get_tensor_by_name(name_prefix +
+ '/delayed_quant/Merge:0')
+ pruned_consumer_set = set()
+ for consumer in consumers:
+ fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0])
+ if (fake_quant_dest_op is None or
+ fake_quant_dest_op.name != fake_quant_op.name):
+ pruned_consumer_set.add(consumer)
+ consumers = pruned_consumer_set
+
+ # If we have
+ # input->pass_through->fake_quant
+ # there is nothing to reroute.
+ #
+ # If we have
+ # input-> pass_through->fake_quant
+ # |-> consumer
+ # Then we reroute such that:
+ # input-> pass_through->fake_quant
+ # |-> consumer
if consumers:
tensors_modified_count = common.RerouteTensor(
quant, inputs, can_modify=consumers)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index e80d2183a6..a9fc6c3c61 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import template
from tensorflow.python.platform import googletest
@@ -306,6 +307,42 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# No ops should be inserted or removed.
self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
+ def testWithSharedWeights(self):
+
+ self._RunTestOverAllRewrites(self._TestWithSharedWeights)
+ self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights)
+
+ def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1):
+ self._TestWithSharedWeights(rewrite_fn, quant_delay)
+
+ def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
+ with ops.Graph().as_default() as g:
+ conv = template.make_template('shared_weights_conv', self._ConvLayer)
+ conv()
+ conv()
+ if quant_delay is None:
+ rewrite_fn()
+ else:
+ rewrite_fn(quant_delay=quant_delay)
+
+ conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
+ weights_quants = [
+ op for op in g.get_operations()
+ if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars'
+ ]
+ # Check that the shared weights variable is not quantized multiple times
+ self.assertTrue(len(weights_quants) == 1)
+ weights_quant_tensor = weights_quants[0].outputs[0]
+ if quant_delay:
+ delayed_weights_quants = [
+ op for op in g.get_operations()
+ if 'weights_quant' in op.name and op.type == 'Merge'
+ ]
+ self.assertTrue(len(delayed_weights_quants) == 1)
+ weights_quant_tensor = delayed_weights_quants[0].outputs[0]
+ # Check that the Conv2D operations get the quantized weights
+ self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
+
def _ConvLayer(
self, input_tensor=None, scope='test', pre_activation_bypass=False,
post_activation_bypass=False):
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index bf699db3ed..f31ad53d3c 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -163,8 +163,8 @@ class TestStateSaverWithCounters(TestStateSaver):
def __init__(self, batch_size, state_size):
super(TestStateSaverWithCounters, self).__init__(batch_size, state_size)
- self._num_state_calls = variables_lib.Variable(0)
- self._num_save_state_calls = variables_lib.Variable(0)
+ self._num_state_calls = variables_lib.VariableV1(0)
+ self._num_save_state_calls = variables_lib.VariableV1(0)
def state(self, name):
with ops_lib.control_dependencies(
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index 4ca5274b2e..291ff83791 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -92,10 +92,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:saver",
"//tensorflow/python:util",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:export",
- "//tensorflow/python/estimator:keras",
- "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras:engine",
"//tensorflow/python/saved_model",
],
@@ -111,6 +108,7 @@ py_test(
":keras_saved_model",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/contrib/session_bundle/exporter_test.py b/tensorflow/contrib/session_bundle/exporter_test.py
index 86df425da0..68419ffea0 100644
--- a/tensorflow/contrib/session_bundle/exporter_test.py
+++ b/tensorflow/contrib/session_bundle/exporter_test.py
@@ -64,10 +64,10 @@ class SaveRestoreShardedTest(test.TestCase):
# v2 is an unsaved variable derived from v0 and v1. It is used to
# exercise the ability to run an init op when restoring a graph.
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(10, name="v0")
+ v0 = variables.VariableV1(10, name="v0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(20, name="v1")
- v2 = variables.Variable(1, name="v2", trainable=False, collections=[])
+ v1 = variables.VariableV1(20, name="v1")
+ v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[])
assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1))
init_op = control_flow_ops.group(assign_v2, name="init_op")
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 00c855daa3..398ac314f4 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -518,7 +518,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":client_lib",
- "//tensorflow/contrib/estimator:head",
+ "//tensorflow/contrib/estimator:estimator_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 0042d37acd..6e3bfbb9bd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -446,6 +446,10 @@ class TensorForestEstimator(estimator.Estimator):
Returns:
A `TensorForestEstimator` instance.
"""
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+
super(TensorForestEstimator, self).__init__(
model_fn=get_model_fn(
params.fill(),
@@ -564,6 +568,10 @@ class MultiForestMultiHeadEstimator(estimator.Estimator):
local_eval=False):
"""See TensorForestEstimator.__init__."""
model_fns = []
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+
for i in range(len(params_list)):
params = params_list[i].fill()
model_fns.append(
@@ -709,6 +717,11 @@ class CoreTensorForestEstimator(core_estimator.Estimator):
Returns:
A `TensorForestEstimator` instance.
"""
+ # Override default number of trainers if config is provided.
+ if num_trainers == 1 and config is not None:
+ num_trainers = max(1, config.num_worker_replicas)
+ if trainer_id == 0 and config is not None:
+ trainer_id = config.global_id_in_cluster
super(CoreTensorForestEstimator, self).__init__(
model_fn=get_model_fn(
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
index 1c4e18dbda..0b02bdcb50 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
@@ -27,7 +27,7 @@ from tensorflow.python.platform import googletest
class ScatterAddNdimTest(test_util.TensorFlowTestCase):
def test1dim(self):
- input_data = variables.Variable(
+ input_data = variables.VariableV1(
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.])
indices = [[1], [10]]
updates = [100., 200.]
@@ -40,8 +40,8 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
input_data.eval())
def test3dim(self):
- input_data = variables.Variable([[[1., 2., 3.], [4., 5., 6.]],
- [[7., 8., 9.], [10., 11., 12.]]])
+ input_data = variables.VariableV1([[[1., 2., 3.], [4., 5., 6.]],
+ [[7., 8., 9.], [10., 11., 12.]]])
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100., 200.]
@@ -53,7 +53,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
def testNoUpdates(self):
init_val = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]]
- input_data = variables.Variable(init_val)
+ input_data = variables.VariableV1(init_val)
indices = []
updates = []
@@ -64,7 +64,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
def testBadInput(self):
init_val = [[[1., 2., 3.], [4., 5., 6.]], [[7., 8., 9.], [10., 11., 12.]]]
- input_data = variables.Variable(init_val)
+ input_data = variables.VariableV1(init_val)
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100.]
with self.cached_session():
@@ -75,8 +75,8 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
self.assertAllEqual(init_val, input_data.eval())
def testIncompleteIndices(self):
- input_data = variables.Variable([[[1., 2., 3.], [4., 5., 6.]],
- [[7., 8., 9.], [10., 11., 12.]]])
+ input_data = variables.VariableV1([[[1., 2., 3.], [4., 5., 6.]],
+ [[7., 8., 9.], [10., 11., 12.]]])
indices = [[0, 0], [1, 1]]
updates = [[100., 200., 300.], [400., 500., 600.]]
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
index f3a1ef0d47..52cb0bd9f9 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
@@ -94,7 +94,7 @@ class TrtConvertTest(test_util.TensorFlowTestCase):
with g.device("/GPU:0"):
inp = array_ops.placeholder(
dtype=dtypes.float32, shape=[None, 1, 1], name="input")
- var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
+ var = variables.VariableV1([[[1.0]]], dtype=dtypes.float32, name="v1")
add = inp + var.value()
mul = inp * add
add = mul + add
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 21c0c30c19..57797214d1 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -1,4 +1,5 @@
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
package(
default_visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index c230919168..cb1f707028 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -159,7 +159,12 @@ py_test(
],
shard_count = 4,
srcs_version = "PY2AND3",
- tags = ["no_pip_gpu"], # b/63391119
+ tags = [
+ "no_pip_gpu", # b/63391119
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan",
+ ],
deps = [
":estimators",
":feature_keys",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index 647455ae42..04d17bc123 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -104,7 +104,7 @@ class EvaluationMetricsTests(test.TestCase):
"ticker":
array_ops.reshape(
math_ops.cast(
- variables.Variable(
+ variables.VariableV1(
name="ticker",
initial_value=0,
dtype=dtypes.int64,
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 4e0b61227e..0c4bdab191 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -36,6 +36,27 @@ cc_library(
)
py_library(
+ name = "async_checkpoint",
+ srcs = ["python/tpu/async_checkpoint.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:summary_ops_v2",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_library(
name = "tpu_estimator",
srcs = [
"python/tpu/error_handling.py",
@@ -46,6 +67,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":async_checkpoint",
":tpu_lib",
"//tensorflow/compiler/xla/experimental/xla_sharding",
"//tensorflow/compiler/xla/python_api:xla_shape",
@@ -81,6 +103,9 @@ tf_gen_op_libs(
],
deps = [
"//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
+ "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils",
+ "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils",
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
],
@@ -100,12 +125,17 @@ tf_custom_op_library(
],
deps = [
"//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
+ "//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils",
+ "//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils",
"//tensorflow/core:lib_proto_parsing",
],
)
tf_gen_op_wrapper_py(
name = "tpu_ops",
+ hidden = [
+ "SendTPUEmbeddingGradients",
+ ],
deps = [
":cross_replica_ops_op_lib",
":heartbeat_ops_op_lib",
@@ -225,7 +255,10 @@ py_library(
":tpu_py",
"//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
+ "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py",
"//tensorflow/contrib/tpu/proto:topology_proto_py",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 3c0456dc2f..6ce6b779a2 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -55,6 +55,9 @@
@@TPUDistributionStrategy
@@keras_to_tpu_model
+
+@@AsyncCheckpointSaverHook
+@@TPUInMemoryEvalHook
"""
from __future__ import absolute_import
@@ -64,6 +67,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
+from tensorflow.contrib.tpu.python.tpu.async_checkpoint import *
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 18b98939b8..5c27d59f82 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -14,10 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
+#include "tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h"
+#include "tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
@@ -53,230 +58,354 @@ namespace tensorflow {
// saving a checkpoint, the model must Retrieve the parameters back into the
// host CPU memory.
-REGISTER_OP("TPUEmbeddingLoadGradientDescentParameters")
- .Input("parameters: float32")
- .Attr("tpu_embedding_config: string")
- .Attr("table_id: int >= 0")
- .Attr("num_hosts: int >= 1")
- .Attr("host_id: int >= 0")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Load an embedding table shard into TPU memory for use with GradientDescent.
-
-TPU embeddings use dedicated per-optimizer Ops for loading and retrieving
-trainable variables and optimizer state from TPU memory. This op enables
-functionality equivalent to GradientDescentOptimizer.
-
-parameters: The shard of the embedding table resident on the host executing this
- op. For single-TPU models, this is the entire embedding table.
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
-table_id: The id of the table specified in the tpu_embedding_config.
-num_hosts: The number of CPU hosts in the distributed training job.
-host_id: Which CPU host in the distributed training job will execute this op.
-)doc");
+namespace {
-namespace tpu_embedding_config_util {
+void RegisterPerTableLoadAndRetrieveOps();
-Status GradientDescentShapes(shape_inference::InferenceContext *c) {
- string config_string;
- TF_RETURN_IF_ERROR(c->GetAttr("tpu_embedding_config", &config_string));
- tpu::TPUEmbeddingConfiguration config;
- if (!config.ParseFromString(config_string)) {
- return errors::InvalidArgument("Malformed tpu_embedding_config.");
+class RegisterPerTableLoadAndRetrieveOpsOnConstruction {
+ public:
+ RegisterPerTableLoadAndRetrieveOpsOnConstruction() {
+ RegisterPerTableLoadAndRetrieveOps();
}
-
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_descriptor_size();
- if (table_id >= num_tables) {
- return errors::InvalidArgument("Table id >= num_tables");
+};
+
+// Object whose constructor does registrations.
+RegisterPerTableLoadAndRetrieveOpsOnConstruction
+ register_per_table_load_and_retrieve_ops_var;
+
+Status RegisterPerTableLoadOpsForAlgorithmBody(
+ tpu::OptimizationAlgorithm alg, bool is_debug_op,
+ OpRegistrationData* op_reg_data) {
+ tpu::GradientAccumulationSupport grad_accum_support;
+ TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
+
+ std::vector<tpu::StateVariableSpecification> state_variable_specs;
+ TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
+ alg,
+ grad_accum_support == tpu::GradientAccumulationSupport::kSupported &&
+ is_debug_op,
+ &state_variable_specs));
+ auto* op_def = &op_reg_data->op_def;
+ op_def->set_name(
+ strings::StrCat("LoadTPUEmbedding", GetOptimizationAlgorithmName(alg),
+ "Parameters", (is_debug_op ? "GradAccumDebug" : "")));
+ // It is important for the order of the inputs to the op defined here
+ // to match the order in input_names because the indexes are used in
+ // the combining transformation.
+ for (const auto& parameter : state_variable_specs) {
+ if (parameter.has_user_defined() || is_debug_op) {
+ auto* arg = op_def->add_input_arg();
+ arg->set_name(parameter.name());
+ arg->set_description(
+ strings::StrCat("Value of ", parameter.name(), " used in the ",
+ GetOptimizationAlgorithmFriendlyName(alg),
+ " optimization algorithm."));
+ arg->set_type(DT_FLOAT);
+ }
}
- int64 width = config.table_descriptor(table_id).dimension();
- int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
-
- TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
+ {
+ auto* table_id_attr = op_def->add_attr();
+ table_id_attr->set_name("table_id");
+ table_id_attr->set_type("int");
+ table_id_attr->set_has_minimum(true);
+ table_id_attr->set_minimum(-1);
+ table_id_attr->mutable_default_value()->set_i(-1);
+ }
+ {
+ auto* table_name_attr = op_def->add_attr();
+ table_name_attr->set_name("table_name");
+ table_name_attr->set_type("string");
+ table_name_attr->mutable_default_value()->set_s("");
+ }
+ {
+ auto* num_shards_attr = op_def->add_attr();
+ num_shards_attr->set_name("num_shards");
+ num_shards_attr->set_type("int");
+ }
+ {
+ auto* shard_id_attr = op_def->add_attr();
+ shard_id_attr->set_name("shard_id");
+ shard_id_attr->set_type("int");
+ }
+ op_def->set_summary("Load embedding parameters for a single table.");
+ string parameter_descriptions;
+ for (const auto& parameter : state_variable_specs) {
+ if (parameter.has_user_defined() || is_debug_op) {
+ strings::Appendf(&parameter_descriptions,
+ R"(
+%s: A tensor containing the initial embedding table %s to use in embedding
+lookups using the %s optimization algorithm.)",
+ parameter.name().c_str(), parameter.name().c_str(),
+ GetOptimizationAlgorithmFriendlyName(alg).c_str());
+ }
+ }
+ op_def->set_description(strings::Printf(R"doc(
+An op that loads optimization parameters into HBM for embedding. Must be
+preceded by a ConfigureTPUEmbeddingHost op that sets up the correct
+embedding table configuration. For example, this op is used to install
+parameters that are loaded from a checkpoint before a training loop is
+executed.
+%s
+table_name: Name of this table; must match a name in the
+ EmbeddingLayerConfiguration proto (overrides table_id).
+num_shards: Number of shards into which the embedding tables are divided.
+shard_id: Identifier of shard for this operation.
+table_id: Index of this table in the EmbeddingLayerConfiguration proto
+ (deprecated).
+)doc",
+ parameter_descriptions.c_str()));
+ op_def->set_is_commutative(false);
+ op_def->set_is_aggregate(false);
+ op_def->set_is_stateful(true);
+ auto shape_inference_function =
+ [state_variable_specs,
+ is_debug_op](shape_inference::InferenceContext* c) -> Status {
+ int table_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
+ string table_name;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
+ // Exactly one must be non-default.
+ if ((table_id >= 0) == (!table_name.empty())) {
+ return errors::InvalidArgument(
+ "exactly one of table_id or table_name must be non-default");
+ }
+ int num_shards;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
+ int shard_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
+ const int user_param_count =
+ std::count_if(state_variable_specs.begin(), state_variable_specs.end(),
+ [&](const tpu::StateVariableSpecification& sv) {
+ return sv.has_user_defined() || is_debug_op;
+ });
+ std::vector<shape_inference::ShapeHandle> inputs(user_param_count);
+ int input_index = 0;
+ for (int i = 0; i < state_variable_specs.size(); ++i) {
+ if (state_variable_specs[i].has_user_defined() || is_debug_op) {
+ std::vector<shape_inference::ShapeHandle> input_temp;
+ TF_RETURN_IF_ERROR(
+ c->input(state_variable_specs[i].name(), &input_temp));
+ if (input_temp.size() != 1) {
+ return errors::InvalidArgument("each input to be rank 1");
+ }
+ inputs[input_index] = input_temp[0];
+ ++input_index;
+ }
+ }
+ // Verify shapes have rank 2 and are compatible when they are
+ // required to be valid.
+ shape_inference::ShapeHandle parameter_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(inputs[0], 2, &parameter_shape));
+ for (int j = 1; j < user_param_count; ++j) {
+ shape_inference::ShapeHandle accumulator_j_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(inputs[j], 2, &accumulator_j_shape));
+ shape_inference::ShapeHandle merged;
+ TF_RETURN_IF_ERROR(
+ c->Merge(parameter_shape, accumulator_j_shape, &merged));
+ }
+ return Status::OK();
+ };
+ op_reg_data->shape_inference_fn = shape_inference_function;
return Status::OK();
}
-} // namespace tpu_embedding_config_util
-
-REGISTER_OP("TPUEmbeddingRetrieveGradientDescentParameters")
- .Output("parameters: float32")
- .Attr("tpu_embedding_config: string")
- .Attr("table_id: int")
- .Attr("num_hosts: int")
- .Attr("host_id: int")
- .SetIsStateful()
- .SetShapeFn(tpu_embedding_config_util::GradientDescentShapes)
- .Doc(R"doc(
-Retrieve an embedding table shard from TPU memory.
-
-TPU embeddings use dedicated per-optimizer Ops for loading and retrieving
-trainable variables and optimizer state from TPU memory. This op enables
-functionality equivalent to GradientDescentOptimizer.
-
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
-table_id: The id of the table specified in tpu_embedding_config.
-num_hosts: The number of CPU hosts in the distributed training job.
-host_id: Which CPU host in the distributed training job will execute this op.
-)doc");
-
-REGISTER_OP("TPUEmbeddingLoadAdagradParameters")
- .Input("parameters: float32")
- .Input("accumulators: float32")
- .Attr("tpu_embedding_config: string")
- .Attr("table_id: int >= 0")
- .Attr("num_hosts: int >= 1")
- .Attr("host_id: int >= 0")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Load an embedding table shard into TensorNode memories for use with Adagrad.
-
-TPU embeddings use dedicated per-optimizer Ops for loading and retrieving
-trainable variables and optimizer state from TPU memory. This op enables
-functionality equivalent to AdagradOptimizer.
-
-parameters: The shard of the embedding table resident on the host executing this
- op. For single-TPU models, this is the entire embedding table.
-accumulators: Shard of the Adagrad accumulators resident on the host executing
- this op.
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
-table_id: The id of the table specified in the embedding_config.
-num_hosts: The number of CPU hosts in the distributed training job.
-host_id: Which CPU host in the distributed training job will execute this op.
-)doc");
-
-namespace tpu_embedding_config_util {
-
-Status AdagradShapes(shape_inference::InferenceContext *c) {
- string config_string;
- TF_RETURN_IF_ERROR(c->GetAttr("tpu_embedding_config", &config_string));
- tpu::TPUEmbeddingConfiguration config;
- if (!config.ParseFromString(config_string)) {
- return errors::InvalidArgument("Malformed tpu_embedding_config.");
+Status RegisterPerTableRetrieveOpsForAlgorithmBody(
+ tpu::OptimizationAlgorithm alg, bool is_debug_op,
+ OpRegistrationData* op_reg_data) {
+ tpu::GradientAccumulationSupport grad_accum_support;
+ TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
+
+ std::vector<tpu::StateVariableSpecification> state_variable_specs;
+ TF_CHECK_OK(GetOptimizationAlgorithmStateVariables(
+ alg,
+ grad_accum_support == tpu::GradientAccumulationSupport::kSupported &&
+ is_debug_op,
+ &state_variable_specs));
+
+ auto* op_def = &op_reg_data->op_def;
+ op_def->set_name(strings::StrCat(
+ "RetrieveTPUEmbedding", tpu::GetOptimizationAlgorithmName(alg),
+ "Parameters", (is_debug_op ? "GradAccumDebug" : "")));
+ // It is important for the order of the outputs of the op defined here
+ // to match the order in output_names because the indexes are used in
+ // the combining transformation.
+ for (const auto& parameter : state_variable_specs) {
+ if (parameter.has_user_defined() || is_debug_op) {
+ auto* arg = op_def->add_output_arg();
+ arg->set_name(parameter.name());
+ arg->set_description(
+ strings::StrCat("Parameter ", parameter.name(), " updated by the ",
+ tpu::GetOptimizationAlgorithmFriendlyName(alg),
+ " optimization algorithm."));
+ arg->set_type(DT_FLOAT);
+ }
}
-
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_descriptor_size();
- if (table_id >= num_tables) {
- return errors::InvalidArgument("Table id >= num_tables");
+ {
+ auto* table_id_attr = op_def->add_attr();
+ table_id_attr->set_name("table_id");
+ table_id_attr->set_type("int");
+ table_id_attr->set_has_minimum(true);
+ table_id_attr->set_minimum(-1);
+ table_id_attr->mutable_default_value()->set_i(-1);
}
- int64 width = config.table_descriptor(table_id).dimension();
- int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
-
- TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
- TF_RETURN_IF_ERROR(
- c->set_output("accumulators", {c->Matrix(num_rows, width)}));
+ {
+ auto* table_name_attr = op_def->add_attr();
+ table_name_attr->set_name("table_name");
+ table_name_attr->set_type("string");
+ table_name_attr->mutable_default_value()->set_s("");
+ }
+ {
+ auto* num_shards_attr = op_def->add_attr();
+ num_shards_attr->set_name("num_shards");
+ num_shards_attr->set_type("int");
+ }
+ {
+ auto* shard_id_attr = op_def->add_attr();
+ shard_id_attr->set_name("shard_id");
+ shard_id_attr->set_type("int");
+ }
+ op_def->set_summary("Retrieve embedding parameters for a single table.");
+ string parameter_descriptions;
+ for (const auto& param : state_variable_specs) {
+ if (param.has_user_defined() || is_debug_op) {
+ strings::Appendf(&parameter_descriptions,
+ R"(
+%s: A tensor containing the embedding table %s to store with the
+parameters from embedding updates using the %s optimization algorithm.)",
+ param.name().c_str(), param.name().c_str(),
+ tpu::GetOptimizationAlgorithmFriendlyName(alg).c_str());
+ }
+ }
+ op_def->set_description(strings::Printf(R"doc(
+An op that retrieves optimization parameters from embedding to host
+memory. Must be preceded by a ConfigureTPUEmbeddingHost op that sets up
+the correct embedding table configuration. For example, this op is
+used to retrieve updated parameters before saving a checkpoint.
+%s
+table_name: Name of this table; must match a name in the
+ EmbeddingLayerConfiguration proto (overrides table_id).
+num_shards: Number of shards into which the embedding tables are divided.
+shard_id: Identifier of shard for this operation.
+table_id: Index of this table in the EmbeddingLayerConfiguration proto
+ (deprecated).
+)doc",
+ parameter_descriptions.c_str()));
+ op_def->set_is_commutative(false);
+ op_def->set_is_aggregate(false);
+ op_def->set_is_stateful(true);
+ auto shape_inference_function =
+ [state_variable_specs,
+ is_debug_op](shape_inference::InferenceContext* c) -> Status {
+ int table_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
+ string table_name;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
+ // Exactly one must be non-default.
+ if ((table_id >= 0) == (!table_name.empty())) {
+ return errors::InvalidArgument(
+ "exactly one of table_id or table_name must be non-default");
+ }
+ int num_shards;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
+ int shard_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
+ for (int j = 0; j < state_variable_specs.size(); ++j) {
+ if (state_variable_specs[j].has_user_defined() || is_debug_op) {
+ auto shape = c->MakeShape(
+ std::vector<shape_inference::DimensionHandle>(2, c->UnknownDim()));
+ TF_RETURN_IF_ERROR(
+ c->set_output(state_variable_specs[j].name(),
+ std::vector<shape_inference::ShapeHandle>(1, shape)));
+ }
+ }
+ return Status::OK();
+ };
+ op_reg_data->shape_inference_fn = shape_inference_function;
return Status::OK();
}
-} // namespace tpu_embedding_config_util
-
-REGISTER_OP("TPUEmbeddingRetrieveAdagradParameters")
- .Output("parameters: float32")
- .Output("accumulators: float32")
- .Attr("tpu_embedding_config: string")
- .Attr("table_id: int >= 0")
- .Attr("num_hosts: int >= 1")
- .Attr("host_id: int >= 0")
- .SetIsStateful()
- .SetShapeFn(tpu_embedding_config_util::AdagradShapes)
- .Doc(R"doc(
-Retrieve an embedding table shard from TPU memory.
-
-TPU embeddings use dedicated per-optimizer Ops for loading and retrieving
-trainable variables and optimizer state from TPU memory. This op enables
-functionality equivalent to AdagradOptimizer.
-
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
-table_id: The id of the table specified in the embedding_config_json.
-num_hosts: The number of CPU hosts in the distributed training job.
-host_id: Which CPU host in the distributed training job will execute this op.
-)doc");
-
-REGISTER_OP("TPUEmbeddingEnqueueSparseBatch")
- .Input("sample_indices: num_tables * int32")
- .Input("embedding_indices: num_tables * int32")
- .Input("aggregation_weights: num_tables * float32")
- .Attr("num_tables: int")
- .Attr("device_ordinal: int = -1")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-An op that feeds a batch of embedding indices and weights to the TPU.
-
-Embedding lookups are equivalent to sparse-dense matrix multiplications: the
-sparse matrix contains nonzeros in column j in order to retrieve row j from the
-embedding table.
-
-The three Tensor list arguments (sample_indices, embedding_indices, and
-aggregation_weights) represent these sparse matrices in COO format. The Tensor
-lists each have one entry for each embedding table specified in the model.
-For the kth embedding table, the three Tensors at position k in the list
-specify a COO-format sparse matrix. For the kth table, the row indices,
-column indices, and nonzero values of the COO sparse matrix are specified by
-sample_indices[k], embedding_indices[k], and aggregation_weights[k],
-respectively. Entries must be sorted by row index, then by column index.
-
-There should be at most one TPUEmbeddingEnqueueSparseBatch op in a signle
-training step per TPU shard.
-
-sample_indices: A list of rank 1 Tensors specifying row indices of the COO
- sparse matrix representing the embedding lookups for each table.
-embedding_indices: A list of rank 1 Tensors specifying column indices of the
- COO sparse matrix representing the embedding lookups for each table.
-aggregation_weights: A list of rank 1 Tensors specifying the nonzero values
- of the COO sparse matrix representing the embedding lookups for each table.
-device_ordinal: The TPU device to use. This should be -1 when the Op
- is running on a TPU device, and >= 0 when the Op is running on the CPU
- device.
-)doc");
-
-namespace tpu_embedding_config_util {
-
-Status ActivationShapes(shape_inference::InferenceContext *c) {
- string config_string;
- TF_RETURN_IF_ERROR(c->GetAttr("tpu_embedding_config", &config_string));
- tpu::TPUEmbeddingConfiguration config;
- if (!config.ParseFromString(config_string)) {
- return errors::InvalidArgument("Malformed tpu_embedding_config.");
+void RegisterPerTableLoadAndRetrieveOps() {
+ // Load ops
+ for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) {
+ OpRegistry::Global()->Register(
+ [alg](OpRegistrationData* op_reg_data) -> Status {
+ return RegisterPerTableLoadOpsForAlgorithmBody(alg, false,
+ op_reg_data);
+ });
+ tpu::GradientAccumulationSupport grad_accum_support;
+ TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
+ if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
+ // TODO(gkurian): Condition this on being used internally within Google.
+ OpRegistry::Global()->Register(
+ [alg](OpRegistrationData* op_reg_data) -> Status {
+ return RegisterPerTableLoadOpsForAlgorithmBody(alg, true,
+ op_reg_data);
+ });
+ }
}
- int64 batch_size = config.batch_size_per_tensor_core();
- int64 num_tables = config.table_descriptor_size();
- for (int table_id = 0; table_id < num_tables; ++table_id) {
- int64 width = config.table_descriptor(table_id).dimension();
- int64 num_features = config.table_descriptor(table_id).vocabulary_size();
- c->set_output(table_id, c->Matrix(batch_size * num_features, width));
+ // Retrieve ops
+ for (tpu::OptimizationAlgorithm alg : tpu::GetOptimizationAlgorithms()) {
+ OpRegistry::Global()->Register(
+ [alg](OpRegistrationData* op_reg_data) -> Status {
+ return RegisterPerTableRetrieveOpsForAlgorithmBody(alg, false,
+ op_reg_data);
+ });
+ tpu::GradientAccumulationSupport grad_accum_support;
+ TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
+ if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
+ // TODO(gkurian): Condition this on being used internally within Google.
+ OpRegistry::Global()->Register(
+ [alg](OpRegistrationData* op_reg_data) -> Status {
+ return RegisterPerTableRetrieveOpsForAlgorithmBody(alg, true,
+ op_reg_data);
+ });
+ }
}
- return Status::OK();
}
-} // namespace tpu_embedding_config_util
+} // namespace
-REGISTER_OP("TPUEmbeddingReceiveActivations")
- .Output("outputs: num_tables * float")
- .Attr("num_tables: int >= 1")
- .Attr("tpu_embedding_config: string")
+REGISTER_OP("RecvTPUEmbeddingActivations")
+ .Output("outputs: num_outputs * float")
+ .Attr("num_outputs: int >= 1")
+ .Attr("config: string")
.SetIsStateful()
- .SetShapeFn(tpu_embedding_config_util::ActivationShapes)
+ .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
+ string config_string;
+ TF_RETURN_IF_ERROR(c->GetAttr("config", &config_string));
+ tpu::TPUEmbeddingConfiguration config;
+ if (!config.ParseFromString(config_string)) {
+ return errors::InvalidArgument("Malformed tpu_embedding_config.");
+ }
+ tpu::AddDefaultEmbeddingOutputLayoutIfNeeded(&config);
+ std::vector<TensorShapeProto> output_shapes;
+ TF_RETURN_IF_ERROR(ComputeOutputTensorShapes(config, &output_shapes));
+ if (c->num_outputs() != output_shapes.size()) {
+ return errors::InvalidArgument("num outputs != size of output shapes");
+ }
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ shape_inference::ShapeHandle output_shape;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeProto(output_shapes[i], &output_shape));
+ c->set_output(i, output_shape);
+ }
+ return Status::OK();
+ })
.Doc(R"doc(
An op that receives embedding activations on the TPU.
The TPU system performs the embedding lookups and aggregations specified by
-the arguments to TPUEmbeddingEnqueueSparseBatch. The results of these
-aggregations are visible to the Tensorflow Graph as the outputs of a
-TPUEmbeddingDequeueActivations Op. This op returns a list containing one
-Tensor of activations per table specified in the model. There can be at most
-one ReceieveActivations op in the TPU graph.
+the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The
+results of these aggregations are visible to the Tensorflow Graph as the
+outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing
+one Tensor of activations per table specified in the model. There can be at
+most one RecvTPUEmbeddingActivations op in the TPU graph.
outputs: A TensorList of embedding activations containing one Tensor per
embedding table in the model.
-num_tables: The number of output activation tensors, equal to the number of
+num_outputs: The number of output activation tensors, equal to the number of
embedding tables in the model.
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
+config: Serialized TPUEmbeddingConfiguration proto.
)doc");
REGISTER_OP("TPUEmbeddingActivations")
@@ -306,12 +435,27 @@ lookup_id: Identifier of the set of embedding indices which produced these
activations.
)doc");
-REGISTER_OP("TPUEmbeddingSendGradients")
- .Input("gradients: num_tables * float32")
- .Attr("num_tables: int >= 1")
- .Attr("tpu_embedding_config: string")
+REGISTER_OP("SendTPUEmbeddingGradients")
+ .Input("inputs: N * float32")
+ .Input("learning_rates: NN * float32")
+ .Attr("N: int >= 1")
+ .Attr("NN: int >= 0 = 0")
+ .Attr("config: string")
.SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
+ .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
+ int nn;
+ TF_RETURN_IF_ERROR(c->GetAttr("NN", &nn));
+ std::vector<shape_inference::ShapeHandle> learning_rates;
+ TF_RETURN_IF_ERROR(c->input("learning_rates", &learning_rates));
+ for (int i = 0; i < nn; ++i) {
+ // Verify that each learning_rates element is scalar
+ shape_inference::ShapeHandle learning_rates_shape;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(learning_rates[i], 0, &learning_rates_shape));
+ }
+
+ return Status::OK();
+ })
.Doc(R"doc(
An op that performs gradient updates of embedding tables.
@@ -321,8 +465,120 @@ with respect to the embedding activations. The embedding tables are updated
from these gradients via the optimizer specified in the configuration given
to tpu.initialize_system.
-gradients: A TensorList of gradients with which to update embedding tables.
-tpu_embedding_config: Serialized TPUEmbeddingConfiguration proto.
+inputs: A TensorList of gradients with which to update embedding tables.
+ It contains one tensor per embedding table in the model.
+learning_rates: A list of float32 scalars, one for each embedding table,
+ containing the learning rates for each table when dynamic learning rate is
+ enabled through the OptimizationParameters in TPUEmbeddingConfiguration.
+ When the learning rate is constant, the list should be empty.
+config: Serialized TPUEmbeddingConfiguration proto.
+)doc");
+
+REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
+ .Input("batch: N * int32")
+ .Attr("N: int")
+ .Attr("device_ordinal: int = -1")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+An op that enqueues a list of input batch tensors to TPUEmbedding.
+
+batch: A list of 1D tensors, one for each embedding table, containing the
+ indices into the tables.
+device_ordinal: The TPU device to use. Should be >= 0 and less than the number
+ of TPU cores in the task on which the node is placed.
+)doc");
+
+REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
+ .Input("sample_indices: N * int32")
+ .Input("embedding_indices: N * int32")
+ .Input("aggregation_weights: N * float32")
+ .Attr("N: int")
+ .Attr("device_ordinal: int = -1")
+ .Attr("combiners: list(string) = []")
+ .SetIsStateful()
+ .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
+ std::vector<string> combiners;
+ TF_RETURN_IF_ERROR(c->GetAttr("combiners", &combiners));
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ if (!combiners.empty() && combiners.size() != n) {
+ return errors::InvalidArgument("Invalid length of combiners. Have ",
+ combiners.size(), " but expected 0 or ",
+ n);
+ }
+
+ return Status::OK();
+ })
+ .Doc(R"doc(
+An op that enqueues TPUEmbedding input indices from a SparseTensor.
+
+This Op eases the porting of code that uses embedding_lookup_sparse(),
+although some Python preprocessing of the SparseTensor arguments to
+embedding_lookup_sparse() is required to produce the arguments to this Op,
+since only a single EnqueueTPUEmbeddingSparseBatch Op is allowed per training
+step.
+
+The tensors at corresponding positions in the three input lists
+must have the same shape, i.e. rank 1 with dim_size() equal to the total
+number of lookups into the table described by the corresponding table_id.
+
+sample_indices: A list of Rank 1 Tensors specifying the training example and
+ feature to which the corresponding embedding_indices and aggregation_weights
+ values belong. sample_indices[i] must equal b * nf + f, where nf is the
+ number of features from the corresponding table, f is in [0, nf), and
+ b is in [0, batch size).
+embedding_indices: A list of Rank 1 Tensors, indices into the embedding tables.
+aggregation_weights: A list of Rank 1 Tensors containing per sample -- i.e. per
+ (training example, feature) -- aggregation weights.
+device_ordinal: The TPU device to use. Should be >= 0 and less than the number
+ of TPU cores in the task on which the node is placed.
+combiners: A list of string scalars, one for each embedding table that specify
+ how to normalize the embedding activations after weighted summation.
+ Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
+ the sum of the weights be 0 for 'mean' or the sum of the squared weights be
+ 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
+ all tables.
+)doc");
+
+REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
+ .Input("sample_indices: N * int32")
+ .Input("embedding_indices: N * int32")
+ .Input("aggregation_weights: N * float32")
+ .Attr("N: int")
+ .Attr("device_ordinal: int = -1")
+ .Attr("combiners: list(string) = []")
+ .Attr("table_ids: list(int)")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+This Op eases the porting of code that uses tf.nn.embedding_lookup_sparse().
+
+sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond
+to ith feature. table_ids[i] indicates which embedding table to look up ith
+feature.
+
+The tensors at corresponding positions in the three input lists (sample_indices,
+embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1
+with dim_size() equal to the total number of lookups into the table described by
+the corresponding feature.
+
+sample_indices: A list of Rank 1 Tensors, corresponds to sp_ids.indices[:,0] in
+ embedding_lookup_sparse().
+embedding_indices: A list of Rank 1 Tensors, corresponds to sp_ids.values
+ in embedding_lookup_sparse().
+aggregation_weights: A list of Rank 1 Tensors, corresponds to sp_weights.values
+ in embedding_lookup_sparse().
+device_ordinal: The TPU device to use. Should be >= 0 and less than the number
+ of TPU cores in the task on which the node is placed.
+combiners: A list of string scalars, one for each embedding table that specify
+ how to normalize the embedding activations after weighted summation.
+ Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
+ the sum of the weights be 0 for 'mean' or the sum of the squared weights be
+ 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
+ all tables.
+table_ids: A list of int. table_ids[i] indicates which embedding table to look
+ up ith feature in the list.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index b498599962..8e6e9aa0cd 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -156,8 +156,7 @@ bool NewSession(const string& service_addr,
channel_args));
NewProfileSessionResponse new_session_response;
TF_QCHECK_OK(FromGrpcStatus(
- stub->NewSession(&context, new_session_request, &new_session_response)))
- << new_session_response.error_message();
+ stub->NewSession(&context, new_session_request, &new_session_response)));
std::cout << "Profile session succeed for host(s):"
<< str_util::Join(hostnames, ",") << std::endl;
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index 68cf510e71..292108f949 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -18,13 +18,15 @@ message Profile {
message Node {
string name = 1; // Semantics depend on contents.
Metrics metrics = 2; // May be omitted e.g. for fused instructions.
- repeated Node children = 3;
+ repeated Node children = 3; // Subjected to pruning.
// Details about what this node represents.
oneof contents {
InstructionCategory category = 4;
XLAInstruction xla = 5;
}
+
+ int32 num_children = 6; // Total number of children before pruning.
// A category of XLA instructions.
// name is a descriptive string, like "data formatting".
message InstructionCategory {
@@ -64,8 +66,8 @@ message Metrics {
// - it does not reveal the peak core FLOPS of the hardware
double flops = 2;
- // The VMEM bandwidth used to load operands from HBM, as a fraction of
- // thereotical VMEM bandwidth on the specific hardware.
+ // The memory bandwidth used to load operands, as a fraction of
+ // thereotical memory bandwidth on the specific hardware.
double memory_bandwidth = 3;
double raw_time = 11; // Elapsed core-time in picoseconds.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index fc1320501b..a43f45554f 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -22,13 +22,22 @@ message LearningRate {
}
}
+// Each optimizer's parameter proto has a link to its documentation and CPU
+// implementation (if available) for user reference.
+
+// https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L151
message AdagradParameters {
float initial_accumulator = 1;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L423
message StochasticGradientDescentParameters {
}
+// https://www.tensorflow.org/api_docs/python/tf/train/FtrlOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L192
message FtrlParameters {
float l1 = 1;
float l2 = 2;
@@ -41,21 +50,38 @@ message FtrlParameters {
// learning rate feature instead, setting the learning rate to:
// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// Here, t is the current timestep.
+//
+// https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54
+//
+// Note that the code by default implements the lazy version of Adam
+// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer)
+// unless the use_non_lazy_adam parameter is set, in which case it implements
+// the normal version of Adam that updates all parameters in the embedding
+// table, even for entries that are not used in the current minibatch
+// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
+// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in
+// order to get correct results; a warning will be printed otherwise (which may
+// change to an error in the future).
message AdamParameters {
float beta1 = 3;
float beta2 = 4;
float epsilon = 5;
float initial_m = 6;
float initial_v = 7;
+ bool use_non_lazy_adam = 8;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L271
message MomentumParameters {
float momentum = 1;
bool use_nesterov = 2;
float initial_accum = 3;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L356
message RmsPropParameters {
float rho = 1;
float momentum = 2;
@@ -64,6 +90,8 @@ message RmsPropParameters {
float initial_mom = 5;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L372
message CenteredRmsPropParameters {
float rho = 1;
float momentum = 2;
@@ -73,6 +101,7 @@ message CenteredRmsPropParameters {
float initial_mg = 6;
}
+// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
message MdlAdagradLightParameters {
float l2 = 1;
float lr_power = 2;
@@ -91,6 +120,8 @@ message MdlAdagradLightParameters {
float initial_benefit = 15;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L68
message AdadeltaParameters {
float rho = 1;
float epsilon = 2;
@@ -98,6 +129,8 @@ message AdadeltaParameters {
float initial_update = 4;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L164
message ProximalAdagradParameters {
float l1 = 1;
float l2 = 2;
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index a1aee69691..e2e4acadab 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -200,6 +200,33 @@ if platform.system() != "Windows":
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
# pylint: enable=redefined-outer-name
+ # pylint: disable=protected-access
+ def send_tpu_embedding_gradients(inputs,
+ config,
+ learning_rates=None,
+ name=None):
+ """A placeholder op for feeding per-sample gradients to the embedding layer.
+
+ Args:
+ inputs: A TensorList of gradients with which to update embedding tables.
+ Contains one tensor per embedding table in the model.
+ config: Serialized TPUEmbeddingConfiguration proto.
+ learning_rates: A TensorList of float32 scalars, one for each embedding
+ table, containing the learning rates for each table when dynamic
+ learning rate is enabled through the OptimizationParameters in
+ TPUEmbeddingConfiguration. When the learning rate is constant, the list
+ should be empty (optional).
+ name: A name for the operation (optional).
+
+ Returns:
+ A SendTPUEmbeddingGradients operation.
+ """
+ if learning_rates is None:
+ learning_rates = []
+ return gen_tpu_ops._send_tpu_embedding_gradients(
+ inputs=inputs, learning_rates=learning_rates, config=config, name=name)
+
+
else:
# We have already built the appropriate libraries into the binary via CMake
# if we have built contrib, so we don't need this
diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
new file mode 100644
index 0000000000..20b7ba0997
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
@@ -0,0 +1,202 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""Hook for asynchronous checkpointing.
+
+This hook dispatches checkpoint writing operations in a separate thread to
+allow execution to continue on the main thread.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import threading
+import time
+
+from tensorflow.core.util.event_pb2 import SessionLog
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import training_util
+from tensorflow.python.training.session_run_hook import SessionRunArgs
+from tensorflow.python.training.summary_io import SummaryWriterCache
+
+
+class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
+ """Saves checkpoints every N steps or seconds."""
+
+ def __init__(self,
+ checkpoint_dir,
+ save_secs=None,
+ save_steps=None,
+ saver=None,
+ checkpoint_basename="model.ckpt",
+ scaffold=None,
+ listeners=None):
+ """Initializes a `CheckpointSaverHook`.
+
+ Args:
+ checkpoint_dir: `str`, base directory for the checkpoint files.
+ save_secs: `int`, save every N secs.
+ save_steps: `int`, save every N steps.
+ saver: `Saver` object, used for saving.
+ checkpoint_basename: `str`, base name for the checkpoint files.
+ scaffold: `Scaffold`, use to get saver object.
+ listeners: List of `CheckpointSaverListener` subclass instances. Used for
+ callbacks that run immediately before or after this hook saves the
+ checkpoint.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of `saver` or `scaffold` should be set.
+ """
+ logging.info("Create AsyncCheckpointSaverHook.")
+ if saver is not None and scaffold is not None:
+ raise ValueError("You cannot provide both saver and scaffold.")
+ self._saver = saver
+ self._save_thread = None
+ self._checkpoint_dir = checkpoint_dir
+ self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
+ self._scaffold = scaffold
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=save_secs, every_steps=save_steps)
+ self._listeners = listeners or []
+ self._steps_per_run = 1
+ self._summary_writer = None
+ self._global_step_tensor = None
+
+ def _set_steps_per_run(self, steps_per_run):
+ self._steps_per_run = steps_per_run
+
+ def begin(self):
+ self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
+ self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ if self._global_step_tensor is None:
+ raise RuntimeError(
+ "Global step should be created to use CheckpointSaverHook.")
+ for l in self._listeners:
+ l.begin()
+
+ def after_create_session(self, session, coord):
+ global_step = session.run(self._global_step_tensor)
+
+ # We do write graph and saver_def at the first call of before_run.
+ # We cannot do this in begin, since we let other hooks to change graph and
+ # add variables in begin. Graph is finalized after all begin calls.
+ training_util.write_graph(
+ ops.get_default_graph().as_graph_def(add_shapes=True),
+ self._checkpoint_dir, "graph.pbtxt")
+ saver_def = self._get_saver().saver_def if self._get_saver() else None
+ graph = ops.get_default_graph()
+ meta_graph_def = meta_graph.create_meta_graph_def(
+ graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
+ self._summary_writer.add_graph(graph)
+ self._summary_writer.add_meta_graph(meta_graph_def)
+ # The checkpoint saved here is the state at step "global_step".
+ self._save(session, global_step)
+ self._timer.update_last_triggered_step(global_step)
+
+ def before_run(self, run_context): # pylint: disable=unused-argument
+ return SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ stale_global_step = run_values.results
+ if self._timer.should_trigger_for_step(stale_global_step +
+ self._steps_per_run):
+ # get the real value after train op.
+ global_step = run_context.session.run(self._global_step_tensor)
+ if self._timer.should_trigger_for_step(global_step):
+ self._timer.update_last_triggered_step(global_step)
+ if self._save(run_context.session, global_step):
+ run_context.request_stop()
+
+ def end(self, session):
+ if self._save_thread:
+ logging.info("Waiting for any pending checkpoints to finish.")
+ self._save_thread.join()
+
+ last_step = session.run(self._global_step_tensor)
+
+ # Save the last checkpoint synchronously if needed.
+ if last_step != self._timer.last_triggered_step():
+ self._save(session, last_step, asynchronous=False)
+
+ for l in self._listeners:
+ l.end(session, last_step)
+
+ def _save(self, session, step, asynchronous=True):
+ """Saves the latest checkpoint, returns should_stop."""
+
+ # Skip saving on step 0
+ if step == 0:
+ return
+
+ def _save_fn():
+ """Run the saver process."""
+ logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
+
+ start_time = time.time()
+ for l in self._listeners:
+ l.before_save(session, step)
+
+ self._get_saver().save(session, self._save_path, global_step=step)
+ self._summary_writer.add_session_log(
+ SessionLog(
+ status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
+ step)
+ end_time = time.time()
+ logging.info("Checkpoint actual writing time: (%.3f sec)",
+ end_time - start_time)
+ logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
+
+ for l in self._listeners:
+ l.before_save(session, step)
+
+ if not asynchronous:
+ _save_fn()
+ return
+
+ if self._save_thread is not None:
+ self._save_thread.join(timeout=0.1)
+ if self._save_thread.is_alive():
+ logging.info("Saver thread still in progress, skipping checkpoint.")
+ return
+
+ self._save_thread = threading.Thread(target=_save_fn)
+ self._save_thread.start()
+
+ def _get_saver(self):
+ if self._saver is not None:
+ return self._saver
+ elif self._scaffold is not None:
+ return self._scaffold.saver
+
+ # Get saver from the SAVERS collection if present.
+ collection_key = ops.GraphKeys.SAVERS
+ savers = ops.get_collection(collection_key)
+ if not savers:
+ raise RuntimeError(
+ "No items in collection {}. Please add a saver to the collection "
+ "or provide a saver or scaffold.".format(collection_key))
+ elif len(savers) > 1:
+ raise RuntimeError(
+ "More than one item in collection {}. "
+ "Please indicate which one to use by passing it to the constructor."
+ .format(collection_key))
+
+ self._saver = savers[0]
+ return savers[0]
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index bf445256b6..696656e840 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -25,10 +25,9 @@ flattened = tf.keras.layers.Flatten()(c1)
logits = tf.keras.layers.Dense(10, activation='softmax')(flattened)
model = tf.keras.Model(inputs=[image], outputs=[logits])
-strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
-model = keras_support.tpu_model(model,
- strategy=strategy,
- tpu_name_or_address=tpu_name)
+resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
+strategy = keras_support.TPUDistributionStrategy(resolver)
+model = keras_support.tpu_model(model, strategy=strategy)
# Only TF optimizers are currently supported.
model.compile(optimizer=tf.train.AdamOptimizer(), ...)
@@ -47,12 +46,12 @@ from __future__ import print_function
import abc
import collections
-import contextlib
import re
import sys
import time
import numpy as np
+import six
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.framework.python.framework import experimental
@@ -69,6 +68,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -90,34 +90,34 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
-_SESSIONS = {}
-
-
-def tpu_session(cluster_resolver):
+def setup_tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
- global _SESSIONS
master = cluster_resolver.master()
- if master not in _SESSIONS:
- cluster_spec = cluster_resolver.cluster_spec()
- config = config_pb2.ConfigProto(isolate_session_state=True)
- if cluster_spec:
- config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- logging.info('Connecting to: %s', master)
- graph = ops.Graph()
- session = tf_session.Session(graph=graph, target=master, config=config)
- with graph.as_default():
- session.run(tpu.initialize_system())
+ # Use the existing session if we're already connected to this TPU
+ if (K.get_session()._target == master and
+ getattr(K.get_session(), '_tpu_initialized', None)):
+ return
+
+ cluster_spec = cluster_resolver.cluster_spec()
+ config = config_pb2.ConfigProto(isolate_session_state=True)
+ if cluster_spec:
+ config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- _SESSIONS[master] = session
- return _SESSIONS[master]
+ logging.info('Initialize')
+ tpu_session = tf_session.Session(target=master, config=config)
+ tpu_session.run(tpu.initialize_system())
+ tpu_session._tpu_initialized = True
+ # N.B. We have to call `K.set_session()` AND set our session as the
+ # TF default. `K.get_session()` surprisingly does not return the value
+ # supplied by K.set_session otherwise.
+ K.set_session(tpu_session)
-def reset_tpu_sessions():
- _SESSIONS.clear()
try:
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
@@ -134,9 +134,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
- master,
- cluster_def=cluster_def,
- query_topology=False))
+ master, cluster_def=cluster_def, query_topology=False))
return tpu_system_metadata
@@ -157,6 +155,8 @@ class TPUDistributionStrategy(object):
replication, typically using all avaiable TPU cores. If overwrites as
`True`, force the model replication using single core, i.e., no
replication.
+ Raises:
+ Exception: No TPU Found on the given worker.
"""
if tpu_cluster_resolver is None:
@@ -172,7 +172,8 @@ class TPUDistributionStrategy(object):
for device in metadata.devices:
if 'TPU:0' in device.name:
self._worker_name = worker_re.search(device.name).group(1)
- break
+ return
+ raise Exception('No TPU found on given worker.')
def _make_assignment_for_model(self, cpu_model):
"""Makes a `TPUAssignment` for the passed in `cpu_model`."""
@@ -183,8 +184,7 @@ class TPUDistributionStrategy(object):
'Degrading to a single core.')
num_cores = 1
- return TPUAssignment(
- worker_name=self._worker_name, num_cores=num_cores)
+ return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores)
class TPUAssignment(object):
@@ -230,6 +230,39 @@ class TPUEmbedding(embeddings.Embedding):
return math_ops.tensordot(inputs, self.embeddings, 1)
+def _cross_replica_concat(tensor, core_id, num_cores, name):
+ """Concatenate `tensor` across cores.
+
+ Args:
+ tensor: The tensor to be concatenated. Must be [int32 and float32].
+ core_id: Tensor indicating the current TPU core.
+ num_cores: Python int. The total number of TPU cores in the system.
+ name: The string name to print for debugging.
+
+ Returns:
+ The same concatenated Tensor on each core.
+ """
+
+ input_dtype = tensor.dtype
+ if input_dtype not in [dtypes.float32, dtypes.int32]:
+ raise TypeError('For model replication, only (float32 and int32) is '
+ 'supported for model outputs and targets. Got {} for '
+ '{}.'.format(input_dtype, name))
+
+ batch_size = tensor.shape[0]
+ mask = math_ops.to_float(math_ops.equal(range(num_cores), core_id))
+ mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims)
+ result = mask * math_ops.to_float(tensor)
+ local_tensor_with_holes = array_ops.reshape(result,
+ [-1] + result.shape.as_list()[2:])
+ concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes)
+ concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:]))
+
+ if concat_tensor != input_dtype:
+ concat_tensor = math_ops.cast(concat_tensor, input_dtype)
+ return concat_tensor
+
+
class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
"""An optimizer that averages gradients across TPU shards."""
@@ -247,9 +280,9 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
super(KerasCrossShardOptimizer, self).__init__()
self._name = name
self._opt = opt
+ logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights)
def get_updates(self, loss, params):
- logging.info('Get updates: %s', loss)
self._opt.get_gradients = self.get_gradients
return self._opt.get_updates(loss, params)
@@ -258,17 +291,15 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
- def set_weights(self, weights):
- # TODO(power): Figure out whether we really need this given there is no
- # caller for this API yet.
- self._opt.set_weights()
-
def get_weights(self):
return self._opt.get_weights()
- @property
- def lr(self):
- return self._opt.lr
+ def get_config(self):
+ return self._opt.get_config()
+
+ # Defer remaining operations to the underlying optimizer
+ def __getattr__(self, key):
+ return getattr(self._opt, key)
class TPUModelOp(
@@ -294,14 +325,22 @@ def _replicated_optimizer(opt):
return KerasCrossShardOptimizer(opt)
-def clone_metrics(metrics):
- """Returns a copy of metrics. A copy is created for stateful metrics."""
- if metrics is None:
- return None
- return [
- m.__class__.from_config(m.get_config())
- if isinstance(m, metrics_module.Metric) else m for m in metrics
- ]
+def _clone_optimizer(optimizer, config=None):
+ """Returns a cloned optimizer with the provided optimizer.config or config."""
+ if not isinstance(optimizer, keras_optimizers.Optimizer):
+ # In the first call to tpu_model(model), Keras may not have wrapped the TF
+ # optimizer in the TFOptimizer helper, e.g., the given model isn't compiled
+ # or optimizer isn't set, and later generated tpu_model compiles with a TF
+ # optimizer.
+ return optimizer
+
+ if isinstance(optimizer, keras_optimizers.TFOptimizer):
+ return keras_optimizers.TFOptimizer(optimizer.optimizer)
+
+ if config is None:
+ config = optimizer.get_config()
+ logging.info('Cloning %s %s', optimizer.__class__.__name__, config)
+ return optimizer.__class__.from_config(config)
class TPURewriteContext(object):
@@ -392,6 +431,7 @@ class TPURewriteContext(object):
return (r, q)
else:
raise ValueError('Invalid shape passed to qr: %s' % input_shape)
+
gen_linalg_ops.qr = qr
ops.name_scope = _name_scope
@@ -407,9 +447,9 @@ class TPURewriteContext(object):
gen_linalg_ops.qr = self._default_qr
-class SizedInfeed(collections.namedtuple('SizedInfeed',
- ['sharded_infeed_tensors',
- 'infeed_ops'])):
+class SizedInfeed(
+ collections.namedtuple('SizedInfeed',
+ ['sharded_infeed_tensors', 'infeed_ops'])):
"""Represents an instantiation of the infeed ops for a concrete input shape.
sharded_infeed_tensors: A data structure of Tensors used to represent the
@@ -595,12 +635,13 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_tensors, [spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
- return SizedInfeed(infeed_ops=infeed_op,
- sharded_infeed_tensors=shard_infeed_tensors)
+ return SizedInfeed(
+ infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors)
class TPUDatasetInfeedManager(TPUInfeedManager):
"""Manages infeed for a `tf.data.Dataset` into a TPU computation.
+
"""
class DatasetInfeedInstance(TPUInfeedInstance):
@@ -618,18 +659,17 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
return {}
# pylint: disable=redefined-outer-name
- def __init__(self, dataset, tpu_assignment, tpu_session):
+ def __init__(self, dataset, tpu_assignment, mode):
"""Constructs a TPUDatasetInfeedManager.
- Must be called within a `KerasTPUModel.tpu_session` context!
-
Args:
dataset: A `tf.data.Dataset` to infeed.
tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
- tpu_session: The `tf.Session` object used for running the TPU model.
+ mode: ModeKeys enum.
"""
self._verify_dataset_shape(dataset)
+
self._dataset = dataset
self._tpu_assignment = tpu_assignment
dummy_x_shape = dataset.output_shapes[0].as_list()
@@ -637,7 +677,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
dummy_y_shape = dataset.output_shapes[1].as_list()
dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset.make_initializable_iterator()
- tpu_session.run(self._iterator.initializer)
+ K.get_session().run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
@@ -650,10 +690,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
- self._dummy_x = np.zeros(dummy_x_shape,
- dtype=dataset.output_types[0].as_numpy_dtype)
- self._dummy_y = np.zeros(dummy_y_shape,
- dtype=dataset.output_types[1].as_numpy_dtype)
+ self._dummy_x = np.zeros(
+ dummy_x_shape, dtype=dataset.output_types[0].as_numpy_dtype)
+ self._dummy_y = np.zeros(
+ dummy_y_shape, dtype=dataset.output_types[1].as_numpy_dtype)
input_specs = []
if isinstance(self._iterator.output_shapes, tuple):
@@ -669,6 +709,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
self._iterator.output_types)
input_specs.append(spec)
+ # Pre-process the inputs and get_next_ops before caching.
+ input_specs, self._get_next_ops = (
+ _inject_tpu_inputs_for_dataset(
+ tpu_assignment, mode, input_specs, self._get_next_ops))
self._infeed_instance = self.DatasetInfeedInstance(input_specs)
def _verify_dataset_shape(self, dataset):
@@ -680,9 +724,8 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
raise ValueError('The dataset must return a tuple of tf.Tensors, '
'instead it returns: %s' % dataset.output_classes)
if len(dataset.output_classes) != 2:
- raise ValueError(
- 'The dataset must return a 2-element tuple, got '
- '%s output classes instead.' % (dataset.output_classes,))
+ raise ValueError('The dataset must return a 2-element tuple, got '
+ '%s output classes instead.' % (dataset.output_classes,))
for i, cls in enumerate(dataset.output_classes):
if cls != ops.Tensor:
raise ValueError('The dataset returned a non-Tensor type (%s) at '
@@ -691,8 +734,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
if not shape:
raise ValueError('The dataset returns a scalar tensor in '
'tuple index %d. Did you forget to batch? '
- '(Output shapes: %s).' % (i,
- dataset.output_shapes))
+ '(Output shapes: %s).' % (i, dataset.output_shapes))
for j, dim in enumerate(shape):
if dim.value is None:
if j == 0:
@@ -732,8 +774,72 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
[spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
- return SizedInfeed(infeed_ops=infeed_ops,
- sharded_infeed_tensors=shard_infeed_tensors)
+ return SizedInfeed(
+ infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors)
+
+
+def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
+ input_specs, get_next_ops):
+ """Append core information to the set of dataset inputs."""
+ # This is used during compilation to identify the current TPU core and enable
+ # concatenation operations across cores.
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return input_specs, get_next_ops
+
+ # Dataset inputs operate on per core basis.
+ per_core_batch_size = input_specs[0].shape.as_list()[0]
+
+ # Insert, at head, the tensor for core_id.
+ assert len(get_next_ops) == tpu_assignment.num_towers
+ for i in range(tpu_assignment.num_towers):
+ core_id_constant = constant_op.constant(
+ np.array([i] * per_core_batch_size).astype('int32'),
+ dtype=dtypes.int32,
+ name='cord_id_constant')
+ get_next_ops[i] = [core_id_constant] + list(get_next_ops[i])
+
+ # Insert the input spec at head also.
+ input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32)
+ ] + input_specs
+
+ return input_specs, get_next_ops
+
+
+def _inject_tpu_inputs_for_infeed(tpu_assignment, mode,
+ core_id_place_holder, input_tensors, inputs):
+ """Append core information to the set of inputs."""
+ # This is used during compilation to identify the current TPU core and enable
+ # concatenation operations across cores.
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return input_tensors, inputs
+
+ # Puts a place holder in input spec.
+ input_tensors = [core_id_place_holder] + input_tensors
+
+ # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the
+ # core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id
+ # (duplicated).
+ num_cores = tpu_assignment.num_towers
+ per_core_batch_size = inputs[0].shape[0] // num_cores
+ core_ids = np.arange(num_cores).repeat(per_core_batch_size)
+ inputs = [core_ids] + inputs
+ return input_tensors, inputs
+
+
+def _read_tpu_coreid_from_infeed(mode, infeed_tensors):
+ """Popping out the core ids from infeed."""
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return None, infeed_tensors
+
+ if len(infeed_tensors) <= 1:
+ raise RuntimeError(
+ 'The infeed tensors on TPU core has only {} tensors. '
+ 'This is not expected. Please report a bug.\nTensors: {}'.format(
+ len(infeed_tensors), infeed_tensors))
+
+ core_id = infeed_tensors[0][0] # Pop out the scalar version.
+ rest = infeed_tensors[1:]
+ return core_id, rest
class TPUFunction(object):
@@ -754,12 +860,11 @@ class TPUFunction(object):
self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
-
- # Copy optimizer configuration. This is done prior to `_specialize_model`
- # as the configuration may require evaluating variables in the CPU session.
- self._optimizer_config = None
- if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
- self._optimizer_config = self.model.optimizer.get_config()
+ self._cloned_optimizer = None
+ # Create a placeholder for the TPU core ID. Cache the placeholder to avoid
+ # modifying the graph for every batch.
+ self._core_id_place_holder = array_ops.placeholder(
+ dtype=dtypes.int32, shape=[1], name='core_id')
def _specialize_model(self, input_specs, infeed_manager):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
@@ -786,6 +891,10 @@ class TPUFunction(object):
shapes=[spec.shape for spec in input_specs],
name='infeed-%s' % self.execution_mode)
+ core_id, infeed_tensors = (
+ _read_tpu_coreid_from_infeed(
+ mode=self.execution_mode, infeed_tensors=infeed_tensors))
+
assert len(infeed_tensors) == len(infeed_layers), (
'Infeed inputs did not match model: %s vs %s' % (infeed_layers,
infeed_tensors))
@@ -801,35 +910,65 @@ class TPUFunction(object):
tpu_targets.append(tensor)
# Clone our CPU model, running within the TPU device context.
+ #
+ # We use the id of the original model as a key to avoid weight collisions
+ # (if a user re-runs the same model multiple times, in e.g. Colab).
with TPURewriteContext(tpu_input_map):
- with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
+ with variable_scope.variable_scope('tpu_%s' % id(self.model)):
with keras_tpu_variables.replicated_scope(
self._tpu_assignment.num_towers):
- self._cloned_model = models.clone_model(self.model)
+ if not self._cloned_optimizer:
+ self._cloned_optimizer = _clone_optimizer(
+ self.model.cpu_optimizer)
- # Create a copy of the optimizer for this graph.
- if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
- cloned_optimizer = keras_optimizers.TFOptimizer(
- self.model.optimizer.optimizer)
- else:
- logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__,
- self._optimizer_config)
- cloned_optimizer = self.model.optimizer.__class__.from_config(
- self._optimizer_config)
+ self._cloned_model = models.clone_model(self.model)
- if is_training or is_test:
- self._cloned_model.compile(
- optimizer=_replicated_optimizer(cloned_optimizer),
- loss=self.model.loss,
- loss_weights=self.model.loss_weights,
- metrics=clone_metrics(self.model.metrics),
- weighted_metrics=clone_metrics(self.model.weighted_metrics),
- target_tensors=tpu_targets,
- )
+ # When running on more than one core, concatenate outputs at the end
+ # of processing. In backprop stage, the gradients will be
+ # calculdated according to the local inputs as gradient of
+ # cross-replica-concat being zero for any outputs other than those
+ # from mlocal core so the loss calculation is identical.
+ num_towers = self.model._tpu_assignment.num_towers
+ if num_towers > 1 and (is_training or is_test):
+ new_outputs = [
+ _cross_replica_concat(
+ o, core_id, num_towers,
+ name='model output ({})'.format(o.name))
+ for o in self._cloned_model.outputs
+ ]
+ self._cloned_model.outputs = new_outputs
+ tpu_targets = [
+ _cross_replica_concat(
+ tensor,
+ core_id,
+ num_towers,
+ name='model target ({})'.format(tensor.name))
+ for tensor in tpu_targets
+ ]
+
+ if is_training or is_test:
+ self._cloned_model.compile(
+ optimizer=_replicated_optimizer(self._cloned_optimizer),
+ loss=self.model.loss,
+ loss_weights=self.model.loss_weights,
+ metrics=metrics_module.clone_metrics(self.model.metrics),
+ weighted_metrics=metrics_module.clone_metrics(
+ self.model.weighted_metrics),
+ target_tensors=tpu_targets,
+ )
# Compute our outfeed depending on the execution mode
if is_training:
- self._cloned_model._make_train_function()
+ if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer):
+ # For Keras optimizer, we try to place the variable weights on the TPU
+ # device. Keras creates optimizer variables (e.g. momentum values for
+ # the Momentum optimizer) when _make_train_function is invoked.
+ with keras_tpu_variables.replicated_variable_for_optimizer(
+ self._tpu_assignment.num_towers):
+ self._cloned_model._make_train_function()
+ else:
+ self._cloned_model._make_train_function()
+
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
for tensor in self._cloned_model.train_function.outputs
@@ -934,6 +1073,7 @@ class TPUFunction(object):
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
return mgr
+
return TPUNumpyInfeedManager(self.model._tpu_assignment)
def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
@@ -958,13 +1098,14 @@ class TPUFunction(object):
# unique input shape.
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
if shape_key not in self._compilation_cache:
- with self.model.tpu_session():
- logging.info('New input shapes; (re-)compiling: mode=%s, %s',
- self.execution_mode, input_specs)
- new_tpu_model_ops = self._specialize_model(input_specs,
- infeed_manager)
- self._compilation_cache[shape_key] = new_tpu_model_ops
- self._test_model_compiles(new_tpu_model_ops)
+ logging.info(
+ 'New input shapes; (re-)compiling: mode=%s '
+ '(# of cores %d), %s', self.execution_mode,
+ self._tpu_assignment.num_towers, input_specs)
+ new_tpu_model_ops = self._specialize_model(input_specs,
+ infeed_manager)
+ self._compilation_cache[shape_key] = new_tpu_model_ops
+ self._test_model_compiles(new_tpu_model_ops)
return self._compilation_cache[shape_key]
@@ -999,6 +1140,10 @@ class TPUFunction(object):
input_tensors = self.model._feed_inputs
inputs = inputs[:len(input_tensors)]
+ input_tensors, inputs = (
+ _inject_tpu_inputs_for_infeed(
+ self._tpu_assignment, self.execution_mode,
+ self._core_id_place_holder, input_tensors, inputs))
return input_tensors, inputs
def _process_outputs(self, outfeed_outputs):
@@ -1059,11 +1204,10 @@ class TPUFunction(object):
# Initialize our TPU weights on the first compile.
self.model._initialize_weights(self._cloned_model)
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
- tpu_model_ops.outfeed_op
- ], infeed_dict)
+ _, _, outfeed_outputs = K.get_session().run([
+ tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
+ tpu_model_ops.outfeed_op
+ ], infeed_dict)
return self._process_outputs(outfeed_outputs)
def pipeline_run(self, cur_step_inputs, next_step_inputs):
@@ -1095,8 +1239,8 @@ class TPUFunction(object):
next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
- if (next_step_infeed_manager is not None
- and cur_step_infeed_manager is not None):
+ if (next_step_infeed_manager is not None and
+ cur_step_infeed_manager is not None):
assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
next_input_tensors, next_step_inputs = (
@@ -1121,14 +1265,12 @@ class TPUFunction(object):
infeed_dict = None
if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
- cur_input_specs = cur_infeed_instance.make_input_specs(
- cur_input_tensors)
+ cur_input_specs = cur_infeed_instance.make_input_specs(cur_input_tensors)
cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
cur_input_specs, cur_step_infeed_manager)
- if (next_infeed_instance
- and next_input_tensors
- and next_step_infeed_manager):
+ if (next_infeed_instance and next_input_tensors and
+ next_step_infeed_manager):
next_input_specs = next_infeed_instance.make_input_specs(
next_input_tensors)
next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
@@ -1139,26 +1281,24 @@ class TPUFunction(object):
self.model._initialize_weights(self._cloned_model)
if next_tpu_model_ops and cur_tpu_model_ops:
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
- cur_tpu_model_ops.outfeed_op
- ], infeed_dict)
+ _, _, outfeed_outputs = K.get_session().run([
+ next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
+ cur_tpu_model_ops.outfeed_op
+ ], infeed_dict)
return self._process_outputs(outfeed_outputs)
+
if cur_tpu_model_ops:
- with self.model.tpu_session() as session:
- _, outfeed_outputs = session.run([
- cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
+ _, outfeed_outputs = K.get_session().run(
+ [cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
return self._process_outputs(outfeed_outputs)
+
if next_tpu_model_ops:
- with self.model.tpu_session() as session:
- session.run(next_tpu_model_ops.infeed_op, infeed_dict)
+ K.get_session().run(next_tpu_model_ops.infeed_op, infeed_dict)
return None
raise RuntimeError('Internal error: both current & next tpu_model_ops '
'were None')
-
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
@@ -1185,8 +1325,6 @@ class KerasTPUModel(models.Model):
self._tpu_model = None
self._tpu_weights_initialized = False
- self._session = tpu_session(cluster_resolver)
-
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
if self._cpu_model.optimizer:
@@ -1223,15 +1361,16 @@ class KerasTPUModel(models.Model):
if target_tensors:
raise ValueError('target_tensors is not supported for TPU execution.')
+ self._cpu_model.compile(
+ _clone_optimizer(optimizer), loss,
+ metrics_module.clone_metrics(metrics), loss_weights, sample_weight_mode,
+ metrics_module.clone_metrics(weighted_metrics), target_tensors,
+ **kwargs)
+
super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights,
sample_weight_mode, weighted_metrics,
target_tensors, **kwargs)
- if not self._cpu_model.optimizer:
- self._cpu_model.compile(optimizer, loss, metrics, loss_weights,
- sample_weight_mode, weighted_metrics,
- target_tensors, **kwargs)
-
def fit(self,
x=None,
y=None,
@@ -1264,8 +1403,8 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess,\
- ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
+ with ops.device('/job:%s/device:CPU:0' %
+ self._tpu_assignment.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@@ -1273,8 +1412,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -1290,26 +1429,24 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(validation_data):
- with self.tpu_session() as sess:
- dataset = validation_data()
- if validation_steps is None:
- raise ValueError('When using tf.data as validation for a model, you '
- 'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- val_x = infeed_manager.dummy_x
- val_y = infeed_manager.dummy_y
- infeed_managers.append((val_x, infeed_manager))
- validation_data = (val_x, val_y)
+ dataset = validation_data()
+ if validation_steps is None:
+ raise ValueError('When using tf.data as validation for a model, you '
+ 'should specify the validation_steps argument.')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ val_x = infeed_manager.dummy_x
+ val_y = infeed_manager.dummy_y
+ infeed_managers.append((val_x, infeed_manager))
+ validation_data = (val_x, val_y)
self._numpy_to_infeed_manager_list = infeed_managers
try:
if not kwargs.get('_pipeline', True):
- logging.info(
- 'Running non-pipelined training loop (`_pipeline=%s`).',
- kwargs['_pipeline'])
+ logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
kwargs.pop('_pipeline')
return super(KerasTPUModel, self).fit(
x,
@@ -1365,50 +1502,32 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess:
- dataset = x()
- if steps is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
self._numpy_to_infeed_manager_list = infeed_managers
try:
- return super(KerasTPUModel, self).evaluate(
- x,
- y,
- batch_size,
- verbose,
- sample_weight,
- steps)
+ return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
+ sample_weight, steps)
finally:
self._numpy_to_infeed_manager_list = []
- def _pipeline_fit(self,
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs):
+ def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
+ validation_split, validation_data, shuffle, class_weight,
+ sample_weight, initial_epoch, steps_per_epoch,
+ validation_steps, **kwargs):
# Similar to super.fit(...), but modified to support software pipelining.
# Backwards compatibility
@@ -1436,13 +1555,8 @@ class KerasTPUModel(models.Model):
# Prepare validation data
val_x, val_y, val_sample_weights = self._prepare_validation_data(
- validation_data,
- validation_split,
- validation_steps,
- x,
- y,
- sample_weights,
- batch_size)
+ validation_data, validation_split, validation_steps, x, y,
+ sample_weights, batch_size)
return self._pipeline_fit_loop(
x,
y,
@@ -1615,8 +1729,8 @@ class KerasTPUModel(models.Model):
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()
- outs = f.pipeline_run(cur_step_inputs=ins_last_batch,
- next_step_inputs=ins_batch)
+ outs = f.pipeline_run(
+ cur_step_inputs=ins_last_batch, next_step_inputs=ins_batch)
ins_last_batch = ins_batch
if batch_index == 0:
@@ -1688,8 +1802,8 @@ class KerasTPUModel(models.Model):
next_step_inputs = ins
else:
next_step_inputs = None
- outs = f.pipeline_run(cur_step_inputs=ins,
- next_step_inputs=next_step_inputs)
+ outs = f.pipeline_run(
+ cur_step_inputs=ins, next_step_inputs=next_step_inputs)
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your '
@@ -1709,25 +1823,21 @@ class KerasTPUModel(models.Model):
break
if do_validation:
- val_outs = training_arrays.test_loop(self,
- val_inputs,
- val_targets,
- sample_weights=val_sample_weights,
- steps=validation_steps,
- verbose=0)
+ val_outs = training_arrays.test_loop(
+ self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for l, o in zip(self.metrics_names, val_outs):
epoch_logs['val_' + l] = o
- def _prepare_validation_data(self,
- validation_data,
- validation_split,
- validation_steps,
- x,
- y,
- sample_weights,
+ def _prepare_validation_data(self, validation_data, validation_split,
+ validation_steps, x, y, sample_weights,
batch_size):
"""Prepares the validation dataset.
@@ -1785,8 +1895,10 @@ class KerasTPUModel(models.Model):
x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
- sample_weights, val_sample_weights = (slice_arrays(
- sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
+ sample_weights, val_sample_weights = (
+ slice_arrays(sample_weights, 0, split_at),
+ slice_arrays(sample_weights, split_at)
+ )
elif validation_steps:
val_x = []
val_y = []
@@ -1798,11 +1910,20 @@ class KerasTPUModel(models.Model):
return val_x, val_y, val_sample_weights
+ @property
+ def optimizer(self):
+ if self._tpu_model:
+ return self._tpu_model.optimizer
+ return self._cpu_model.optimizer
+
+ @optimizer.setter
+ def optimizer(self, optimizer):
+ self._optimizer = optimizer
+
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
- self,
- model_fn_lib.ModeKeys.TRAIN,
+ self, model_fn_lib.ModeKeys.TRAIN,
tpu_assignment=self._tpu_assignment)
return self.train_function
@@ -1837,18 +1958,48 @@ class KerasTPUModel(models.Model):
self._tpu_weights_initialized = True
weights = self._cpu_model.get_weights()
- with self.tpu_session():
- logging.info('Setting weights on TPU model.')
- cloned_model.set_weights(weights)
+
+ if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
+ cpu_optimizer_config = {}
+ else:
+ cpu_optimizer_config = self.cpu_optimizer.get_config()
+
+ logging.info('Setting weights on TPU model.')
+ cloned_model.set_weights(weights)
+ for k, v in six.iteritems(cpu_optimizer_config):
+ opt_var = getattr(self._tpu_model.optimizer, k)
+ if isinstance(opt_var, variables.Variable):
+ logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var))
+ K.get_session().run(opt_var.assign(v))
+ else:
+ logging.warning('Cannot update non-variable config: %s', k)
+
+ @property
+ def cpu_optimizer(self):
+ return self._cpu_model.optimizer
def sync_to_cpu(self):
"""Copy weights from the CPU, returning a synchronized CPU model."""
- if self._tpu_weights_initialized:
- with self.tpu_session():
- logging.info('Copying TPU weights to the CPU')
- tpu_weights = self._tpu_model.get_weights()
+ if not self._tpu_weights_initialized:
+ return self._cpu_model
+
+ logging.info('Copying TPU weights to the CPU')
+ tpu_weights = self._tpu_model.get_weights()
- self._cpu_model.set_weights(tpu_weights)
+ # TFOptimizers have no configurable options
+ if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
+ tpu_optimizer_config = {}
+ else:
+ tpu_optimizer_config = self._tpu_model.optimizer.get_config()
+
+ self._cpu_model.set_weights(tpu_weights)
+ for k, v in six.iteritems(tpu_optimizer_config):
+ logging.info('TPU -> CPU %s: %s', k, v)
+ opt_var = getattr(self.cpu_optimizer, k)
+ if isinstance(opt_var, variables.Variable):
+ K.get_session().run(opt_var.assign(v))
+ else:
+ logging.warning('Cannot update non-variable config: %s', k)
return self._cpu_model
@@ -1869,26 +2020,6 @@ class KerasTPUModel(models.Model):
self._cpu_model.set_weights(weights)
self._tpu_weights_initialized = False
- @contextlib.contextmanager
- def tpu_session(self):
- """Yields a TPU session and sets it as the default Keras session."""
- with self._session.graph.as_default():
- default_session = K.get_session()
- # N.B. We have to call `K.set_session()` AND set our session as the
- # TF default. `K.get_session()` surprisingly does not return the value
- # supplied by K.set_session otherwise.
- K.set_session(self._session)
- with self._session.as_default():
- yield self._session
- K.set_session(default_session)
-
- def shutdown(self):
- # TODO(b/111364423): Actually shut down the system.
- logging.info('Skipping shutting down TPU system.')
- # with self.tpu_session() as session:
- # session.run(tpu.shutdown_system())
- self._session.close()
-
# pylint: disable=bad-continuation
def _validate_shapes(model):
@@ -1929,7 +2060,9 @@ Output shape: %(output_shape)s
@experimental
def tpu_model(model, strategy=None):
- """Copy `model` along with weights to the TPU. Returns a TPU model.
+ """Copy `model` along with weights to the TPU.
+
+ Returns a TPU model.
Usage:
```
@@ -1944,21 +2077,16 @@ def tpu_model(model, strategy=None):
model.compile(
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
...)
- model.shutdown()
```
Args:
- model: A `KerasTPUModel`.
+ model: A `tf.keras.Model` instance.
strategy: `TPUDistributionStrategy`. The strategy to use for replicating
- model across multiple TPU cores.
+ model across multiple TPU cores.
Returns:
A new `KerasTPUModel` instance.
"""
- # Force initialization of the CPU model.
- model.get_weights()
- model.reset_states()
-
_validate_shapes(model)
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
@@ -1972,4 +2100,34 @@ def tpu_model(model, strategy=None):
'`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
'Got: {}'.format(type(strategy)))
- return KerasTPUModel(cpu_model=model, strategy=strategy)
+ # If the model has already been initialized, grab the optimizer configuration
+ # and model weights before entering the TPU session.
+ if model.optimizer:
+ if (isinstance(model.optimizer, keras_optimizers.Optimizer) and not
+ isinstance(model.optimizer, keras_optimizers.TFOptimizer)):
+ optimizer_config = model.optimizer.get_config()
+ else:
+ optimizer_config = None
+ model_weights = model.get_weights()
+ else:
+ model_weights = None
+
+ setup_tpu_session(strategy._tpu_cluster_resolver)
+
+ # Force initialization of the CPU model in the TPU session.
+ cpu_model = models.clone_model(model)
+ if model.optimizer:
+ cpu_model.compile(
+ _clone_optimizer(model.optimizer, optimizer_config),
+ model.loss,
+ metrics_module.clone_metrics(model.metrics),
+ model.loss_weights,
+ model.sample_weight_mode,
+ metrics_module.clone_metrics(model.weighted_metrics),
+ )
+
+ if model_weights:
+ cpu_model.set_weights(model_weights)
+ cpu_model.reset_states()
+
+ return KerasTPUModel(cpu_model=cpu_model, strategy=strategy)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index 170977d8ab..004b1012e5 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -25,10 +25,15 @@ from __future__ import print_function
import contextlib
+import numpy as np
+
from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -73,7 +78,7 @@ class ReplicatedVariable(object):
if tpu_context is None:
return self._primary_var.handle
- return tpu_context.get_replicated_var_handle(self)
+ return tpu_context.get_replicated_var_handle(self._name, self._vars)
@contextlib.contextmanager
def _assign_dependencies(self):
@@ -285,3 +290,51 @@ def replicated_scope(num_replicas):
return variable_scope.variable_scope(
"", custom_getter=_replicated_variable_getter)
+
+
+@contextlib.contextmanager
+def replicated_variable_for_optimizer(num_replicas):
+ """Context manager for optimizer weights. Overrides K.variable."""
+ if num_replicas == 1:
+ yield
+ return
+
+ try:
+ old_v = backend.variable
+
+ def opt_variable(value, dtype=None, name=None, constraint=None):
+ """Instantiates a variable and returns it."""
+ if dtype is None:
+ dtype = backend.floatx()
+
+ variables = []
+ for i in range(num_replicas):
+ # Keras holds the variables in optimizer class instance , so the name
+ # does not matter here. ResourceVariable constructor will find a unique
+ # name (including name=None) for each replica.
+ with ops.device("device:TPU:{}".format(i)):
+ v = resource_variable_ops.ResourceVariable(
+ value,
+ dtype=dtypes_module.as_dtype(dtype),
+ name=name,
+ constraint=constraint)
+ variables.append(v)
+ name = "replicate_{}_{}".format("variable" if name is None else name,
+ ops.uid())
+ v = ReplicatedVariable(name, variables)
+
+ # pylint: disable=protected-access
+
+ if isinstance(value, np.ndarray):
+ v._keras_shape = value.shape
+ elif hasattr(value, "shape"):
+ v._keras_shape = backend.int_shape(value)
+ v._uses_learning_phase = False
+ backend.track_variable(v)
+ return v
+
+ backend.variable = opt_variable
+ yield
+
+ finally:
+ backend.variable = old_v
diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py
index 3e91e2df32..05264f5a46 100644
--- a/tensorflow/contrib/tpu/python/tpu/session_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/session_support.py
@@ -41,6 +41,29 @@ class CoordinatorShutdownException(Exception):
pass
+def _make_heartbeat_op(session, device, request_ph):
+ """Return a heartbeat op or None if heartbeats are not supported by device."""
+ try:
+ # Test if we can connect in a isolated graph + session
+ with ops.Graph().as_default():
+ with session_lib.Session(target=session.sess_str) as temp_session:
+ with ops.device(device):
+ heartbeat_op = tpu_ops.worker_heartbeat('')
+ options = config_pb2.RunOptions(timeout_in_ms=5000)
+ temp_session.run(heartbeat_op, options=options)
+ except errors.InvalidArgumentError as _:
+ logging.warning('Error running heartbeat on %s', device)
+ return None
+ except errors.DeadlineExceededError as _:
+ logging.warning('Timeout connecting to %s when testing heartbeat', device)
+ return None
+
+ # If we successfully connected and pinged the worker, go ahead and construct
+ # the operation.
+ with ops.device(device):
+ return tpu_ops.worker_heartbeat(request_ph)
+
+
class WorkerHeartbeatManager(object):
"""Manages the status/heartbeat monitor for a set of workers."""
@@ -72,30 +95,27 @@ class WorkerHeartbeatManager(object):
name='worker_heartbeat_request', dtype=dtypes.string)
heartbeat_ops = []
+ kept_devices = []
for device in devices:
- with ops.device(device):
- heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
+ heartbeat_op = _make_heartbeat_op(session, device, request_placeholder)
+ if heartbeat_op is not None:
+ kept_devices.append(device)
+ heartbeat_ops.append(heartbeat_op)
+ else:
+ logging.warning('Heartbeat support not available for %s', device)
- return WorkerHeartbeatManager(session, devices, heartbeat_ops,
+ return WorkerHeartbeatManager(session, kept_devices, heartbeat_ops,
request_placeholder)
- def heartbeat_supported(self):
- """Returns True if heartbeat operations are supported on all workers."""
- try:
- # Send ping to verify worker has heartbeat support.
- self.ping()
- return True
- except errors.InvalidArgumentError as _:
- return False
+ def num_workers(self):
+ return len(self._devices)
def configure(self, message):
"""Configure heartbeat manager for all devices.
Args:
message: `event_pb2.WorkerHeartbeatRequest`
-
Returns: `None`
-
"""
logging.info('Configuring worker heartbeat: %s',
text_format.MessageToString(message))
@@ -155,7 +175,7 @@ class WorkerHeartbeatManager(object):
def all_worker_devices(session):
"""Return a list of devices for each worker in the system."""
devices = session.list_devices()
- return [device.name for device in devices if 'CPU' in device.name]
+ return [device.name for device in devices if ':CPU:' in device.name]
class WatchdogManager(threading.Thread):
@@ -184,7 +204,6 @@ class WatchdogManager(threading.Thread):
"""Initialize a watchdog manager.
Args:
-
session: Session connected to worker devices. A cloned session and graph
will be created for managing worker pings.
devices: Set of devices to monitor. If none, all workers will be
@@ -277,16 +296,14 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
target=training_session.sess_str, graph=self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
self._session, all_worker_devices(self._session))
- self._heartbeat_supported = self._workers.heartbeat_supported()
+ self._heartbeat_supported = self._workers.num_workers() > 0
if self._heartbeat_supported:
self._workers.configure(
event_pb2.WorkerHeartbeatRequest(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
else:
logging.warn(
- 'Worker heartbeats not supported by all workers. No failure '
- 'handling will be enabled.'
- )
+ 'No workers support hearbeats. Failure handling will be disabled.')
def saver(self):
if self._saver:
@@ -303,8 +320,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook):
logging.error(
'Multiple savers in the SAVERS collection. On-demand checkpointing '
'will be disabled. Pass an explicit `saver` to the constructor to '
- 'override this behavior.'
- )
+ 'override this behavior.')
return None
return savers[0]
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 712b02ff0d..11aaa1c66a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -155,19 +155,20 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._pivot = pivot
self._replicated_vars = {}
- def get_replicated_var_handle(self, var):
+ def get_replicated_var_handle(self, name, vars_):
"""Returns a variable handle for replicated TPU variable 'var'.
This is a method used by an experimental replicated variable implementation
and is not intended as a public API.
Args:
- var: The replicated TPU variable.
+ name: The common name of the variable.
+ vars_: The replicated TPU variables.
Returns:
The handle of the TPU replicated input node.
"""
- handle = self._replicated_vars.get(var)
+ handle = self._replicated_vars.get(name)
if handle is not None:
return handle
@@ -183,10 +184,10 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
saved_context = graph._get_control_flow_context()
graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input(
- [v.handle for v in var._vars], name=var.name + "/handle")
+ [v.handle for v in vars_], name=name + "/handle")
graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access
- self._replicated_vars[var] = handle
+ self._replicated_vars[name] = handle
return handle
def report_unsupported_operations(self):
@@ -661,6 +662,10 @@ def split_compile_and_replicate(computation,
# be less confusing to clients if they knowingly choose to use resource
# variables.
# Partitioned variables is not supported (b/112311320).
+ vscope = variable_scope.get_variable_scope()
+ saved_use_resource = vscope.use_resource
+ saved_custom_getter = vscope.custom_getter
+
def custom_getter(getter, name, *args, **kwargs):
"""Variables on TPU have a few restrictions."""
partitioner = kwargs["partitioner"]
@@ -671,12 +676,10 @@ def split_compile_and_replicate(computation,
"`partitioner` that is {} for variable {}. "
"Setting `partitioner` to `None`."
.format(partitioner, name))
- return getter(name, *args, **kwargs)
-
- vscope = variable_scope.get_variable_scope()
-
- saved_use_resource = vscope.use_resource
- saved_custom_getter = vscope.custom_getter
+ if saved_custom_getter is None:
+ return getter(name, *args, **kwargs)
+ else:
+ return saved_custom_getter(getter, name, *args, **kwargs)
vscope.set_use_resource(True)
vscope.set_custom_getter(custom_getter)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index b1a8a16d1e..7cfb6c38fa 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -118,6 +118,11 @@ class TPUContext(object):
return self._internal_ctx.num_hosts
@property
+ def current_host(self):
+ """The current host index for the TPU system."""
+ return self._invocation_index
+
+ @property
def num_of_replicas_per_host(self):
"""The number of replicas for each host."""
if self._internal_ctx.model_parallelism_enabled:
@@ -698,7 +703,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
config.tpu_config.num_cores_per_replica is None):
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
- 'Please fix as soon as possible (leaving num_shards as None.')
+ 'Please fix as soon as possible (leaving num_shards as None.)')
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 23c54511ca..3aa5b6efa1 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -231,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
- @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
+ `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
@@ -254,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
- summaries with @{tf.contrib.summary.create_file_writer}.
+ summaries with `tf.contrib.summary.create_file_writer`.
"""
def __new__(cls,
@@ -404,12 +404,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
self._feed_error = None
self._finished = False
+ self._should_initialize_tpu = True
def begin(self):
logging.info('TPU job name %s', self._master_job)
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_ops = [tpu.initialize_system(job=self._master_job)]
- self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ if self._should_initialize_tpu:
+ self._init_ops = [tpu.initialize_system(job=self._master_job)]
+ self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ else:
+ self._init_ops = []
+ self._finalize_ops = []
summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
self._init_ops.extend(summary_writer_init_ops)
@@ -421,10 +426,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
def _run_infeed(self, queue_ctx, session):
logging.info('Starting infeed thread controller.')
if self._initial_infeed_sleep_secs:
- logging.info('%s thread sleeping for %d seconds.', self._name,
+ logging.info('Infeed thread sleeping for %d seconds.',
self._initial_infeed_sleep_secs)
time.sleep(self._initial_infeed_sleep_secs)
- logging.info('%s thread starting after sleep', self._name)
+ logging.info('Infeed thread starting after sleep')
with self._rendezvous.catch_errors(source='infeed', session=session):
if self._run_infeed_loop_on_coordinator:
diff --git a/tensorflow/contrib/tpu/utils/BUILD b/tensorflow/contrib/tpu/utils/BUILD
new file mode 100644
index 0000000000..c27b737287
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/BUILD
@@ -0,0 +1,30 @@
+# Description: Utilities for TPU Operations
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "tpu_embedding_optimization_parameters_utils",
+ srcs = ["tpu_embedding_optimization_parameters_utils.cc"],
+ hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "@com_google_absl//absl/base",
+ ],
+)
+
+cc_library(
+ name = "tpu_embedding_output_layout_utils",
+ srcs = ["tpu_embedding_output_layout_utils.cc"],
+ hdrs = ["tpu_embedding_output_layout_utils.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc
new file mode 100644
index 0000000000..76cb5531cd
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc
@@ -0,0 +1,255 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace tpu {
+
+string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad:
+ return "Adagrad";
+ case OptimizationAlgorithm::kStochasticGradientDescent:
+ return "StochasticGradientDescent";
+ case OptimizationAlgorithm::kFtrl:
+ return "FTRL";
+ case OptimizationAlgorithm::kAdam:
+ return "ADAM";
+ case OptimizationAlgorithm::kMomentum:
+ return "Momentum";
+ case OptimizationAlgorithm::kRmsProp:
+ return "RMSProp";
+ case OptimizationAlgorithm::kCenteredRmsProp:
+ return "CenteredRMSProp";
+ case OptimizationAlgorithm::kMdlAdagradLight:
+ return "MDLAdagradLight";
+ case OptimizationAlgorithm::kAdadelta:
+ return "Adadelta";
+ case OptimizationAlgorithm::kProximalAdagrad:
+ return "ProximalAdagrad";
+ case OptimizationAlgorithm::PARAMETERS_NOT_SET:
+ return "*** Not set ***";
+ }
+}
+
+string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad:
+ return "Adagrad";
+ case OptimizationAlgorithm::kStochasticGradientDescent:
+ return "stochastic gradient descent";
+ case OptimizationAlgorithm::kFtrl:
+ return "FTRL";
+ case OptimizationAlgorithm::kAdam:
+ return "ADAM";
+ case OptimizationAlgorithm::kMomentum:
+ return "Momentum";
+ case OptimizationAlgorithm::kRmsProp:
+ return "RMSProp";
+ case OptimizationAlgorithm::kCenteredRmsProp:
+ return "centered RMSProp";
+ case OptimizationAlgorithm::kMdlAdagradLight:
+ return "MDL Adagrad Light";
+ case OptimizationAlgorithm::kAdadelta:
+ return "Adadelta";
+ case OptimizationAlgorithm::kProximalAdagrad:
+ return "proximal Adagrad";
+ case OptimizationAlgorithm::PARAMETERS_NOT_SET:
+ return "unknown (not specified)";
+ }
+}
+
+// Returns the number of optimization parameter vectors used by the optimization
+// algorithm, excluding the weights themselves and assuming no gradient
+// accumulation.
+Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int* count) {
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad:
+ *count = 1;
+ return Status::OK();
+ case OptimizationAlgorithm::kStochasticGradientDescent:
+ *count = 0;
+ return Status::OK();
+ case OptimizationAlgorithm::kFtrl:
+ *count = 2;
+ return Status::OK();
+ case OptimizationAlgorithm::kAdam:
+ *count = 2;
+ return Status::OK();
+ case OptimizationAlgorithm::kMomentum:
+ *count = 1;
+ return Status::OK();
+ case OptimizationAlgorithm::kRmsProp:
+ *count = 2;
+ return Status::OK();
+ case OptimizationAlgorithm::kCenteredRmsProp:
+ *count = 3;
+ return Status::OK();
+ case OptimizationAlgorithm::kMdlAdagradLight:
+ *count = 3;
+ return Status::OK();
+ case OptimizationAlgorithm::kAdadelta:
+ *count = 2;
+ return Status::OK();
+ case OptimizationAlgorithm::kProximalAdagrad:
+ *count = 1;
+ return Status::OK();
+ case OptimizationAlgorithm::PARAMETERS_NOT_SET:
+ return errors::InvalidArgument("No optimization algorithm specified");
+ }
+}
+
+Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
+ GradientAccumulationSupport* support) {
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad:
+ *support = GradientAccumulationSupport::kSupported;
+ return Status::OK();
+ case OptimizationAlgorithm::kStochasticGradientDescent:
+ *support = GradientAccumulationSupport::kUnnecessary;
+ return Status::OK();
+ default: {
+ int auxiliary_parameter_count;
+ TF_RETURN_IF_ERROR(
+ GetBaseAuxiliaryParameterCount(alg, &auxiliary_parameter_count));
+ *support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
+ ? GradientAccumulationSupport::kSupported
+ : GradientAccumulationSupport::kNotSupported;
+ return Status::OK();
+ }
+ }
+}
+namespace {
+// Make a normal state variable specification.
+StateVariableSpecification MakeStandardStateVariableSpecification(
+ const string& name) {
+ StateVariableSpecification result;
+ result.set_name(name);
+ result.mutable_user_defined();
+ return result;
+}
+} // namespace
+
+Status GetOptimizationAlgorithmStateVariables(
+ OptimizationAlgorithm alg, bool use_gradient_accumulation,
+ std::vector<StateVariableSpecification>* state_variables) {
+ // The first parameter set is always the weights themselves.
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("parameters"));
+ // The order of the returned parameters needs to match the offsets used by
+ // the algorithm implementations in test_util.cc and
+ // address_handler_program_creator.cc.
+ switch (alg) {
+ case OptimizationAlgorithm::kAdagrad: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ break;
+ }
+ case OptimizationAlgorithm::kStochasticGradientDescent: {
+ // None.
+ break;
+ }
+ case OptimizationAlgorithm::kFtrl: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("linears"));
+ break;
+ }
+ case OptimizationAlgorithm::kAdam: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("momenta"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("velocities"));
+ break;
+ }
+ case OptimizationAlgorithm::kMomentum: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("momenta"));
+ break;
+ }
+ case OptimizationAlgorithm::kRmsProp: {
+ state_variables->push_back(MakeStandardStateVariableSpecification("ms"));
+ state_variables->push_back(MakeStandardStateVariableSpecification("mom"));
+ break;
+ }
+ case OptimizationAlgorithm::kCenteredRmsProp: {
+ state_variables->push_back(MakeStandardStateVariableSpecification("ms"));
+ state_variables->push_back(MakeStandardStateVariableSpecification("mom"));
+ state_variables->push_back(MakeStandardStateVariableSpecification("mg"));
+ break;
+ }
+ case OptimizationAlgorithm::kMdlAdagradLight: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("weights"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("benefits"));
+ break;
+ }
+ case OptimizationAlgorithm::kAdadelta: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("updates"));
+ break;
+ }
+ case OptimizationAlgorithm::kProximalAdagrad: {
+ state_variables->push_back(
+ MakeStandardStateVariableSpecification("accumulators"));
+ break;
+ }
+ case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
+ return errors::InvalidArgument("No optimization algorithm specified");
+ }
+ }
+ // This needs to be last so that the save/restore ops do not need to know
+ // about gradient accumulation.
+ if (use_gradient_accumulation) {
+ StateVariableSpecification gradient_acc;
+ gradient_acc.set_name("gradient_accumulators");
+ gradient_acc.mutable_fill_with_constant()->set_initial_value(
+ kGradientAccumulatorInitialValue);
+ state_variables->push_back(std::move(gradient_acc));
+ }
+ if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
+ return errors::InvalidArgument(
+ "Optimization algorithm", GetOptimizationAlgorithmName(alg),
+ "does not support gradient accumulation because it "
+ "already has too many other accumulators");
+ }
+ return Status::OK();
+} // namespace tpu
+
+std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
+ return {
+ OptimizationAlgorithm::kAdagrad,
+ OptimizationAlgorithm::kStochasticGradientDescent,
+ OptimizationAlgorithm::kFtrl,
+ OptimizationAlgorithm::kAdam,
+ OptimizationAlgorithm::kMomentum,
+ OptimizationAlgorithm::kRmsProp,
+ OptimizationAlgorithm::kCenteredRmsProp,
+ OptimizationAlgorithm::kMdlAdagradLight,
+ OptimizationAlgorithm::kAdadelta,
+ OptimizationAlgorithm::kProximalAdagrad,
+ };
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h
new file mode 100644
index 0000000000..81d50264ed
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
+#define TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
+
+#include <string>
+#include "absl/base/casts.h"
+#include "tensorflow/contrib/tpu/proto/optimization_parameters.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace tpu {
+
+using OptimizationAlgorithm = OptimizationParameters::ParametersCase;
+
+// Returns the name of the optimization algorithm.
+string GetOptimizationAlgorithmName(OptimizationAlgorithm alg);
+
+// Returns a user-friendly name for the optimization algorithm.
+string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg);
+
+// Returns all supported optimization algorithms.
+std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms();
+
+enum class GradientAccumulationSupport {
+ // Accumulation cannot be used with this optimizer.
+ kNotSupported,
+
+ // Accumulation is unnecessary because optimizer application is commutative.
+ kUnnecessary,
+
+ // Accumulation is allowed and changes optimizer behavior.
+ kSupported,
+};
+
+// Returns the number of optimization parameter vectors used by the optimization
+// algorithm, excluding the weights themselves and assuming no gradient
+// accumulation.
+Status GetBaseAuxiliaryParameterCount(OptimizationAlgorithm alg, int *count);
+
+// Returns whether (and how) an optimization algorithm supports gradient
+// accumulation.
+Status GetGradientAccumulationSupport(OptimizationAlgorithm alg,
+ GradientAccumulationSupport *support);
+
+// Returns the parameter specifications for the optimization algorithm (the main
+// parameters first, followed by any auxiliary parameters such as Adagrad
+// accumulators).
+Status GetOptimizationAlgorithmStateVariables(
+ OptimizationAlgorithm alg, bool use_gradient_accumulation,
+ std::vector<StateVariableSpecification> *state_variables);
+
+// Maximum value of auxiliar_parameter_count for any optimization algorithm.
+static constexpr int kMaxAuxiliaryParameterCount = 3;
+
+// Fill value for gradient accumulators. This is a denormal so that it will be
+// flushed to zero on the current TPU platforms and needs to continue to have
+// the following properties in the future:
+//
+// 1. Does not have the same bit pattern as a zero and can be distinguished from
+// it using integer operations.
+// 2. Treated as zero by floating-point arithmetic operations (at least addition
+// and subtraction).
+// 3. Cannot be produced by any floating-point arithmetic operation, including
+// those involving itself.
+//
+// It does not need to compare equal or not equal to zero in floating point. We
+// need to use a non-zero value here because some optimization algorithms are
+// not no-ops on zero gradients, so we need to distinguish an accumulated
+// gradient of zero from one that has been cleared after its gradients have
+// already been applied to the parameters and accumulators.
+const float kGradientAccumulatorInitialValue = absl::bit_cast<float, uint32>(1);
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc
new file mode 100644
index 0000000000..8480ec4b8b
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h"
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace tpu {
+
+void AddDefaultEmbeddingOutputLayoutIfNeeded(
+ TPUEmbeddingConfiguration* config) {
+ if (config->has_output_layout()) {
+ // Model or previous step has already filled this in.
+ return;
+ }
+
+ TPUEmbeddingOutputLayout* layout = config->mutable_output_layout();
+ // Create output tensors.
+ for (const auto& table : config->table_descriptor()) {
+ TPUEmbeddingOutputLayout::EmbeddingOutputTensor* output =
+ layout->add_output();
+ TPUEmbeddingOutputLayout::TwoDOutputTensor* two_d = output->mutable_two_d();
+ two_d->set_dim1_size(table.dimension());
+ two_d->set_dim0_size_per_sample(table.num_features());
+ }
+
+ // Create table output locations.
+ for (int table_id = 0; table_id < config->table_descriptor_size();
+ ++table_id) {
+ TPUEmbeddingOutputLayout::TableDescriptor* output_table =
+ layout->add_table();
+ const auto& table = config->table_descriptor(table_id);
+ for (int feature_index = 0; feature_index < table.num_features();
+ ++feature_index) {
+ TPUEmbeddingOutputLayout::FeatureDescriptor* output_feature =
+ output_table->add_feature();
+ TPUEmbeddingOutputLayout::OutputLocation* output_location =
+ output_feature->add_output_location();
+ output_location->set_tensor_index(table_id);
+ output_location->set_dim0_offset(feature_index);
+ output_location->set_dim1_offset(0);
+ }
+ }
+}
+
+Status ComputeOutputTensorShapes(const TPUEmbeddingConfiguration& config,
+ std::vector<TensorShapeProto>* shapes) {
+ if (!config.has_output_layout()) {
+ return errors::InvalidArgument(
+ "TPUEmbeddingConfiguration is missing output layout.");
+ }
+ const TPUEmbeddingOutputLayout& layout = config.output_layout();
+ int batch_size = config.batch_size_per_tensor_core();
+
+ for (int i = 0; i < layout.output_size(); ++i) {
+ const auto& output = layout.output(i);
+ TensorShapeProto shape;
+ switch (output.output_format_case()) {
+ case TPUEmbeddingOutputLayout::EmbeddingOutputTensor::OutputFormatCase::
+ kTwoD: {
+ auto* dim0 = shape.add_dim();
+ dim0->set_size(output.two_d().dim0_size_per_sample() * batch_size);
+ auto* dim1 = shape.add_dim();
+ dim1->set_size(output.two_d().dim1_size());
+ break;
+ }
+ case TPUEmbeddingOutputLayout::EmbeddingOutputTensor::OutputFormatCase::
+ OUTPUT_FORMAT_NOT_SET: {
+ return errors::InvalidArgument(
+ "Output layout in TPUEmbeddingConfiguration has unset embedding "
+ "output tensor format.");
+ }
+ default: {
+ return errors::InvalidArgument(
+ "Output layout in TPUEmbeddingConfiguration has invalid or "
+ "unhandled embedding output tensor format.");
+ }
+ }
+ shapes->push_back(shape);
+ }
+ return Status::OK();
+}
+
+} // namespace tpu
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h
new file mode 100644
index 0000000000..c10fbeeff2
--- /dev/null
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_
+#define TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_
+
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace tpu {
+
+// Creates a default output layout for compatibility if none was provided by the
+// model.
+void AddDefaultEmbeddingOutputLayoutIfNeeded(TPUEmbeddingConfiguration* config);
+
+// Computes the shape of the output tensors from an output layout.
+Status ComputeOutputTensorShapes(
+ const TPUEmbeddingConfiguration& config,
+ std::vector<tensorflow::TensorShapeProto>* shapes);
+
+} // namespace tpu
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index ddf8365d61..b565ebd073 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -313,6 +313,5 @@ tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
- java_api_version = 2,
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/training/python/training/device_setter_test.py b/tensorflow/contrib/training/python/training/device_setter_test.py
index 20746d911c..3bb2dce83d 100644
--- a/tensorflow/contrib/training/python/training/device_setter_test.py
+++ b/tensorflow/contrib/training/python/training/device_setter_test.py
@@ -98,10 +98,10 @@ class GreedyLoadBalancingStrategyTest(test.TestCase):
cluster=_CLUSTER_SPEC,
ps_strategy=device_setter_lib.GreedyLoadBalancingStrategy(
2, device_setter_lib.byte_size_load_fn))):
- u = variables.Variable(array_ops.zeros([2, 2]))
- v = variables.Variable(array_ops.zeros([2, 1]))
- w = variables.Variable(array_ops.zeros([2, 2]))
- x = variables.Variable(array_ops.zeros([1, 3]))
+ u = variables.VariableV1(array_ops.zeros([2, 2]))
+ v = variables.VariableV1(array_ops.zeros([2, 1]))
+ w = variables.VariableV1(array_ops.zeros([2, 2]))
+ x = variables.VariableV1(array_ops.zeros([1, 3]))
a = v + w
self.assertDeviceEqual("/job:ps/task:0", u.device)
self.assertDeviceEqual("/job:ps/task:0", u.initializer.device)
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
index f46d03209c..8896a95327 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import nest as tf_nest
-class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
+class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that prepends a queue to another `Dataset`.
A vector of handles to the queue is returned as the first component of
@@ -39,7 +39,7 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
"""Initialize `PrependFromQueueAndPaddedBatchDataset`."""
- super(_PrependFromQueueAndPaddedBatchDataset, self).__init__()
+ super(_PrependFromQueueAndPaddedBatchDataset, self).__init__(input_dataset)
if sparse.any_sparse(input_dataset.output_classes):
raise TypeError(
"Batching of padded sparse tensors is not currently supported")
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 59b7dd04e9..57819cec70 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -144,10 +144,12 @@ load(
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
+ "if_dynamic_kernels",
"if_static",
"tf_cuda_tests_tags",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
load(
"//third_party/mkl:build_defs.bzl",
@@ -237,7 +239,6 @@ tf_proto_library(
srcs = [],
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
protodeps = [
":protos_all_proto",
@@ -706,14 +707,11 @@ cc_library(
cc_library(
name = "feature_util",
srcs = ["example/feature_util.cc"],
- hdrs = [
- "example/feature_util.h",
- "platform/types.h",
- ],
+ hdrs = ["example/feature_util.h"],
visibility = ["//visibility:public"],
deps = [
":core_stringpiece",
- ":platform_protobuf",
+ ":lib_proto_parsing",
":protos_all_cc",
],
)
@@ -1040,6 +1038,7 @@ tf_gen_op_libs(
"dataset_ops",
"decode_proto_ops",
"encode_proto_ops",
+ "experimental_dataset_ops",
"function_ops",
"functional_ops",
"image_ops",
@@ -1056,7 +1055,6 @@ tf_gen_op_libs(
"random_grad",
"random_ops",
"remote_fused_graph_ops",
- "resource_variable_ops",
"rpc_ops",
"scoped_allocator_ops",
"sdca_ops",
@@ -1098,6 +1096,14 @@ tf_gen_op_libs(
deps = ["//tensorflow/core/kernels:debug_ops"],
)
+tf_gen_op_libs(
+ is_external = False,
+ op_lib_names = [
+ "resource_variable_ops",
+ ],
+ deps = [":lib"],
+)
+
# And one for all user ops
cc_library(
name = "user_ops_op_lib",
@@ -1163,6 +1169,7 @@ cc_library(
":dataset_ops_op_lib",
":decode_proto_ops_op_lib",
":encode_proto_ops_op_lib",
+ ":experimental_dataset_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@@ -1292,8 +1299,8 @@ cc_library(
# This includes implementations of all kernels built into TensorFlow.
cc_library(
- name = "all_kernels",
- visibility = ["//visibility:public"],
+ name = "all_kernels_statically_linked",
+ visibility = ["//visibility:private"],
deps = [
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:audio",
@@ -1362,6 +1369,7 @@ cc_library(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_transpose_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
@@ -1372,6 +1380,15 @@ cc_library(
]),
)
+cc_library(
+ name = "all_kernels",
+ visibility = ["//visibility:public"],
+ deps = if_dynamic_kernels(
+ [],
+ otherwise = [":all_kernels_statically_linked"],
+ ),
+)
+
tf_cuda_library(
name = "tensorflow_opensource",
copts = tf_copts(),
@@ -2367,7 +2384,6 @@ tf_proto_library(
srcs = ERROR_CODES_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
provide_cc_alias = True,
)
@@ -2388,7 +2404,6 @@ tf_proto_library(
srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
protodeps = [
":error_codes_proto",
@@ -2468,6 +2483,8 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/op_segment.h",
"framework/rendezvous.h", # only needed for tests
"framework/resource_var.h",
+ "framework/run_handler.h",
+ "framework/run_handler_util.h",
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
@@ -2544,6 +2561,7 @@ tf_cuda_library(
"**/*test*",
"**/*main.cc",
"example/example_parser_configuration.*",
+ "example/feature_util.cc",
"util/reporter.cc",
"framework/fake_input.*",
"framework/op_gen_lib.*",
@@ -2573,6 +2591,7 @@ tf_cuda_library(
],
}),
deps = [
+ ":feature_util",
":lib",
":lib_internal",
":protos_all_proto_text",
@@ -2952,6 +2971,7 @@ tf_cuda_library(
":core_cpu_internal",
":device_tracer",
":framework",
+ ":framework_internal",
":graph",
":lib",
":lib_internal",
@@ -2989,7 +3009,7 @@ tf_cuda_library(
"platform/device_tracer.h",
],
copts = tf_copts(),
- cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(),
+ cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()),
visibility = ["//visibility:private"],
deps = [
":core_cpu_internal",
@@ -3811,6 +3831,7 @@ tf_cc_test_mkl(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
]),
@@ -4098,6 +4119,19 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "framework_run_handler_util_test",
+ size = "small",
+ srcs = ["framework/run_handler_util_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ],
+)
+
tf_cuda_cc_test(
name = "common_runtime_direct_session_test",
size = "small",
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
new file mode 100644
index 0000000000..fa8fc96bb2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalAssertNextDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
new file mode 100644
index 0000000000..5fd88e7a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalCSVDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
new file mode 100644
index 0000000000..ac1f9719fe
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
@@ -0,0 +1,21 @@
+op {
+ graph_op_name: "ExperimentalDirectedInterleaveDataset"
+ in_arg {
+ name: "selector_input_dataset"
+ description: <<END
+A dataset of scalar `DT_INT64` elements that determines which of the
+`N` data inputs should produce the next output element.
+END
+ }
+ in_arg {
+ name: "data_input_datasets"
+ description: <<END
+`N` datasets with the same type that will be interleaved according to
+the values of `selector_input_dataset`.
+END
+ }
+ summary: <<END
+A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
new file mode 100644
index 0000000000..66511eff60
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
@@ -0,0 +1,58 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResource"
+ in_arg {
+ name: "string_arg"
+ description: <<END
+String argument to the function call.
+END
+ }
+ in_arg {
+ name: "target_device"
+ description: <<END
+Target device to execute the function on.
+END
+ }
+ out_arg {
+ name: "resource"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name across
+multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+Function to be executed.
+END
+ }
+ attr {
+ name: "buffer_size"
+ description: <<END
+Size of the buffer.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Creates a resource that fills up a buffer by making function calls.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
new file mode 100644
index 0000000000..bf4b66b22b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
@@ -0,0 +1,25 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceGetNext"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A list of return values.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Gets the next element from a FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
new file mode 100644
index 0000000000..729718ddb3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceReset"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ summary: <<END
+Resets the FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
new file mode 100644
index 0000000000..fe266c111f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIdentityIndexedDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
new file mode 100644
index 0000000000..d42546516d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIgnoreErrorsDataset"
+ summary: <<END
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
new file mode 100644
index 0000000000..e285f87e10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetGet"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
new file mode 100644
index 0000000000..60c32473b5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetMaterialize"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
new file mode 100644
index 0000000000..b72b229e9a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIteratorGetDevice"
+ summary: <<END
+Returns the name of the device on which `resource` has been placed.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
new file mode 100644
index 0000000000..b38b23a51d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalLMDBDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
new file mode 100644
index 0000000000..9676b9d284
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
new file mode 100644
index 0000000000..d73b5bfda3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolDataset"
+ in_arg {
+ name: "thread_pool"
+ description: <<END
+A resource produced by the ThreadPoolHandle op.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
new file mode 100644
index 0000000000..48bf93406c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolHandle"
+ out_arg {
+ name: "handle"
+ description: <<END
+A resource that can be consumed by one or more ExperimentalThreadPoolDataset
+ops.
+END
+ }
+ attr {
+ name: "num_threads"
+ description: <<END
+The number of threads in the thread pool.
+END
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ description: <<END
+The maximum degree of parallelism to use within operations that execute on this
+threadpool.
+END
+ }
+ attr {
+ name: "display_name"
+ description: <<END
+A human-readable name for the threads that may be visible in some
+visualizations.
+threadpool.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
new file mode 100644
index 0000000000..68ed797a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalUniqueDataset"
+ summary: <<END
+Creates a dataset that contains the unique elements of `input_dataset`.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
index 40d7d371ca..7142a0e3f2 100644
--- a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
@@ -9,7 +9,7 @@ The lower regularized incomplete Gamma function is defined as:
where
-\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
+\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\)
is the lower incomplete Gamma function.
diff --git a/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt
new file mode 100644
index 0000000000..08414b3e68
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt
@@ -0,0 +1,26 @@
+op {
+ visibility: HIDDEN
+ graph_op_name: "ReduceDataset"
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "initial_state"
+ description: <<END
+A nested structure of tensors, representing the initial state of the
+transformation.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+A function that maps `(old_state, input_element)` to `new_state`. It must take
+two arguments and return a nested structures of tensors. The structure of
+`new_state` must match the structure of `initial_state`.
+END
+ }
+ summary: "Reduces the input dataset to a singleton using a reduce function."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
index cc21ddc815..7d2fbcd00b 100644
--- a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
@@ -1,5 +1,15 @@
op {
graph_op_name: "StringLength"
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is counted to compute string length. One of: `"BYTE"` (for
+the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8
+encoded Unicode code points in each string). Results are undefined
+if `unit=UTF8_CHAR` and the `input` strings do not contain structurally
+valid UTF-8.
+END
+ }
in_arg {
name: "input"
description: <<END
diff --git a/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt
new file mode 100644
index 0000000000..7898fe8d6b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt
@@ -0,0 +1,28 @@
+op {
+ graph_op_name: "UnicodeScript"
+ endpoint {
+ name: "UnicodeScript"
+ }
+ in_arg {
+ name: "input"
+ description: <<END
+A Tensor of int32 Unicode code points.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A Tensor of int32 script codes corresponding to each input code point.
+END
+ }
+ summary: <<END
+Determine the script codes of a given tensor of Unicode integer code points.
+END
+ description: <<END
+This operation converts Unicode code points to script codes corresponding to
+each code point. Script codes correspond to International Components for
+Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html.
+Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will
+match input shape.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt
new file mode 100644
index 0000000000..ca107abc6b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Xdivy"
+ summary: "Returns 0 if x == 0, and x / y otherwise, elementwise."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt
new file mode 100644
index 0000000000..da625f7836
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Xlogy"
+ summary: "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
index 9552fc92e3..e395e333bf 100644
--- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "BatchToSpaceND"
endpoint {
- name: "manip.batch_to_space_nd"
+ name: "batch_to_space_nd"
}
endpoint {
- name: "batch_to_space_nd"
+ name: "manip.batch_to_space_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
index 71257c8855..598f23bde3 100644
--- a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "GatherNd"
endpoint {
- name: "manip.gather_nd"
+ name: "gather_nd"
}
endpoint {
- name: "gather_nd"
+ name: "manip.gather_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
index c469665b66..b3d596de7a 100644
--- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "Reshape"
endpoint {
- name: "manip.reshape"
+ name: "reshape"
}
endpoint {
- name: "reshape"
+ name: "manip.reshape"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
index 77f595927b..51478b7c34 100644
--- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "ReverseV2"
endpoint {
- name: "manip.reverse"
+ name: "reverse"
}
endpoint {
- name: "reverse"
+ name: "manip.reverse"
deprecated: true
}
endpoint {
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
index a65a19b542..85888da45a 100644
--- a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "ScatterNd"
endpoint {
- name: "manip.scatter_nd"
+ name: "scatter_nd"
}
endpoint {
- name: "scatter_nd"
+ name: "manip.scatter_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
index af323a6cf3..146b97f444 100644
--- a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "SpaceToBatchND"
endpoint {
- name: "manip.space_to_batch_nd"
+ name: "space_to_batch_nd"
}
endpoint {
- name: "space_to_batch_nd"
+ name: "manip.space_to_batch_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
index 01c02e1f70..df012414e3 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "StringLength"
- endpoint {
- name: "strings.length"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
index c34061c941..1d8695f1fd 100644
--- a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "Tile"
endpoint {
- name: "manip.tile"
+ name: "tile"
}
endpoint {
- name: "tile"
+ name: "manip.tile"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt
new file mode 100644
index 0000000000..a884a46143
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "UnicodeScript"
+ endpoint {
+ name: "strings.unicode_script"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt
new file mode 100644
index 0000000000..984442ba2b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "Xdivy"
+ endpoint {
+ name: "math.xdivy"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt
new file mode 100644
index 0000000000..b4a5299256
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "Xlogy"
+ endpoint {
+ name: "math.xlogy"
+ }
+}
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 99cb9ac6a0..419867ff58 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -470,19 +470,19 @@ bool ReplaceTensorWithConstant(
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
- // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
+ // 1) Do not replace another constant.
+ // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY
- // constraint, do not replace it.
- // 3) If the constant op created does not have a kernel implementation
- // for the device, do not use it.
- // 4) If the size of the constant in bytes is too large (>
+ // 3) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
+ // 4) If the constant op created does not have a kernel implementation
+ // for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) Do not replace another constant.
+ // 5) If the constant op for the device has different output memory type
+ // from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@@ -497,8 +497,7 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
- if ((memory_type == HOST_MEMORY && !is_int32) ||
- (memory_type == DEVICE_MEMORY && is_int32)) {
+ if (memory_type == HOST_MEMORY && !is_int32) {
return false;
}
}
@@ -536,6 +535,23 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
+ if (partition_device && device_type != DEVICE_CPU) {
+ MemoryType original_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
+ &original_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ MemoryType const_output_memory_type;
+ if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
+ &const_output_memory_type)
+ .ok()) {
+ return false;
+ }
+ if (original_output_memory_type != const_output_memory_type) {
+ return false;
+ }
+ }
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index d800a86199..6e2eb66b94 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [dst, recv_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Host->Device Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
+ edge_name](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
- wrapped_done_);
+ CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
+ dst, to, recv_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Host->Device Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
@@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [edge_name, src, send_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Device->Host Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [edge_name, src, send_dev_context, out_allocator, status_cb,
+ cpu_allocator](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
- wrapped_done_);
+ CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
+ src, to, send_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Device->Host Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index af5d5b17e7..458e133b68 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/run_handler.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
#endif // __ANDROID__
}
+static RunHandlerPool* GetOrCreateRunHandlerPool(
+ const SessionOptions& options) {
+ static RunHandlerPool* pool =
+ new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
+ return pool;
+}
+
+bool DirectSession::ShouldUseRunHandlerPool() const {
+ if (options_.config.session_inter_op_thread_pool_size() > 0 ||
+ options_.config.use_per_session_threads()) {
+ return false;
+ }
+ return true;
+}
+
DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr,
DirectSessionFactory* const factory)
@@ -363,7 +379,7 @@ Status DirectSession::MaybeInitializeExecutionState(
Status DirectSession::Create(const GraphDef& graph) {
TF_RETURN_IF_ERROR(init_error_);
if (graph.node_size() > 0) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (graph_created_) {
return errors::AlreadyExists(
"A Graph has already been created for this session.");
@@ -375,7 +391,7 @@ Status DirectSession::Create(const GraphDef& graph) {
Status DirectSession::Extend(const GraphDef& graph) {
TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
return ExtendLocked(graph);
}
@@ -582,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
}
- Executor::Args::Runner default_runner = [this,
- pool](Executor::Args::Closure c) {
- SchedClosure(pool, std::move(c));
- };
+ std::unique_ptr<RunHandler> handler;
+ if (ShouldUseRunHandlerPool() &&
+ run_options.experimental().use_run_handler_pool()) {
+ // Non-null only when a global inter-op pool is used.
+ VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
+ handler = GetOrCreateRunHandlerPool(options_)->Get();
+ }
+ auto* handler_ptr = handler.get();
+
+ Executor::Args::Runner default_runner = nullptr;
+
+ if (pool == nullptr) {
+ default_runner = [](Executor::Args::Closure c) { c(); };
+ } else if (handler_ptr != nullptr) {
+ default_runner = [handler_ptr](Executor::Args::Closure c) {
+ handler_ptr->ScheduleInterOpClosure(std::move(c));
+ };
+ } else {
+ default_runner = [this, pool](Executor::Args::Closure c) {
+ SchedClosure(pool, std::move(c));
+ };
+ }
+
for (const auto& item : executors_and_keys->items) {
- // TODO(zhengxq): support partial run.
- // TODO(zhengxq): if the device picks its own threadpool, we need to assign
+ // TODO(azaks): support partial run.
+ // TODO(azaks): if the device picks its own threadpool, we need to assign
// less threads to the main compute pool by default.
thread::ThreadPool* device_thread_pool =
item.device->tensorflow_device_thread_pool();
+ // TODO(crk): Investigate usage of RunHandlerPool when using device specific
+ // thread pool(s).
if (!device_thread_pool) {
args.runner = default_runner;
} else {
@@ -1172,7 +1209,7 @@ Status DirectSession::CreateExecutors(
int graph_def_version;
{
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
graph_def_version =
execution_state_->original_graph_def().versions().producer();
}
@@ -1400,7 +1437,7 @@ Status DirectSession::CreateGraphs(
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types, int64* collective_graph_key) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
std::unique_ptr<ClientGraph> client_graph;
std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index c2cf3c7fd7..3a168bbe3f 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -215,7 +215,7 @@ class DirectSession : public Session {
// if not already initialized.
Status MaybeInitializeExecutionState(const GraphDef& graph,
bool* out_already_initialized)
- EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
// Retrieves an already existing set of executors to run 'inputs' and
// 'outputs', or creates and caches them for future use.
@@ -247,8 +247,11 @@ class DirectSession : public Session {
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata);
+ // Returns whether inter-op execution uses a global pool.
+ bool ShouldUseRunHandlerPool() const;
+
::tensorflow::Status ExtendLocked(const GraphDef& graph)
- EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
::tensorflow::Status ResourceHandleToInputTensor(
const Tensor& resource_tensor, Tensor* retrieved_tensor);
@@ -289,7 +292,7 @@ class DirectSession : public Session {
}
::tensorflow::Status CheckGraphCreated(const char* method) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (!graph_created_) {
return errors::InvalidArgument(
"Session was not created with a graph before ", method, "!");
@@ -313,10 +316,8 @@ class DirectSession : public Session {
DeviceSet device_set_;
string session_handle_;
- bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
-
- mutex graph_def_lock_;
- GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
+ mutex graph_state_lock_;
+ bool graph_created_ GUARDED_BY(graph_state_lock_) = false;
// The thread-pools to use for running ops, with a bool indicating if the pool
// is owned.
@@ -367,11 +368,11 @@ class DirectSession : public Session {
// nodes can not be moved to a different device. Maps node names to
// device names.
std::unordered_map<string, string> stateful_placements_
- GUARDED_BY(graph_def_lock_);
+ GUARDED_BY(graph_state_lock_);
// Execution_state; used when placing the entire graph.
std::unique_ptr<GraphExecutionState> execution_state_
- GUARDED_BY(graph_def_lock_);
+ GUARDED_BY(graph_state_lock_);
// The function library, before any rewrites or optimizations have been
// performed. In particular, CreateGraphs() may need to modify the function
@@ -386,7 +387,7 @@ class DirectSession : public Session {
std::atomic<int64> edge_name_counter_ = {0};
std::atomic<int64> handle_name_counter_ = {0};
- // For generating step ids that are unique across all sessions.
+ // For generating step ids that are unique across this sessions.
static std::atomic_int_fast64_t step_id_counter_;
// Global timeout for all blocking operations in this session.
@@ -395,8 +396,6 @@ class DirectSession : public Session {
// Manages all the cost models for the graphs executed in this session.
CostModelManager cost_model_manager_;
- Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
-
// For testing collective graph key generation.
mutex collective_graph_key_lock_;
int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 65e816c202..e3e431f800 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -625,6 +625,34 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
}
+TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
+ Initialize({3, 2, -1, 0});
+ auto session = CreateSession();
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def_));
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<string> target_nodes = {y_neg_};
+ std::vector<Tensor> outputs;
+
+ // Prepares RunOptions and RunMetadata
+ RunOptions run_options;
+ run_options.mutable_experimental()->set_use_run_handler_pool(true);
+
+ Status s = session->Run(run_options, inputs, output_names, target_nodes,
+ &outputs, nullptr);
+ TF_ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ // The first output should be initialized and have the correct
+ // output.
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+}
+
TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
GraphDef def;
Graph g(OpRegistry::Global());
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 2ed4f69f90..2c63b8704e 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -108,7 +108,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(2, shape.dim(0).size());
EXPECT_EQ(1, shape.dim(1).size());
if (node->name() == y->name()) {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// if MKL is used, it goes through various additional
// graph rewrite pass. In TF, everytime a graph pass
// happens, "constant" nodes are allocated
@@ -117,16 +117,16 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
// which increments the value of AllocationId.
// Thus AllocationId becomes more than TF if MKL
// is used. Now IDs for MKL are 8 more than TF.
- EXPECT_EQ(29, cm->AllocationId(node, 0));
-#else
EXPECT_EQ(21, cm->AllocationId(node, 0));
-#endif
- } else {
-#ifdef INTEL_MKL
- EXPECT_EQ(30, cm->AllocationId(node, 0));
#else
+ EXPECT_EQ(13, cm->AllocationId(node, 0));
+#endif // INTEL_MKL && ENABLE_MKL
+ } else {
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
EXPECT_EQ(22, cm->AllocationId(node, 0));
-#endif
+#else
+ EXPECT_EQ(14, cm->AllocationId(node, 0));
+#endif // INTEL_MKL && ENABLE_MKL
}
}
EXPECT_LE(0, cm->MaxExecutionTime(node));
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 6cd4fd22ea..34bf73972f 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -97,12 +97,6 @@ class Executor {
typedef std::function<void()> Closure;
typedef std::function<void(Closure)> Runner;
Runner runner = nullptr;
-
- // A callback that is invoked each time a node has finished executing.
- typedef std::function<Status(const string& node_name, const int output_slot,
- const Tensor* tensor, const bool is_ref,
- OpKernelContext* ctx)>
- NodeOutputsCallback;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 96ecfb41d4..37a979a8f1 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -38,7 +38,8 @@ void GraphOptimizer::Optimize(
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn) {
+ const std::function<bool(const Node*)>& cse_consider_fn,
+ const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -62,6 +63,7 @@ void GraphOptimizer::Optimize(
if (opts_.do_constant_folding()) {
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
+ cf_opts.consider = cf_consider_fn;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 80246281cd..789cc56942 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -45,12 +45,15 @@ class GraphOptimizer {
//
// If cse_consider_fn is not null then only nodes for which cse_consider_fn
// returns true will be considered for CSE.
+ // If cf_consider_fn is not null then only nodes for which cf_consider_fn
+ // returns true will be considered for CF.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn = nullptr);
+ const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
const OptimizerOptions& options() { return opts_; }
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 538a70668a..429b19599b 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -251,6 +251,7 @@ class MklCPUAllocator : public Allocator {
// max_alloc_size from large_size_allocator would be the maximum
// size allocated by MklCPUAllocator.
stats->max_alloc_size = l_stats.max_alloc_size;
+ stats->bytes_limit = std::max(s_stats.bytes_limit, l_stats.bytes_limit);
}
void ClearStats() override {
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
index a67411cd2e..e08ab57638 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
@@ -50,4 +50,4 @@ TEST(MKLBFCAllocatorTest, TestMaxLimit) {
} // namespace tensorflow
-#endif // INTEL_MKL
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index a81f8650bf..b1fe928ba7 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -41,6 +41,16 @@ limitations under the License.
// Set true for greater intelligibility of debug mode log messages.
#define READABLE_KEYS false
+// RingReduce algorithm exchanges chunks of tensor between devices. The chunk
+// size depends on the number of subdivisions specified in the algorithm. If
+// the user does not specify the number of subdivisions, we infer the number
+// dynamically so that the resulting chunk size does not exceed
+// kMaxChunkSizeBytes, empirically set at 4 MiB.
+constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
+// kMaxSubdivsPerDev is used to give an upper bound on the number of
+// subdivisions dynamically generated. A reasonable value would be a small
+// multiple of the number of NICs adjacent to each device.
+constexpr int kMaxSubdivsPerDevice = 2;
namespace tensorflow {
namespace {
@@ -92,7 +102,62 @@ RingReducer::RingReducer()
RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
+Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
+ if (col_params->instance.shape.num_elements() == 0) {
+ return errors::Internal("shape in CollectiveParams should be non-empty");
+ }
+ const int kAvgDevPerTask =
+ col_params->group.group_size / col_params->group.num_tasks;
+ const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask;
+ if (kMaxNumSubdivs <= 0) {
+ return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
+ " in RingReducer");
+ }
+ // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
+ // as many offsets as needed so that the size of tensor chunks <=
+ // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large
+ // lead to worse performance.
+ int num_subdivs = 0;
+ const size_t tensor_size = col_params->instance.shape.num_elements() *
+ DataTypeSize(col_params->instance.data_type);
+ size_t chunk_size;
+ do {
+ ++num_subdivs;
+ int num_chunks = col_params->group.group_size * num_subdivs;
+ chunk_size = tensor_size / num_chunks;
+ VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
+ << " chunk_size " << chunk_size;
+ } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
+ if (num_subdivs <= 0) {
+ return errors::Internal("Unexpected num_subdivs ", num_subdivs,
+ " in RingReducer");
+ }
+
+ int subdiv_stride = kAvgDevPerTask / num_subdivs;
+ if (subdiv_stride == 0) subdiv_stride = 1;
+ col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
+ for (int sdi = 0; sdi < num_subdivs; ++sdi) {
+ int subdiv_offset = subdiv_stride * sdi;
+ if (sdi % 2 == 1) subdiv_offset *= -1;
+ col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
+ }
+
+ if (VLOG_IS_ON(2)) {
+ string subdiv_buf;
+ for (const int subdiv_offset :
+ col_params->instance.impl_details.subdiv_offsets) {
+ strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
+ }
+ VLOG(2) << "Dynamically generated " << num_subdivs
+ << " subdiv_offsets:" << subdiv_buf << " tensor_size "
+ << tensor_size << " chunk_size " << chunk_size;
+ }
+
+ return Status::OK();
+}
+
Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
+ // TODO(b/113171733): change CHECKs to return errors.
CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
const string& device_name =
@@ -123,12 +188,11 @@ Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
dev_per_task.push_back(dev_count);
CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
- // Generate a ring permutation for each requested offset.
if (col_params->instance.impl_details.subdiv_offsets.empty()) {
- return errors::Internal(
- "Subdiv offsets should be non-empty for ring reducer, size=",
- col_params->instance.impl_details.subdiv_offsets.size());
+ TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
}
+
+ // Generate a ring permutation for requested offset.
VLOG(2) << "Setting up perms for col_params " << col_params
<< " subdiv_permutations "
<< &col_params->instance.impl_details.subdiv_permutations;
@@ -646,7 +710,8 @@ bool RingReducer::RunAsyncParts() {
case RF_SEND:
--send_pending_count;
break;
- default: {} // Ignore any other actions
+ default: {
+ } // Ignore any other actions
}
}
}
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 28df85399e..75aba43572 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -549,37 +549,38 @@ class RingReducerTest : public ::testing::Test {
int32 reduce_counter_ GUARDED_BY(mu_) = 0;
};
-TEST_F(RingReducerTest, InitializeParams) {
- static const int kNumDevsPerTask = 8;
- static const int kNumTasks = 3;
- static const int kNumDevs = kNumDevsPerTask * kNumTasks;
+CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
+ const int num_tasks) {
CollectiveParams cp;
- std::vector<string> device_names;
- std::vector<string> task_names;
+ const int kNumDevs = num_devs_per_task * num_tasks;
cp.group.group_key = 1;
cp.group.group_size = kNumDevs;
cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = kNumTasks;
+ cp.group.num_tasks = num_tasks;
cp.instance.instance_key = 3;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DataType(DT_FLOAT);
- cp.instance.shape = TensorShape({5});
+ cp.instance.shape = TensorShape({kNumDevs});
cp.instance.impl_details.collective_name = "RingReduce";
cp.instance.impl_details.subdiv_offsets.push_back(0);
cp.is_source = false;
for (int i = 0; i < kNumDevs; ++i) {
- int task_id = i / kNumDevsPerTask;
- int dev_id = i % kNumDevsPerTask;
+ int task_id = i / num_devs_per_task;
+ int dev_id = i % num_devs_per_task;
string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
- task_names.push_back(task_name);
string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
- device_names.push_back(device_name);
cp.instance.task_names.push_back(task_name);
cp.instance.device_names.push_back(device_name);
}
+ return cp;
+}
- int test_rank = 0;
- cp.default_rank = test_rank;
+TEST_F(RingReducerTest, InitializeParams) {
+ const int kNumDevsPerTask = 8;
+ const int kNumTasks = 3;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ cp.default_rank = 0;
cp.instance.impl_details.subdiv_offsets = {0, 4};
RunSubdivPermsTest(&cp,
{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
@@ -588,8 +589,15 @@ TEST_F(RingReducerTest, InitializeParams) {
8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
{0, 4});
- test_rank = 3;
- cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {0, -4};
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
+ 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
+ {0, 3});
+
+ cp.default_rank = 3;
cp.instance.impl_details.subdiv_offsets = {3, -3};
RunSubdivPermsTest(&cp,
{{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
@@ -599,6 +607,49 @@ TEST_F(RingReducerTest, InitializeParams) {
{0, 1});
}
+TEST_F(RingReducerTest, AutomaticSubdivs) {
+ const int kNumDevsPerTask = 8;
+ const int kNumTasks = 3;
+ const int kNumDevs = kNumDevsPerTask * kNumTasks;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ // Test automatic generation of subdiv offsets.
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets.clear();
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
+ {0});
+
+ // Set shape so that with 2 subdivs chunk_size is 3 MiB. This should cause 2
+ // offsets, {0, -4}, to be generated.
+ {
+ int num_subdivs = 2;
+ int num_chunks = kNumDevs * num_subdivs;
+ size_t chunk_size = 3 * 1048576; // 3 MB
+ size_t tensor_size = chunk_size * num_chunks;
+ cp.instance.shape =
+ TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
+ }
+ cp.instance.impl_details.subdiv_offsets.clear();
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
+ 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
+ {0, 3});
+}
+
+TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
+ const int kNumDevsPerTask = 1;
+ const int kNumTasks = 4;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets.clear();
+ cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
+}
+
// TODO(b/113171733): change to use TEST_P.
#define DEF_TEST(B, T, W, D, S, L, A) \
TEST_F(RingReducerTest, \
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 0fbc20b34b..8587d1783a 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -113,8 +113,11 @@ class MklCPUAllocatorFactory : public AllocatorFactory {
}
};
+#ifdef ENABLE_MKL
REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory);
+#endif // ENABLE_MKL
+
} // namespace
-#endif
+#endif // INTEL_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index f7a2967d00..3361819e43 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -475,10 +475,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
delete step_container;
});
Executor::Args args;
- {
- mutex_lock l(mu_);
- args.step_id = ++next_id_;
- }
+ args.step_id = step_id;
args.rendezvous = rendezvous;
args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
args.cancellation_manager = cancellation_manager;
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 20a07d86a2..50403b4004 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1306,6 +1306,113 @@ Status RandomShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+namespace {
+
+// This SliceHelper processes the output shape of the `slice`
+// when the tensor of `sizes` is available.
+template <typename T>
+Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
+ const Tensor* sizes_value,
+ std::vector<DimensionHandle>* dims) {
+ auto sizes_vec = sizes_value->vec<T>();
+ for (int i = 0; i < sizes_value->NumElements(); ++i) {
+ DimensionHandle dim = c->Dim(c->input(0), i);
+ if (sizes_vec(i) != -1) {
+ auto dim_val = c->Value(dim);
+ if (sizes_vec(i) < 0) {
+ return errors::InvalidArgument(
+ "Out of bounds slicing on dimension ", i, " of length ", dim_val,
+ ": sizes vector cannot be < -1, but was ", sizes_vec(i));
+ }
+
+ dims->emplace_back(c->MakeDim(sizes_vec(i)));
+ } else {
+ DimensionHandle result;
+ TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
+ dims->emplace_back(result);
+ }
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status SliceShape(InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ ShapeHandle begin_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
+ ShapeHandle sizes_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
+
+ // Merge to check compatibility of begin and sizes tensors.
+ TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
+
+ DimensionHandle ndims = c->Dim(begin_shape, 0);
+ if (c->ValueKnown(ndims)) {
+ TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
+ }
+
+ // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
+ // values, even though the `begin` value does not represent a shape.
+ ShapeHandle begin_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
+
+ // We check the tensor value here and will only use
+ // `MakeShapeFromShapeTensor` when `sizes_value` is null.
+ // The reason is that `sizes` might contain -1, which can't
+ // be represented (-1 in the ShapeHandle would mean "unknown").
+ const Tensor* sizes_value = c->input_tensor(2);
+
+ if (sizes_value != nullptr) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
+ std::vector<DimensionHandle> dims;
+ // If the begin and sizes tensors are available, then
+ // we can be precise about the shape of the output.
+ if (sizes_value->dtype() == DT_INT64) {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int64>(c, begin_value, sizes_value, &dims));
+ } else {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int32>(c, begin_value, sizes_value, &dims));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ } else {
+ // In case `sizes` is not available (`sizes_value` is null),
+ // we could try to use `MakeShapeFromShapeTensor` here.
+ // If sizes contain -1, we will simply consider it as `Unknown`.
+ // This is less than ideal but still an improvement of shape inference.
+ // The following is an example that returns [None, 1, None] with this
+ // code path:
+ // z = tf.zeros((1, 2, 3))
+ // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
+ // m.get_shape().as_list()
+ ShapeHandle sizes_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
+ if (c->RankKnown(sizes_value)) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
+ std::vector<DimensionHandle> dims;
+ dims.reserve(c->Rank(sizes_value));
+ for (int i = 0; i < c->Rank(sizes_value); ++i) {
+ dims.emplace_back(c->Dim(sizes_value, i));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ }
+ // We might know the rank of the input.
+ if (c->RankKnown(input)) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
+ return Status::OK();
+ } else {
+ return shape_inference::UnknownShape(c);
+ }
+ }
+
+ return Status::OK();
+}
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index e6f9f935f9..3a496e06ae 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
// Shape function for random operations.
Status RandomShape(shape_inference::InferenceContext* c);
+// Shape function for Slice opertaions.
+Status SliceShape(shape_inference::InferenceContext* c);
+
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index a17959a448..20f957190b 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1101,6 +1101,14 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
return Status::OK();
}
+Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
+ mutex_lock l(mu_);
+ bool added;
+ TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
+ TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
+ return Status::OK();
+}
+
Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
const auto& i = function_defs_.find(func);
if (i == function_defs_.end()) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e01eb7503d..4d6d68e214 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -331,6 +331,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// a non-OK status if "func" was not found in the library, OK otherwise.
Status ReplaceFunction(const string& func, const FunctionDef& fdef);
+ // Replaces the gradient corresponding to `grad.function_name()`. Returns
+ // a non-OK status if "grad.function_name()" was not found in the library, OK
+ // otherwise.
+ Status ReplaceGradient(const GradientDef& grad);
+
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index d5c203d276..0445c242e9 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -93,7 +93,6 @@ FunctionDef IsZero() {
FunctionDef RandomUniform() {
const Tensor kZero = test::AsScalar<int64>(0);
- const Tensor kTen = test::AsScalar<int64>(10);
return FDH::Define(
// Name
@@ -108,19 +107,11 @@ FunctionDef RandomUniform() {
"Const",
{},
{{"value", kZero}, {"dtype", DT_INT64}}},
- {{"random_uniform/min"},
- "Const",
- {},
- {{"value", kZero}, {"dtype", DT_INT64}}},
- {{"random_uniform/max"},
- "Const",
- {},
- {{"value", kTen}, {"dtype", DT_INT64}}},
{{"random_uniform"},
- "RandomUniformInt",
- {},
- {{"T", DT_INT64},
- {"Tout", DT_INT64},
+ "RandomUniform",
+ {"random_uniform/shape"},
+ {{"T", DT_INT32},
+ {"Tout", DT_FLOAT},
{"seed", 87654321},
{"seed2", 42}}}});
}
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 187bfa2c88..0ff67554eb 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#include <string>
-#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 25f8de8dcc..81ed5f95f0 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -209,16 +209,16 @@ template <>
class OpDefBuilderWrapper<true> {
public:
OpDefBuilderWrapper(const char name[]) : builder_(name) {}
- OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
- builder_.Attr(spec);
+ OpDefBuilderWrapper<true>& Attr(string spec) {
+ builder_.Attr(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Input(StringPiece spec) {
- builder_.Input(spec);
+ OpDefBuilderWrapper<true>& Input(string spec) {
+ builder_.Input(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Output(StringPiece spec) {
- builder_.Output(spec);
+ OpDefBuilderWrapper<true>& Output(string spec) {
+ builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper<true>& SetIsCommutative() {
@@ -237,12 +237,12 @@ class OpDefBuilderWrapper<true> {
builder_.SetAllowsUninitializedInput();
return *this;
}
- OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
- builder_.Deprecated(version, explanation);
+ OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
+ builder_.Deprecated(version, std::move(explanation));
return *this;
}
- OpDefBuilderWrapper<true>& Doc(StringPiece text) {
- builder_.Doc(text);
+ OpDefBuilderWrapper<true>& Doc(string text) {
+ builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper<true>& SetShapeFn(
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 34a7a43d38..8a9bb63182 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -526,32 +526,32 @@ void FinalizeDoc(const string& text, OpDef* op_def,
} // namespace
-OpDefBuilder::OpDefBuilder(StringPiece op_name) {
- op_def()->set_name(string(op_name)); // NOLINT
+OpDefBuilder::OpDefBuilder(string op_name) {
+ op_def()->set_name(std::move(op_name));
}
-OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
- attrs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Attr(string spec) {
+ attrs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Input(StringPiece spec) {
- inputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Input(string spec) {
+ inputs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
- outputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Output(string spec) {
+ outputs_.push_back(std::move(spec));
return *this;
}
#ifndef TF_LEAN_BINARY
-OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
+OpDefBuilder& OpDefBuilder::Doc(string text) {
if (!doc_.empty()) {
errors_.push_back(
strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
} else {
- doc_.assign(text.data(), text.size());
+ doc_ = std::move(text);
}
return *this;
}
@@ -577,14 +577,14 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
return *this;
}
-OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
+OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
if (op_def()->has_deprecation()) {
errors_.push_back(
strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
} else {
OpDeprecation* deprecation = op_def()->mutable_deprecation();
deprecation->set_version(version);
- deprecation->set_explanation(string(explanation));
+ deprecation->set_explanation(std::move(explanation));
}
return *this;
}
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index 0b39d6e848..8077b20598 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -51,7 +51,7 @@ struct OpRegistrationData {
class OpDefBuilder {
public:
// Constructs an OpDef with just the name field set.
- explicit OpDefBuilder(StringPiece op_name);
+ explicit OpDefBuilder(string op_name);
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
// format "<name>:<type>" or "<name>:<type>=<default>"
@@ -84,7 +84,7 @@ class OpDefBuilder {
// * Ability to restrict the type of the tensor like the existing
// restrictions for type attrs.
// Perhaps by linking the type of the tensor to a type attr?
- OpDefBuilder& Attr(StringPiece spec);
+ OpDefBuilder& Attr(string spec);
// Adds an input or output to this OpDefBuilder (and returns *this).
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
@@ -101,8 +101,8 @@ class OpDefBuilder {
// in the spec?
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
// handling?
- OpDefBuilder& Input(StringPiece spec);
- OpDefBuilder& Output(StringPiece spec);
+ OpDefBuilder& Input(string spec);
+ OpDefBuilder& Output(string spec);
// Turns on the indicated boolean flag in this OpDefBuilder (and
// returns *this).
@@ -112,7 +112,7 @@ class OpDefBuilder {
OpDefBuilder& SetAllowsUninitializedInput();
// Deprecate the op at a certain GraphDef version.
- OpDefBuilder& Deprecated(int version, StringPiece explanation);
+ OpDefBuilder& Deprecated(int version, string explanation);
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
@@ -128,9 +128,9 @@ class OpDefBuilder {
// to suppress the automatically-generated type documentation in
// generated output.
#ifndef TF_LEAN_BINARY
- OpDefBuilder& Doc(StringPiece text);
+ OpDefBuilder& Doc(string text);
#else
- OpDefBuilder& Doc(StringPiece text) { return *this; }
+ OpDefBuilder& Doc(string text) { return *this; }
#endif
// Sets the shape function to be used for shape inference.
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index ebdaaec153..508a8d3149 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -288,4 +288,13 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
return ctx->resource_manager()->Delete(p);
}
+Status ResourceHandlesShape(shape_inference::InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index d58deaa3fc..4a531648d9 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
+#include <memory>
#include <string>
#include <typeindex>
#include <typeinfo>
@@ -127,6 +128,14 @@ class ResourceMgr {
Status Lookup(const string& container, const string& name,
T** resource) const TF_MUST_USE_RESULT;
+ // Similar to Lookup, but looks up multiple resources at once, with only a
+ // single lock acquisition.
+ template <typename T>
+ Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
+ containers_and_names,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
+ resource) const TF_MUST_USE_RESULT;
+
// If "container" has a resource "name", returns it in
// "*resource". Otherwise, invokes creator() to create the resource.
// The caller takes the ownership of one ref on "*resource".
@@ -239,14 +248,31 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
// Create a resource pointed by a given resource handle.
+//
+// If successful, the caller transfers the ownership of one ref on `resource` to
+// `ctx->resource_mgr()`.
template <typename T>
Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
// Looks up a resource pointed by a given resource handle.
+//
+// If the lookup is successful, the caller takes the ownership of one ref on
+// `*value`, and must call its `Unref()` method when it has finished using it.
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
+// Looks up multiple resources pointed by a sequence of resource handles.
+template <typename T>
+Status LookupResources(
+ OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
+
// Looks up or creates a resource.
+//
+// If successful, the caller takes the ownership of one ref on `*value`, and
+// must call its `Unref()` method when it has finished using it. If the
+// `creator` is invoked, its reference on the created resource is transferred
+// to `ctx->resource_mgr()`.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator);
@@ -358,6 +384,26 @@ class ResourceHandleOp : public OpKernel {
std::atomic<bool> initialized_{false};
};
+// Utility op kernel to produce a handle to a resource of type T.
+template <typename T>
+class ResourceHandlesOp : public OpKernel {
+ public:
+ explicit ResourceHandlesOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ std::vector<string> containers_;
+ std::vector<string> names_;
+ mutex mutex_;
+ std::vector<Tensor> resources_;
+ std::atomic<bool> initialized_{false};
+};
+
+Status ResourceHandlesShape(shape_inference::InferenceContext* c);
+
// Registers a kernel for an op which produces a handle to a resource of the
// specified type.
#define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \
@@ -390,6 +436,24 @@ Status ResourceMgr::Lookup(const string& container, const string& name,
}
template <typename T>
+Status ResourceMgr::LookupMany(
+ absl::Span<std::pair<const string*, const string*> const>
+ containers_and_names,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
+ CheckDeriveFromResourceBase<T>();
+ tf_shared_lock l(mu_);
+ resources->resize(containers_and_names.size());
+ for (size_t i = 0; i < containers_and_names.size(); ++i) {
+ T* resource;
+ TF_RETURN_IF_ERROR(LookupInternal(*containers_and_names[i].first,
+ *containers_and_names[i].second,
+ &resource));
+ (*resources)[i].reset(resource);
+ }
+ return Status::OK();
+}
+
+template <typename T>
Status ResourceMgr::LookupInternal(const string& container, const string& name,
T** resource) const {
ResourceBase* found = nullptr;
@@ -499,6 +563,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
}
template <typename T>
+Status LookupResources(
+ OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) {
+ std::vector<std::pair<const string*, const string*>> containers_and_names(
+ p.size());
+ for (size_t i = 0; i < p.size(); ++i) {
+ TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
+ containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
+ }
+ return ctx->resource_manager()->LookupMany(containers_and_names, values);
+}
+
+template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
@@ -555,6 +632,46 @@ void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
ctx->set_output(0, resource_);
}
+template <typename T>
+ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ int n;
+ OP_REQUIRES_OK(context, context->GetAttr("N", &n));
+ OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
+ OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
+ OP_REQUIRES(
+ context, containers_.size() == n,
+ errors::InvalidArgument("Number of containers (", containers_.size(),
+ ") must be equal to N (", n, ")"));
+ OP_REQUIRES(context, names_.size() == n,
+ errors::InvalidArgument("Number of names (", containers_.size(),
+ ") must be equal to N (", n, ")"));
+ resources_.resize(n);
+}
+
+template <typename T>
+void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
+ if (!initialized_.load()) {
+ mutex_lock ml(mutex_);
+ // Checking again to see if another thread has initialized the resource.
+ if (!initialized_.load()) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ for (size_t i = 0; i < resources_.size(); ++i) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
+ &resources_[i], attr));
+ ResourceHandle h =
+ MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
+ resources_[i].template scalar<ResourceHandle>()() = h;
+ }
+ initialized_.store(true);
+ }
+ }
+ for (size_t i = 0; i < resources_.size(); ++i) {
+ ctx->set_output(i, resources_[i]);
+ }
+}
+
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc
new file mode 100644
index 0000000000..0c4007eafc
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.cc
@@ -0,0 +1,249 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/run_handler.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/run_handler_util.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+// Contains the concrete implementation of the RunHandler.
+// Externally visible RunHandler class simply forwards the work to this one.
+class RunHandler::Impl {
+ public:
+ explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) {
+ Reset();
+ }
+
+ ~Impl() {}
+
+ void set_inter_op_scheduling_range(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ inter_op_scheduling_range_.store(EncodePartition(start, limit),
+ std::memory_order_release);
+ }
+
+ std::uint_fast32_t inter_op_scheduling_range() const {
+ return inter_op_scheduling_range_.load(std::memory_order_acquire);
+ }
+
+ // Stores now time (in microseconds) since unix epoch when the handler is
+ // requested via RunHandlerPool::Get().
+ uint64 start_time_us() const { return start_time_us_; }
+
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ void Reset();
+
+ RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
+
+ private:
+ // Encoding/decoding logic for storing [start, limit) into a single
+ // uint_fast32_t int. We assume that pool_num_threads < (1 << 16).
+ const int kMaxPartitionBits = 16;
+ const int kMaxThreads = 1 << kMaxPartitionBits;
+
+ std::uint_fast32_t EncodePartition(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ return (start << kMaxPartitionBits) | limit;
+ }
+
+ void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start,
+ std::uint_fast32_t* limit) {
+ *limit = val & (kMaxThreads - 1);
+ val >>= kMaxPartitionBits;
+ *start = val;
+ }
+
+ std::atomic_uint_fast32_t inter_op_scheduling_range_;
+ RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
+ uint64 start_time_us_;
+};
+
+// Contains shared state across all run handlers present in the pool. Also
+// responsible for pool management decisions.
+// This class is thread safe.
+class RunHandlerPool::Impl {
+ public:
+ explicit Impl(int num_inter_op_threads)
+ : max_handlers_(128),
+ inter_op_thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)),
+ iterations_(0) {
+ VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
+ for (int i = 0; i < max_handlers_; ++i) {
+ handlers_.emplace_back(new RunHandler::Impl(this));
+ free_handlers_.push_back(handlers_.back().get());
+ }
+ }
+
+ ~Impl() {
+ // Sanity check that all handlers have been returned back to the pool before
+ // destruction.
+ DCHECK_EQ(handlers_.size(), max_handlers_);
+ DCHECK_EQ(free_handlers_.size(), handlers_.size());
+ DCHECK_EQ(sorted_active_handlers_.size(), 0);
+ }
+
+ thread::ThreadPool* inter_op_thread_pool() const {
+ return inter_op_thread_pool_.get();
+ }
+
+ std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ while (free_handlers_.empty()) {
+ one_handler_free_.wait(l);
+ }
+ // Remove the last entry from free_handlers_ and add to the end of
+ // sorted_active_handlers_.
+ auto* handler_impl = free_handlers_.back();
+ handler_impl->Reset();
+ // Sortedness isn't violated if we simply add at the end of the list, since
+ // handlers are expected to be obtained in increasing order of time.
+ sorted_active_handlers_.push_back(handler_impl);
+ DCHECK_LE(sorted_active_handlers_.size(), max_handlers_);
+ free_handlers_.pop_back();
+
+ RecomputePoolStatsLocked();
+ return WrapUnique<RunHandler>(new RunHandler(handler_impl));
+ }
+
+ void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ DCHECK_GT(sorted_active_handlers_.size(), 0);
+
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ double elapsed = (now - handler->start_time_us()) / 1000.0;
+ time_hist_.Add(elapsed);
+
+ // Erase from and update sorted_active_handlers_. Add it to the end of
+ // free_handlers_.
+ auto iter = std::find(sorted_active_handlers_.begin(),
+ sorted_active_handlers_.end(), handler);
+ DCHECK(iter != sorted_active_handlers_.end())
+ << "Unexpected handler: " << handler
+ << " is being requested for release";
+
+ // Remove this handler from this list and add it to the list of free
+ // handlers.
+ sorted_active_handlers_.erase(iter);
+ free_handlers_.push_back(handler);
+ DCHECK_LE(free_handlers_.size(), max_handlers_);
+
+ RecomputePoolStatsLocked();
+ }
+ one_handler_free_.notify_one();
+ }
+
+ private:
+ void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Maximum number of handlers pre-created during pool construction time. The
+ // number has been chosen expecting each handler might at least want 1
+ // inter-op thread for execution (during compute intensive workloads like
+ // inference).
+ const int max_handlers_;
+
+ // Thread safe part.
+ const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_;
+
+ // Thread compatible part used only by lock under RunHandlerPool.
+ // Handlers are sorted by start time.
+ std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_);
+ std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_);
+ // Histogram of elapsed runtime of every handler (in ms).
+ histogram::Histogram time_hist_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_);
+ int64 iterations_ GUARDED_BY(mu_);
+ condition_variable one_handler_free_;
+ mutex mu_;
+};
+
+void RunHandlerPool::Impl::RecomputePoolStatsLocked() {
+ int num_active_requests = sorted_active_handlers_.size();
+ if (num_active_requests == 0) return;
+
+ int num_threads = inter_op_thread_pool_->NumThreads();
+
+ inter_op_start_.resize(num_active_requests);
+ inter_op_limit_.resize(num_active_requests);
+
+ const int kMinThreadsPerRequest = 3;
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ kMinThreadsPerRequest, &inter_op_start_,
+ &inter_op_limit_);
+
+ for (int i = 0; i < num_active_requests; ++i) {
+ sorted_active_handlers_[i]->set_inter_op_scheduling_range(
+ inter_op_start_[i], inter_op_limit_[i]);
+ }
+
+ if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) {
+ VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
+ VLOG(1) << "Active session runs: " << num_active_requests;
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ string ranges_str = "";
+ string times_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) {
+ times_str += " ";
+ ranges_str += " ";
+ }
+
+ times_str += strings::StrCat(
+ (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms.");
+ ranges_str += strings::StrCat("[", inter_op_start_[i], ", ",
+ inter_op_limit_[i], ")");
+ }
+ VLOG(1) << "Elapsed times are: " << times_str;
+ VLOG(1) << "Ranges are: " << ranges_str;
+ }
+}
+
+void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
+ std::uint_fast32_t start = 0, limit = 0;
+ DecodePartition(inter_op_scheduling_range(), &start, &limit);
+ pool_impl_->inter_op_thread_pool()->Schedule(std::move(fn));
+}
+
+void RunHandler::Impl::Reset() {
+ set_inter_op_scheduling_range(
+ 0, pool_impl_->inter_op_thread_pool()->NumThreads());
+ start_time_us_ = tensorflow::Env::Default()->NowMicros();
+}
+
+RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
+ : impl_(new Impl(num_inter_op_threads)) {}
+
+RunHandlerPool::~RunHandlerPool() {}
+
+std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); }
+
+RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
+
+void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
+ impl_->ScheduleInterOpClosure(std::move(fn));
+}
+
+RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h
new file mode 100644
index 0000000000..72fa6301b4
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.h
@@ -0,0 +1,95 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+
+class RunHandler;
+
+// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
+// that can be used for tracking inter-op work for a given Session::Run().
+// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
+// 'active' when its unique_ptr is returned by Get() and is being used by a
+// client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
+//
+// Expected usage:
+//
+// * Create a single RunHandlerPool (say run_handler_pool_).
+//
+// * When a Session::Run() is invoked, obtain a handler by:
+// auto handler = run_handler_pool_->Get();
+//
+// * Use handler for scheduling all inter-op work by:
+// handler->ScheduleInterOpClosure(closure);
+//
+// This class is thread safe.
+class RunHandlerPool {
+ public:
+ explicit RunHandlerPool(int num_inter_op_threads);
+ ~RunHandlerPool();
+
+ // Returns an inactive RunHandler from the pool.
+ //
+ // RunHandlers in RunHandlerPool are initially 'inactive'.
+ // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
+ // and is being used by a client. It becomes 'inactive' once more when the
+ // unique_ptr is destroyed.
+ //
+ // Will block unless there is an inactive handler.
+ std::unique_ptr<RunHandler> Get();
+
+ private:
+ class Impl;
+ friend class RunHandler;
+
+ std::unique_ptr<Impl> impl_;
+};
+
+// RunHandler can be used to schedule inter-op closures to run on a global pool
+// shared across all Session::Run(s).
+//
+// It can only be created via RunHandlerPool::Get().
+//
+// This class can be used instead of directly scheduling closures on a global
+// pool since it maintains a global view across all sessions and optimizes pool
+// scheduling to improve (median and tail) latency.
+//
+// This class is thread safe.
+class RunHandler {
+ public:
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ ~RunHandler();
+
+ private:
+ class Impl;
+ friend class RunHandlerPool::Impl;
+
+ explicit RunHandler(Impl* impl);
+
+ Impl* impl_; // NOT OWNED.
+};
+
+} // end namespace tensorflow.
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc
new file mode 100644
index 0000000000..3087998c69
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.cc
@@ -0,0 +1,57 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <algorithm>
+#include <cmath>
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec) {
+ // Each request is expected to have weight W[i] = num_active_requests - i.
+ // Therefore, total_weight = sum of all request weights.
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float demand_factor = static_cast<float>(num_threads) / total_weight;
+ float last_cumulative_weight = 0.0;
+ min_threads_per_request = std::max(1, min_threads_per_request);
+ for (int i = 0; i != num_active_requests; i++) {
+ float cumulative_weight =
+ static_cast<float>(i + 1) *
+ (num_active_requests - static_cast<float>(i) * 0.5f);
+ float weight = cumulative_weight - last_cumulative_weight;
+ // Quantize thread_demand by rounding up, and also satisfying
+ // `min_threads_per_request` constraint.
+ // Note: We subtract a small epsilon (0.00001) to prevent ceil(..) from
+ // rounding weights like 4.0 to 5.
+ int demand =
+ std::max(min_threads_per_request,
+ static_cast<int>(ceil(weight * demand_factor - 0.00001f)));
+ // For the quantized range [start, end); compute the floor of real start,
+ // and expand downwards from there with length `demand` and adjust for
+ // boundary conditions.
+ int start = last_cumulative_weight * demand_factor;
+ int end = std::min(num_threads, start + demand);
+ start = std::max(0, std::min(start, end - demand));
+ start_vec->at(i) = start;
+ end_vec->at(i) = end;
+ last_cumulative_weight = cumulative_weight;
+ }
+}
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler_util.h b/tensorflow/core/framework/run_handler_util.h
new file mode 100644
index 0000000000..c0c36aeccb
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.h
@@ -0,0 +1,43 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+
+#include <cstdint>
+#include <vector>
+
+namespace tensorflow {
+
+// Assign thread ranges to requests.
+// Requests are numbered 0...num_active_requests-1, and
+// threads are numbered 0...num_threads-1.
+// On return, the range start_vec->at(i)...end_vec->at(i)-1
+// indicates the subrange of the threads available to request i.
+// The ranges given to different requests may overlap.
+// Lower numbered requests will tend to be assigned more threads.
+// Thus, a client might associate older requests with lower
+// array indices so they receive access to more threads.
+// However, the routine ensures that each request is given access
+// to at least min(min_threads_per_request, num_threads) threads.
+// Every thread will be assigned to at least one request range,
+// assuming there is at least one request.
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec);
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc
new file mode 100644
index 0000000000..a1928c132b
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util_test.cc
@@ -0,0 +1,93 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <vector>
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+namespace tensorflow {
+namespace {
+
+void VerifyFunction(int num_active_requests, int num_threads,
+ int min_threads_per_request, bool print_stats = false) {
+ if (print_stats) {
+ LOG(INFO) << "Test case# num_active_requests: " << num_active_requests
+ << " num_threads: " << num_threads
+ << " min_threads: " << min_threads_per_request;
+ }
+ std::vector<std::uint_fast32_t> start(num_active_requests);
+ std::vector<std::uint_fast32_t> end(num_active_requests);
+
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ min_threads_per_request, &start, &end);
+ string range_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) range_str += " ";
+ range_str += strings::StrCat("[", start[i], ", ", end[i], ")");
+
+ ASSERT_GE(start[i], 0) << range_str;
+ ASSERT_LE(end[i], num_threads) << range_str;
+ if (i > 0) {
+ // Due to linearly decreasing demand, #threads(i - 1) >= #threads(i)
+ ASSERT_GE(end[i - 1] - start[i - 1], end[i] - start[i]) << range_str;
+ // No missing threads.
+ ASSERT_GE(end[i - 1], start[i]) << range_str;
+ }
+ // Each interval is at least of size 'min_threads_per_request'.
+ ASSERT_GE((end[i] - start[i]), min_threads_per_request) << range_str;
+ // Verify that assigned (quantized) threads is not overly estimated
+ // from real demand, when the demand is high (>=
+ // min_threads_per_request).
+ float entry_weight = num_active_requests - i;
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float thread_demand = (entry_weight * num_threads) / total_weight;
+ if (thread_demand > min_threads_per_request) {
+ // We expect some over-estimation of threads due to quantization,
+ // but we hope it's not more than 1 extra thread.
+ ASSERT_NEAR(end[i] - start[i], thread_demand, 1.0)
+ << "Ranges: " << range_str << " thread_demand: " << thread_demand
+ << " i: " << i;
+ }
+ }
+ ASSERT_EQ(end[num_active_requests - 1], num_threads);
+ ASSERT_EQ(start[0], 0);
+ if (print_stats) {
+ LOG(INFO) << "Assigned ranges: " << range_str;
+ }
+}
+
+TEST(RunHandlerUtilTest, TestComputeInterOpSchedulingRanges) {
+ const int kMinThreadsPerRequestBound = 12;
+ const int kMaxActiveRequests = 128;
+ const int kMaxThreads = 128;
+
+ for (int min_threads_per_request = 1;
+ min_threads_per_request <= kMinThreadsPerRequestBound;
+ ++min_threads_per_request) {
+ for (int num_active_requests = 1; num_active_requests <= kMaxActiveRequests;
+ ++num_active_requests) {
+ for (int num_threads = min_threads_per_request;
+ num_threads <= kMaxThreads; ++num_threads) {
+ VerifyFunction(num_active_requests, num_threads,
+ min_threads_per_request);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 3df677675e..1dea6da911 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -813,7 +813,7 @@ Tensor Tensor::Slice(int64 start, int64 limit) const {
}
Tensor Tensor::SubSlice(int64 index) const {
- CHECK_GE(dims(), 2); // Crash ok.
+ CHECK_GE(dims(), 1); // Crash ok.
CHECK_LE(0, index); // Crash ok.
int64 dim0_size = shape_.dim_size(0);
CHECK_LE(index, dim0_size); // Crash ok.
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 8a0c70fef2..d0f9eb56e2 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -219,7 +219,7 @@ class Tensor {
/// must check the returned tensor's alignment before calling certain
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
///
- /// REQUIRES: `dims()` >= 2
+ /// REQUIRES: `dims()` >= 1
/// REQUIRES: `0 <= dim0_start < dim_size(0)`
Tensor SubSlice(int64 index) const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 0bfa53e6c5..c596604143 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1246,6 +1246,9 @@ TEST(Tensor, SubSlice_Basic) {
EXPECT_EQ(&tx(5, j, k), &ty(j, k));
}
}
+ Tensor z = y.SubSlice(3).SubSlice(31);
+ auto tz = z.unaligned_flat<float>();
+ EXPECT_EQ(*tz.data(), 5.0);
}
{
// Test unaligned access via a SubSlice.
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index f5b0105862..06d3fefef1 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -977,7 +977,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -2448,6 +2450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.tanh = "Tanh";
csinfo_.tanh_grad = "TanhGrad";
csinfo_.reshape = "Reshape";
+ csinfo_.slice = "Slice";
csinfo_.softmax = "Softmax";
csinfo_.split = "Split";
// Element-wise ops. Ensure you also add any new ops to IsOpElementWise
@@ -2555,6 +2558,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.slice,
+ mkl_op_registry::GetMklOpName(csinfo_.slice),
+ CopyAttrsSlice, AlwaysRewrite});
rinfo_.push_back({csinfo_.softmax,
mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsDataType, AlwaysRewrite});
@@ -2674,6 +2680,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string tanh;
string tanh_grad;
string reshape;
+ string slice;
string softmax;
string split;
string squared_difference;
@@ -3132,6 +3139,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
@@ -3150,7 +3158,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -3735,6 +3745,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
nb->Attr("Tshape", Tshape);
}
+void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ DataType Index;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("Index", Index);
+}
+
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index e8bac847e5..77640e287c 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
+TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Int32Input'}"
+ "node { name: 'D' op: 'Slice'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Index' value { type: DT_INT32 } }"
+ " input: ['A', 'B', 'C'] }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Int32Input);"
+ "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/"
+ "_2(Const);E(Zeta)|A->D;A->E;"
+ "A:control->DMT/_0:control;A:control->DMT/"
+ "_1:control;A:control->DMT/_2:control;"
+ "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test
@@ -3586,4 +3606,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index b67a321fc1..8c5ffd71a3 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -133,7 +133,9 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
+#endif // ENABLE_MKL
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index ebcb6de551..319437a801 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index b8d8243174..0b8cb5e919 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -29,21 +29,24 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
return output_arg_id;
}
+ // Default is 1 port per output arg.
+ int n = 1;
+
const auto& output_arg = op.output_arg(output_arg_id);
if (!output_arg.number_attr().empty()) {
- const int n = node.attr().at(output_arg.number_attr()).i();
- if (n < 0) {
- // This should never happen.
- DCHECK_GE(n, 0);
- return -1;
- }
- if (port_id < n) {
- return output_arg_id;
- }
- port_id -= n;
- } else {
- --port_id;
+ n = node.attr().at(output_arg.number_attr()).i();
+ } else if (!output_arg.type_list_attr().empty()) {
+ n = node.attr().at(output_arg.type_list_attr()).list().type_size();
+ }
+
+ if (n < 0) {
+ // This should never happen.
+ DCHECK_GE(n, 0);
+ return -1;
+ } else if (port_id < n) {
+ return output_arg_id;
}
+ port_id -= n;
}
return -1;
@@ -69,7 +72,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
- string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
+ const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
InputPort input;
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 30512d9d47..3d7d2faf7c 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -79,6 +80,34 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
}
}
+TEST_F(GraphViewTest, ParseSingleExample) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<string>(s.WithOpName("a"), "", {});
+ Output b = ops::Const<int64>(s.WithOpName("b"), 1, {1, 1});
+ ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"},
+ {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& c_node_def = *graph_view.GetNode("c");
+
+ const OpDef* c_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(c_node_def.op(), &c_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 1));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 2));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 3));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 4));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 5));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 6));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 7));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 8));
+}
+
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index bbc0fedd22..2c490f3966 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -38,6 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
restore_op = other.restore_op;
save_restore_loc_tensor = other.save_restore_loc_tensor;
queue_runners = other.queue_runners;
+ allowed_optimizations = other.allowed_optimizations;
graph.Swap(graph_def);
}
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
index 939e5fa046..a0748abfe6 100644
--- a/tensorflow/core/grappler/grappler_item.h
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -77,6 +77,15 @@ struct GrapplerItem {
// Return a set of node names that must be preserved. This includes feed and
// fetch nodes, keep_ops, init_ops.
std::unordered_set<string> NodesToPreserve() const;
+
+ // Restrict types of optimizations that are allowed for this GrapplerItem.
+ struct AllowedOptimizations {
+ // Is it allowed to add nodes to the graph that do not have registered
+ // gradient function.
+ bool non_differentiable_rewrites = true;
+ };
+
+ AllowedOptimizations allowed_optimizations;
};
// Return the transitive fanin of a set of terminal nodes.
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 029515ad3c..369046666d 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -192,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string feed_name = NodeName(feed_node);
new_item->feed.emplace_back(feed_name, Tensor());
}
+ for (const auto& fetch_node : cfg.fetch_nodes) {
+ new_item->fetch.emplace_back(NodeName(fetch_node));
+ }
- // Attempt to detect the fetch node(s).
- if (meta_graph.collection_def().count("train_op") > 0) {
+ // Attempt to detect the fetch node(s) if they were not set explicitly.
+ if (new_item->fetch.empty() &&
+ meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h
index aafd2fdcda..1698587f8c 100644
--- a/tensorflow/core/grappler/grappler_item_builder.h
+++ b/tensorflow/core/grappler/grappler_item_builder.h
@@ -49,6 +49,8 @@ struct ItemConfig {
bool prune_graph = false;
// Override feed nodes list.
std::set<string> feed_nodes;
+ // Override fetch nodes list.
+ std::set<string> fetch_nodes;
};
// Factory method for creating a GrapplerItem from a MetaGraphDef.
diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc
index 4b90bf3038..d00981f174 100644
--- a/tensorflow/core/grappler/grappler_item_builder_test.cc
+++ b/tensorflow/core/grappler/grappler_item_builder_test.cc
@@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
}
+TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 0);
+ auto y = ops::Const(s.WithOpName("y"), 1);
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ ItemConfig config;
+ config.feed_nodes.insert("x");
+ config.fetch_nodes.insert("z");
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, config);
+ ASSERT_TRUE(item != nullptr);
+
+ EXPECT_EQ(item->feed.size(), 1);
+ EXPECT_EQ(item->fetch.size(), 1);
+ EXPECT_EQ(item->feed[0].first, "x");
+ EXPECT_EQ(item->fetch[0], "z");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 3521669b63..9f0d9dbf28 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -425,6 +425,10 @@ bool IsSwitch(const NodeDef& node) {
return op == "Switch" || op == "RefSwitch";
}
+bool IsSymbolicGradient(const NodeDef& node) {
+ return node.op() == "SymbolicGradient";
+}
+
bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 25ab6b65ac..7f86a5f295 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -149,6 +149,7 @@ bool IsStridedSliceGrad(const NodeDef& node);
bool IsSub(const NodeDef& node);
bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
+bool IsSymbolicGradient(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
bool IsTile(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 960d1addb3..c708f84948 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -525,6 +525,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:colocation",
@@ -541,6 +542,7 @@ tf_cuda_cc_test(
":custom_graph_optimizer_registry",
":meta_optimizer",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index ab97dcdb99..7d5014ee0a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -276,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
for (int i = 0; i < output->input_size(); ++i) {
auto input = output->input(i);
- string name = ParseNodeName(input, &position);
+ StringPiece name = ParseNodeNameAsStringPiece(input, &position);
if (name == node.name() && /*control input*/ position < 0) {
return true;
}
@@ -1568,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
for (NodeDef* output : outputs) {
if (IsControlInput(output->input(0))) continue;
int port;
- const string node_name = ParseNodeName(output->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(output->input(0), &port);
if (node_name == node.name()) {
tails->insert(ChainLink(output, port));
} else {
@@ -1618,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
} else {
for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
int port;
- const string node_name = ParseNodeName(new_tail->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(new_tail->input(0), &port);
if (node_name != tail->name()) {
return Status::OK();
}
@@ -2929,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
for (const auto& input : node.input()) {
int pos;
- string node_name = ParseNodeName(input, &pos);
- h = Hash64CombineUnordered(Hash64(node_name), h);
+ const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos);
+ h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h);
h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
@@ -3043,10 +3045,11 @@ void ArithmeticOptimizer::DedupComputations() {
}
std::set<int> duplicates;
// Populate feed_inplace_op;
- std::unordered_map<string, bool> feeds_inplace_op;
+ std::unordered_set<NodeDef*> feeds_inplace_op;
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
- feeds_inplace_op[optimized_graph_->node(i).name()] =
- FeedsInPlaceOp(graph_view, optimized_graph_->node(i));
+ if (FeedsInPlaceOp(graph_view, optimized_graph_->node(i))) {
+ feeds_inplace_op.insert(optimized_graph_->mutable_node(i));
+ }
}
do {
stop = true;
@@ -3056,9 +3059,8 @@ void ArithmeticOptimizer::DedupComputations() {
continue;
}
NodeDef* node = optimized_graph_->mutable_node(i);
- const string& node_name = node->name();
- if (node_name.empty()) continue;
- if (feeds_inplace_op[node_name] || !CanDedup(*node)) {
+ if (!CanDedup(*node) ||
+ feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
continue;
}
NodeDef* rep = nodes.FindOrAddRepresentative(node);
@@ -3069,7 +3071,7 @@ void ArithmeticOptimizer::DedupComputations() {
// races. For example: If we dedup nodes initializing two independent
// inplace accumulations, they will write to the same buffer, clobbering
// each other's results.
- if (feeds_inplace_op[rep->name()]) {
+ if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
continue;
}
VLOG(3) << "Remove duplicated node: node=" << node->name()
@@ -3078,7 +3080,8 @@ void ArithmeticOptimizer::DedupComputations() {
for (NodeDef* fanout : fanouts) {
for (int i = 0; i < fanout->input_size(); ++i) {
string* fanout_input = fanout->mutable_input(i);
- const int position = NodePositionIfSameNode(*fanout_input, node_name);
+ const int position =
+ NodePositionIfSameNode(*fanout_input, node->name());
// Update name in-place.
if (position < -1) {
continue;
@@ -3246,6 +3249,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
+ // Disable restricted graph rewrites.
+ options_.unary_ops_composition &=
+ item.allowed_optimizations.non_differentiable_rewrites;
+
if (options_.dedup_computations) {
DedupComputations();
}
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index cf305cebe1..5a3abbb545 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -22,6 +22,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -31,6 +32,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":graph_test_utils",
":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
@@ -87,11 +89,12 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -121,10 +124,10 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
] + tf_protos_all(),
)
@@ -146,6 +149,62 @@ tf_cc_test(
)
cc_library(
+ name = "graph_test_utils",
+ testonly = 1,
+ srcs = ["graph_test_utils.cc"],
+ hdrs = [
+ "graph_test_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "hoist_random_uniform",
+ srcs = ["hoist_random_uniform.cc"],
+ hdrs = [
+ "hoist_random_uniform.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "hoist_random_uniform_test",
+ srcs = ["hoist_random_uniform_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":hoist_random_uniform",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "latency_all_edges",
srcs = ["latency_all_edges.cc"],
hdrs = [
@@ -256,7 +315,7 @@ cc_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core:ptr_util",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -265,6 +324,7 @@ tf_cc_test(
srcs = ["map_and_filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_and_filter_fusion",
"//tensorflow/core:framework",
@@ -294,6 +354,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -302,6 +363,7 @@ tf_cc_test(
srcs = ["map_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_fusion",
"//tensorflow/core:framework",
@@ -339,6 +401,7 @@ tf_cc_test(
srcs = ["map_parallelization_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_parallelization",
"//tensorflow/core:framework",
@@ -422,6 +485,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":hoist_random_uniform",
":latency_all_edges",
":map_and_batch_fusion",
":map_and_filter_fusion",
@@ -459,6 +523,7 @@ cc_library(
":function_utils",
":graph_utils",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -474,6 +539,7 @@ tf_cc_test(
srcs = ["vectorization_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
":function_utils",
":vectorization_utils",
"//tensorflow/core:framework",
@@ -483,7 +549,10 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ # For ops we need registered
+ "//tensorflow/core/kernels/data:dataset_ops",
"//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:logging_ops",
"//tensorflow/tools/graph_transforms:transform_utils",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
index c71aa6e804..1ad495bbad 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
@@ -43,19 +43,14 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
fused_node.set_op("FilterDataset");
fused_node.add_input(first_filter_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = first_filter_node.attr().at("predicate");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["predicate"] = std::move(attr);
- copy_attribute("Targuments", first_filter_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, second_filter_node, &fused_node);
+ graph_utils::CopyAttribute(key, second_filter_node, &fused_node);
return fused_node;
}
@@ -120,8 +115,8 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// functions, or make sure that optimization passes run after filter
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(first_filter_node->name());
nodes_to_delete.insert(second_filter_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
index 12b1924efd..c8becc5cc0 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,14 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
TEST(FilterFusionTest, FuseTwoFilterIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
index e95ea1a4c1..311df15bc2 100644
--- a/tensorflow/core/grappler/optimizers/data/function_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -14,31 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace grappler {
namespace function_utils {
-namespace {
-
-template <typename Predicate, typename Collection>
-std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
- const Collection& collection) {
- std::vector<int> indices = {};
- unsigned idx = 0;
- for (auto&& element : collection) {
- if (predicate(element)) {
- indices.push_back(idx);
- }
- idx++;
- }
- return indices;
-}
-
-} // namespace
FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
const string& output, int position)
@@ -152,32 +137,27 @@ bool ContainsFunctionOutputWithName(StringPiece name,
}
int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
function.signature().input_arg());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
function.signature().output_arg());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
function.node_def());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; },
function.node_def());
-
- return indices.empty() ? -1 : indices.front();
}
void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
new file mode 100644
index 0000000000..b2eec7220e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}});
+}
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "FilterDataset", {string(input_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
new file mode 100644
index 0000000000..ca0fde997d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "XTimesTwo");
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "IsZero");
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 2dd9ee822e..3eaaf8fbef 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -201,25 +201,22 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const FunctionDef& function) {
return function.signature().name() == name;
},
library.function());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
graph.node());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
- return indices.empty() ? -1 : indices.front();
}
std::vector<int> FindAllGraphNodesWithOp(const string& op,
@@ -260,6 +257,21 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
}
function->mutable_signature()->set_name(std::move(name));
}
+
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node) {
+ (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
+}
+
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node) {
+ CopyAttribute(attribute_name, first, to_node);
+ (*to_node->mutable_attr())
+ .at(attribute_name)
+ .mutable_list()
+ ->MergeFrom(second.attr().at(attribute_name).list());
+}
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index b117482db2..3af34f6904 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -31,6 +31,21 @@ namespace tensorflow {
namespace grappler {
namespace graph_utils {
+// Returns the index of the first element in collection that fulfills predicate.
+// If no such element exists, returns -1.
+template <typename Predicate, typename Collection>
+int GetFirstElementIndexWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ return idx;
+ }
+ idx++;
+ }
+ return -1;
+}
+
// Adds a node to the graph.
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
@@ -101,11 +116,21 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the node name using the `prefix` name as a prefix while guaranteeing the
-// name is unique across the graph.
+// Sets the function name using the `prefix` name as a prefix while guaranteeing
+// the name is unique across the function library.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
+// Copies attribute having name `attribute_name` from node `from` to node
+// `to_node`.
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node);
+
+// Concatenates list attribute having name `attribute_name` from `first` and
+// `second` node, setting it to `to_node`.
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 6877c207c4..db986542b2 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -24,6 +24,18 @@ namespace grappler {
namespace graph_utils {
namespace {
+TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
+ std::vector<int> vec({1, 2, 3, 4, 5, 6});
+ auto result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 3 == 0; }, vec);
+
+ EXPECT_EQ(result, 2);
+
+ result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 7 == 0; }, vec);
+ EXPECT_EQ(result, -1);
+}
+
TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
new file mode 100644
index 0000000000..ce0b2db039
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
+ const FunctionDef& stateless_function,
+ MutableGraphView* graph) {
+ NodeDef stateless_map;
+ graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
+ &stateless_map);
+
+ stateless_map.set_op("MapDataset");
+ stateless_map.add_input(zip_node.name());
+ // Add placeholders.
+ for (int i = 1; i < map_node.input_size(); i++)
+ stateless_map.add_input(map_node.input(i));
+
+ auto attr = map_node.attr().at("f");
+ *attr.mutable_func()->mutable_name() = stateless_function.signature().name();
+ *attr.mutable_func()->mutable_attr() = stateless_function.attr();
+ (*stateless_map.mutable_attr())["f"] = std::move(attr);
+
+ graph_utils::CopyAttribute("Targuments", map_node, &stateless_map);
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::CopyAttribute(key, map_node, &stateless_map);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr;
+
+ return stateless_map;
+}
+
+NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
+ MutableGraphView* graph) {
+ NodeDef random_dataset;
+ random_dataset.set_op("RandomDataset");
+ graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
+ &random_dataset);
+
+ const auto* seed = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed").i(), graph);
+ const auto* seed2 = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed2").i(), graph);
+
+ random_dataset.add_input(seed->name());
+ random_dataset.add_input(seed2->name());
+
+ (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape();
+ (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return random_dataset;
+}
+
+NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
+ NodeDef batch_dataset;
+ batch_dataset.set_op("BatchDatasetV2");
+ graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
+ &batch_dataset);
+ const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
+ const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
+ batch_dataset.add_input(random_dataset.name());
+ batch_dataset.add_input(batch_size->name());
+ batch_dataset.add_input(drop_reminder->name());
+
+ (*batch_dataset.mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape()
+ ->mutable_dim()
+ ->Add()
+ ->set_size(-1);
+ (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return batch_dataset;
+}
+
+NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
+ MutableGraphView* graph) {
+ NodeDef zip_node;
+ graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
+ &zip_node);
+
+ zip_node.set_op("ZipDataset");
+ zip_node.add_input(first_node.name());
+ zip_node.add_input(second_node.name());
+
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node);
+
+ (*zip_node.mutable_attr())["N"].set_i(2);
+
+ return zip_node;
+}
+
+// We need to insert our argument before the placeholders, which are the last
+// arguments.
+OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
+ int new_argument_idx = signature->input_arg_size() - num_placeholders;
+ signature->add_input_arg();
+ for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
+ signature->mutable_input_arg()->SwapElements(i - 1, i);
+ }
+ auto* seed_arg = signature->mutable_input_arg(new_argument_idx);
+ seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
+ seed_arg->set_type(DT_INT64);
+
+ return seed_arg;
+}
+
+// Make function that uses `StatelessRandomUniform` instead of `RandomUniform`
+// to make it less statefull. The function can still be stateful, but in when
+// other stateful ops are e.g. `Assert`, then it will be parallelizable.
+const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
+ bool is_stateful,
+ int num_placeholders,
+ FunctionDefLibrary* library) {
+ FunctionDef* stateless_function = library->add_function();
+ *stateless_function = map_function;
+ if (is_stateful)
+ stateless_function->mutable_signature()->set_is_stateful(is_stateful);
+ graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
+ stateless_function);
+
+ auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(),
+ num_placeholders);
+
+ auto* const random_uniform = stateless_function->mutable_node_def(
+ function_utils::FindFunctionNodeWithOp("RandomUniform",
+ *stateless_function));
+
+ // Replace RandomUniform node with StatelessRandomUniform.
+ random_uniform->set_op("StatelessRandomUniform");
+ random_uniform->add_input(seed_arg->name());
+ (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64);
+ random_uniform->mutable_attr()->erase("seed");
+ random_uniform->mutable_attr()->erase("seed2");
+
+ return stateless_function;
+}
+// This function returns true if function is stateful and has single
+// RandomUniform op and no other stateful ops except Assert.
+// `is_stateful_after_hoisting` is set to true if RandomUniform is the only
+// stateful op and hoisting can be performed.
+bool CanHoistRandomUniform(const FunctionDef& map_function,
+ const FunctionLibraryDefinition& library,
+ bool* is_stateful_after_hoisting,
+ const NodeDef** random_uniform_op) {
+ if (!map_function.signature().is_stateful()) return false;
+ *is_stateful_after_hoisting = true;
+
+ bool have_other_stateful_ops = false;
+
+ for (const auto& node : map_function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Skip stateless nodes and assert, as it does not actually have a state.
+ if (!op_def->is_stateful()) continue;
+
+ if (op_def->name() == "Assert") {
+ have_other_stateful_ops = true;
+ continue;
+ }
+
+ // TODO(prazek): For now we only handle RandomUniform, we should handle
+ // RandomUniformInt as well.
+ if (op_def->name() != "RandomUniform") return false;
+
+ // TODO(prazek): For now we can only hoist single RandomUniform.
+ if (*random_uniform_op != nullptr) return false;
+
+ *random_uniform_op = &node;
+ }
+
+ if (!have_other_stateful_ops) *is_stateful_after_hoisting = false;
+
+ // Have we found single RandomUniform?
+ return *random_uniform_op != nullptr;
+}
+
+int NumberOfPlaceholders(const NodeDef& map_node) {
+ // First input of MapDataset is the argument to the function. Rest of the
+ // inputs are placeholders.
+ return map_node.input_size() - 1;
+}
+
+} // namespace
+
+Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ // TODO(prazek): we could also handle ParallelMapDataset and
+ // MapAndBatchDataset.
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ const auto& fun = map_node->attr().at("f");
+ const FunctionDef* func = function_library.Find(fun.func().name());
+
+ const NodeDef* random_uniform_op = nullptr;
+ bool is_stateful_after_hoisting = true;
+ if (!CanHoistRandomUniform(*func, function_library,
+ &is_stateful_after_hoisting, &random_uniform_op))
+ continue;
+ const auto* random_seed_dataset =
+ graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph));
+
+ const auto* batch_dataset =
+ graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph));
+
+ const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph);
+
+ const auto* zip_node =
+ graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph));
+
+ const auto* stateless_func = MakeLessStatefulFunction(
+ *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node),
+ output->mutable_library());
+
+ const auto* stateless_map = graph.AddNode(
+ MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
+
+ graph.ReplaceInput(*map_node, *stateless_map);
+
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
new file mode 100644
index 0000000000..d1bcf6782d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization hoists instances of `random_uniform` out of a function
+// with the aim of making it stateless. It creates a new function that takes a
+// random seed as an extra argument and uses `stateless_random_uniform` instead
+// of `random_uniform` to make it stateless.
+// It also creates RandomDataset(seed).batch(2), which is zipped with old input
+// to the map. The batching in RandomDataset is because we need 2 seeds for
+// `stateless_random_uniform`.
+// TODO(prazek): for now only `RandomUniform` is handled, but we could handle
+// `RandomUniformInt` similarly.
+class HoistRandomUniform : public CustomGraphOptimizer {
+ public:
+ HoistRandomUniform() = default;
+ ~HoistRandomUniform() override = default;
+
+ string name() const override { return "hoist_random_uniform"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
new file mode 100644
index 0000000000..455459e3f6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(HoistRandomUniform, SimpleHoisting) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}}),
+ graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"),
+ NDef("cache", "CacheDataset", {"map1", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::RandomUniform(),
+ });
+
+ HoistRandomUniform optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
+ const int zip_dataset_id =
+ graph_utils::FindGraphNodeWithOp("ZipDataset", output);
+ const int random_dataset_id =
+ graph_utils::FindGraphNodeWithOp("RandomDataset", output);
+ const int batch_random_id =
+ graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output);
+ ASSERT_NE(random_dataset_id, -1);
+ ASSERT_NE(zip_dataset_id, -1);
+ ASSERT_NE(new_map_id, -1);
+ ASSERT_NE(batch_random_id, -1);
+
+ const auto& new_map = output.node(new_map_id);
+ const auto& zip = output.node(zip_dataset_id);
+ const auto& random = output.node(random_dataset_id);
+ const auto& batch = output.node(batch_random_id);
+
+ ASSERT_EQ(new_map.input_size(), 1);
+ EXPECT_EQ(new_map.input(0), zip.name());
+
+ ASSERT_EQ(zip.input_size(), 2);
+ EXPECT_EQ(zip.input(0), "range");
+ EXPECT_EQ(zip.input(1), batch.name());
+
+ ASSERT_EQ(batch.input_size(), 3);
+ EXPECT_EQ(batch.input(0), random.name());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 63945b8b9e..e66766eb23 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -80,11 +80,12 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
// Set `f` and `Targuments` attributes.
for (auto key : {"f", "Targuments"}) {
- (*new_node.mutable_attr())[key] = map_node.attr().at(key);
+ graph_utils::CopyAttribute(key, map_node, &new_node);
}
+
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = batch_node.attr().at(key);
+ graph_utils::CopyAttribute(key, batch_node, &new_node);
}
return new_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
index f1844a141c..c4868eacbb 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
fused_node.set_op("MapDataset");
fused_node.add_input(map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = map_node.attr().at("f");
attr.mutable_func()->set_name(fused_function.signature().name());
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", map_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr;
// Add the predicate output attributes.
(*fused_node.mutable_attr())["output_types"]
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
index f029a093fa..6e6da37d7c 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -27,24 +28,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
-
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
+using graph_tests_utils::MakeMapNode;
TEST(MapAndFilterFusionTest, FuseMapAndFilter) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index a78ecb09f7..bd943342e8 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
&fused_node);
-
fused_node.set_op("MapDataset");
fused_node.add_input(parent_map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = parent_map_node.attr().at("f");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", parent_map_node, &fused_node);
-
+ graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+ auto value_or_false = [](const AttrValue* attr) {
+ if (!attr) return false;
+ return attr->b();
+ };
+
+ const auto* first_parallelism =
+ gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism");
+ const auto* second_parallelism =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism");
+ // Some graphs cannot execute with use_inter_op_parallelism=False, so we need
+ // to set it to true if one of the ops have it set to true.
+ if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) {
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+ }
return fused_node;
}
@@ -123,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(parent_map_node->name());
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index b25dfbd0b8..8889f9dddd 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -28,14 +29,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeMapNode;
TEST(MapFusionTest, FuseTwoMapNodesIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
index 305325e434..782c9f48b7 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -84,9 +84,6 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
graph.ReplaceInput(*map_node, *parallel_map);
-
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
index b2a5d9b6af..9fdfe8af30 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,16 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
- StringPiece function_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
+using graph_tests_utils::MakeMapNode;
const char stateless_fun_name[] = "XTimesTwo";
const char stateful_fun_name[] = "RandomUniform";
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 7a2f1910da..9328a7ca99 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -35,10 +35,6 @@ namespace tensorflow {
namespace grappler {
namespace {
-void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
- (*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
-}
-
// Returns a FunctionDef containing a MapDefun op that wraps the original
// function.
FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
@@ -61,7 +57,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
for (const string& k : {"f", "output_types", "output_shapes"}) {
// Function, output types and (unbatched) shapes are the same as the
// original map node.
- CopyAttribute(k, map_node, map_defun_node);
+ graph_utils::CopyAttribute(k, map_node, map_defun_node);
}
// Get types of input arguments from original map function
@@ -90,21 +86,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// efficient vectorization with VectorizeMapDefun.
FunctionDef* vectorized_func =
CreateMapDefunWrapper(map_node, orig_func, library);
- NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
- DCHECK_EQ(map_defun_node->op(), "MapDefun");
-
- // Create a copy of the original function so that we can mutate it, and
- // attach that to the map defun node.
- FunctionDef* map_defun_fn = library->add_function();
- *map_defun_fn = orig_func;
- graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
- map_defun_fn);
- (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
- map_defun_fn->signature().name());
-
- vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
- map_defun_node);
- return vectorized_func;
+ const NodeDef& map_defun_node = vectorized_func->node_def(0);
+ DCHECK_EQ(map_defun_node.op(), "MapDefun");
+
+ // TODO(b/116285210): Unreferenced functions should get cleaned up later
+ FunctionDef* result;
+ Status s = vectorization_utils::VectorizeMapDefun(
+ *vectorized_func, map_defun_node, library, &result);
+
+ if (!s.ok()) {
+ LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ return vectorized_func;
+ }
+ return result;
}
bool IsOutputShapesFullyDefined(const NodeDef& node) {
@@ -195,13 +189,16 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
}
// Set attrs
- CopyAttribute("Targuments", old_map_node, &map_node);
+ graph_utils::CopyAttribute("Targuments", old_map_node, &map_node);
auto& func_attr = (*map_node.mutable_attr())["f"];
func_attr.mutable_func()->set_name(vectorized_func.signature().name());
for (auto key : {"output_shapes", "output_types"}) {
- CopyAttribute(key, old_batch_node, &map_node);
+ graph_utils::CopyAttribute(key, old_batch_node, &map_node);
}
+
+ (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+
return map_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
index ed1bd6bc97..f4faf41549 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -30,72 +30,51 @@ namespace {
using test::function::GDef;
using test::function::NDef;
-void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
- TensorShapeProto* t) {
- for (size_t i = 0; i < dims.size(); ++i) {
- auto* d = t->add_dim();
- d->set_size(dims[i]);
- }
-}
-
-AttrValue MakeShapeListAttr(
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
- AttrValue shapes_attr;
- for (size_t i = 0; i < shapes.size(); ++i) {
- MakeTensorShapeProtoHelper(shapes[i],
- shapes_attr.mutable_list()->add_shape());
- }
-
- return shapes_attr;
-}
-
-NodeDef MakeMapNodeHelper(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- StringPiece map_op_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name, StringPiece map_op_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return test::function::NDef(
name, map_op_name, {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef(string(function_name))},
{"Targuments", {}},
- {"output_shapes", MakeShapeListAttr(output_shapes)},
+ {"output_shapes", output_shapes},
{"output_types", output_types}});
}
-NodeDef MakeMapNode(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
output_shapes, output_types);
}
-NodeDef MakeBatchNode(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDataset",
- {string(input_node_name), string(input_batch_size_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDataset",
+ {string(input_node_name), string(input_batch_size_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeBatchV2Node(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDatasetV2",
- {string(input_node_name), string(input_batch_size_name),
- string(input_drop_remainder_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ StringPiece input_drop_remainder_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDatasetV2",
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
+NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) {
return NDef(name, "RangeDataset", inputs,
- {{"output_shapes", MakeShapeListAttr({{}})},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})},
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
}
@@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
item.graph = GDef(
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("input", "InputDataset", {},
- {{"output_shapes", MakeShapeListAttr({{}})}}),
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}),
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
@@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
}
+TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {FunctionDefHelper::Create(
+ "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {},
+ {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}},
+ {{"res", "o:z"}, {"res2", "o:z"}})});
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index cb0ff670e8..99c4afa634 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = repeat_node.attr().at(key);
+ graph_utils::CopyAttribute(key, repeat_node, &new_node);
}
return new_node;
};
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 1462cb234d..37aa24b947 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -9,13 +9,14 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
VECTORIZER_DEPS = [
":vectorizer_registry",
- "//tensorflow/core/grappler/optimizers/data:function_utils",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
] + tf_protos_all()
cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index c1739737a0..3af6bab409 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
@@ -23,26 +23,21 @@ namespace vectorization_utils {
class CastVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
}
- // Add new Cast node
- NodeDef* new_cast_node = outer_scope->add_node_def();
- *new_cast_node = node;
- new_cast_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_cast_node);
- new_cast_node->set_input(0, inputs[0]);
-
- // Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
- strings::StrCat(new_cast_node->name(), ":y:0");
+ // Add new Cast node with the same op and attrs as the original node
+ auto new_cast_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
+ // Add input and output mappings
+ input_ports->push_back({new_cast_node, 0});
+ output_ports->push_back({new_cast_node, 0});
return Status::OK();
}
};
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 776d3179c5..74ce520ce1 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
@@ -23,31 +23,29 @@ namespace vectorization_utils {
class UnpackVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
- // Add new Unpack node
- NodeDef* new_unpack_node = outer_scope->add_node_def();
- *new_unpack_node = node;
- new_unpack_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_unpack_node);
+ // Add new Unpack node with the same op and attrs as the original node
+ auto new_unpack_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
// Increment "axis" attr by 1:
- (*new_unpack_node->mutable_attr())["axis"].set_i(
- node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(0, inputs[0]);
+ int new_axis = node.def().attr().at("axis").i() + 1;
+ new_unpack_node->AddAttr("axis", new_axis);
- // Add the output mappings to conversion map
- int num = new_unpack_node->attr().at("num").i();
+ // Add the input mappings
+ input_ports->push_back({new_unpack_node, 0});
+
+ // Add the output mappings
+ int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
- strings::StrCat(new_unpack_node->name(), ":output:", i);
+ output_ports->push_back({new_unpack_node, i});
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index d341dbba7d..56eb88c95e 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -17,30 +17,33 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> Port;
+
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
class Vectorizer {
public:
virtual ~Vectorizer() {}
- // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs, and adding mappings to `conversion_map`
- // from old output tensor names to new (vectorized) output tensor names.
- // The new node(s) collectively have the same number of inputs and outputs as
- // the node being converted, and use the tensor names in `inputs` as their
- // inputs.
- virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) = 0;
+ // on elements of the vector inputs. The new Node(s) collectively have the
+ // same number of input and output ports as the node being converted.
+ // Adds mappings for the new nodes' input and output ports to `inputs` and
+ // `outputs` respectively, where the i'th Port in inputs/outputs
+ // corresponds to the i'th input/output port of the node to be converted.
+ virtual Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) = 0;
};
} // namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 86e303564b..663ceba027 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -24,9 +24,9 @@ namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* inputs,
+ std::vector<Port>* outputs) override {
return Status::OK();
}
};
@@ -39,10 +39,12 @@ TEST(TestVectorizer, TestTestVectorizer) {
auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
EXPECT_NE(vectorizer, nullptr);
- FunctionDef function;
- NodeDef node;
- std::map<string, string> conversion_map;
- EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+ Graph g(OpRegistry::Global());
+ NodeDef node_def;
+ Status s;
+ Node* node = g.AddNode(node_def, &s);
+ std::vector<Port> inputs, outputs;
+ EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
}
} // namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index cb56b65985..cea667f668 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,13 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include <memory>
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -36,255 +40,346 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-using function_utils::FunctionDefTensorDesc;
-
namespace {
-void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
- const string& output_retval, const DataType t) {
- // Set to unknown shape
- TensorShapeProto tensor_shape_proto;
- PartialTensorShape().AsProto(&tensor_shape_proto);
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> TensorDesc;
- function_utils::AddFunctionOutputWithUniqueName(
- "vectorized_out", output_retval, map_defun_fn, t);
+const char* const kRetValOp = "_Retval";
- *(*map_defun_node->mutable_attr())["output_shapes"]
- .mutable_list()
- ->add_shape() = tensor_shape_proto;
- (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
+ Graph* graph) {
+ // NOTE: We need two for loops here because we can't mutate the set of output
+ // edges as we iterate over them.
+ std::vector<const Edge*> edges_to_replace;
+ for (auto edge : old_src.first->out_edges()) {
+ if (edge->src_output() == old_src.second) {
+ edges_to_replace.push_back(edge);
+ }
+ }
+ for (auto edge : edges_to_replace) {
+ graph->AddEdge(new_src.first, new_src.second, edge->dst(),
+ edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
}
-void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, int output_position) {
- DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
- << "Trying to remove output that doesn't exist. Output number: "
- << output_position;
+Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
+ const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DataType type = output.first->output_type(output.second);
+ int index = map_defun_fn->ret_nodes.size();
- int num_later_outputs =
- map_defun_fn->signature().output_arg_size() - output_position - 1;
+ NodeDef ret_node_def;
+ ret_node_def.set_name("map_out");
+ ret_node_def.set_op(kRetValOp);
+ AddNodeAttr("T", type, &ret_node_def);
+ AddNodeAttr("index", index, &ret_node_def);
- // Remove from map_defun_fn's ret dict and output args
- map_defun_fn->mutable_ret()->erase(
- map_defun_fn->signature().output_arg(output_position).name());
- map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
- output_position, 1);
+ Status s;
+ Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
+ TF_RETURN_IF_ERROR(s);
- // Renumber outputs that come after
- for (int i = 0; i < num_later_outputs; ++i) {
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i + 1),
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i),
- outer_scope);
- }
- map_defun_node->mutable_attr()
- ->at("output_shapes")
- .mutable_list()
- ->mutable_shape()
- ->DeleteSubrange(output_position, 1);
- map_defun_node->mutable_attr()
- ->at("output_types")
- .mutable_list()
- ->mutable_type()
- ->ExtractSubrange(output_position, 1, nullptr);
+ map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
+ map_defun_fn->ret_nodes.push_back(ret_node);
+ map_defun_fn->ret_types.push_back(type);
+
+ return s;
}
-int FindOutputToConvert(const FunctionDef& function,
- const std::set<string>& unconvertible,
- FunctionDefTensorDesc* f) {
- for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
- const string& ret_key = function.signature().output_arg(i).name();
- *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
+ FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
- if (unconvertible.find(f->node_name) == unconvertible.end()) {
- return i;
- }
+ // Modify map_defun_fn's signature and remove the output node from its graph
+ map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
+ map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
+ output_position);
+ map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
+ output_position);
+
+ // Renumber the nodes and edges that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ ReplaceEdgeSources({map_defun_node, output_position + i + 1},
+ {map_defun_node, output_position + i}, outer_scope);
+ // Each ret node has an "index" attr that has to be updated
+ map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
+ output_position + i);
}
- return -1;
}
// Helper class that vectorizes the body of a MapDefun node, adding new
// operations to the graph that collectively compute the same value as what
// running the MapDefun function on slices of the input would produce.
-// Each instance of the class encapsulates all the data necessary to vectorize a
-// MapDefun op in place.
+// This class transforms the input FunctionDefs into their corresponding
+// Graph objects and works on the graphs directly, then converts them back
+// to FunctionDefs when GetResult is called.
class Vectorization {
public:
- Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node)
- : outer_scope_(outer_scope),
- map_defun_fn_(map_defun_fn),
- map_defun_node_(map_defun_node) {}
+ explicit Vectorization(FunctionDefLibrary* lib)
+ : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
- // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
- // the outer_scope_, until there are no convertible outputs remaining.
- // This method is idempotent.
- void Vectorize();
+ // Adds the vectorized function and new map_defun_fn to lib, and points
+ // vectorized_function to the former. Returns an error status if
+ // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
+ // along the way.
+ Status Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Vectorizes the map defun function's output at output_position
- Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
- // Given a descriptor of the original output tensor, gets a string
- // corresponding to the converted output tensor.
- Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
- string* converted);
- Status AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc);
+ // Converts FunctionDefs to Graphs.
+ Status Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node);
+
+ // Converts Graphs back to FunctionDefs and adds them to `lib_`.
+ Status GetResult(FunctionDef** vectorized_function);
+
+ // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
+ // `outer_scope_`, until there are no convertible outputs remaining.
+ void VectorizeHelper();
+
+ // Vectorizes map_defun_fn's output at output_position.
+ Status ConvertOutput(int output_position);
// Adds mappings from node's outputs tensors to converted output tensors,
// creating the necessary new node(s). Generally, the steps to convert an op
// are:
- // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
- // and modify map_defun_node_ attrs accordingly
- // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
// These operations collectively compute the same value as what running
// the original operation on slices of the input tensors would produce.
// For example, a Cast op in MapDefun translates to a Cast op in
- // outer_scope_, since the vectorized version of Cast is itself.
- // 3) Set inputs of new node(s) to the corresponding converted inputs (that
- // are now outputs of map_defun_node_)
- // 4) For each output of the old node, add the mapping of output strings to
- // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
- Status AddConversionMappingFromOp(const NodeDef& node,
- const FunctionDefTensorDesc& output_desc);
-
- // Maps a tensor name to the name of the corresponding vectorized tensor. For
- // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
- std::map<string, string> conversion_map_;
- // Unconvertible node names
- std::set<string> unconvertible_;
-
- FunctionDef* outer_scope_;
- FunctionDef* map_defun_fn_;
- NodeDef* map_defun_node_;
+ // `outer_scope_`, since the vectorized version of Cast is itself.
+ // 2) Promote the inputs of the op inputs to outputs of the
+ // `map_defun_node_` and `map_defun_fn_`.
+ // 3) Add edges between the promoted inputs (that are now outputs of
+ // `map_defun_node`) and the inputs ports of the new node(s).
+ // 4) For each output of the old node, add the mapping of output tensors to
+ // the conversion map.
+ Status AddConversionMapping(Node* op_node);
+
+ // Maps a tensor to the corresponding vectorized tensor. For example,
+ // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0}
+ std::map<TensorDesc, TensorDesc> conversion_map_;
+
+ // Unconvertible ret nodes
+ std::set<Node*> unconvertible_;
+
+ FunctionDefLibrary* lib_; // Not owned
+ FunctionLibraryDefinition lib_def_;
+ // Note that FunctionBody has a pointer to a Graph object that corresponds
+ // to the function's subgraph, with additional kArgOp and kRetValOp nodes
+ // that denote that function arguments and return values. These nodes have the
+ // attrs "T" for the type, and "index" for the argument / retval index
+ // respectively. FunctionBody also keeps track of arg/ret_nodes and
+ // arg/ret_types, that should be ordered according to argument/output indices.
+ std::unique_ptr<Graph> outer_scope_;
+ std::unique_ptr<FunctionBody> map_defun_fn_;
+ Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+ Status status_;
};
-Status Vectorization::AddConversionMappingFromOp(
- const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
- for (const string& input_name : node.input()) {
- if (IsControlInput(input_name)) {
+Status Vectorization::AddConversionMapping(Node* op_node) {
+ for (auto edge : op_node->in_edges()) {
+ if (edge->IsControlEdge()) {
return errors::InvalidArgument(
"Vectorizing outputs with control inputs is currently not "
"supported.");
}
}
- // TODO(rachelim): Have some mechanism for registering converters and some
- // uniform, simpler way to represent them.
-
- DataTypeVector types;
- const OpDef* op_def = nullptr;
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
- TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
-
- std::vector<string> promoted_inputs;
- promoted_inputs.reserve(node.input_size());
- for (int i = 0; i < node.input_size(); ++i) {
- promoted_inputs.push_back(strings::StrCat(
- map_defun_node_->name(),
- ":output:", map_defun_fn_->signature().output_arg_size() + i));
- }
-
- auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
if (vectorizer == nullptr) {
return errors::Unimplemented("No vectorizer registered for op: ",
- node.op());
+ op_node->type_string());
+ }
+ std::vector<Port> input_ports, output_ports;
+ input_ports.reserve(op_node->num_inputs());
+ output_ports.reserve(op_node->num_outputs());
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ &input_ports, &output_ports));
+
+ std::vector<const Edge*> input_edges;
+ TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
+
+ if (op_node->num_outputs() != output_ports.size() ||
+ op_node->num_inputs() != input_ports.size() ||
+ input_edges.size() != input_ports.size()) {
+ return errors::Internal("Vectorizer inputs/outputs don't match.");
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
- &conversion_map_));
+ // Promote the inputs of the op to MapDefun outputs and connect the edges
+ // accordingly.
+ for (size_t i = 0; i < op_node->num_inputs(); ++i) {
+ auto edge = input_edges[i];
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
+ input_ports[i].first, input_ports[i].second);
+ }
- // If we get here, the conversion was successful, so we promote the inputs
- // of the ops to MapDefun outputs.
- for (int i = 0; i < types.size(); ++i) {
- AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ // Add output mappings.
+ for (size_t i = 0; i < op_node->num_outputs(); ++i) {
+ conversion_map_.insert({{op_node, i}, std::move(output_ports[i])});
}
return Status::OK();
}
-Status Vectorization::AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc) {
- int input_index = function_utils::FindFunctionInputWithName(
- output_desc.node_name, *map_defun_fn_);
- if (input_index == -1) {
- return errors::Internal("Cannot convert non-existent input.");
+Status Vectorization::ConvertOutput(int output_position) {
+ // ret_edge->src() is the actual op that generated the retval, and
+ // ret_edge->dst() is the retval node whose op is "_Retval"
+ const Edge* ret_edge;
+ TF_RETURN_IF_ERROR(
+ map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
+
+ TensorDesc output({ret_edge->src(), ret_edge->src_output()});
+ TensorDesc converted_output;
+ if (auto found = gtl::FindOrNull(conversion_map_, output)) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ converted_output = *found;
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
+ converted_output = conversion_map_.at(output);
}
- conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
+ outer_scope_.get());
+ RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
+ map_defun_node_);
+
return Status::OK();
}
-Status Vectorization::ConvertOutputHelper(
- const FunctionDefTensorDesc& output_desc, string* converted) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
- *converted = *found;
- return Status::OK();
+Status Vectorization::Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node,
+ FunctionDef** result) {
+ TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
+ VectorizeHelper();
+ return GetResult(result);
+}
+
+void Vectorization::VectorizeHelper() {
+ while (true) {
+ int output_position = graph_utils::GetFirstElementIndexWithPredicate(
+ [this](Node* n) {
+ return this->unconvertible_.find(n) == this->unconvertible_.end();
+ },
+ map_defun_fn_->ret_nodes);
+
+ // No outputs left to convert
+ if (output_position == -1) break;
+
+ Status s = ConvertOutput(output_position);
+ if (!s.ok()) {
+ Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
+ VLOG(2) << "Could not convert the output at node: "
+ << output_node->DebugString() << "\nError: " << s;
+ unconvertible_.insert(output_node);
+ }
}
- int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
- *map_defun_fn_);
- if (index == -1) { // The output comes from an input
- TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->ret_nodes.empty()) {
+ outer_scope_->RemoveNode(map_defun_node_);
} else {
- TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
- map_defun_fn_->node_def(index), output_desc));
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
- *converted = conversion_map_.at(output_desc.full_str);
- return Status::OK();
}
+Status Vectorization::Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node) {
+ // Convert outer_scope and map_defun_fn to FunctionBodys so we can
+ // work on Graphs directly.
+ const FunctionDef* map_defun_fn =
+ lib_def_.Find(map_defun_node.attr().at("f").func().name());
+
+ if (map_defun_fn == nullptr) {
+ return errors::NotFound("Could not find function with name ",
+ map_defun_node.attr().at("f").func().name(),
+ " in function library.");
+ }
-Status Vectorization::ConvertOutput(int output_position,
- const FunctionDefTensorDesc& output_desc) {
- string converted_output_name;
- TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+ auto get_func_sig = [this](const string& op, const OpDef** sig) {
+ return this->lib_def_.LookUpOpDef(op, sig);
+ };
+
+ FunctionBody* outer_fn;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
+ get_func_sig, &outer_fn));
+ // We don't need outer_fn, just the graph
+ outer_scope_.reset(outer_fn->graph);
+ outer_fn->graph = nullptr;
+ delete outer_fn;
+
+ FunctionBody* tmp;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
+ get_func_sig, &tmp));
+ map_defun_fn_.reset(tmp);
+
+ // Find the MapDefun node in outer_scope_
+ int node_id = graph_utils::GetFirstElementIndexWithPredicate(
+ [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
+ outer_scope_->nodes());
+ if (node_id == -1) {
+ return errors::NotFound("Could not find node with name ",
+ map_defun_node.name(), " in outer_scope.");
+ }
+ map_defun_node_ = outer_scope_->FindNodeId(node_id);
+
+ // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to
+ // the conversion map
+ for (auto arg_node : map_defun_fn_->arg_nodes) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(
+ arg_node->attrs().Find("index")->i(), &input_node));
- // Remove the old output and make everything that referenced it point
- // to the new string
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node_->name(), ":output:", output_position),
- converted_output_name, outer_scope_);
- RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
- output_position);
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0}});
+ }
return Status::OK();
}
-void Vectorization::Vectorize() {
- while (true) {
- FunctionDefTensorDesc desc;
- int output_position =
- FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
- if (output_position == -1) break;
+Status Vectorization::GetResult(FunctionDef** vectorized_function) {
+ TF_RETURN_IF_ERROR(status_);
- if (!ConvertOutput(output_position, desc).ok()) {
- unconvertible_.insert(desc.node_name);
- }
- }
+ if (!map_defun_fn_->ret_nodes.empty()) {
+ FunctionDef* map_defun_fn = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
- // If we've converted all the outputs of the MapDefun function, we no longer
- // need the MapDefun node and can delete it.
- if (map_defun_fn_->signature().output_arg_size() == 0) {
- outer_scope_->mutable_node_def()->DeleteSubrange(
- function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
- *outer_scope_),
- 1);
+ AttrValue func_attr;
+ func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
+ map_defun_node_->AddAttr("f", func_attr);
}
- if (!unconvertible_.empty()) {
- VLOG(2) << "The following nodes could not be converted: ["
- << absl::StrJoin(unconvertible_, ", ") << "].";
- }
+ *vectorized_function = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
+ *vectorized_function);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *outer_scope_, (*vectorized_function)->signature().name(),
+ *vectorized_function));
+ return Status::OK();
}
+
} // namespace
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node) {
- Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result) {
+ *result = nullptr;
+ return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
}
} // end namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
index bb405faa77..bd7d390900 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -24,22 +24,28 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-// Given a function, `map_defun_fn`, that is mapped across some input vector
-// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
-// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
-// `outer_scope`; that is, replacing `map_defun_fn` operations with new
-// `outer_scope` operations that produce the same vector output(s) as executing
-// the `map_defun_fn` operations on elements of vector input(s) would. If all
-// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
-// eliminated from `outer_scope` altogether. However, if some operations cannot
-// be lifted, and this vectorization only succeeds partially, `map_defun_node`
-// remains to be used for operations that were not lifted.
+// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`)
+// that maps a function in lib across some input vector elements,
+// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope`
+// by "lifting" operations from the MapDefun function to the new function
+// (`result`); that is, replacing operations in the MapDefun function with
+// operations that produce the same vector output(s) as executing the original
+// operations on elements of vector input(s) would. If all operations in the
+// MapDefun function are successfully lifted, `result` has no MapDefun node
+// altogether. However, if some operations cannot be lifted, and this
+// vectorization only succeeds partially, a MapDefun node remains in `result` to
+// be used for operations that were not lifted, and the modified MapDefun
+// function is added to `lib`. The newly vectorized function `result` is also
+// added to `lib`.
+//
+// Returns Status::OK() if the vectorization is completely or partially
+// successful. Otherwise, returns an error, and sets `result` to nullptr.
//
// Example:
// If the input to the `VectorizeMapDefun` function is a MapDefun
// whose `map_defun_fn` performs the Cast operation, the vectorization will
// eliminate the MapDefun. This is because the Cast operation supports
-// any tensor shape and can thus be lifted to the `outer_scope`.
+// any tensor shape and can thus be lifted to `result`.
//
// Before:
//
@@ -68,7 +74,7 @@ namespace vectorization_utils {
//
// After:
//
-// outer_scope +------+
+// result +------+
// +---------------+ Arg0 +---------+
// | +---+--+ |
// | | |
@@ -80,8 +86,9 @@ namespace vectorization_utils {
// +---------------+ Ret0 +---------+
// +------+
//
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node);
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result);
} // end namespace vectorization_utils
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index e129fa9237..1ff62217dd 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@@ -60,6 +61,11 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
return node;
}
+string GetRetval(const FunctionDef& function_def, int index) {
+ return function_def.ret().at(
+ function_def.signature().output_arg(index).name());
+}
+
// TODO(rachelim): Use FunctionDefHelper::Create instead
FunctionDef CreateFunction(
StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
@@ -85,7 +91,6 @@ FunctionDef CreateFunction(
return func;
}
-TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
// Before:
//
@@ -133,10 +138,15 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
- EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
+ EXPECT_EQ(GetRetval(*vectorized, 1), "ret1");
}
// Before:
@@ -149,12 +159,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// | +-----------+ Arg0 +---+ Arg1 +----+ |
// | | +---+--+ +---+--+ | |
// | | | | | |
-// | | +------+ | +---v--+ | |
-// | | |Const | | | Op0 | | |
-// | | +---v--+ | +---+--+ | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
// | | | | | | |
// | | | +---v--+ +---v--+ | |
-// | | +---| XOp1 | | XOp2 | | |
+// | | +---| XOp1 | | Cast | | |
// | | +---+--+ +---+--+ | |
// | | | | | |
// | | MapDefun +---v--+ +---v--+ | |
@@ -165,23 +175,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// +---------------+ Ret0 +---+ Ret1 +--------+
// +------+ +------+
//
-// where XOp1 and XOp2 are not convertible.
+// where XOp1 is not convertible.
//
// After:
//
-// No change because the ops are not convertible.
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ | |
+// | +-----------+ Arg0 +-+ | |
+// | | +---+--+ | | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ | +---v--+ |
+// | | +---| XOp1 | | | Cast | |
+// | | +---+--+ | +---+--+ |
+// | | | | | |
+// | | MapDefun +---v--+ | | |
+// | +-----------+ Ret0 +-+ | |
+// | +---+--+ | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
//
TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
{{"ret0", DT_INT32}, {"ret1", DT_INT32}},
- {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}});
+ // TODO(rachelim): If we ever write a converter for MatMul, we have to
+ // change this test.
NodeDef* x_op1 =
- function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner);
CHECK_NOTNULL(x_op1);
+ graph_transforms::SetNodeAttr("T", DT_INT32, x_op1);
- NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
- CHECK_NOTNULL(x_op2);
+ NodeDef* cast_node =
+ AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_node);
FunctionDef outer = CreateFunction(
"outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
@@ -193,12 +230,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
- // They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
- EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto map_defun_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
+ // The Cast node should be converted just fine.
+ EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0");
+
+ // The inner function should only have one retval.
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+ EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1);
}
// Before:
@@ -257,14 +304,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -330,16 +382,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -411,21 +468,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), "x");
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -486,7 +548,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{"ret1", "MyUnstack:output:1"},
{"ret2", "MyUnstack:output:2"}});
NodeDef* cast_op =
- AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner);
CHECK_NOTNULL(cast_op);
NodeDef* unstack_op =
AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
@@ -505,25 +567,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 2);
+ EXPECT_EQ(vectorized->node_def_size(), 2);
}
// Before:
@@ -561,9 +628,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}},
{{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
- // The attrs aren't relevant
- NodeDef* print_op =
- function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ NodeDef* print_op = function_utils::AddNode(
+ "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner);
+ graph_transforms::SetNodeAttr("T", DT_INT32, print_op);
+ graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}),
+ print_op);
CHECK_NOTNULL(print_op);
NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
false, &inner);
@@ -578,11 +647,27 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
// They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ // We check this somewhat manually as the names of nodes may have changed
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+ const NodeDef& map_defun_node = vectorized->node_def(0);
+ EXPECT_EQ(map_defun_node.op(), "MapDefun");
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+
+ const NodeDef& print_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn));
+ const NodeDef& cast_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn));
+ string control_input = strings::StrCat("^", print_node.name());
+ EXPECT_TRUE(cast_node.input(0) == control_input ||
+ cast_node.input(1) == control_input);
}
// TODO(rachelim): More test cases when we get around to implementing them:
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
index 9701a038d0..800160e649 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -38,7 +38,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// be optimized away by dependency optimizer.
for (string& inp : *node.mutable_input()) {
if (!IsControlInput(inp)) {
- inp = AsControlDependency(inp);
+ inp = AsControlDependency(NodeName(inp));
}
}
} else if (IsCheckNumerics(node) || IsPrint(node)) {
@@ -54,7 +54,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// input.
for (size_t i = 1; i < node.input_size(); ++i) {
if (!IsControlInput(node.input(i))) {
- *node.mutable_input(i) = AsControlDependency(node.input(i));
+ *node.mutable_input(i) = AsControlDependency(NodeName(node.input(i)));
}
}
}
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index 96ceee791f..affd2d51c2 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -43,6 +43,35 @@ TEST_F(DebugStripperTest, OutputEqualToInput) {
CompareGraphs(item.graph, output);
}
+TEST_F(DebugStripperTest, StripAssertOnTwoOutputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape({6}));
+ auto split =
+ ops::Split(s.WithOpName("split"), /*axis=*/0, input, /*num_split=*/2);
+ Output x = split[0];
+ Output y = split[1];
+ Output ge = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto assert = ops::Assert(s.WithOpName("Assert"), ge, {x, y});
+ Output add = ops::Add(
+ s.WithOpName("add").WithControlDependencies({assert.operation}), x, y);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ for (const NodeDef& node : output.node()) {
+ for (const string& input : node.input()) {
+ if (IsControlInput(input)) {
+ EXPECT_EQ(input.find(':'), -1);
+ }
+ }
+ }
+}
+
TEST_F(DebugStripperTest, StripAssertFromGraph) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index c59645e5f2..a5f851fb1a 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -115,6 +116,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ if (cfg_.disable_meta_optimizer()) {
+ return Status::OK();
+ }
if (!cfg_.disable_model_pruning()) {
optimizers->push_back(MakeUnique<ModelPruner>());
}
@@ -172,11 +176,12 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
- return InitializeCustomGraphOptimizers(optimizers);
+ return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
}
Status MetaOptimizer::InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ std::set<string> initialized_custom_optimizers;
for (const string& optimizer_name : cfg_.optimizers()) {
auto optimizer = MakeNewOptimizer(optimizer_name);
if (optimizer) {
@@ -190,18 +195,26 @@ Status MetaOptimizer::InitializeOptimizersByName(
if (custom_optimizer) {
VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
- TF_RETURN_IF_ERROR(custom_optimizer->Init());
+ TF_RETURN_IF_ERROR(custom_optimizer->Init(
+ GetCustomGraphOptimizerConfig(optimizer_name)));
optimizers->push_back(std::move(custom_optimizer));
+ initialized_custom_optimizers.insert(optimizer_name);
} else {
VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
- return InitializeCustomGraphOptimizers(optimizers);
+ return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
+ optimizers);
}
Status MetaOptimizer::InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
+ if (pre_initialized_optimizers.find(optimizer_config.name()) !=
+ pre_initialized_optimizers.end()) {
+ continue;
+ }
// Initialize the ExperimentalImplementationSelector here instead of
// CustomizeOptimizer registry, due the static link issue in TensorRT for
// double registry.
@@ -237,6 +250,16 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers(
return Status::OK();
}
+const RewriterConfig::CustomGraphOptimizer*
+MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
+ for (const auto& config : cfg_.custom_optimizers()) {
+ if (config.name() == name) {
+ return &config;
+ }
+ }
+ return nullptr;
+}
+
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
@@ -391,6 +414,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
FunctionLibraryDefinition flib(OpRegistry::Global(),
optimized_graph->library());
+ // Find functions for which we might need to compute a gradient at runtime.
+ gtl::FlatSet<string> differentiable_functions;
+ for (const NodeDef& node : optimized_graph->node()) {
+ if (IsSymbolicGradient(node)) {
+ const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
+ if (f_attr) differentiable_functions.insert(f_attr->func().name());
+ }
+ }
+
// Optimize each function only once.
std::unordered_set<string> optimized_funcs;
bool optimize_function_library = true;
@@ -406,6 +438,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Skip parametrized functions (function type or body is defined only at
// function call time by caller node attributes).
+ // They should be specialized to their instantiation type parameters by
+ // the function optimizer, before we can optimize function body.
if (IsParametrized(func)) continue;
VLOG(3) << "Optimize function: function=" << func_name;
@@ -420,6 +454,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
func, flib, item.graph.versions().producer(), &func_item));
+ // If we need to compute the gradient of optimized function at runtime, we
+ // can't perform non-differentiable rewrites.
+ if (differentiable_functions.find(func_name) !=
+ differentiable_functions.end()) {
+ func_item.allowed_optimizations.non_differentiable_rewrites = false;
+ }
+
// Optimize function body graph.
GraphDef optimized_func_graph;
TF_RETURN_IF_ERROR(
@@ -470,6 +511,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
}
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
+ if (cfg.disable_meta_optimizer()) {
+ return false;
+ }
return !cfg.disable_model_pruning() ||
cfg.layout_optimizer() != RewriterConfig::OFF ||
cfg.function_optimization() != RewriterConfig::OFF ||
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 831c5e37c0..99a0a33ffa 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -54,7 +54,11 @@ class MetaOptimizer : public GraphOptimizer {
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
// Initialize active optimizers from RewriterConfig.custom_optimizers.
Status InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Returns the config for a custom graph optimizer. Null if none was found.
+ const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig(
+ const string& name) const;
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
// multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index e74e0f7501..3f3f43382f 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -71,6 +72,59 @@ class TestGraphOptimizer : public TestOptimizer {
REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+class TestOptimizerWithParams : public TestOptimizer {
+ public:
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ CHECK(config != nullptr);
+ return Status::OK();
+ }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams);
+
+// Record various properties of the GrapplerItems passed for optimization.
+class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
+ public:
+ static void SetAllowedOptimizations(
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations) {
+ allowed_optimizations_ = allowed_optimizations;
+ }
+ static void ResetAllowedOptimizations() { allowed_optimizations_ = nullptr; }
+
+ GrapplerItemPropertiesAccumulator() {}
+ string name() const override {
+ return "grappler_item_properties_accumulator";
+ }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ *optimized_graph = item.graph;
+ if (allowed_optimizations_) {
+ allowed_optimizations_->insert({item.id, item.allowed_optimizations});
+ }
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ static gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations_;
+};
+
+gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ GrapplerItemPropertiesAccumulator::allowed_optimizations_;
+
+REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -90,6 +144,25 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizerWithParams");
+ auto* custom_config = rewriter_config.add_custom_optimizers();
+ custom_config->set_name("TestOptimizerWithParams");
+ (*custom_config->mutable_parameter_map())["foo"] = AttrValue();
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
@@ -305,6 +378,89 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
}
+TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
+ using test::function::NDef;
+ using FDH = FunctionDefHelper;
+
+ // We will record what type of optimizations meta optimizer allows for each
+ // GrapplerItem (main graph and graphs for each function).
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>
+ allowed_optimizations;
+ GrapplerItemPropertiesAccumulator::SetAllowedOptimizations(
+ &allowed_optimizations);
+
+ // Just record properties of optimized Grappler items.
+ RewriterConfig rewriter_config;
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+
+ // Define simple function library with two identical mul functions.
+ FunctionDef mul_func_1 = FunctionDefHelper::Create(
+ "MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ FunctionDef mul_func_2 = FunctionDefHelper::Create(
+ "MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ // Tensorflow graph:
+ //
+ // x0 = tf.Placeholder(tf.float);
+ // x1 = tf.Placeholder(tf.float);
+ // dy = tf.Placeholder(tf.float);
+ //
+ // mul_1 = MyMul1(x0, x1);
+ // mul_2 = MyMul2(x0, x1);
+ // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
+ GrapplerItem item;
+ item.id = "main";
+ item.graph = test::function::GDef(
+ {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ // Calls into function library
+ NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
+ NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
+ // Symbolic gradient of a MyMul2
+ NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
+ {{"f", FDH::FunctionRef("MyMul2", {})},
+ {"Tin", DataTypeSlice{DT_FLOAT}},
+ {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
+ kDevice)},
+ // FunctionLib
+ {mul_func_1, mul_func_2});
+ item.fetch = {"mul_1", "mul_2", "dx"};
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ // Our custom optimizer must be called for the main graph and for the two
+ // functions.
+ ASSERT_EQ(allowed_optimizations.size(), 3);
+
+ auto allowed_optimizations_main =
+ gtl::FindOrNull(allowed_optimizations, "main");
+ ASSERT_NE(allowed_optimizations_main, nullptr);
+ EXPECT_TRUE(allowed_optimizations_main->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_1 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul1");
+ ASSERT_NE(allowed_optimizations_my_mul_1, nullptr);
+ EXPECT_TRUE(allowed_optimizations_my_mul_1->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_2 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul2");
+ ASSERT_NE(allowed_optimizations_my_mul_2, nullptr);
+ EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 98c27300a9..89eb76046e 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -71,6 +71,7 @@ bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
if (output_arg_id < 0) {
LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
<< node.DebugString() << "\n"
+ << fanin.node->DebugString() << "\n"
<< fanin_odef->DebugString();
return false;
}
@@ -158,7 +159,7 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices,
}
bool IsTPUGraphDef(const GraphDef& def) {
- for (auto node : def.node()) {
+ for (const auto& node : def.node()) {
if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
node.op() == "TPUPartitionedCall") {
return true;
@@ -168,7 +169,13 @@ bool IsTPUGraphDef(const GraphDef& def) {
}
// All the nodes that should be blacklisted and not swapped.
-bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); }
+bool IsBlacklisted(const NodeDef& node) {
+ return
+ // Collective ops should not be swapped.
+ IsCollective(node) ||
+ // NoOp breaks perf regression tests (probably due to group dependencies).
+ IsNoOp(node);
+}
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -197,6 +204,10 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Topologically sort the graph, so that we traverse the nodes in order. This
// will help us discover producer->consumer chains of Host ops.
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+
+ // All the Const nodes, and their original devices in topological order.
+ std::vector<std::pair<NodeDef*, string>> const_nodes;
+
for (auto& node : *optimized_graph->mutable_node()) {
// Check if node already on CPU.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
@@ -230,10 +241,28 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}
+ if (IsConstant(node)) {
+ const_nodes.emplace_back(&node, node.device());
+ }
// Try and swap the device to Host.
node.set_device(
internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
+
+ // Traverse all `const_nodes`, and map them back to GPU greedily.
+ for (auto& it : const_nodes) {
+ NodeDef* node = it.first;
+ const string& device = it.second;
+
+ // Check all the consumers of this node, if any of them are on the original
+ // device, swap this node back onto the original device.
+ for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
+ if (fanout.node->device() == device) {
+ node->set_device(device);
+ break;
+ }
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 339ddfd1b5..173cb3fe3c 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -128,6 +128,38 @@ TEST_F(PinToHostOptimizerTest, TopologicalSort) {
EXPECT_EQ(found, 4);
}
+TEST_F(PinToHostOptimizerTest, NoSwap) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `b` should be too big to swap, consequently `c` should not be swapped.
+ // PinToHostOptimizer should then detect that `a` should not be swapped.
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 1});
+ Output b = ops::Const(s.WithOpName("b"), 1, {1, 1024 * 1024});
+ Output c = ops::MatMul(s.WithOpName("c"), a, b);
+
+ GrapplerItem item;
+ item.fetch = {"a", "b", "c"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_TRUE(node.device().empty());
+ ++found;
+ }
+ EXPECT_EQ(found, 3);
+}
+
TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 008a289cfd..9ada8b7ff9 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(const_cast<GraphDef*>(&item.graph));
// During inference, most of the inputs to FusedBatchNorm are constant, and we
// can therefore replace the op with a much cheaper set of primitives.
+ optimized_graph->mutable_node()->Reserve(item.graph.node_size());
for (const NodeDef& node : item.graph.node()) {
if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") {
bool optimizable = (node.attr().count("T") == 0 ||
@@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
!node.attr().at("is_training").b());
if (optimizable) {
int const_inputs = 0;
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& props = properties.GetInputProperties(node.name());
for (const auto& prop : props) {
if (prop.has_value()) {
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index 4542d17ccc..6ccb1cd783 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -33,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*optimized_graph = item.graph;
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(optimized_graph);
// The product of all the dimensions in a tensor shape can be expressed more
@@ -55,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
const GraphView::OutputPort reduce_indices =
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop =
properties.GetOutputProperties(reduce_indices.node->name());
if (prop.size() < reduce_indices.port_id) {
@@ -92,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
continue;
}
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop1 = properties.GetInputProperties(input1.node->name());
const auto& prop2 = properties.GetInputProperties(input2.node->name());
if (prop1.size() != 1 || prop2.size() != 1) {
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 0424c9e8a4..5867d01324 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
+#include <iterator>
#include <memory>
#include <queue>
#include <vector>
@@ -155,44 +156,6 @@ bool IsControlInput(const string& name) {
return !name.empty() && name[0] == '^';
}
-string NodeName(const string& name) {
- int position;
- return ParseNodeName(name, &position);
-}
-
-int NodePosition(const string& name) {
- int position;
- ParseNodeNameAsStringPiece(name, &position);
- return position;
-}
-
-int NodePositionIfSameNode(const string& input_name, const string& node_name) {
- const bool is_ctrl = input_name[0] == '^';
- auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
- auto node_it = node_name.begin();
- if (std::distance(input_it, input_name.end()) < node_name.size()) {
- return -2;
- }
- while (node_it != node_name.end()) {
- if (*input_it++ != *node_it++) {
- return -2;
- }
- }
- if (input_it == input_name.end()) {
- return is_ctrl ? -1 : 0;
- } else if (*input_it++ == ':') {
- StringPiece remaining(&(*input_it),
- std::distance(input_it, input_name.end()));
- int position;
- if (!strings::safe_strto32(remaining, &position)) {
- return -2;
- }
- return is_ctrl ? -1 : position;
- } else {
- return -2;
- }
-}
-
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 296ee1678e..95126d470c 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
@@ -102,40 +101,92 @@ bool IsControlInput(const string& name);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
+// Returns the trailing position number (or zero if no number is present) if
+// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
+// Returns -2 if NodeName(input_name) is not equal to node_name.
+// Note: This function is used very heavily, and this hand-optimized
+// version is 3-4x faster than the version using Scanner, which it replaced.
+// This is worth the reduction in readability.
+inline int NodePositionIfSameNode(const string& input_name,
+ const string& node_name) {
+ if (input_name.empty()) return -2;
+ const bool is_ctrl = input_name[0] == '^';
+ auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
+ auto node_it = node_name.begin();
+ if (node_name.empty() ||
+ std::distance(input_it, input_name.end()) < node_name.size()) {
+ return -2;
+ }
+ while (node_it != node_name.end()) {
+ if (*input_it++ != *node_it++) {
+ return -2;
+ }
+ }
+ if (input_it == input_name.end()) {
+ return is_ctrl ? -1 : 0;
+ } else if (*input_it++ == ':') {
+ StringPiece remaining(&(*input_it),
+ std::distance(input_it, input_name.end()));
+ int position;
+ if (!strings::safe_strto32(remaining, &position)) {
+ return -2;
+ }
+ return is_ctrl ? -1 : position;
+ } else {
+ return -2;
+ }
+}
+
// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
-string NodeName(const string& name);
+inline StringPiece NodeNameAsStringPiece(const string& name) {
+ static const string empty;
+ if (name.empty()) return StringPiece(empty);
+ const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ if (end_it != name.end() && *end_it != ':') {
+ return StringPiece(empty);
+ }
+ return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
+}
-// Get the trailing position number ":{digits}" (if any) of a node name.
-// Returns -1 for control inputs.
-int NodePosition(const string& name);
+// Return the node name corresponding to 'name' if name is valid, or the empty
+// string otherwise.
+inline string NodeName(const string& name) {
+ return string(NodeNameAsStringPiece(name));
+}
+// Returns the node name and position in a single call.
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
int* position) {
- // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
- // to get a node name.
- strings::Scanner scan(name);
- scan.ZeroOrOneLiteral("^")
- .RestartCapture()
- .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
- .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
- StringPiece capture;
- StringPiece remaining;
- if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
+ static const string empty;
+ if (name.empty()) {
*position = 0;
- static const string empty;
return StringPiece(empty);
- } else {
- if (name[0] == '^') {
- *position = -1;
- } else if (remaining.empty()) {
- *position = 0;
- } else {
- // Skip the first ':' character.
- CHECK(strings::safe_strto32(remaining.substr(1), position));
+ }
+ const bool is_ctrl = name[0] == '^';
+ const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
+ *position = is_ctrl ? -1 : 0;
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
+ if (end_it != name.end()) {
+ if (*end_it != ':') {
+ return StringPiece(empty);
+ } else if (!is_ctrl) {
+ ++end_it;
+ StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
+ if (!strings::safe_strto32(remaining, position)) {
+ return StringPiece(empty);
+ }
}
- return capture;
}
+ return node_name;
}
// Returns the node name and position in a single call.
@@ -143,10 +194,11 @@ inline string ParseNodeName(const string& name, int* position) {
return string(ParseNodeNameAsStringPiece(name, position));
}
-// Returns NodePosition(input_name) if NodeName(input_name) == node_name.
-// Otherwise returns -2;
-// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0.
-int NodePositionIfSameNode(const string& input_name, const string& node_name);
+inline int NodePosition(const string& name) {
+ int position;
+ ParseNodeNameAsStringPiece(name, &position);
+ return position;
+}
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index a428aea7f5..6861fb423c 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -41,7 +41,8 @@ Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
tensorflow::NameRangeMap outputs_range_map;
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
node, registration.op_def, nullptr, &outputs_range_map));
- connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map);
+ connectivity->RegisterFunctionBodyOutputs(node.name(),
+ std::move(outputs_range_map));
return Status::OK();
}
@@ -75,20 +76,22 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
} // namespace
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
- const InputArgExpansion& input_arg_expansion) {
- const auto& input_name = input_arg_expansion.input_name;
+ InputArgExpansion input_arg_expansion) {
+ string input_name = input_arg_expansion.input_name;
const auto& placeholders = input_arg_expansion.placeholders;
- input_arg_expansions_.emplace(input_name, input_arg_expansion);
+
for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i];
- input_arg_placeholders_.emplace(
- placeholder, InputArgPlaceholder{input_name, /*position=*/i});
+ input_arg_placeholders_.insert(
+ {placeholder, InputArgPlaceholder{input_name, /*position=*/i}});
}
+ input_arg_expansions_.insert(
+ {std::move(input_name), std::move(input_arg_expansion)});
}
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
- const string& node_name, const tensorflow::NameRangeMap& outputs) {
- function_body_outputs_[node_name] = outputs;
+ const string& node_name, tensorflow::NameRangeMap&& outputs) {
+ function_body_outputs_[node_name] = std::move(outputs);
}
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
@@ -174,11 +177,12 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
const auto& output_range = output->second;
if (position == -1) {
+ graph_def_inputs->reserve(graph_def_inputs->size() +
+ output_range.second - output_range.first);
// If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) {
- i == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", i));
+ graph_def_inputs->push_back(
+ i == 0 ? node_name : strings::StrCat(node_name, ":", i));
}
} else {
if (position > (output_range.second - output_range.first)) {
@@ -187,9 +191,8 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
" position: ", position, " (out of range)");
}
int pos = output_range.first + position;
- pos == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", pos));
+ graph_def_inputs->push_back(
+ pos == 0 ? node_name : strings::StrCat(node_name, ":", pos));
}
return Status::OK();
@@ -211,8 +214,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
}
function_body_node->clear_input();
- for (const string& expanded_input : expanded_inputs)
- function_body_node->add_input(expanded_input);
+ for (string& expanded_input : expanded_inputs)
+ function_body_node->add_input(std::move(expanded_input));
return Status::OK();
}
@@ -323,7 +326,7 @@ GrapplerFunctionItem::GrapplerFunctionItem(
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
- feed.emplace_back(placeholder, Tensor());
+ feed.push_back({placeholder, Tensor()});
input_arg_placeholders_.insert(placeholder);
}
}
@@ -460,7 +463,7 @@ Status InstantiationBodyParameters(
auto it = func_instantiation_attr.find(placeholder);
if (it != func_instantiation_attr.end()) {
- body_parameters->emplace(placeholder, it->second);
+ body_parameters->insert({placeholder, it->second});
} else {
return errors::InvalidArgument("Can't resolve placeholder: ",
placeholder);
@@ -498,10 +501,6 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// GraphDef input format (name[:position])
GrapplerFunctionConnectivity connectivity;
- std::vector<InputArgExpansion> inputs;
- std::vector<OutputArgExpansion> outputs;
- std::vector<string> keep_nodes;
-
// Function body shares the library with the graph that instantiated it.
GraphDef function_body;
*function_body.mutable_library() = flib.ToProto();
@@ -518,6 +517,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
}
}
+ std::vector<InputArgExpansion> inputs;
+ inputs.reserve(signature.input_arg_size());
+
// For each input argument create a placeholder in function body.
for (const OpDef::ArgDef& input : signature.input_arg()) {
if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
@@ -542,9 +544,10 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
/*is_ref*/ input.is_ref(),
/*placeholders=*/{input.name()}};
connectivity.RegisterInputArgExpansion(input_expansion);
- inputs.push_back(input_expansion);
+ inputs.push_back(std::move(input_expansion));
}
+ std::vector<string> keep_nodes;
// Add all function nodes to the function body
for (const NodeDef& func_def_node : func.node_def()) {
NodeDef* new_node = function_body.add_node();
@@ -572,6 +575,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
}
+ std::vector<OutputArgExpansion> outputs;
+ outputs.reserve(signature.output_arg_size());
// Add function outputs
for (const OpDef::ArgDef& out : signature.output_arg()) {
std::vector<string> output_tensors;
@@ -589,8 +594,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
OutputArgExpansion output{/*output_name=*/out.name(),
/*data_type=*/output_data_type,
/*is_ref=*/out.is_ref(),
- /*output_tensors=*/output_tensors};
- outputs.push_back(output);
+ /*output_tensors=*/std::move(output_tensors)};
+ outputs.push_back(std::move(output));
}
bool is_stateful = signature.is_stateful();
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 733caf325f..ef944ced09 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -70,9 +71,9 @@ struct OutputArgExpansion {
// and fold it back when doing backward conversion.
class GrapplerFunctionConnectivity {
public:
- void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion);
+ void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
void RegisterFunctionBodyOutputs(const string& node_name,
- const tensorflow::NameRangeMap& outputs);
+ tensorflow::NameRangeMap&& outputs);
// Expand input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]).
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index 8ff5f20c6d..9b6c1f690b 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -149,7 +149,9 @@ TEST_F(UtilsTest, NodePosition) {
}
TEST_F(UtilsTest, NodePositionIfSameNode) {
- EXPECT_EQ(0, NodePositionIfSameNode("abc", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode(":123", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode(":", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode("", ""));
EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc"));
EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc"));
EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc"));
@@ -369,6 +371,25 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
+#define BM_ParseNodeNameAsStringPiece(I, NAME) \
+ static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \
+ string input = I; \
+ for (int i = 0; i < iters; ++i) { \
+ int position; \
+ const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \
+ CHECK_GE(position, -1); \
+ CHECK(!name.empty()); \
+ } \
+ } \
+ BENCHMARK(BM_ParseNodeNameAsStringPiece_##NAME)
+
+BM_ParseNodeNameAsStringPiece("foo", foo);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz", foo_bar_baz);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz", foo_bar_baz_ctrl);
+BM_ParseNodeNameAsStringPiece("foo:123", foo123);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl);
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ab69925d04..9439ab332c 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1197,8 +1197,10 @@ tf_cc_test(
tf_cc_test(
name = "example_parsing_ops_test",
- size = "large",
+ size = "medium",
srcs = ["example_parsing_ops_test.cc"],
+ shard_count = 4,
+ tags = ["optonly"],
deps = [
":example_parsing_ops",
":ops_testutil",
@@ -2028,8 +2030,8 @@ tf_kernel_library(
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:resource_variable_ops_op_lib",
- "//third_party/eigen3",
],
)
@@ -4049,11 +4051,6 @@ cc_library(
)
SPARSE_DEPS = [
- ":bounds_check",
- ":cwise_op",
- ":fill_functor",
- ":scatter_functor",
- "//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:sparse_ops_op_lib",
@@ -4086,7 +4083,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_cross_op",
prefix = "sparse_cross_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4098,13 +4097,19 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_dense_binary_op_shared",
prefix = "sparse_dense_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_sparse_binary_op_shared",
prefix = "sparse_sparse_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4136,7 +4141,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_softmax",
prefix = "sparse_softmax",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4148,25 +4155,37 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_tensor_dense_add_op",
prefix = "sparse_tensor_dense_add_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":scatter_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_tensor_dense_matmul_op",
prefix = "sparse_tensor_dense_matmul_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ ":fill_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_to_dense_op",
prefix = "sparse_to_dense_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_xent_op",
prefix = "sparse_xent_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4431,11 +4450,20 @@ cc_library(
":string_strip_op",
":string_to_hash_bucket_op",
":substr_op",
+ ":unicode_script_op",
],
)
+cc_library(
+ name = "string_util",
+ srcs = ["string_util.cc"],
+ hdrs = ["string_util.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
STRING_DEPS = [
":bounds_check",
+ ":string_util",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -5166,6 +5194,7 @@ filegroup(
"spacetobatch_functor.h",
"spacetodepth_op.h",
"spectrogram.h",
+ "string_util.h",
"tensor_array.h",
"tile_functor.h",
"tile_ops_cpu_impl.h",
@@ -5245,6 +5274,8 @@ filegroup(
"cwise_op_squared_difference.cc",
"cwise_op_sub.cc",
"cwise_op_tanh.cc",
+ "cwise_op_xlogy.cc",
+ "cwise_op_xdivy.cc",
"data_format_ops.cc",
"decode_wav_op.cc",
"deep_conv2d.cc",
@@ -5334,6 +5365,7 @@ filegroup(
"spectrogram_op.cc",
"stack_ops.cc",
"string_join_op.cc",
+ "string_util.cc",
"summary_op.cc",
"tensor_array.cc",
"tensor_array_ops.cc",
@@ -5459,6 +5491,7 @@ filegroup(
"batch_kernels.*",
"regex_full_match_op.cc",
"regex_replace_op.cc",
+ "unicode_script_op.cc",
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
"mkl_*",
"xsmm_*",
@@ -6404,6 +6437,12 @@ tf_mkl_kernel_library(
)
tf_mkl_kernel_library(
+ name = "mkl_slice_op",
+ prefix = "mkl_slice_op",
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
+tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
deps = ARRAY_DEPS + mkl_deps(),
@@ -6547,6 +6586,16 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "unicode_script_op",
+ srcs = ["unicode_script_op.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:string_ops_op_lib",
+ "@icu//:common",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 54c45bfe63..f48bd0c318 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,14 +17,18 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 584b507c70..25ae795d8e 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,10 +21,15 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
+
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 792eb74e31..039b0db144 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -1,7 +1,7 @@
# Description: Utilities.
package(
- default_visibility = ["//tensorflow:internal"],
+ default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
@@ -12,7 +12,11 @@ cc_library(
name = "periodic_function_dynamic",
srcs = ["periodic_function.cc"],
hdrs = ["periodic_function.h"],
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
@@ -21,7 +25,11 @@ cc_library(
cc_library(
name = "periodic_function",
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
":periodic_function_dynamic",
"//tensorflow/core:lib",
@@ -190,7 +198,11 @@ cc_library(
testonly = 1,
srcs = ["fake_clock_env.cc"],
hdrs = ["fake_clock_env.h"],
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
index c9664f0c1c..1ab72af059 100644
--- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
@@ -11,6 +11,7 @@ message Node {
oneof node {
Leaf leaf = 1;
BucketizedSplit bucketized_split = 2;
+ CategoricalSplit categorical_split = 3;
}
NodeMetadata metadata = 777;
}
@@ -57,6 +58,18 @@ message BucketizedSplit {
int32 right_id = 4;
}
+message CategoricalSplit {
+ // Categorical feature column and split describing the rule feature value ==
+ // value.
+ int32 feature_id = 1;
+ int32 value = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index cc90bb2f45..2798722536 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -60,14 +60,26 @@ int32 BoostedTreesEnsembleResource::next_node(
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- const auto& split = node.bucketized_split();
- if (bucketized_features[split.feature_id()](index_in_batch) <=
- split.threshold()) {
- return split.left_id();
- } else {
- return split.right_id();
+
+ switch (node.node_case()) {
+ case boosted_trees::Node::kBucketizedSplit: {
+ const auto& split = node.bucketized_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) <=
+ split.threshold())
+ ? split.left_id()
+ : split.right_id();
+ }
+ case boosted_trees::Node::kCategoricalSplit: {
+ const auto& split = node.categorical_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) ==
+ split.value())
+ ? split.left_id()
+ : split.right_id();
+ }
+ default:
+ DCHECK(false) << "Node type " << node.node_case() << " not supported.";
}
+ return -1;
}
float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index e0da91125b..82e2913b64 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -143,6 +143,7 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
c->forward_input_or_allocate_output(
{0}, 0, c->input(0).shape(), &output),
done);
+ col_params_.instance.shape = c->input(0).shape();
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, col_exec, done](const Status& s) {
@@ -171,7 +172,7 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
- OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
col_params_.is_source = true;
col_params_.instance.impl_details.subdiv_offsets = {0};
@@ -195,13 +196,14 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- c, c->forward_input_or_allocate_output({0}, 0, shape_, &output),
- done);
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, col_params_.instance.shape, &output),
+ done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
- c, shape_.IsSameSize(c->input(0).shape()),
+ c, col_params_.instance.shape.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_.name,
" does not match shape of input"),
done);
@@ -214,8 +216,6 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
}
private:
- TensorShape shape_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
};
@@ -234,7 +234,7 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
- OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
col_params_.is_source = false;
col_params_.instance.impl_details.subdiv_offsets = {0};
@@ -258,7 +258,8 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
if (c->mutable_output(0) == nullptr) {
// No input, so must allocate output.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+ OP_REQUIRES_OK_ASYNC(
+ c, c->allocate_output(0, col_params_.instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
@@ -270,8 +271,6 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
}
private:
- TensorShape shape_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
};
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 717a9f40a9..78856c4a99 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> {
};
#endif
+#define TF_REQUIRES(EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
+ } while (false)
+
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params) {
+ TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
+ TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
+ TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
+ string data_format_string;
+ TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
+ TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
+ errors::InvalidArgument("Invalid data format"));
+
+ const auto& strides = params->strides;
+ const auto& dilations = params->dilations;
+ const auto& data_format = params->data_format;
+
+ TF_REQUIRES(dilations.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ TF_REQUIRES(strides.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int64 stride_n = GetTensorDim(strides, data_format, 'N');
+ const int64 stride_c = GetTensorDim(strides, data_format, 'C');
+ const int64 stride_h = GetTensorDim(strides, data_format, 'H');
+ const int64 stride_w = GetTensorDim(strides, data_format, 'W');
+ TF_REQUIRES(
+ stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ TF_REQUIRES(stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+
+ const int64 dilation_n = GetTensorDim(dilations, data_format, 'N');
+ const int64 dilation_c = GetTensorDim(dilations, data_format, 'C');
+ const int64 dilation_h = GetTensorDim(dilations, data_format, 'H');
+ const int64 dilation_w = GetTensorDim(dilations, data_format, 'W');
+ TF_REQUIRES(
+ dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ TF_REQUIRES(
+ dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ return Status::OK();
+}
+
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions) {
+ // Check that 2D convolution input and filter have exactly 4 dimensions.
+ TF_REQUIRES(input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ TF_REQUIRES(filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+ for (int i = 0; i < 3; i++) {
+ TF_REQUIRES(
+ FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ // The last dimension for input is in_depth. Check that it is the same as the
+ // filter's in_depth or it is evenly divisible by filter's in_depth.
+ const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C');
+ const int64 patch_depth_raw = filter.dim_size(2);
+ TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input depth too large"));
+ TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Patch depth too large"));
+ const int in_depth = static_cast<int>(in_depth_raw);
+ const int patch_depth = static_cast<int>(patch_depth_raw);
+ TF_REQUIRES(in_depth % patch_depth == 0,
+ errors::InvalidArgument(
+ "input depth must be evenly divisible by filter depth: ",
+ in_depth, " vs ", patch_depth));
+
+ // The last dimension for filter is out_depth.
+ const int out_depth = static_cast<int>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H');
+ TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input rows too large"));
+ const int input_rows = static_cast<int>(input_rows_raw);
+ const int filter_rows = static_cast<int>(filter.dim_size(0));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W');
+ TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input cols too large"));
+ const int input_cols = static_cast<int>(input_cols_raw);
+ const int filter_cols = static_cast<int>(filter.dim_size(1));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = GetTensorDim(input, params.data_format, 'N');
+ TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("batch is too large"));
+ const int batch = static_cast<int>(batch_raw);
+
+ // Take the stride and dilation from the second and third dimensions only (we
+ // do not support striding or dilation on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
+ const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
+ const int dilation_rows =
+ GetTensorDim(params.dilations, params.data_format, 'H');
+ const int dilation_cols =
+ GetTensorDim(params.dilations, params.data_format, 'W');
+
+ // Compute windowed output sizes for rows and columns.
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
+ &out_rows, &pad_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
+ &out_cols, &pad_cols));
+
+ dimensions->batch = batch;
+ dimensions->input_rows = input_rows;
+ dimensions->input_cols = input_cols;
+ dimensions->in_depth = in_depth;
+ dimensions->filter_rows = filter_rows;
+ dimensions->filter_cols = filter_cols;
+ dimensions->patch_depth = patch_depth;
+ dimensions->out_depth = out_depth;
+ dimensions->stride_rows = stride_rows;
+ dimensions->stride_cols = stride_cols;
+ dimensions->dilation_rows = dilation_rows;
+ dimensions->dilation_cols = dilation_cols;
+ dimensions->out_rows = out_rows;
+ dimensions->out_cols = out_cols;
+ dimensions->pad_rows = pad_rows;
+ dimensions->pad_cols = pad_cols;
+
+ return Status::OK();
+}
+
+#undef TF_REQUIRES
+
template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
- string data_format;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
- OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
+
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- OP_REQUIRES(context, strides_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
- const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
- const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
- const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
- const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
- OP_REQUIRES(
- context, stride_n == 1 && stride_c == 1,
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
- OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
- errors::InvalidArgument(
- "Row and column strides should be larger than 0."));
-
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
// Input tensor is of the following dimensions:
// [ batch, in_rows, in_cols, in_depth ]
-
const Tensor& input = context->input(0);
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, out_depth]
const Tensor& filter = context->input(1);
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- OP_REQUIRES(context, filter.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
- filter.shape().DebugString()));
-
- for (int i = 0; i < 3; i++) {
- OP_REQUIRES(
- context,
- FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
- errors::InvalidArgument("filter too large"));
- }
+ Conv2DDimensions dimensions;
+ OP_REQUIRES_OK(context,
+ ComputeConv2DDimension(params_, input, filter, &dimensions));
- // The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth or be evenly divisible by filter's in_depth.
- const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- const int64 patch_depth = filter.dim_size(2);
- OP_REQUIRES(context, in_depth % patch_depth == 0,
- errors::InvalidArgument(
- "input depth must be evenly divisible by filter depth: ",
- in_depth, " vs ", patch_depth));
-
- // The last dimension for filter is out_depth.
- const int out_depth = static_cast<int>(filter.dim_size(3));
-
- // The second dimension for input is rows/height.
- // The first dimension for filter is rows/height.
- const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input rows too large"));
- const int input_rows = static_cast<int>(input_rows_raw);
- const int filter_rows = static_cast<int>(filter.dim_size(0));
-
- // The third dimension for input is columns/width.
- // The second dimension for filter is columns/width.
- const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input cols too large"));
- const int input_cols = static_cast<int>(input_cols_raw);
- const int filter_cols = static_cast<int>(filter.dim_size(1));
-
- // The first dimension for input is batch.
- const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
- OP_REQUIRES(context,
- FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("batch is too large"));
- const int batch = static_cast<int>(batch_raw);
-
- // For now we take the stride and dilation from the second and third
- // dimensions only (we do not support striding or dilation on the batch or
- // depth dimension).
- const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
-
- const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
- const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
-
- int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_rows, filter_rows, dilation_rows,
- stride_rows, padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_cols, filter_cols, dilation_cols,
- stride_cols, padding_, &out_cols, &pad_cols));
- TensorShape out_shape =
- ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
+ TensorShape out_shape = ShapeFromFormat(
+ params_.data_format, dimensions.batch, dimensions.out_rows,
+ dimensions.out_cols, dimensions.out_depth);
// Output tensor is of the following dimensions:
// [ in_batch, out_rows, out_cols, out_depth ]
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- VLOG(2) << "Conv2D: in_depth = " << in_depth
- << ", patch_depth = " << patch_depth
- << ", input_cols = " << input_cols
- << ", filter_cols = " << filter_cols
- << ", input_rows = " << input_rows
- << ", filter_rows = " << filter_rows
- << ", stride_rows = " << stride_rows
- << ", stride_cols = " << stride_cols
- << ", dilation_rows = " << dilation_rows
- << ", dilation_cols = " << dilation_cols
- << ", out_depth = " << out_depth;
+ VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
+ << ", patch_depth = " << dimensions.patch_depth
+ << ", input_cols = " << dimensions.input_cols
+ << ", filter_cols = " << dimensions.filter_cols
+ << ", input_rows = " << dimensions.input_rows
+ << ", filter_rows = " << dimensions.filter_rows
+ << ", stride_rows = " << dimensions.stride_rows
+ << ", stride_cols = " << dimensions.stride_cols
+ << ", dilation_rows = " << dimensions.dilation_rows
+ << ", dilation_cols = " << dimensions.dilation_cols
+ << ", out_depth = " << dimensions.out_depth;
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
@@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> {
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
if (LaunchXsmmConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
#endif
if (LaunchDeepConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
- output, data_format_);
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, params_.padding,
+ output, params_.data_format);
}
private:
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
+ Conv2DParameters params_;
bool use_cudnn_;
- Padding padding_;
- TensorFormat data_format_;
- LaunchConv2DOp<Device, T> launcher_;
bool cudnn_use_autotune_;
+ LaunchConv2DOp<Device, T> launcher_;
+
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
};
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index adf4601b43..7ec878e0b2 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase {
string DebugString() { return "Im2ColBufferResource"; }
};
+// Convolution parameters specified by Op attributes.
+struct Conv2DParameters {
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Convolution dimensions inferred from parameters, input and filter tensors.
+struct Conv2DDimensions {
+ int batch;
+ int input_rows;
+ int input_cols;
+ int in_depth;
+
+ int filter_rows;
+ int filter_cols;
+ int patch_depth;
+ int out_depth;
+
+ int stride_rows;
+ int stride_cols;
+
+ int dilation_rows;
+ int dilation_cols;
+
+ int64 out_rows;
+ int64 out_cols;
+ int64 pad_rows;
+ int64 pad_cols;
+};
+
+// Initializes and validates Conv2D parameters configured by OpKernel
+// attributes.
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params);
+
+// Computes and validates convolutions dimensions from Conv2D parameters. If
+// parameters are valid, dimensions will be updated with derived convolution
+// dimensions, otherwise error will be returned.
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions);
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
new file mode 100644
index 0000000000..e4b21a66c6
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xdivy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
new file mode 100644
index 0000000000..1e1b5a426e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xlogy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_xdivy.cc b/tensorflow/core/kernels/cwise_op_xdivy.cc
new file mode 100644
index 0000000000..6a6aec5e86
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xdivy.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xdivy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xdivy<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_xlogy.cc b/tensorflow/core/kernels/cwise_op_xlogy.cc
new file mode 100644
index 0000000000..e71a9109b2
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xlogy.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xlogy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xlogy<TYPE>>);
+REGISTER_SYCL_KERNEL(Eigen::half);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+REGISTER_SYCL_KERNEL(complex64);
+REGISTER_SYCL_KERNEL(complex128);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 22eb66e979..66ba827a90 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -471,6 +471,45 @@ struct functor_traits<bitwise_xor_op<Scalar>> {
enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
};
+// TODO(srvasude): Add packet versions of this operation.
+template <typename Scalar>
+struct xlogy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xlogy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x * numext::log(y);
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xlogy_op<Scalar>> {
+ enum {
+ Cost = (sizeof(Scalar) == 4 ? 40 : 85) + Eigen::NumTraits<Scalar>::MulCost,
+ PacketAccess = false
+ };
+};
+
+template <typename Scalar>
+// TODO(srvasude): Add packet versions of this operation.
+struct xdivy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xdivy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x / y;
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xdivy_op<Scalar>> {
+ enum { Cost = Eigen::NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+
} // end namespace internal
} // end namespace Eigen
@@ -830,6 +869,12 @@ struct squared_difference
Eigen::internal::scalar_difference_op<T>>> {};
template <typename T>
+struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
+
+template <typename T>
+struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
+
+template <typename T>
struct less : base<T, Eigen::internal::less<T>, bool> {};
template <typename T>
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 980edffceb..8ad3b4d1fc 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -20,9 +20,9 @@ namespace tensorflow {
BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
DataType in)
: OpKernel(ctx) {
-#ifndef INTEL_MKL
+#if !defined(INTEL_MKL) || !defined(ENABLE_MKL)
OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
-#endif
+#endif // !INTEL_MKL || !ENABLE_MKL
}
void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 87efdff789..6333853cdf 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -765,6 +765,7 @@ tf_kernel_library(
":window_dataset_op",
":writer_ops",
":zip_dataset_op",
+ "//tensorflow/core/kernels/data/experimental:dataset_kernels",
],
)
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index ec6cb37193..43406db3ed 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -1,22 +1,26 @@
# Description:
-# Contains kernels for datasets and iterators.
+# Contains experimental kernels for datasets and iterators.
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+)
+
cc_library(
name = "indexed_dataset_headers",
hdrs = ["indexed_dataset.h"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
)
-cc_library(
+tf_kernel_library(
name = "indexed_dataset",
srcs = [
"identity_indexed_dataset.cc",
@@ -24,103 +28,102 @@ cc_library(
],
deps = [
":indexed_dataset_headers",
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "prefetching_kernels",
srcs = ["prefetching_kernels.cc"],
deps = [
- "//tensorflow/core:core_cpu_headers_lib",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "directed_interleave_dataset_op",
srcs = ["directed_interleave_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "csv_dataset_op",
srcs = ["csv_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "ignore_errors_dataset_op",
srcs = ["ignore_errors_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "lmdb_dataset_op",
srcs = ["lmdb_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
"@lmdb",
- "@protobuf_archive//:protobuf_headers",
],
)
-cc_library(
+tf_kernel_library(
name = "threadpool_dataset_op",
srcs = ["threadpool_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "unique_dataset_op",
srcs = ["unique_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "assert_next_dataset_op",
srcs = ["assert_next_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "dataset_kernels",
deps = [
":assert_next_dataset_op",
@@ -132,8 +135,5 @@ cc_library(
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
)
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
index c19a609780..3511cca0f5 100644
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
@@ -147,8 +147,9 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
-REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
- AssertNextDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
+ AssertNextDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
index 21ec50fb6b..7451ca4cb1 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
@@ -852,7 +852,8 @@ class CSVDatasetOp : public DatasetOpKernel {
}; // class CSVDatasetOp
// Register the kernel implementation for CSVDataset.
-REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
+ CSVDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
index a5321620bf..c47a9099c4 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -272,8 +272,9 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
- DirectedInterleaveDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
+ DirectedInterleaveDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
index c3cb45dbf7..2141f118ca 100644
--- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
+++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -147,8 +147,9 @@ class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU),
- IdentityIndexedDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
index beec344534..b34377c642 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
namespace data {
@@ -133,8 +132,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
- IgnoreErrorsDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
+ IgnoreErrorsDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
index ced8ab0d60..75ea462f40 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.cc
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -361,12 +361,14 @@ class IndexedDatasetGet : public OpKernel {
};
REGISTER_KERNEL_BUILDER(
- Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU),
MaterializedHandleOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU),
- MaterializeDatasetOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU),
- IndexedDatasetGet);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
index 7aa2d3fdbc..27a8360cbc 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.h
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
-#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -116,4 +116,4 @@ Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
} // namespace data
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
index d233c1f8ec..8a88d32f0c 100644
--- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
@@ -210,7 +210,8 @@ class LMDBDatasetOp : public DatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
+ LMDBDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
index 96f1dd0059..2c6179d9f5 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
@@ -338,20 +338,20 @@ class FunctionBufferResourceHandleOp : public OpKernel {
DataTypeVector output_types_;
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_CPU)
.HostMemory("resource")
.HostMemory("string_arg")
.HostMemory("target_device"),
FunctionBufferResourceHandleOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_GPU)
.HostMemory("resource")
.HostMemory("string_arg")
.HostMemory("target_device"),
FunctionBufferResourceHandleOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_SYCL)
.HostMemory("resource")
.HostMemory("string_arg")
@@ -403,16 +403,16 @@ class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_CPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_GPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_SYCL)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
@@ -440,16 +440,16 @@ class FunctionBufferingResourceResetOp : public OpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_CPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_GPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_SYCL)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
@@ -473,8 +473,9 @@ class IteratorGetDeviceOp : public OpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
- IteratorGetDeviceOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
+ IteratorGetDeviceOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
index 30fa97a636..c80493d3a1 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -209,10 +209,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
ThreadPoolHandleOp);
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
- ThreadPoolDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
+ ThreadPoolDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
index 57fc5697a4..cd612e0eb2 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
@@ -199,8 +199,9 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
HANDLE_TYPE(DT_INT64);
HANDLE_TYPE(DT_STRING);
default:
- LOG(FATAL) << "UniqueDataset unhandled data type: "
- << DataTypeString(lhs.dtype());
+ DCHECK(false) << "UniqueDataset unhandled data type: "
+ << DataTypeString(lhs.dtype());
+ return false;
}
}
};
@@ -215,7 +216,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
} // namespace
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 71a36314a0..b4367d5a11 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -86,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
return Status::OK();
}
@@ -96,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ if (!initialized_) {
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ initialized_ = true;
+ }
+
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
@@ -123,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
+ bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index d6ee42a7c6..e7244ee208 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -30,8 +30,7 @@ namespace {
class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
@@ -421,7 +420,6 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 8b417bb1c2..14aefe5d54 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -31,8 +31,7 @@ namespace {
class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_));
@@ -507,7 +506,6 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index c0bc507ec0..7a833668ac 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -659,6 +659,115 @@ class ToSingleElementOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
+class ReduceDatasetOp : public AsyncOpKernel {
+ public:
+ explicit ReduceDatasetOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ background_worker_(
+ ctx->env(),
+ strings::StrCat("reduce_thread_", SanitizeThreadSuffix(name()))) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ background_worker_.Schedule([this, ctx, done]() {
+ DatasetBase* dataset;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
+ OpInputList inputs;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
+ done);
+ std::vector<Tensor> state(inputs.begin(), inputs.end());
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ CapturedFunction::Create(reduce_func_, ctx, "other_arguments",
+ use_inter_op_parallelism_, &captured_func),
+ done);
+
+ IteratorContext iter_ctx(ctx);
+ OP_REQUIRES_OK_ASYNC(ctx, captured_func->Instantiate(&iter_ctx), done);
+
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
+ done);
+
+ // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
+ // avoid destruction races.
+ IteratorBase* raw_iterator = iterator.release();
+ auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
+ delete raw_iterator;
+ done();
+ });
+
+ // Iterate through the input dataset.
+ Status status;
+ while (true) {
+ std::vector<Tensor> next_input_element;
+ bool end_of_input;
+ status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
+ &end_of_input);
+ if (!status.ok() || end_of_input) {
+ break;
+ }
+
+ // Run the reduce function to update the current state.
+ std::vector<Tensor> args;
+ args.reserve(state.size() + next_input_element.size());
+ std::copy(state.begin(), state.end(), std::back_inserter(args));
+ std::copy(next_input_element.begin(), next_input_element.end(),
+ std::back_inserter(args));
+
+ std::vector<Tensor> reduce_func_output;
+ status =
+ captured_func->Run(&iter_ctx, std::move(args), &reduce_func_output);
+ if (!status.ok()) {
+ break;
+ }
+ std::swap(reduce_func_output, state);
+ }
+
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ return;
+ }
+ for (int i = 0; i < state.size(); ++i) {
+ OP_REQUIRES_ASYNC(
+ ctx, state[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The result does not match the expected type for component ", i,
+ ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(state[i].dtype()), "."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
+ errors::InvalidArgument(
+ "The result does not match the expected shape for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", state[i].shape().DebugString(), "."),
+ done);
+ ctx->set_output(i, state[i]);
+ }
+ });
+ }
+
+ private:
+ NameAttrList reduce_func_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
+ BackgroundWorker background_worker_;
+};
+
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@@ -1146,6 +1255,8 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU),
AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
+REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
+ ReduceDatasetOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 2bbf4af664..b4c7f9e510 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -37,6 +37,8 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 5f143967d9..d909b9e9d3 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -134,19 +134,17 @@ class MultiDeviceIterator : public ResourceBase {
void Reset() LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
- if (background_thread_finished_) {
- return;
- }
-
- cancelled_ = true;
- // Wake up the background thread.
- for (int i = 0; i < size_; ++i) {
- buffer_[i].cond_var.notify_all();
- }
+ if (!background_thread_finished_) {
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
- // Make sure background thread has finished first.
- while (!background_thread_finished_) {
- shutdown_cond_var_.wait(l);
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
}
}
RunPendingCallbacks();
@@ -182,7 +180,7 @@ class MultiDeviceIterator : public ResourceBase {
buffer_[shard_num].cond_var.notify_all();
}
} else {
- if (background_thread_finished_) {
+ if (end_of_iterator_) {
produced_output = true;
elem.end_of_sequence = true;
} else {
@@ -219,8 +217,12 @@ class MultiDeviceIterator : public ResourceBase {
while (!buffer_[i].callbacks.empty()) {
if (buffer_[i].data.empty()) {
HostBufferElement elem;
- elem.status =
- errors::Cancelled("Cancelled and buffer not filled.");
+ if (end_of_iterator_) {
+ elem.end_of_sequence = true;
+ } else {
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ }
cancellation_elements.push_back(std::move(elem));
} else {
cancellation_elements.push_back(
@@ -293,6 +295,7 @@ class MultiDeviceIterator : public ResourceBase {
{
mutex_lock l(mu_);
background_thread_finished_ = true;
+ end_of_iterator_ = true;
shutdown_cond_var_.notify_all();
}
RunPendingCallbacks();
@@ -312,6 +315,7 @@ class MultiDeviceIterator : public ResourceBase {
std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
bool background_thread_finished_ GUARDED_BY(mu_) = false;
bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool end_of_iterator_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 2e6e0465f7..2bb38bf0b9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -1084,6 +1084,9 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
// The above design choices were made with automated optimizations in mind,
// isolating the degree of parallelism as the single tunable knob of this
// implementation.
+//
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index ee20249bfe..da067a4e6f 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -27,6 +27,8 @@ namespace tensorflow {
namespace data {
namespace {
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
@@ -104,18 +106,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
- std::shared_ptr<InvocationResult> result = invocation_results_[i];
- TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ const auto& result = *(invocation_results_[i]);
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
- result->return_values.size()));
- for (size_t j = 0; j < result->return_values.size(); j++) {
- TF_RETURN_IF_ERROR(
- writer->WriteTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- result->return_values[j]));
+ result.return_values.size()));
+ for (size_t j = 0; j < result.return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result.return_values[j]));
}
- if (result->end_of_input) {
+ if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")),
@@ -133,9 +134,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
- std::shared_ptr<InvocationResult> result(new InvocationResult());
- invocation_results_.push_back(result);
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
+ auto& result = *invocation_results_.back();
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
@@ -151,17 +152,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
": ", size, " is not a valid value of type size_t."));
}
}
- result->return_values.reserve(num_return_values);
+ result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(
- reader->ReadTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
+ result.return_values.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ &result.return_values.back()));
}
- result->end_of_input = reader->Contains(full_name(
+ result.end_of_input = reader->Contains(full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")));
- result->notification.Notify();
+ result.notification.Notify();
}
return Status::OK();
}
@@ -257,7 +257,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
while (!busy()) {
- invocation_results_.emplace_back(new InvocationResult());
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index dbe31f37b8..2a911aa368 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -32,8 +32,7 @@ namespace {
class ScanDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ScanDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tstate", &state_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -258,7 +257,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector state_types_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index f5314f7a75..c8abfb9eb5 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -34,16 +34,18 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
- *output = new Dataset(ctx, input, stats_aggregator_resource);
+ *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource);
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const Tensor& resource_handle,
StatsAggregatorResource* stats_aggregator_resource)
: DatasetBase(DatasetContext(ctx)),
input_(input),
+ resource_handle_(resource_handle),
stats_aggregator_resource_(stats_aggregator_resource) {
input_->Ref();
stats_aggregator_resource_->Ref();
@@ -75,8 +77,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* resource_handle_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, resource_handle_node}, output));
+ return Status::OK();
}
private:
@@ -111,16 +118,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
private:
@@ -129,6 +134,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
+ const Tensor resource_handle_;
StatsAggregatorResource* stats_aggregator_resource_;
};
};
diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
index c90ad2cfeb..ada1235449 100644
--- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
@@ -31,9 +31,37 @@ class FuzzParseTensor : public FuzzSession {
}
void FuzzImpl(const uint8_t* data, size_t size) final {
+ // We need to be sure that we don't request too many elements (i.e., we
+ // don't make ASAN OOM). In theory, a tensor shape can have arbitrary large
+ // number of elements, up to the limit of the memory available to the OS.
+ // However, due to the tracing done in ASAN, after 2^32 bytes of requested
+ // memory we would get a crash in the fuzzer (see b/34190148). Hence, let's
+ // try parsing the proto here, check that the size (if valid) is below a
+ // maximum threshold (using 2^20 for convenience), and then run the
+ // remainder of the fuzzer testing. Of course, this duplicates some work
+ // but it's better than repeating the investigation whenever Autofuzz
+ // detects another similar OOM.
+ string as_string = string(reinterpret_cast<const char*>(data), size);
+ TensorProto proto;
+ if (!ParseProtoUnlimited(&proto, as_string)) {
+ LOG(WARNING) << "Unable to parse proto of tensor\n";
+ return;
+ }
+ if (!TensorShape::IsValid(proto.tensor_shape())) {
+ LOG(WARNING) << "Invalid tensor shape\n";
+ return;
+ }
+ TensorShape shape(proto.tensor_shape());
+ const int64 num_elements = shape.num_elements();
+ const int64 max_num_elements = 1 << 20;
+ if (num_elements > max_num_elements) {
+ LOG(WARNING) << "Requiring a tensor with too many elements\n";
+ return;
+ }
+
+ // Now we can do the actual fuzz implementation
Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
- input_tensor.scalar<string>()() =
- string(reinterpret_cast<const char*>(data), size);
+ input_tensor.scalar<string>()() = as_string;
// TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
RunOneInput(input_tensor).IgnoreError();
}
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index 277ee2be02..1c78de253e 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -114,7 +114,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// Eigen implementation below is not highly performant. gather_nd_generator
// does not seem to be called in parallel, leading to very poor performance.
// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
@@ -126,12 +126,12 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
const Eigen::array<Eigen::DenseIndex, 1> loc{i};
gather_nd_generator(loc);
}
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 79967aab38..4ad390a411 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// MKL does not support half, bfloat16 and int32 types for
// matrix-multiplication, so register the kernel to use default Eigen based
@@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU_EIGEN);
TF_CALL_complex128(REGISTER_CPU_EIGEN);
TF_CALL_double(REGISTER_CPU_EIGEN);
-#endif
+#endif // INTEL_MKL_DNN_ONLY
-#else // INTEL MKL
+#else // INTEL_MKL && ENABLE_MKL
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
@@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_float(REGISTER_GPU);
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 0841395dc3..bc135de11e 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel {
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulMkl<CPUDevice, TYPE>)
+#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+#endif // ENABLE_MKL
} // end namespace tensorflow
#endif
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 52157ed5fb..f406ad2ab5 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -853,7 +853,7 @@ class MklConvCustomBackpropFilterOp
// MKL DNN allocates large buffers when a conv gradient filter primtive is
// created. So we don't cache conv backward primitives when the env
- // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true.
+ // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
convBwdFilterDims, do_not_cache);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index c38c9cc27c..a501ce2c93 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -713,7 +713,7 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
TFPaddingToMklDnnPadding(this->padding_));
// We don't cache those primitves if the env variable
- // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
// includes potentialy large buffers. MKL DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 184e0cb003..b332edad0a 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -901,7 +901,7 @@ class MklConvOp : public OpKernel {
// In some cases, primitve descriptor includes potentialy large buffers,
// we don't cache those primitves if the env variable
- // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true. MKL DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
// 2. 1x1 convolution with stride != 1
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 077d62ce32..f4788f4851 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
};
#define REGISTER_CPU(T) \
@@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel {
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
+#ifdef ENABLE_MKL
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
@@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
+#endif // ENABLE_MKL
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc
new file mode 100644
index 0000000000..d63e14adf6
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_slice_op.cc
@@ -0,0 +1,358 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/array_ops.cc.
+
+#ifdef INTEL_MKL
+#ifndef INTEL_MKL_ML_ONLY
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/prefetch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "mkldnn.hpp"
+#include "tensorflow/core/util/mkl_util.h"
+
+using mkldnn::stream;
+using mkldnn::view;
+
+namespace tensorflow {
+
+namespace {
+
+gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
+ gtl::InlinedVector<int64, 4> out;
+ if (tensor.dtype() == DT_INT32) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int32>()(i));
+ }
+ } else if (tensor.dtype() == DT_INT64) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int64>()(i));
+ }
+ } else {
+ // tensor must be either int32 or int64
+ DCHECK(false);
+ }
+ return out;
+}
+
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// A version of SharedValidation (slice_op.h) written for input that is in
+// either Mkl layout or Tensorflow layout.
+// A shared code to validate input shapes and check for identity, which is not dependent on the type of T.
+// We do this to reduce code size by not duplicating all this for all T (float, double, int32, etc.)
+static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size) {
+ const int kInputTensorIndex = 0;
+ const int kInputBeginIndex = 1;
+ const int kInputSizeIndex = 2;
+ const Tensor& input = MklGetInput(context, kInputTensorIndex);
+ const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
+ const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
+
+ MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
+ GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
+ GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
+ GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
+
+ // Begin and size tensors cannot be in MklDnn layout.
+ DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
+ DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
+
+ TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetTfShape()
+ : input.shape();
+ const int input_dims = input_tf_shape.dims();
+
+ OP_REQUIRES(
+ context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
+ context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
+ begin_tensor.NumElements() == input_dims &&
+ size_tensor.NumElements() == input_dims,
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
+ " and ", size_tensor.shape().DebugString(), " instead."));
+
+ *begin = IntTensorToInt64Vec(begin_tensor);
+ *size = IntTensorToInt64Vec(size_tensor);
+ for (int i = 0; i < input_dims; ++i) {
+ if ((*size)[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
+ }
+ }
+
+ *is_identity = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = (*begin)[i];
+ int64 s = (*size)[i];
+ if (input_tf_shape.dim_size(i) == 0) {
+ OP_REQUIRES(
+ context, b == 0 && s == 0,
+ errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
+ ") and size[", i, "] == 0 ", "(got ", s,
+ ") when ", "input.dim_size(", i, ") == 0"));
+ } else {
+ OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input_tf_shape.dim_size(i),
+ "], but got ", b));
+ OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_tf_shape.dim_size(i) - b,
+ "], but ", "got ", s));
+ }
+ const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
+ (*is_identity) &= take_all;
+ }
+}
+
+// A version of SharedSliceCommonCases function written for input tensor
+// that may be in MklDnn layout or in Tensorflow layout.
+template <typename T>
+static void CheckCommonCasesForMklInputs(OpKernelContext* context,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size,
+ bool* done) {
+ bool is_identity = true;
+ *done = false;
+
+ ValidateMklInputs(context, &is_identity, begin, size);
+ if (!context->status().ok()) return;
+
+ const Tensor& input = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ // Mkl metadata tensor in this case can just be forwarded from input to
+ // output.
+ AllocateOutputSetMklShape(context, 0, input_mkl_shape);
+ *done = true;
+ }
+}
+
+// MKL-DNN implementation of Slice
+template <typename Device, typename T>
+class MklDnnSliceOp : public OpKernel {
+ public:
+ explicit MklDnnSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ ~MklDnnSliceOp() {}
+
+ void Compute(OpKernelContext* context) override {
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
+ bool done = false;
+
+ CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
+ if (!context->status().ok() || done == true) return;
+
+ // Though MKL-DNN supports more than 8 dimension and
+ // less than 12 dimension tensor.
+ // But we are mimicking functionality of Eigen Slice op for CPU.
+ if (begin.size() >= 8) {
+ OP_REQUIRES(
+ context, false,
+ errors::Unimplemented("MklDnnSliceOp : Unhandled input dimensions"));
+ }
+
+ ComputeMklDnnSlice(context, begin, size);
+ }
+
+ private:
+ // Slice op implemented using MKL-DNN APIs.
+ void ComputeMklDnnSlice(OpKernelContext* context,
+ const gtl::InlinedVector<int64, 4>& begin,
+ const gtl::InlinedVector<int64, 4>& size) {
+ try {
+ // MKL-DNN API usage below is guided by description at:
+ // https://github.com/01org/mkl-dnn/issues/69
+ //
+ // Relevant part of the description is copied below:
+ //
+ // Let's say you want to copy a part of memory into another buffer (and
+ // probably change the format). Then your steps are:
+ //
+ // 1. create memory primitive descriptor in_mem_pd and memory primitive
+ // in_mem_p for the entire source data.
+ // 2. create view primitive descriptor in_submem_pd based on in_mem_pd,
+ // initial offsets, and sub-sizes
+ // 3. create memory primitive descriptor out_mem_pd and memory primitive
+ // out_mem_p for the output (the logical sizes should match sub-sizes
+ // used in step 2, but the format might be arbitrary)
+ // 4. create reorder primitive descriptor reorder_pd based on in_submem_pd
+ // and out_mem_pd
+ // 5. create reorder primitive itself based on reorder_pd, in_mem_p, and
+ // out_mem_p.
+ //
+ // Please notice that there is no view primitive. There is only view
+ // primitive descriptor. And the reorder uses source memory as input but
+ // traverses it according to a view in_submem_pd.
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Populate offsets and sizes in memory::dims format based on vector.
+ memory::dims begin_dims = {};
+ begin_dims.resize(begin.size());
+ for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
+ memory::dims size_dims = {};
+ bool empty = false;
+ size_dims.resize(size.size());
+ for (size_t i = 0; i < size.size(); ++i) {
+ size_dims[i] = size[i];
+ if (size_dims[i] == 0) empty = true;
+ }
+
+ Tensor* output_tensor = nullptr;
+ MklDnnShape output_mkl_shape;
+
+ // If no dimension is selected in slice, the result should be empty.
+ // Just return an empty output tensor, and a dummy Mkl-shape tensor.
+ if (empty) { // for empty dims
+ auto shape_to = MklDnnDimsToTFShape(size_dims);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
+ output_mkl_shape);
+ return;
+ }
+
+ // Step 1 (as per above description) - Create memory for user data.
+ // We use blocked format here to describe input tensor.
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (input_mkl_shape.IsMklTensor()) {
+ auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
+ auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
+ begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
+ size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
+ auto input_md = input_mkl_shape.GetMklLayout();
+ src.SetUsrMem(input_md, &input_tensor);
+ } else {
+ // Initialize input dimensions and strides to be used when input is not
+ // in MklDnn layout.
+ memory::dims input_dims, input_strides;
+ input_dims = TFShapeToMklDnnDims(input_tensor.shape());
+ input_strides = CalculateTFStrides(input_dims);
+ // Create input memory descriptor.
+ auto input_md =
+ MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
+ src.SetUsrMem(input_md, &input_tensor);
+ }
+
+ // Step 2 - create view primitive descriptor
+ auto view_pd =
+ view::primitive_desc(src.GetUsrMemPrimDesc(), size_dims, begin_dims)
+ .dst_primitive_desc();
+ auto output_strides = CalculateTFStrides(size_dims);
+ auto output_md =
+ MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
+ auto output_pd = memory::primitive_desc(output_md, cpu_engine);
+
+ // Step 3 - Create memory for output. If input is in MklDnn layout, then
+ // output is also in MklDnn layout. Otherwise, output is in Tensorflow
+ // layout.
+ AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
+ &output_tensor, &output_mkl_shape);
+ DCHECK(output_tensor);
+ DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
+ output.SetUsrMem(output_md, output_tensor);
+
+ std::vector<primitive> net;
+ // Step 4 - create reorder primitive desc between view_pd and output_pd.
+ auto reorder_pd =
+ reorder::primitive_desc(view_pd, output.GetUsrMemPrimDesc());
+ // Step 5 - create reorder primitive itself.
+ net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *output.GetUsrMem()));
+ // Execute the reorder primitive.
+ stream(stream::kind::eager).submit(net).wait();
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ string(e.message) + ", in file " + string(__FILE__) +
+ ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ void AllocateOutputTensor(OpKernelContext* context,
+ const MklDnnShape& input_mkl_shape,
+ memory::primitive_desc* output_pd,
+ const memory::dims& output_dims,
+ Tensor** output_tensor,
+ MklDnnShape* output_mkl_shape) {
+ DCHECK(output_tensor);
+ DCHECK(output_mkl_shape);
+
+ TensorShape output_tf_shape;
+
+ if (input_mkl_shape.IsMklTensor()) {
+ // Since input tensor is in Mkl layout, output tensor will be in Mkl
+ // layout.
+
+ // Allocate shape of Mkl tensor.
+ output_mkl_shape->SetMklTensor(true);
+ output_mkl_shape->SetMklLayout(output_pd);
+ output_mkl_shape->SetElemType(MklDnnType<T>());
+ output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
+ input_mkl_shape.GetTfDataFormat());
+
+ output_tf_shape.AddDim((output_pd->get_size() / sizeof(T)) + 1);
+ } else {
+ // If input is not in Mkl layout, then output won't be in Mkl layout.
+ output_mkl_shape->SetMklTensor(false);
+ output_tf_shape = MklDnnDimsToTFShape(output_dims);
+ }
+
+ AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
+ *output_mkl_shape);
+ }
+};
+
+// MKL-DNN Slice registration
+#define REGISTER_MKL_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("_MklSlice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDnnSliceOp<CPUDevice, type>);
+
+TF_CALL_float(REGISTER_MKL_SLICE);
+#undef REGISTER_MKL_SLICE
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index fc1c9003aa..fdb4c84c46 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -97,7 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
errors::Internal("Could not find handle ", handle),
done);
+ // We need to pass global op_registry as default_registry when creating
+ // graph. So that graph optimization passes can lookup all possible ops
+ // by name.
auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def());
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ graph.get()->AddFunctionLibrary(global_flib.ToProto()));
CopyGraph(*fbody->graph, graph.get());
OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done);
@@ -250,9 +256,11 @@ class PartitionedCallOp : public AsyncOpKernel {
VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
<< partitions.size() << " shards.";
- const FunctionLibraryDefinition* flib_def = &graph->flib_def();
for (const auto& partition : partitions) {
- std::unique_ptr<Graph> subgraph(new Graph(flib_def));
+ std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ subgraph.get()->AddFunctionLibrary(global_flib.ToProto()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = true;
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 26705a8d34..23d76986bf 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -51,7 +51,9 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif
-#include "tensorflow/core/kernels/resource_variable_ops.h"
+#include <memory>
+#include <vector>
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -60,10 +62,12 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
+#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -72,6 +76,8 @@ limitations under the License.
namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
+ ResourceHandlesOp<Var>);
ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
@@ -101,13 +107,58 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
ctx->set_output(0, t);
}
+ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
+ int n;
+ OP_REQUIRES_OK(c, c->GetAttr("N", &n));
+ OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
+ OP_REQUIRES(c, n == dtypes_.size(),
+ errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp (", n,
+ " vs. ", dtypes_.size(), ")"));
+}
+
+void ReadVariablesOp::Compute(OpKernelContext* ctx) {
+ std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
+ dtypes_.size());
+ std::vector<const ResourceHandle*> handles(dtypes_.size());
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ handles[i] = &HandleFromInput(ctx, i);
+ }
+ const auto status = LookupResources(ctx, handles, &variables);
+ OP_REQUIRES(ctx, status.ok(),
+ errors::FailedPrecondition(
+ "Error while reading resource variable. This could mean that "
+ "the variable was uninitialized. ",
+ status.ToString()));
+
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ // We're acquiring a reference to the underlying buffer while
+ // holding a shared lock to guarantee ordering of reads and
+ // writes.
+ tf_shared_lock ml(*variables[i]->mu());
+ const Tensor& t = *variables[i]->tensor();
+ OP_REQUIRES(ctx, dtypes_[i] == t.dtype(),
+ errors::InvalidArgument(
+ "Trying to read variable ", handles[i]->name(),
+ " from Container: ", handles[i]->container(),
+ " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
+ " got ", DataTypeString(t.dtype())));
+ ctx->set_output(i, t);
+ }
+}
+
REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
+ ReadVariablesOp);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
+ ReadVariablesOp);
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
@@ -122,11 +173,20 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
ResourceHandleOp<Var>)
-
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_variant(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
+
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
+ .Device(DEVICE_GPU)
+ .HostMemory("resources")
+ .TypeConstraint("dtypes",
+ {DT_INT64, DT_COMPLEX64,
+ DT_COMPLEX128, DT_HALF, DT_FLOAT,
+ DT_DOUBLE, DT_BOOL, DT_VARIANT}),
+ ResourceHandlesOp<Var>);
+
#endif // GOOGLE_CUDA
template <typename T>
diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h
index 9b60106f13..cffb732c38 100644
--- a/tensorflow/core/kernels/resource_variable_ops.h
+++ b/tensorflow/core/kernels/resource_variable_ops.h
@@ -28,6 +28,16 @@ class ReadVariableOp : public OpKernel {
DataType dtype_;
};
+class ReadVariablesOp : public OpKernel {
+ public:
+ explicit ReadVariablesOp(OpKernelConstruction* c);
+ void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
+ private:
+ DataTypeVector dtypes_;
+};
+
class DestroyResourceOp : public OpKernel {
public:
explicit DestroyResourceOp(OpKernelConstruction* ctx);
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 77594479cb..a006c69297 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -228,191 +228,6 @@ class SliceOp : public OpKernel {
}
};
-#ifdef INTEL_MKL
-template <typename Device, typename T>
-class MklSliceOp : public OpKernel {
- public:
- explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- TensorShape output_shape;
- gtl::InlinedVector<int64, 4> begin;
- gtl::InlinedVector<int64, 4> size;
- Tensor* result = nullptr;
- bool done = false;
- SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
- &done);
- if (!context->status().ok() || done == true) return;
-
- const Tensor& input = context->input(0);
- const int input_dims = input.dims();
-
- if (output_shape.num_elements() > 0) {
- if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
- DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- auto input = context->input(0).tensor<T, 2>();
- auto output = result->tensor<T, 2>();
- // TODO(agarwal): Consider multi-threading this loop for cases where
- // size[0] is very large.
- for (int i = 0; i < size[0]; ++i) {
- const int64 row = begin[0] + i;
- if (i + 1 < size[0]) {
- port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
- }
- memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
- }
- return;
- }
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- HandleCase<NDIM>(context, begin, size, result); \
- return; \
- }
-
- HANDLE_DIM(1);
- HANDLE_DIM(2);
- HANDLE_DIM(3);
- HANDLE_DIM(4);
- HANDLE_DIM(5);
- HANDLE_DIM(6);
- HANDLE_DIM(7);
-
-#undef HANDLE_DIM
-
- OP_REQUIRES(
- context, false,
- errors::Unimplemented("SliceOp : Unhandled input dimensions"));
- }
- }
-
- private:
- // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following
- // criteria matches for slice_dim: if indices for slice are 0 in all dims
- // except slice_dim and if sizes of all the dimensions of the slice are same
- // as the sizes of all the dimensions of the input except slice_dim, then
- // returns True. Otherwise, returns False.
- bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (dim != slice_dim &&
- (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) {
- return false;
- }
- }
- return true;
- }
-
- // Is 'input' tensor being sliced over a single dimension out of 4?
- //
- // This check is applicable in the context of Slice of a 4-D tensor in
- // NHWC or NCHW format over channel dimension.
- //
- // If indices for slice are 0 in all dims except one dimension and if sizes of
- // all dimensions of slice are same as sizes of all dimensions of inputs
- // except that dimension, then we are slicing over a single dimension.
- //
- // Returns True if Slicing over a single dimension, and sets slice_dim
- // to the number of the dimension that satisfies criteria.
- bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int* slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) {
- *slice_dim = dim;
- return true;
- }
- }
- return false;
- }
-
- template <int NDIM>
- void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size, Tensor* result) {
- int slice_dim = -1;
- TensorShape in_shape = context->input(0).shape();
- // Special case for handling 4-D tensor slice when shape of the slice
- // differs from the input tensor in only 1 out of 4 dimensions.
- // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
- // format over channel dimension.
- if (NDIM == 4 &&
- DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
- size_t in_strides[4] = {
- (size_t)in_shape.dim_size(1) * in_shape.dim_size(2) *
- in_shape.dim_size(3),
- (size_t)in_shape.dim_size(2) * in_shape.dim_size(3),
- (size_t)in_shape.dim_size(3), (size_t)1};
-
- size_t out_strides[4] = {(size_t)size[1] * size[2] * size[3],
- (size_t)size[2] * size[3], (size_t)size[3],
- (size_t)1};
-
- T* in_buf = const_cast<T*>(
- const_cast<const T*>(context->input(0).flat<T>().data()));
- T* op_buf = result->flat<T>().data();
-
- if (slice_dim == 1) {
- /* data format = NCHW */
-
-#pragma omp parallel for
- for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T* ip = in_buf + (d0 * in_strides[0]);
- T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
-#pragma omp parallel for
- for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T* ip1 = ip + (d1 * in_strides[1]);
- T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
- // For NCHW, H and W will be contiguous. So we can copy
- // both with one memcpy.
- memcpy(static_cast<void*>(op1), static_cast<void*>(ip1),
- sizeof(T) * in_strides[1]);
- }
- }
- return;
- } else if (slice_dim == 3) {
- /* data_format = NHWC */
-
-#pragma omp parallel for
- for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T* ip = in_buf + (d0 * in_strides[0]);
- T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
-#pragma omp parallel for
- for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T* ip1 = ip + (d1 * in_strides[1]);
- T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
-#pragma omp parallel for
- for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
- T* ip2 = ip1 + (d2 * in_strides[2]);
- T* ip3 = ip2 + begin[3];
- T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
- T* op3 = op2;
- memcpy(static_cast<void*>(op3), static_cast<void*>(ip3),
- sizeof(T) * size[3]);
- }
- }
- }
- return;
- }
- // slice_dim is not 1 or 3, then we fallback to Eigen implementation.
- }
-
- Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
- for (int i = 0; i < NDIM; ++i) {
- indices[i] = begin[i];
- sizes[i] = size[i];
- }
-
- functor::Slice<Device, T, NDIM>()(
- context->eigen_device<Device>(), result->tensor<T, NDIM>(),
- context->input(0).tensor<T, NDIM>(), indices, sizes);
- }
-};
-#endif
-
// Forward declarations of the functor specializations for declared in the
// sharded source files.
namespace functor {
@@ -440,7 +255,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
#undef DECLARE_CPU_SPEC
} // namespace functor
-#ifndef INTEL_MKL
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
@@ -452,19 +266,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
#undef REGISTER_SLICE
-#else
-#define REGISTER_SLICE(type) \
- REGISTER_KERNEL_BUILDER(Name("Slice") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .HostMemory("begin") \
- .HostMemory("size"), \
- MklSliceOp<CPUDevice, type>)
-
-TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
-TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-#undef REGISTER_SLICE
-#endif // INTEL_MKL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc
index a6829b29d9..435a7abdca 100644
--- a/tensorflow/core/kernels/string_length_op.cc
+++ b/tensorflow/core/kernels/string_length_op.cc
@@ -14,13 +14,18 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/string_util.h"
namespace tensorflow {
namespace {
class StringLengthOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
@@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel {
auto src = input.flat<string>();
auto dst = output->flat<int32>();
- for (int n = 0; n < src.size(); ++n) {
- dst(n) = src(n).size();
+ switch (unit_) {
+ case CharUnit::BYTE:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = src(n).size();
+ }
+ break;
+ case CharUnit::UTF8_CHAR:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = UTF8StrLen(src(n));
+ }
+ break;
}
}
+
+ private:
+ CharUnit unit_ = CharUnit::BYTE;
};
REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
new file mode 100644
index 0000000000..3a9803a052
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.cc
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/string_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace {
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+} // namespace
+
+namespace tensorflow {
+
+// Sets unit value based on str.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) {
+ if (str == "UTF8") {
+ *encoding = UnicodeEncoding::UTF8;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid encoding \"", str, "\": Should be one of: BYTE"));
+ }
+ return Status::OK();
+}
+
+// Sets unit value based on str.
+Status ParseCharUnit(const string& str, CharUnit* unit) {
+ if (str == "BYTE") {
+ *unit = CharUnit::BYTE;
+ } else if (str == "UTF8_CHAR") {
+ *unit = CharUnit::UTF8_CHAR;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR"));
+ }
+ return Status::OK();
+}
+
+// Return the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string) {
+ const int32 byte_size = string.size();
+ const char* const end = string.data() + byte_size;
+ const char* ptr = string.data();
+ int32 skipped_count = 0;
+ while (ptr < end) {
+ skipped_count += IsTrailByte(*ptr++) ? 1 : 0;
+ }
+ const int32 result = byte_size - skipped_count;
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
new file mode 100644
index 0000000000..390cf57702
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.h
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Enumeration for unicode encodings. Used by ops such as
+// tf.strings.unicode_encode and tf.strings.unicode_decode.
+// TODO(edloper): Add support for:
+// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE
+enum class UnicodeEncoding { UTF8 };
+
+// Enumeration for character units. Used by string such as
+// tf.strings.length and tf.substr.
+// TODO(edloper): Add support for: UTF32_CHAR, etc.
+enum class CharUnit { BYTE, UTF8_CHAR };
+
+// Sets `encoding` based on `str`.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
+
+// Sets `unit` value based on `str`.
+Status ParseCharUnit(const string& str, CharUnit* unit);
+
+// Returns the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 765467bc1e..0e6c0ddccc 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -62,7 +62,8 @@ TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
}
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index e8dc4fad21..384a63e945 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -81,7 +81,8 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index 83b83fcdb9..4262a5404b 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -15,14 +15,16 @@ limitations under the License.
#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/util/ptr_util.h"
+
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource) {
+ *maybe_resource = nullptr;
if (ctx->input_dtype(input) == DT_RESOURCE) {
- Var* var;
- if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
- core::ScopedUnref scoped_unref(var);
- return var->mu();
+ if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
+ return (*maybe_resource)->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
@@ -33,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
-// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
-// Safe to pass duplicates - will only lock each distinct mutex once. If
-// do_lock is false, returns immediately. Note that this silently doesn't lock
-// mutexes for invalid variable references; in all usages this is followed by
-// GetInputTensor which will signal a failure.
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// in address order to mitigate deadlock. Returns a structure that, when
+// deleted, will release the acquired mutexes. Safe to pass duplicates - will
+// only lock each distinct mutex once. If do_lock is false, returns
+// immediately. Note that this silently doesn't lock mutexes for invalid
+// variable references; in all usages this is followed by GetInputTensor which
+// will signal a failure.
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
bool any_resource = false;
for (auto i : input_ids) {
@@ -47,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
break;
}
}
- std::vector<mutex_lock> locks;
if (!do_lock && !any_resource) {
- return locks;
+ return VariableInputLockHolder({}, {});
}
+ std::vector<Var*> vars;
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
- mutex* mutex = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
+ if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
@@ -64,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
+ std::unique_ptr<std::vector<mutex_lock>> locks =
+ MakeUnique<std::vector<mutex_lock>>();
+ locks->reserve(acquire_order.size());
+
for (auto input : acquire_order) {
- mutex* mu = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, input, &var);
+ core::ScopedUnref scoped_unref(var);
if (mu != nullptr) {
- locks.emplace_back(*mu);
+ locks->emplace_back(*mu);
}
}
- return locks;
+ return VariableInputLockHolder(std::move(vars), std::move(locks));
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 071cb371a7..9f173a80f7 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -23,9 +23,42 @@ limitations under the License.
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
+// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
+//
+// If `input` corresponds to a `DT_RESOURCE`-type variable input,
+// `*maybe_resource` will be updated to contain the underlying resource, and the
+// caller will be responsible for calling `Unref()` on that resource.
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource);
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// Utility structure that releases a sequence of borrowed mutexes when it is
+// deleted.
+struct VariableInputLockHolder {
+ public:
+ VariableInputLockHolder(std::vector<Var*> vars,
+ std::unique_ptr<std::vector<mutex_lock>> locks)
+ : vars_(std::move(vars)), locks_(std::move(locks)) {}
+
+ VariableInputLockHolder(VariableInputLockHolder&& other)
+ : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {}
+
+ ~VariableInputLockHolder() {
+ // Release the locks before unreffing the Vars, because each lock
+ // is potentially borrowed from a Var in vars_.
+ locks_.reset();
+ for (Var* var : vars_) {
+ var->Unref();
+ }
+ }
+
+ private:
+ std::vector<Var*> vars_;
+ // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
+ // because a `std::vector<mutex_lock>` is not movable on all platforms.
+ std::unique_ptr<std::vector<mutex_lock>> locks_;
+};
+
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 9a07ded17d..acf162deec 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -561,7 +561,9 @@ class ApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* resource;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource);
+ core::ScopedUnref scoped_unref(resource);
if (use_exclusive_lock_ && mu != nullptr) {
mutex_lock l1(*mu);
// Don't try to acquire a lock on the second ref as they share the same
@@ -710,7 +712,9 @@ class SparseApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &var);
+ core::ScopedUnref scoped_unref(var);
// mu_accum is actually the same mutex as mu_var since currently we use a
// global mutex.
//
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 0f0f65c5a3..48e392c070 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
MklConjugateTransposeCpuOp);
-TF_CALL_ALL_TYPES(REGISTER);
-#undef REGISTER
-
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER);
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
ConjugateTransposeCpuOp);
+#endif // INTEL_MKL && ENABLE_MKL
+
TF_CALL_ALL_TYPES(REGISTER)
#undef REGISTER
-#endif // INTEL_MKL
#if GOOGLE_CUDA
Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
diff --git a/tensorflow/core/kernels/unicode_script_op.cc b/tensorflow/core/kernels/unicode_script_op.cc
new file mode 100644
index 0000000000..085e397eba
--- /dev/null
+++ b/tensorflow/core/kernels/unicode_script_op.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "unicode/errorcode.h" // TF:icu
+#include "unicode/uscript.h" // TF:icu
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class UnicodeScriptOp : public OpKernel {
+ public:
+ explicit UnicodeScriptOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(context, context->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<int32>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<int32>();
+
+ icu::ErrorCode status;
+ for (int i = 0; i < input_flat.size(); i++) {
+ UScriptCode script_code = uscript_getScript(input_flat(i), status);
+ if (status.isSuccess()) {
+ output_flat(i) = script_code;
+ } else {
+ output_flat(i) = -1;
+ status.reset();
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("UnicodeScript").Device(DEVICE_CPU),
+ UnicodeScriptOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 442686c92a..c9f80df5e4 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1531,37 +1531,6 @@ REGISTER_OP("Size")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ScalarShape);
-namespace {
-
-// This SliceHelper processes the output shape of the `slice`
-// when the tensor of `sizes` is available.
-template <typename T>
-Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
- const Tensor* sizes_value,
- std::vector<DimensionHandle>* dims) {
- auto sizes_vec = sizes_value->vec<T>();
- for (int i = 0; i < sizes_value->NumElements(); ++i) {
- DimensionHandle dim = c->Dim(c->input(0), i);
- if (sizes_vec(i) != -1) {
- auto dim_val = c->Value(dim);
- if (sizes_vec(i) < 0) {
- return errors::InvalidArgument(
- "Out of bounds slicing on dimension ", i, " of length ", dim_val,
- ": sizes vector cannot be < -1, but was ", sizes_vec(i));
- }
-
- dims->emplace_back(c->MakeDim(sizes_vec(i)));
- } else {
- DimensionHandle result;
- TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
- dims->emplace_back(result);
- }
- }
-
- return Status::OK();
-}
-} // namespace
-
// --------------------------------------------------------------------------
REGISTER_OP("Slice")
.Input("input: T")
@@ -1570,83 +1539,22 @@ REGISTER_OP("Slice")
.Output("output: T")
.Attr("T: type")
.Attr("Index: {int32,int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle input = c->input(0);
- ShapeHandle begin_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
- ShapeHandle sizes_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
-
- // Merge to check compatibility of begin and sizes tensors.
- TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
-
- DimensionHandle ndims = c->Dim(begin_shape, 0);
- if (c->ValueKnown(ndims)) {
- TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
- }
-
- // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
- // values, even though the `begin` value does not represent a shape.
- ShapeHandle begin_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
-
- // We check the tensor value here and will only use
- // `MakeShapeFromShapeTensor` when `sizes_value` is null.
- // The reason is that `sizes`might contain -1, which can't
- // be represented (-1 in the ShapeHandle would mean "unknown".
- const Tensor* sizes_value = c->input_tensor(2);
-
- if (sizes_value != nullptr) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
- std::vector<DimensionHandle> dims;
- // If the begin and sizes tensors are available, then
- // we can be precise about the shape of the output.
- if (sizes_value->dtype() == DT_INT64) {
- TF_RETURN_IF_ERROR(
- SliceHelper<int64>(c, begin_value, sizes_value, &dims));
- } else {
- TF_RETURN_IF_ERROR(
- SliceHelper<int32>(c, begin_value, sizes_value, &dims));
- }
-
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- } else {
- // In case `sizes` is not available (`sizes_value` is null),
- // we could try to use `MakeShapeFromShapeTensor` here.
- // If sizes contain -1, we will simply consider it as `Unknown`.
- // This is less than ideal but still an improvement of shape inference.
- // The following is an example that returns [None, 1, None] with this
- // code path:
- // z = tf.zeros((1, 2, 3))
- // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
- // m.get_shape().as_list()
- ShapeHandle sizes_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
- if (c->RankKnown(sizes_value)) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
- std::vector<DimensionHandle> dims;
- dims.reserve(c->Rank(sizes_value));
- for (int i = 0; i < c->Rank(sizes_value); ++i) {
- dims.emplace_back(c->Dim(sizes_value, i));
- }
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- }
-
- // We might know the rank of the input.
- if (c->RankKnown(input)) {
- c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
- return Status::OK();
- } else {
- return shape_inference::UnknownShape(c);
- }
- }
+ .SetShapeFn(shape_inference::SliceShape);
- return Status::OK();
- });
+#ifdef INTEL_MKL
+REGISTER_OP("_MklSlice")
+ .Input("input: T")
+ .Input("begin: Index")
+ .Input("size: Index")
+ .Input("mkl_input: uint8")
+ .Input("mkl_begin: uint8")
+ .Input("mkl_size: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: type")
+ .Attr("Index: {int32,int64}")
+ .SetShapeFn(shape_inference::SliceShape);
+#endif
REGISTER_OP("StridedSlice")
.Input("input: T")
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 7c4184bff4..b8cf538554 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -180,6 +180,8 @@ REGISTER_OP("BoostedTreesMakeStatsSummary")
return Status::OK();
});
+// TODO(nponomareva): when/if creating the new op for unbucketized data, rename
+// bucketized_features to features.
REGISTER_OP("BoostedTreesPredict")
.Input("tree_ensemble_handle: resource")
.Input("bucketized_features: num_bucketized_features * int32")
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index b02ea64ac9..43c14d83b5 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -21532,6 +21532,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -24105,6 +24520,85 @@ op {
}
}
op {
+ name: "FusedBatchNorm"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "offset"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "mean"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "variance"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_mean"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_variance"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_1"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_2"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormGrad"
input_arg {
name: "y_backprop"
@@ -24178,6 +24672,85 @@ op {
}
}
op {
+ name: "FusedBatchNormGrad"
+ input_arg {
+ name: "y_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space_1"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space_2"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "x_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "scale_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "offset_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_3"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_4"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormGradV2"
input_arg {
name: "y_backprop"
@@ -24345,6 +24918,96 @@ op {
}
}
op {
+ name: "FusedBatchNormGradV2"
+ input_arg {
+ name: "y_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "reserve_space_1"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "reserve_space_2"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "x_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "scale_backprop"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "offset_backprop"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_3"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_4"
+ type_attr: "U"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "U"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormV2"
input_arg {
name: "x"
@@ -24512,6 +25175,96 @@ op {
}
}
op {
+ name: "FusedBatchNormV2"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "offset"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "mean"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "variance"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_mean"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "batch_variance"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_1"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_2"
+ type_attr: "U"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "U"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedPadConv2D"
input_arg {
name: "input"
@@ -44518,6 +45271,59 @@ op {
is_stateful: true
}
op {
+ name: "ReduceDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
@@ -60085,6 +60891,29 @@ op {
}
}
op {
+ name: "Softplus"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SoftplusGrad"
input_arg {
name: "gradients"
@@ -60221,6 +61050,33 @@ op {
}
}
op {
+ name: "SoftplusGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Softsign"
input_arg {
name: "features"
@@ -60341,6 +61197,29 @@ op {
}
}
op {
+ name: "Softsign"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SoftsignGrad"
input_arg {
name: "gradients"
@@ -60477,6 +61356,33 @@ op {
}
}
op {
+ name: "SoftsignGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SpaceToBatch"
input_arg {
name: "input"
@@ -70498,6 +71404,30 @@ op {
}
}
op {
+ name: "StringLength"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "StringSplit"
input_arg {
name: "input"
@@ -74449,6 +75379,17 @@ op {
}
}
op {
+ name: "UnicodeScript"
+ input_arg {
+ name: "input"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
@@ -76159,6 +77100,62 @@ op {
is_stateful: true
}
op {
+ name: "Xdivy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Xlogy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ZerosLike"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 1ada623cf5..71f4cc3c4c 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -756,6 +756,19 @@ REGISTER_OP("DatasetToSingleElement")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
+REGISTER_OP("ReduceDataset")
+ .Input("input_dataset: variant")
+ .Input("initial_state: Tstate")
+ .Input("other_arguments: Targuments")
+ .Output("components: output_types")
+ .Attr("f: func")
+ .Attr("Tstate: list(type) >= 1")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
.Output("string_handle: string")
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
index d1a771f005..f6bd5dce26 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -17,24 +17,16 @@ limitations under the License.
namespace tensorflow {
-REGISTER_OP("DirectedInterleaveDataset")
+REGISTER_OP("ExperimentalDirectedInterleaveDataset")
.Input("selector_input_dataset: variant")
.Input("data_input_datasets: N * variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("N: int >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
-
-selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines
- which of the `N` data inputs should produce the next output element.
-data_input_datasets: `N` datasets with the same type that will be interleaved
- according to the values of `selector_input_dataset`.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("CSVDataset")
+REGISTER_OP("ExperimentalCSVDataset")
.Input("filenames: string")
.Input("compression_type: string")
.Input("buffer_size: int64")
@@ -76,35 +68,26 @@ REGISTER_OP("CSVDataset")
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("IgnoreErrorsDataset")
+REGISTER_OP("ExperimentalIgnoreErrorsDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("UniqueDataset")
+REGISTER_OP("ExperimentalUniqueDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the unique elements of `input_dataset`.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("IteratorGetDevice")
+REGISTER_OP("ExperimentalIteratorGetDevice")
.Input("resource: resource")
.Output("device: string")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Returns the name of the device on which `resource` has been placed.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("FunctionBufferingResource")
+REGISTER_OP("ExperimentalFunctionBufferingResource")
.Input("string_arg: string")
.Input("target_device: string")
.Output("resource: resource")
@@ -113,77 +96,36 @@ REGISTER_OP("FunctionBufferingResource")
.Attr("f: func")
.Attr("buffer_size: int")
.Attr("output_types: list(type)")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Creates a resource that fills up a buffer by making function calls.
-
-string_arg: String argument to the function call.
-target_device: Target device to execute the function on.
-resource: Handle to the resource created.
-f: Function to be executed.
-buffer_size: Size of the buffer.
-container: If non-empty, this resource is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this resource will be shared under the given name
- across multiple sessions.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceGetNext")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceGetNext")
.Input("function_buffer_resource: resource")
.Attr("output_types: list(type)")
.Output("output: output_types")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Gets the next element from a FunctionBufferingResource.
+ .SetShapeFn(shape_inference::UnknownShape);
-function_buffer_resource: The FunctionBufferingResource handle.
-output: A list of return values.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceReset")
+REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
.Input("function_buffer_resource: resource")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Resets the FunctionBufferingResource.
-
-function_buffer_resource: The FunctionBufferingResource handle.
-)doc");
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThreadPoolDataset")
+REGISTER_OP("ExperimentalThreadPoolDataset")
.Input("input_dataset: variant")
.Input("thread_pool: resource")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that uses a custom thread pool to compute `input_dataset`.
-
-handle: A resource produced by the ThreadPoolHandle op.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("ThreadPoolHandle")
+REGISTER_OP("ExperimentalThreadPoolHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Attr("num_threads: int")
.Attr("max_intra_op_parallelism: int = 1")
.Attr("display_name: string")
.Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Doc(R"doc(
-Creates a custom thread pool with the given number of threads.
-
-handle: A resource that can be consumed by one or more ThreadPoolDataset ops.
-num_threads: The number of threads in the thread pool.
-max_intra_op_parallelism: The maximum degree of parallelism to use within
- operations that execute on this threadpool.
-display_name: A human-readable name for the threads that may be visible in
- some visualizations.
-)doc");
-
-REGISTER_OP("AssertNextDataset")
+ .Attr("shared_name: string = ''");
+
+REGISTER_OP("ExperimentalAssertNextDataset")
.Input("input_dataset: variant")
.Input("transformations: string")
.Output("handle: variant")
@@ -196,7 +138,7 @@ REGISTER_OP("AssertNextDataset")
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("LMDBDataset")
+REGISTER_OP("ExperimentalLMDBDataset")
.Input("filenames: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
@@ -205,4 +147,61 @@ REGISTER_OP("LMDBDataset")
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ExperimentalIdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ExperimentalIndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 07f876cb90..55dcc50325 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -549,6 +549,40 @@ Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Pow", PowGrad);
+Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"zeros"}, "ZerosLike", {"x"}},
+ {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
+ {{"is_zero_cast"}, "Cast", {"is_x_zero"},
+ {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
+ {{"xlogygrad"}, "Xdivy", {"x", "y"}},
+ {{"gx"}, "Mul", {"safe_logy", "dz"}},
+ {{"gy"}, "Mul", {"xlogygrad", "dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
+
+Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"zeros"}, "ZerosLike", {"x"}},
+ {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
+ {{"is_zero_cast"}, "Cast", {"is_x_zero"},
+ {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
+ {{"y2"}, "Square", {"y"}},
+ {{"negy2"}, "Neg", {"y2"}},
+ {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
+ {{"gx"}, "Mul", {"safe_divy", "dz"}},
+ {{"gy"}, "Mul", {"xdivygrad", "dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
+
Status MaximumMinimumGradHelper(const string& comparator,
const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 5ee79809ac..9fc6b34147 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -909,6 +909,46 @@ TEST_F(MathGradTest, ComplexPow) {
}
#endif // TENSORFLOW_USE_SYCL
+TEST_F(MathGradTest, Xlogy) {
+ auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
+ TensorShape({2, 3}));
+ auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](float x, float y) -> float { return x == 0. ? 0. : std::log(y); };
+ auto h = [](float x, float y) -> float { return x == 0. ? 0. : x / y; };
+ SymGrad("Xlogy", x, y, &dx, &dy);
+ test::ExpectClose(
+ dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
+ g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
+ TensorShape({2, 3})));
+ test::ExpectClose(
+ dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
+ h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
+ TensorShape({2, 1})));
+}
+
+TEST_F(MathGradTest, Xdivy) {
+ auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
+ TensorShape({2, 3}));
+ auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](float x, float y) -> float { return x == 0. ? 0. : 1 / y; };
+ auto h = [](float x, float y) -> float {
+ return x == 0. ? 0. : -x / (y * y);
+ };
+ SymGrad("Xdivy", x, y, &dx, &dy);
+ test::ExpectClose(
+ dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
+ g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
+ TensorShape({2, 3})));
+ test::ExpectClose(
+ dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
+ h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
+ TensorShape({2, 1})));
+}
+
TEST_F(MathGradTest, Maximum) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
TensorShape({2, 3}));
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 717263a9b0..3eff728f03 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -429,6 +429,20 @@ Returns (x - y)(x - y) element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("Xlogy")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {half, float, double, complex64, complex128}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
+REGISTER_OP("Xdivy")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {half, float, double, complex64, complex128}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
#undef BINARY_FEWER
#undef BINARY_MORE
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 2485fa4717..d1d81b27cc 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -178,7 +178,7 @@ REGISTER_OP("FusedBatchNorm")
.Output("reserve_space_2: T")
.Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);
@@ -196,7 +196,7 @@ REGISTER_OP("FusedBatchNormV2")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);
@@ -213,7 +213,7 @@ REGISTER_OP("FusedBatchNormGrad")
.Output("reserve_space_4: T")
.Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
@@ -231,7 +231,7 @@ REGISTER_OP("FusedBatchNormGradV2")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
@@ -1009,32 +1009,30 @@ REGISTER_OP("SeluGrad")
.Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
-// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("SoftplusGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
-// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softsign")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("SoftsignGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
// --------------------------------------------------------------------------
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 4c5a472e9f..abee803889 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10039,6 +10039,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -11459,6 +11874,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11532,6 +11953,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11616,6 +12043,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11700,6 +12133,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -22845,6 +23284,59 @@ op {
is_stateful: true
}
op {
+ name: "ReduceDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
@@ -28714,18 +29206,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28749,18 +29233,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28780,18 +29256,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28815,18 +29283,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -33043,6 +33503,19 @@ op {
name: "output"
type: DT_INT32
}
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
}
op {
name: "StringSplit"
@@ -35644,6 +36117,17 @@ op {
}
}
op {
+ name: "UnicodeScript"
+ input_arg {
+ name: "input"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
@@ -36756,6 +37240,62 @@ op {
is_stateful: true
}
op {
+ name: "Xdivy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Xlogy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ZerosLike"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 26499540f1..adc9cd1486 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -19,6 +19,7 @@
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeAndType;
@@ -56,6 +57,36 @@ Status ReadVariableShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status ReadVariablesShapeFn(InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector value_dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes));
+ if (n != value_dtypes.size()) {
+ return errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp");
+ }
+ for (int i = 0; i < n; ++i) {
+ ShapeAndType shape_and_type;
+ auto* handle_data = c->input_handle_shapes_and_types(i);
+ if (handle_data == nullptr || handle_data->empty()) {
+ shape_and_type.shape = c->UnknownShape();
+ shape_and_type.dtype = DT_INVALID;
+ } else {
+ shape_and_type = (*handle_data)[0];
+ if (shape_and_type.dtype != value_dtypes[i]) {
+ return errors::InvalidArgument(
+ "Trying to read variable with wrong dtype. "
+ "Expected ",
+ DataTypeString(shape_and_type.dtype), " got ",
+ DataTypeString(value_dtypes[i]));
+ }
+ }
+ c->set_output(i, shape_and_type.shape);
+ }
+ return Status::OK();
+}
+
} // namespace
REGISTER_OP("VarHandleOp")
@@ -79,12 +110,53 @@ REGISTER_OP("VarHandleOp")
return Status::OK();
});
+REGISTER_OP("_VarHandlesOp")
+ .Attr("containers: list(string)")
+ .Attr("shared_names: list(string)")
+ .Attr("N: int >= 0")
+ .Attr("dtypes: list(type)")
+ .Attr("shapes: list(shape)")
+ .Output("resources: N * resource")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
+ std::vector<PartialTensorShape> shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+ if (dtypes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of dtypes (n=", n,
+ ", num dtypes=", dtypes.size(), ")");
+ }
+ if (shapes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of shapes (n=", n,
+ ", num shapes=", shapes.size(), ")");
+ }
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s));
+ c->set_output_handle_shapes_and_types(
+ i, std::vector<ShapeAndType>{{s, dtypes[i]}});
+ }
+
+ return Status::OK();
+ });
+
REGISTER_OP("ReadVariableOp")
.Input("resource: resource")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn(ReadVariableShapeFn);
+REGISTER_OP("_ReadVariablesOp")
+ .Attr("N: int >= 0")
+ .Input("resources: N * resource")
+ .Output("values: dtypes")
+ .Attr("dtypes: list(type)")
+ .SetShapeFn(ReadVariablesShapeFn);
+
Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FunctionDefHelper::Define(
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 99159839d0..b4fbde54d9 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -203,6 +203,7 @@ REGISTER_OP("StringStrip")
REGISTER_OP("StringLength")
.Input("input: string")
.Output("output: int32")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("EncodeBase64")
@@ -243,4 +244,9 @@ REGISTER_OP("Substr")
return shape_inference::BroadcastBinaryOpShapeFn(c);
});
+REGISTER_OP("UnicodeScript")
+ .Input("input: int32")
+ .Output("output: int32")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index bb841aeab7..3b14757945 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -641,54 +641,41 @@ def tf_additional_lib_deps():
def tf_additional_core_deps():
return select({
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/core/platform/cloud:gcs_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_hdfs_support_windows_override": [],
- "//tensorflow:with_hdfs_support_android_override": [],
- "//tensorflow:with_hdfs_support_ios_override": [],
- "//tensorflow:with_hdfs_support": [
- "//tensorflow/core/platform/hadoop:hadoop_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support_android_override": [],
- "//tensorflow:with_aws_support_ios_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/core/platform/s3:s3_file_system",
+ "//tensorflow/core/platform/hadoop:hadoop_file_system",
],
- "//conditions:default": [],
})
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
"//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
],
- "//conditions:default": [],
})
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
"//tensorflow/contrib/cloud/kernels:gcs_config_ops",
],
- "//conditions:default": [],
})
def tf_lib_proto_parsing_deps():
diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD
index af034bdd7d..2bf371276e 100644
--- a/tensorflow/core/profiler/BUILD
+++ b/tensorflow/core/profiler/BUILD
@@ -40,7 +40,6 @@ tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
- java_api_version = 2,
protodeps = tf_additional_all_protos(),
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 85cd02350a..104ab039cb 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -453,6 +453,11 @@ message RunOptions {
// same group_key value (in a distributed computation where tasks
// run disjoint graphs).
int64 collective_graph_key = 1;
+ // If true, then operations (using the inter-op pool) across all
+ // session::run() calls will be centrally scheduled, optimizing for (median
+ // and tail) latency.
+ // Consider using this option for CPU-bound workloads like inference.
+ bool use_run_handler_pool = 2;
};
Experimental experimental = 8;
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index bb8f88336d..8e0448d536 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -77,6 +77,8 @@ message RewriterConfig {
Toggle scoped_allocator_optimization = 15;
// Force small ops onto the CPU (default is ON).
Toggle pin_to_host_optimization = 18;
+ // Disable the entire meta optimizer (off by default).
+ bool disable_meta_optimizer = 19;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
@@ -143,8 +145,8 @@ message RewriterConfig {
// not configurable (in contrast to memory optimization passes through the
// meta-optimizer) and act only on manual op annotations.
//
- // Custom registered optimizers will be run after the base optimizers, in
- // the order that they are specified.
+ // Custom optimizers (see custom_optimizers) that are not part of this
+ // schedule will be run after - in the order that they were specified.
repeated string optimizers = 100;
// Message to describe custom graph optimizer and its parameters
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index cf7ffd8149..04aaea4f89 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -2039,8 +2039,8 @@ class MklPrimitiveFactory {
/// Fuction to check whether primitive memory optimization is enabled
static inline bool IsPrimitiveMemOptEnabled() {
bool is_primitive_mem_opt_enabled = true;
- TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
- &is_primitive_mem_opt_enabled));
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
return is_primitive_mem_opt_enabled;
}
@@ -2095,9 +2095,8 @@ static inline memory::format get_desired_format(int channel,
fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = is_2d
- ? memory::format::nChw8c
- : memory::format::ncdhw; //not support avx2 for 3d yet.
+ fmt_desired = is_2d ? memory::format::nChw8c
+ : memory::format::ncdhw; // no avx2 support for 3d yet.
} else {
fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
@@ -2209,7 +2208,8 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
// utility function to determine if it is conv 1x1 and stride != 1
// for purpose of temporarily disabling primitive reuse
-inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
+ memory::dims strides) {
if (filter_dims.size() != 4 || strides.size() != 2) return false;
return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index c081ceae57..e01058dff6 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -38,10 +38,10 @@ bool CudaSupportsHalfMatMulAndConv() {
}
bool IsMklEnabled() {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
return true;
#else
return false;
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
}
} // end namespace tensorflow
diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD
index 648358606c..f40ec9b752 100644
--- a/tensorflow/core/util/tensor_bundle/BUILD
+++ b/tensorflow/core/util/tensor_bundle/BUILD
@@ -64,6 +64,11 @@ cc_library(
tf_cc_test(
name = "tensor_bundle_test",
srcs = ["tensor_bundle_test.cc"],
+ data = glob(["testdata/**"]),
+ tags = [
+ "nomsan",
+ "notsan",
+ ],
deps = [
":tensor_bundle",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index ea8a259d1a..2dcb57a1f9 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -64,27 +64,36 @@ namespace {
// Reads "num_elements" string elements from file[offset, offset+size) into the
// length-N "destination". Discards the original content of "destination".
//
-// Checksums the string lengths (as restored uint32, not varint32 bytes) and
-// string bytes, and stores it into "actual_crc32c".
+// Checksums the string lengths (as restored uint32 or uint64, not varint64
+// bytes) and string bytes, and stores it into "actual_crc32c".
Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
size_t offset, size_t size, string* destination,
uint32* actual_crc32c) {
if (size == 0) return Status::OK();
CHECK_GT(size, 0);
- // Reads "num_elements" varint32's from "buffered_file".
+ // Reads "num_elements" varint64's from "buffered_file".
TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
- std::vector<uint32> string_lengths(num_elements);
+ std::vector<uint64> string_lengths(num_elements);
for (size_t i = 0; i < num_elements; ++i) {
- TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i]));
+ TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
+ if (string_lengths[i] <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&string_lengths[i]),
+ sizeof(uint64));
+ }
}
if (offset + size < buffered_file->Tell()) {
return errors::DataLoss("String lengths longer than expected offset ",
offset + size);
}
- *actual_crc32c =
- crc32c::Value(reinterpret_cast<const char*>(string_lengths.data()),
- sizeof(uint32) * num_elements);
// Reads the length-checksum.
uint32 length_checksum = 0;
@@ -104,7 +113,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
// Reads the actual string bytes.
for (size_t i = 0; i < num_elements; ++i) {
- const uint32 string_length = string_lengths[i];
+ const uint64 string_length = string_lengths[i];
string* buffer = &destination[i];
buffer->resize(string_length);
@@ -218,8 +227,8 @@ Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
size_t* bytes_written, uint32* crc32c) {
// On-disk format:
- // [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes]
- // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes),
+ // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
+ // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
// the length-checksum, and all the string bytes.
DCHECK_EQ(val.dtype(), DT_STRING);
const string* strings = GetStringBackingBuffer(val);
@@ -230,12 +239,21 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
*crc32c = 0;
for (int64 i = 0; i < val.NumElements(); ++i) {
const string* elem = &strings[i];
- DCHECK_EQ(elem->size(), static_cast<uint32>(elem->size()));
- const uint32 elem_size = static_cast<uint32>(elem->size());
-
- core::PutVarint32(&lengths, elem_size);
- *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
- sizeof(uint32));
+ DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
+ const uint64 elem_size = static_cast<uint64>(elem->size());
+
+ core::PutVarint64(&lengths, elem_size);
+ if (elem_size <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
+ *crc32c = crc32c::Extend(*crc32c,
+ reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *crc32c = crc32c::Extend(
+ *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
+ }
}
TF_RETURN_IF_ERROR(out->Append(lengths));
*bytes_written = lengths.size();
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 59c42baa06..9567e4750b 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -39,6 +39,11 @@ string Prefix(const string& prefix) {
return strings::StrCat(testing::TmpDir(), "/", prefix);
}
+string TestdataPrefix(const string& prefix) {
+ return strings::StrCat(testing::TensorFlowSrcRoot(),
+ "/core/util/tensor_bundle/testdata/", prefix);
+}
+
template <typename T>
Tensor Constant(T v, TensorShape shape) {
Tensor ret(DataTypeToEnum<T>::value, shape);
@@ -458,7 +463,26 @@ TEST(TensorBundleTest, NonStandardShapes) {
TestNonStandardShapes<qint8>();
}
+TEST(TensorBundleTest, StringTensorsOldFormat) {
+ // Test string tensor bundle made with previous version of code that use
+ // varint32s to store string lengths (we now use varint64s).
+ BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
+ TF_ASSERT_OK(reader.status());
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+
+ Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
+ Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
+ Expect<string>(
+ &reader, "strs",
+ test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
+ Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+}
+
TEST(TensorBundleTest, StringTensors) {
+ constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
+ Tensor long_string_tensor(DT_STRING, TensorShape({1}));
+
{
BundleWriter writer(Env::Default(), Prefix("foo"));
TF_EXPECT_OK(writer.Add("string_tensor",
@@ -467,6 +491,12 @@ TEST(TensorBundleTest, StringTensors) {
TF_EXPECT_OK(writer.Add(
"strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
+
+ // Requires a 64-bit length.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign(kLongLength, 'd');
+ TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
+
// Mixes in some floats.
TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
TF_ASSERT_OK(writer.Finish());
@@ -474,9 +504,9 @@ TEST(TensorBundleTest, StringTensors) {
{
BundleReader reader(Env::Default(), Prefix("foo"));
TF_ASSERT_OK(reader.status());
- EXPECT_EQ(
- AllTensorKeys(&reader),
- std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "long_scalar", "scalar",
+ "string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1})));
@@ -484,7 +514,35 @@ TEST(TensorBundleTest, StringTensors) {
Expect<string>(
&reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
+
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+
+ // We don't use the Expect function so we can re-use the
+ // `long_string_tensor` buffer for reading out long_scalar to keep memory
+ // usage reasonable.
+ EXPECT_TRUE(reader.Contains("long_scalar"));
+ DataType dtype;
+ TensorShape shape;
+ TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
+ EXPECT_EQ(DT_STRING, dtype);
+ EXPECT_EQ(TensorShape({1}), shape);
+
+ // Zero-out the string so that we can be sure the new one is read in.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign("");
+
+ // Read long_scalar and check it contains kLongLength 'd's.
+ TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
+ ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
+ EXPECT_EQ(kLongLength, backing_string->length());
+ for (char c : *backing_string) {
+ // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
+ // compiler optimizations.
+ if (c != 'd') {
+ FAIL() << "long_scalar is not full of 'd's as expected.";
+ break;
+ }
+ }
}
}
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
new file mode 100644
index 0000000000..428d3ef79e
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
@@ -0,0 +1,3 @@
+This tensor bundle was generated from cl/214343133, before string tensor
+lengths were written as varint64s. This is here to check backwards
+compatibility between the new code and old checkpoints.
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
new file mode 100644
index 0000000000..23b488e5fe
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
new file mode 100644
index 0000000000..a22a69e6e1
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
Binary files differ
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index f327b645f5..f5f0d7c3c8 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -68,6 +68,7 @@ android_binary(
srcs = glob([
"src/**/*.java",
]),
+ aapt_version = "aapt",
# Package assets from assets dir as well as all model targets. Remove undesired models
# (and corresponding Activities in source) to reduce APK size.
assets = [
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 8b60e6fd25..b4d4db3e4d 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2562,92 +2562,6 @@ func Fill(scope *Scope, dims tf.Output, value tf.Output) (output tf.Output) {
return op.Output(0)
}
-// EditDistanceAttr is an optional argument to EditDistance.
-type EditDistanceAttr func(optionalAttr)
-
-// EditDistanceNormalize sets the optional normalize attribute to value.
-//
-// value: boolean (if true, edit distances are normalized by length of truth).
-//
-// The output is:
-// If not specified, defaults to true
-func EditDistanceNormalize(value bool) EditDistanceAttr {
- return func(m optionalAttr) {
- m["normalize"] = value
- }
-}
-
-// Computes the (possibly normalized) Levenshtein Edit Distance.
-//
-// The inputs are variable-length sequences provided by SparseTensors
-// (hypothesis_indices, hypothesis_values, hypothesis_shape)
-// and
-// (truth_indices, truth_values, truth_shape).
-//
-// The inputs are:
-//
-// Arguments:
-// hypothesis_indices: The indices of the hypothesis list SparseTensor.
-// This is an N x R int64 matrix.
-// hypothesis_values: The values of the hypothesis list SparseTensor.
-// This is an N-length vector.
-// hypothesis_shape: The shape of the hypothesis list SparseTensor.
-// This is an R-length vector.
-// truth_indices: The indices of the truth list SparseTensor.
-// This is an M x R int64 matrix.
-// truth_values: The values of the truth list SparseTensor.
-// This is an M-length vector.
-// truth_shape: truth indices, vector.
-//
-// Returns A dense float tensor with rank R - 1.
-//
-// For the example input:
-//
-// // hypothesis represents a 2x1 matrix with variable-length values:
-// // (0,0) = ["a"]
-// // (1,0) = ["b"]
-// hypothesis_indices = [[0, 0, 0],
-// [1, 0, 0]]
-// hypothesis_values = ["a", "b"]
-// hypothesis_shape = [2, 1, 1]
-//
-// // truth represents a 2x2 matrix with variable-length values:
-// // (0,0) = []
-// // (0,1) = ["a"]
-// // (1,0) = ["b", "c"]
-// // (1,1) = ["a"]
-// truth_indices = [[0, 1, 0],
-// [1, 0, 0],
-// [1, 0, 1],
-// [1, 1, 0]]
-// truth_values = ["a", "b", "c", "a"]
-// truth_shape = [2, 2, 2]
-// normalize = true
-//
-// The output will be:
-//
-// // output is a 2x2 matrix with edit distances normalized by truth lengths.
-// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis
-// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis
-func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EditDistance",
- Input: []tf.Input{
- hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Reverses specific dimensions of a tensor.
//
// Given a `tensor`, and a `bool` tensor `dims` representing the dimensions
@@ -3828,27 +3742,6 @@ func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf
return op.Output(0)
}
-// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
-// layer.
-func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesGetEnsembleStates",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// Creates a tree ensemble model and returns a handle to it.
//
// Arguments:
@@ -3890,169 +3783,613 @@ func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Out
return op.Output(0)
}
-// Computes the sum along sparse segments of a tensor.
+// BoostedTreesEnsembleResourceHandleOpAttr is an optional argument to BoostedTreesEnsembleResourceHandleOp.
+type BoostedTreesEnsembleResourceHandleOpAttr func(optionalAttr)
+
+// BoostedTreesEnsembleResourceHandleOpContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesEnsembleResourceHandleOpContainer(value string) BoostedTreesEnsembleResourceHandleOpAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// BoostedTreesEnsembleResourceHandleOpSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesEnsembleResourceHandleOpSharedName(value string) BoostedTreesEnsembleResourceHandleOpAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Creates a handle to a BoostedTreesEnsembleResource
+func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTreesEnsembleResourceHandleOpAttr) (resource tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesEnsembleResourceHandleOp",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits.
+type ComputeAccidentalHitsAttr func(optionalAttr)
+
+// ComputeAccidentalHitsSeed sets the optional seed attribute to value.
//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value.
//
-// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
-// dimension, selecting a subset of dimension 0, specified by `indices`.
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Computes the ids of the positions in sampled_candidates that match true_labels.
//
-// For example:
+// When doing log-odds NCE, the result of this op should be passed through a
+// SparseToDense op, then added to the logits of the sampled candidates. This has
+// the effect of 'removing' the sampled labels that match the true labels by
+// making the classifier sure that they are sampled labels.
//
-// ```python
-// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+// Arguments:
+// true_classes: The true_classes output of UnpackSparseLabels.
+// sampled_candidates: The sampled_candidates output of CandidateSampler.
+// num_true: Number of true labels per context.
//
-// # Select two rows, one segment.
-// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
-// # => [[0 0 0 0]]
+// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label
+// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element
+// is -FLOAT_MAX.
+func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ComputeAccidentalHits",
+ Input: []tf.Input{
+ true_classes, sampled_candidates,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler.
+type FixedUnigramCandidateSamplerAttr func(optionalAttr)
+
+// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value.
//
-// # Select two rows, two segment.
-// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
-// # => [[ 1 2 3 4]
-// # [-1 -2 -3 -4]]
+// value: Each valid line in this file (which should have a CSV-like format)
+// corresponds to a valid word ID. IDs are in sequential order, starting from
+// num_reserved_ids. The last entry in each line is expected to be a value
+// corresponding to the count or relative probability. Exactly one of vocab_file
+// and unigrams needs to be passed to this op.
+// If not specified, defaults to ""
+func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["vocab_file"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value.
//
-// # Select all rows, two segments.
-// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
-// # => [[0 0 0 0]
-// # [5 6 7 8]]
+// value: The distortion is used to skew the unigram probability distribution.
+// Each weight is first raised to the distortion's power before adding to the
+// internal unigram distribution. As a result, distortion = 1.0 gives regular
+// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives
+// a uniform distribution.
+// If not specified, defaults to 1
+func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["distortion"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value.
//
-// # Which is equivalent to:
-// tf.segment_sum(c, tf.constant([0, 0, 1]))
-// ```
+// value: Optionally some reserved IDs can be added in the range [0,
+// ..., num_reserved_ids) by the users. One use case is that a special unknown
+// word token is used as ID 0. These IDs will have a sampling probability of 0.
+// If not specified, defaults to 0
+func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["num_reserved_ids"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value.
//
-// Arguments:
+// value: A sampler can be used to sample from a subset of the original range
+// in order to speed up the whole computation through parallelism. This parameter
+// (together with 'shard') indicates the number of partitions that are being
+// used in the overall computation.
+// If not specified, defaults to 1
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// REQUIRES: value >= 1
+func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["num_shards"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value.
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+// value: A sampler can be used to sample from a subset of the original range
+// in order to speed up the whole computation through parallelism. This parameter
+// (together with 'num_shards') indicates the particular partition number of a
+// sampler op, when partitioning is being used.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["shard"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value.
+//
+// value: A list of unigram counts or probabilities, one per ID in sequential
+// order. Exactly one of vocab_file and unigrams should be passed to this op.
+// If not specified, defaults to <>
+func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["unigrams"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a learned unigram distribution.
+//
+// A unigram sampler could use a fixed unigram distribution read from a
+// file or passed in as an in-memory array instead of building up the distribution
+// from data on the fly. There is also an option to skew the distribution by
+// applying a distortion power to the weights.
+//
+// The vocabulary file should be in CSV-like format, with the last field
+// being the weight associated with the word.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseSegmentSum",
+ Type: "FixedUnigramCandidateSampler",
Input: []tf.Input{
- data, indices, segment_ids,
+ true_classes,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes hyperbolic sine of x element-wise.
-func Sinh(scope *Scope, x tf.Output) (y tf.Output) {
+// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
+type LogUniformCandidateSamplerAttr func(optionalAttr)
+
+// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a log-uniform distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "Sinh",
+ Type: "LogUniformCandidateSampler",
Input: []tf.Input{
- x,
+ true_classes,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes the minimum along segments of a tensor.
+// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler.
+type UniformCandidateSamplerAttr func(optionalAttr)
+
+// UniformCandidateSamplerSeed sets the optional seed attribute to value.
//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
-// for an explanation of segments.
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
//
-// This operator is similar to the unsorted segment sum operator found
-// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
-// Instead of computing the sum over segments, it computes the minimum such that:
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a uniform distribution.
//
-// \\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
-// that `segment_ids[j...] == i`.
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
//
-// If the minimum is empty for a given segment ID `i`, it outputs the largest
-// possible value for the specific numeric type,
-// `output[i] = numeric_limits<T>::max()`.
+// For each batch, this op picks a single set of sampled candidate labels.
//
-// If the given segment ID `i` is negative, then the corresponding value is
-// dropped, and will not be included in the result.
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
//
// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
//
-// segment_ids: A tensor whose shape is a prefix of `data.shape`.
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "UniformCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping.
+type GenerateVocabRemappingAttr func(optionalAttr)
+
+// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value.
//
+// value: Number of entries in the old vocab file to consider. If -1,
+// use the entire old vocabulary.
+// If not specified, defaults to -1
//
-// Returns Has same shape as data, except for the first `segment_ids.rank`
-// dimensions, which are replaced with a single dimension which has size
-// `num_segments`.
-func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// REQUIRES: value >= -1
+func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr {
+ return func(m optionalAttr) {
+ m["old_vocab_size"] = value
+ }
+}
+
+// Given a path to new and old vocabulary files, returns a remapping Tensor of
+//
+// length `num_new_vocab`, where `remapping[i]` contains the row number in the old
+// vocabulary that corresponds to row `i` in the new vocabulary (starting at line
+// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
+// in the new vocabulary is not in the old vocabulary. The old vocabulary is
+// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
+// default value of -1.
+//
+// `num_vocab_offset` enables
+// use in the partitioned variable case, and should generally be set through
+// examining partitioning info. The format of the files should be a text file,
+// with each line containing a single entity within the vocabulary.
+//
+// For example, with `new_vocab_file` a text file containing each of the following
+// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
+// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
+// `[0, -1, 2]`.
+//
+// The op also returns a count of how many entries in the new vocabulary
+// were present in the old vocabulary, which is used to calculate the number of
+// values to initialize in a weight matrix remapping
+//
+// This functionality can be used to remap both row vocabularies (typically,
+// features) and column vocabularies (typically, classes) from TensorFlow
+// checkpoints. Note that the partitioning logic relies on contiguous vocabularies
+// corresponding to div-partitioned variables. Moreover, the underlying remapping
+// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
+// use the corresponding index_table_from_file() as the FeatureColumn framework
+// does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
+//
+// Arguments:
+// new_vocab_file: Path to the new vocab file.
+// old_vocab_file: Path to the old vocab file.
+// new_vocab_offset: How many entries into the new vocab file to start reading.
+// num_new_vocab: Number of entries in the new vocab file to remap.
+//
+// Returns A Tensor of length num_new_vocab where the element at index i
+// is equal to the old ID that maps to the new ID i. This element is -1 for any
+// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab.
+func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "UnsortedSegmentMin",
+ Type: "GenerateVocabRemapping",
Input: []tf.Input{
- data, segment_ids, num_segments,
+ new_vocab_file, old_vocab_file,
},
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Broadcasts a tensor value to one or more other devices.
+func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
+ opspec := tf.OpSpec{
+ Type: "CollectiveBcastSend",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Computes rectified linear 6: `min(max(features, 0), 6)`.
-func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
+// Mutually reduces multiple tensors of identical type and shape.
+func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
opspec := tf.OpSpec{
- Type: "Relu6",
+ Type: "CollectiveReduce",
Input: []tf.Input{
- features,
+ input,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Computes the sum along segments of a tensor.
+// AbortAttr is an optional argument to Abort.
+type AbortAttr func(optionalAttr)
+
+// AbortErrorMsg sets the optional error_msg attribute to value.
//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
+// value: A string which is the message associated with the exception.
+// If not specified, defaults to ""
+func AbortErrorMsg(value string) AbortAttr {
+ return func(m optionalAttr) {
+ m["error_msg"] = value
+ }
+}
+
+// AbortExitWithoutError sets the optional exit_without_error attribute to value.
+// If not specified, defaults to false
+func AbortExitWithoutError(value bool) AbortAttr {
+ return func(m optionalAttr) {
+ m["exit_without_error"] = value
+ }
+}
+
+// Raise a exception to abort the process when called.
//
-// Computes a tensor such that
-// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
-// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
-// need not be sorted and need not cover all values in the full
-// range of valid values.
+// If exit_without_error is true, the process will exit normally,
+// otherwise it will exit with a SIGABORT signal.
//
-// If the sum is empty for a given segment ID `i`, `output[i] = 0`.
-// If the given segment ID `i` is negative, the value is dropped and will not be
-// added to the sum of the segment.
+// Returns nothing but an exception.
//
-// `num_segments` should equal the number of distinct segment IDs.
+// Returns the created operation.
+func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Abort",
+
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Forwards the input to the output.
//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentSum.png" alt>
-// </div>
+// This operator represents the loop termination condition used by the
+// "pivot" switches of a loop.
//
// Arguments:
+// input: A boolean scalar, representing the branch predicate of the Switch op.
//
-// segment_ids: A tensor whose shape is a prefix of `data.shape`.
+// Returns The same tensor as `input`.
+func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LoopCond",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns a tensor of zeros with the same shape and type as x.
//
+// Arguments:
+// x: a tensor of type T.
//
-// Returns Has same shape as data, except for the first `segment_ids.rank`
-// dimensions, which are replaced with a single dimension which has size
-// `num_segments`.
-func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// Returns a tensor of the same shape and type as x but filled with zeros.
+func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "UnsortedSegmentSum",
+ Type: "ZerosLike",
Input: []tf.Input{
- data, segment_ids, num_segments,
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns a copy of the input tensor.
+func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Snapshot",
+ Input: []tf.Input{
+ input,
},
}
op := scope.AddOperation(opspec)
@@ -4413,6 +4750,162 @@ func SlideDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output,
return op.Output(0)
}
+// EditDistanceAttr is an optional argument to EditDistance.
+type EditDistanceAttr func(optionalAttr)
+
+// EditDistanceNormalize sets the optional normalize attribute to value.
+//
+// value: boolean (if true, edit distances are normalized by length of truth).
+//
+// The output is:
+// If not specified, defaults to true
+func EditDistanceNormalize(value bool) EditDistanceAttr {
+ return func(m optionalAttr) {
+ m["normalize"] = value
+ }
+}
+
+// Computes the (possibly normalized) Levenshtein Edit Distance.
+//
+// The inputs are variable-length sequences provided by SparseTensors
+// (hypothesis_indices, hypothesis_values, hypothesis_shape)
+// and
+// (truth_indices, truth_values, truth_shape).
+//
+// The inputs are:
+//
+// Arguments:
+// hypothesis_indices: The indices of the hypothesis list SparseTensor.
+// This is an N x R int64 matrix.
+// hypothesis_values: The values of the hypothesis list SparseTensor.
+// This is an N-length vector.
+// hypothesis_shape: The shape of the hypothesis list SparseTensor.
+// This is an R-length vector.
+// truth_indices: The indices of the truth list SparseTensor.
+// This is an M x R int64 matrix.
+// truth_values: The values of the truth list SparseTensor.
+// This is an M-length vector.
+// truth_shape: truth indices, vector.
+//
+// Returns A dense float tensor with rank R - 1.
+//
+// For the example input:
+//
+// // hypothesis represents a 2x1 matrix with variable-length values:
+// // (0,0) = ["a"]
+// // (1,0) = ["b"]
+// hypothesis_indices = [[0, 0, 0],
+// [1, 0, 0]]
+// hypothesis_values = ["a", "b"]
+// hypothesis_shape = [2, 1, 1]
+//
+// // truth represents a 2x2 matrix with variable-length values:
+// // (0,0) = []
+// // (0,1) = ["a"]
+// // (1,0) = ["b", "c"]
+// // (1,1) = ["a"]
+// truth_indices = [[0, 1, 0],
+// [1, 0, 0],
+// [1, 0, 1],
+// [1, 1, 0]]
+// truth_values = ["a", "b", "c", "a"]
+// truth_shape = [2, 2, 2]
+// normalize = true
+//
+// The output will be:
+//
+// // output is a 2x2 matrix with edit distances normalized by truth lengths.
+// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis
+// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis
+func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EditDistance",
+ Input: []tf.Input{
+ hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput.
+type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr)
+
+// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, height, width, channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, channels, height, width].
+// If not specified, defaults to "NHWC"
+func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value.
+//
+// value: 1-D tensor of length 4. The dilation factor for each dimension of
+// `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+// element on that dimension. The dimension order is determined by the value of
+// `data_format`, see above for details. Dilations in the batch and depth
+// dimensions must be 1.
+// If not specified, defaults to <i:1 i:1 i:1 i:1 >
+func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
+ return func(m optionalAttr) {
+ m["dilations"] = value
+ }
+}
+
+// Computes the gradients of depthwise convolution with respect to the input.
+//
+// Arguments:
+// input_sizes: An integer vector representing the shape of `input`, based
+// on `data_format`. For example, if `data_format` is 'NHWC' then
+// `input` is a 4-D `[batch, height, width, channels]` tensor.
+// filter: 4-D with shape
+// `[filter_height, filter_width, in_channels, depthwise_multiplier]`.
+// out_backprop: 4-D with shape based on `data_format`.
+// For example, if `data_format` is 'NHWC' then
+// out_backprop shape is `[batch, out_height, out_width, out_channels]`.
+// Gradients w.r.t. the output of the convolution.
+// strides: The stride of the sliding window for each dimension of the input
+// of the convolution.
+// padding: The type of padding algorithm to use.
+//
+// Returns 4-D with shape according to `data_format`. For example, if
+// `data_format` is 'NHWC', output shape is `[batch, in_height,
+// in_width, in_channels]`. Gradient w.r.t. the input of the
+// convolution.
+func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DepthwiseConv2dNativeBackpropInput",
+ Input: []tf.Input{
+ input_sizes, filter, out_backprop,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ApproximateEqualAttr is an optional argument to ApproximateEqual.
type ApproximateEqualAttr func(optionalAttr)
@@ -4581,33 +5074,90 @@ func SparseReduceSumSparse(scope *Scope, input_indices tf.Output, input_values t
return op.Output(0), op.Output(1), op.Output(2)
}
-// Returns x + y element-wise.
+// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler.
+type AllCandidateSamplerAttr func(optionalAttr)
+
+// AllCandidateSamplerSeed sets the optional seed attribute to value.
//
-// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a learned unigram distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to produce.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "AddV2",
+ Type: "AllCandidateSampler",
Input: []tf.Input{
- x, y,
+ true_classes,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes exponential of x element-wise. \\(y = e^x\\).
-func Exp(scope *Scope, x tf.Output) (y tf.Output) {
+// Returns x + y element-wise.
+//
+// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Exp",
+ Type: "AddV2",
Input: []tf.Input{
- x,
+ x, y,
},
}
op := scope.AddOperation(opspec)
@@ -4768,104 +5318,6 @@ func Asin(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// Computes the maximum along segments of a tensor.
-//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
-//
-// This operator is similar to the unsorted segment sum operator found
-// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
-// Instead of computing the sum over segments, it computes the maximum such that:
-//
-// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
-// that `segment_ids[j...] == i`.
-//
-// If the maximum is empty for a given segment ID `i`, it outputs the smallest
-// possible value for the specific numeric type,
-// `output[i] = numeric_limits<T>::lowest()`.
-//
-// If the given segment ID `i` is negative, then the corresponding value is
-// dropped, and will not be included in the result.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
-// </div>
-//
-// Arguments:
-//
-// segment_ids: A tensor whose shape is a prefix of `data.shape`.END
-// }
-// out_arg {
-// name: "output"
-// description: <<END
-// Has same shape as data, except for the first `segment_ids.rank`
-// dimensions, which are replaced with a single dimension which has size
-// `num_segments`.
-//
-func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "UnsortedSegmentMax",
- Input: []tf.Input{
- data, segment_ids, num_segments,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// NthElementAttr is an optional argument to NthElement.
-type NthElementAttr func(optionalAttr)
-
-// NthElementReverse sets the optional reverse attribute to value.
-//
-// value: When set to True, find the nth-largest value in the vector and vice
-// versa.
-// If not specified, defaults to false
-func NthElementReverse(value bool) NthElementAttr {
- return func(m optionalAttr) {
- m["reverse"] = value
- }
-}
-
-// Finds values of the `n`-th order statistic for the last dimension.
-//
-// If the input is a vector (rank-1), finds the entries which is the nth-smallest
-// value in the vector and outputs their values as scalar tensor.
-//
-// For matrices (resp. higher rank input), computes the entries which is the
-// nth-smallest value in each row (resp. vector along the last dimension). Thus,
-//
-// values.shape = input.shape[:-1]
-//
-// Arguments:
-// input: 1-D or higher with last dimension at least `n+1`.
-// n: 0-D. Position of sorted vector to select along the last dimension (along
-// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
-//
-// Returns The `n`-th order statistic along each last dimensional slice.
-func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "NthElement",
- Input: []tf.Input{
- input, n,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the sum along sparse segments of a tensor.
//
// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
@@ -5142,6 +5594,74 @@ func TensorSliceDataset(scope *Scope, components []tf.Output, output_shapes []tf
return op.Output(0)
}
+// Computes hyperbolic sine of x element-wise.
+func Sinh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Sinh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the sum along sparse segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
+// dimension, selecting a subset of dimension 0, specified by `indices`.
+//
+// For example:
+//
+// ```python
+// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+//
+// # Select two rows, one segment.
+// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
+// # => [[0 0 0 0]]
+//
+// # Select two rows, two segment.
+// tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
+// # => [[ 1 2 3 4]
+// # [-1 -2 -3 -4]]
+//
+// # Select all rows, two segments.
+// tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
+// # => [[0 0 0 0]
+// # [5 6 7 8]]
+//
+// # Which is equivalent to:
+// tf.segment_sum(c, tf.constant([0, 0, 1]))
+// ```
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentSum(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentSum",
+ Input: []tf.Input{
+ data, indices, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes natural logarithm of (1 + x) element-wise.
//
// I.e., \\(y = \log_e (1 + x)\\).
@@ -6701,6 +7221,63 @@ func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.
return components
}
+// Computes rectified linear 6: `min(max(features, 0), 6)`.
+func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Relu6",
+ Input: []tf.Input{
+ features,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the minimum along segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// for an explanation of segments.
+//
+// This operator is similar to the unsorted segment sum operator found
+// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+// Instead of computing the sum over segments, it computes the minimum such that:
+//
+// \\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+// that `segment_ids[j...] == i`.
+//
+// If the minimum is empty for a given segment ID `i`, it outputs the largest
+// possible value for the specific numeric type,
+// `output[i] = numeric_limits<T>::max()`.
+//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
+// Arguments:
+//
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.
+//
+//
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "UnsortedSegmentMin",
+ Input: []tf.Input{
+ data, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes rectified linear gradients for a Relu operation.
//
// Arguments:
@@ -7674,6 +8251,44 @@ func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAt
return op.Output(0)
}
+// Bucketizes 'input' based on 'boundaries'.
+//
+// For example, if the inputs are
+// boundaries = [0, 10, 100]
+// input = [[-5, 10000]
+// [150, 10]
+// [5, 100]]
+//
+// then the output will be
+// output = [[0, 3]
+// [3, 2]
+// [1, 3]]
+//
+// Arguments:
+// input: Any shape of Tensor contains with int or float type.
+// boundaries: A sorted list of floats gives the boundary of the buckets.
+//
+// Returns Same shape with 'input', each value of input replaced with bucket index.
+//
+// @compatibility(numpy)
+// Equivalent to np.digitize.
+// @end_compatibility
+func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"boundaries": boundaries}
+ opspec := tf.OpSpec{
+ Type: "Bucketize",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// FusedBatchNormV2Attr is an optional argument to FusedBatchNormV2.
type FusedBatchNormV2Attr func(optionalAttr)
@@ -8790,6 +9405,119 @@ func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output
return op.Output(0)
}
+// Computes exponential of x element-wise. \\(y = e^x\\).
+func Exp(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Exp",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// NthElementAttr is an optional argument to NthElement.
+type NthElementAttr func(optionalAttr)
+
+// NthElementReverse sets the optional reverse attribute to value.
+//
+// value: When set to True, find the nth-largest value in the vector and vice
+// versa.
+// If not specified, defaults to false
+func NthElementReverse(value bool) NthElementAttr {
+ return func(m optionalAttr) {
+ m["reverse"] = value
+ }
+}
+
+// Finds values of the `n`-th order statistic for the last dimension.
+//
+// If the input is a vector (rank-1), finds the entries which is the nth-smallest
+// value in the vector and outputs their values as scalar tensor.
+//
+// For matrices (resp. higher rank input), computes the entries which is the
+// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+//
+// values.shape = input.shape[:-1]
+//
+// Arguments:
+// input: 1-D or higher with last dimension at least `n+1`.
+// n: 0-D. Position of sorted vector to select along the last dimension (along
+// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
+//
+// Returns The `n`-th order statistic along each last dimensional slice.
+func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "NthElement",
+ Input: []tf.Input{
+ input, n,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the maximum along segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// This operator is similar to the unsorted segment sum operator found
+// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+// Instead of computing the sum over segments, it computes the maximum such that:
+//
+// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+// that `segment_ids[j...] == i`.
+//
+// If the maximum is empty for a given segment ID `i`, it outputs the smallest
+// possible value for the specific numeric type,
+// `output[i] = numeric_limits<T>::lowest()`.
+//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
+// </div>
+//
+// Arguments:
+//
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.END
+// }
+// out_arg {
+// name: "output"
+// description: <<END
+// Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+//
+func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "UnsortedSegmentMax",
+ Input: []tf.Input{
+ data, segment_ids, num_segments,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Transforms a vector of brain.Example protos (as strings) into typed tensors.
//
// Arguments:
@@ -9491,56 +10219,62 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
-// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
-type ResourceApplyFtrlAttr func(optionalAttr)
+// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D.
+type FusedResizeAndPadConv2DAttr func(optionalAttr)
-// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value.
+// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value.
//
-// value: If `True`, updating of the var and accum tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
+// value: If true, the centers of the 4 corner pixels of the input and output tensors are
+// aligned, preserving the values at the corner pixels. Defaults to false.
// If not specified, defaults to false
-func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr {
+func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr {
return func(m optionalAttr) {
- m["use_locking"] = value
+ m["resize_align_corners"] = value
}
}
-// Update '*var' according to the Ftrl-proximal scheme.
+// Performs a resize and padding as a preprocess during a convolution.
//
-// accum_new = accum + grad * grad
-// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
-// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
-// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
-// accum = accum_new
+// It's often possible to do spatial transformations more efficiently as part of
+// the packing stage of a convolution, so this op allows for an optimized
+// implementation where these stages are fused together. This prevents the need to
+// write out the intermediate results as whole tensors, reducing memory pressure,
+// and we can get some latency gains by merging the transformation calculations.
+// The data_format attribute for Conv2D isn't supported by this op, and defaults to
+// 'NHWC' order.
+// Internally this op uses a single per-graph scratch buffer, which means that it
+// will block if multiple versions are being run in parallel. This is because this
+// operator is primarily an optimization to minimize memory usage.
//
// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// linear: Should be from a Variable().
-// grad: The gradient.
-// lr: Scaling factor. Must be a scalar.
-// l1: L1 regulariation. Must be a scalar.
-// l2: L2 regulariation. Must be a scalar.
-// lr_power: Scaling factor. Must be a scalar.
+// input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
+// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+// new size for the images.
+// paddings: A two-column matrix specifying the padding sizes. The number of
+// rows must be the same as the rank of `input`.
+// filter: 4-D with shape
+// `[filter_height, filter_width, in_channels, out_channels]`.
//
-// Returns the created operation.
-func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) {
+// strides: 1-D of length 4. The stride of the sliding window for each dimension
+// of `input`. Must be in the same order as the dimension specified with format.
+// padding: The type of padding algorithm to use.
+func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
+ attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ResourceApplyFtrl",
+ Type: "FusedResizeAndPadConv2D",
Input: []tf.Input{
- var_, accum, linear, grad, lr, l1, l2, lr_power,
+ input, size, paddings, filter,
},
Attrs: attrs,
}
- return scope.AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
// RandomUniformAttr is an optional argument to RandomUniform.
@@ -9597,6 +10331,58 @@ func RandomUniform(scope *Scope, shape tf.Output, dtype tf.DataType, optional ..
return op.Output(0)
}
+// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
+type ResourceApplyFtrlAttr func(optionalAttr)
+
+// ResourceApplyFtrlUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var and accum tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyFtrlUseLocking(value bool) ResourceApplyFtrlAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the Ftrl-proximal scheme.
+//
+// accum_new = accum + grad * grad
+// linear += grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
+// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
+// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
+// accum = accum_new
+//
+// Arguments:
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// linear: Should be from a Variable().
+// grad: The gradient.
+// lr: Scaling factor. Must be a scalar.
+// l1: L1 regulariation. Must be a scalar.
+// l2: L2 regulariation. Must be a scalar.
+// lr_power: Scaling factor. Must be a scalar.
+//
+// Returns the created operation.
+func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyFtrl",
+ Input: []tf.Input{
+ var_, accum, linear, grad, lr, l1, l2, lr_power,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Encode audio data using the WAV file format.
//
// This operation will generate a string suitable to be saved out to create a .wav
@@ -9733,23 +10519,6 @@ func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...Ass
return scope.AddOperation(opspec)
}
-// Broadcasts a tensor value to one or more other devices.
-func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
- opspec := tf.OpSpec{
- Type: "CollectiveBcastSend",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Split a `SparseTensor` into `num_split` tensors along one dimension.
//
// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
@@ -9873,6 +10642,118 @@ func ResourceSparseApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, li
return scope.AddOperation(opspec)
}
+// Calculates gains for each feature and returns the best possible split information for the feature.
+//
+// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
+//
+// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
+//
+// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
+//
+// The length of output lists are all of the same length, `num_features`.
+// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature.
+//
+// Arguments:
+// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive).
+// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
+// l1: l1 regularization factor on leaf weights, per instance based.
+// l2: l2 regularization factor on leaf weights, per instance based.
+// tree_complexity: adjustment to the gain, per leaf based.
+// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting.
+// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors.
+//
+// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
+func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"max_splits": max_splits}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesCalculateBestGainsPerFeature",
+ Input: []tf.Input{
+ node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil {
+ scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
+ return
+ }
+ return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list
+}
+
+// EncodePngAttr is an optional argument to EncodePng.
+type EncodePngAttr func(optionalAttr)
+
+// EncodePngCompression sets the optional compression attribute to value.
+//
+// value: Compression level.
+// If not specified, defaults to -1
+func EncodePngCompression(value int64) EncodePngAttr {
+ return func(m optionalAttr) {
+ m["compression"] = value
+ }
+}
+
+// PNG-encode an image.
+//
+// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`
+// where `channels` is:
+//
+// * 1: for grayscale.
+// * 2: for grayscale + alpha.
+// * 3: for RGB.
+// * 4: for RGBA.
+//
+// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
+// default or a value from 0 to 9. 9 is the highest compression level, generating
+// the smallest output, but is slower.
+//
+// Arguments:
+// image: 3-D with shape `[height, width, channels]`.
+//
+// Returns 0-D. PNG-encoded image.
+func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodePng",
+ Input: []tf.Input{
+ image,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DataFormatVecPermuteAttr is an optional argument to DataFormatVecPermute.
type DataFormatVecPermuteAttr func(optionalAttr)
@@ -10109,6 +10990,112 @@ func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value
return op.Output(0)
}
+// This op consumes a lock created by `MutexLock`.
+//
+// This op exists to consume a tensor created by `MutexLock` (other than
+// direct control dependencies). It should be the only that consumes the tensor,
+// and will raise an error if it is not. Its only purpose is to keep the
+// mutex lock tensor alive until it is consumed by this op.
+//
+// **NOTE**: This operation must run on the same device as its input. This may
+// be enforced via the `colocate_with` mechanism.
+//
+// Arguments:
+// mutex_lock: A tensor returned by `MutexLock`.
+//
+// Returns the created operation.
+func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ConsumeMutexLock",
+ Input: []tf.Input{
+ mutex_lock,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd.
+type ResourceScatterNdAddAttr func(optionalAttr)
+
+// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value.
+//
+// value: An optional bool. Defaults to True. If True, the assignment will
+// be protected by a lock; otherwise the behavior is undefined,
+// but may exhibit less contention.
+// If not specified, defaults to true
+func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Adds sparse `updates` to individual values or slices within a given
+//
+// variable according to `indices`.
+//
+// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+//
+// `indices` must be integer tensor, containing indices into `ref`.
+// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+//
+// The innermost dimension of `indices` (with length `K`) corresponds to
+// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+// dimension of `ref`.
+//
+// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+//
+// ```
+// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+// ```
+//
+// For example, say we want to update 4 scattered elements to a rank-1 tensor to
+// 8 elements. In Python, that update would look like this:
+//
+// ```python
+// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True)
+// indices = tf.constant([[4], [3], [1] ,[7]])
+// updates = tf.constant([9, 10, 11, 12])
+// update = tf.scatter_nd_add(ref, indices, updates)
+// with tf.Session() as sess:
+// print sess.run(update)
+// ```
+//
+// The resulting update to ref would look like this:
+//
+// [1, 12, 3, 14, 14, 6, 7, 20]
+//
+// See `tf.scatter_nd` for more details about how to make updates to
+// slices.
+//
+// Arguments:
+// ref: A resource handle. Must be from a VarHandleOp.
+// indices: A Tensor. Must be one of the following types: int32, int64.
+// A tensor of indices into ref.
+// updates: A Tensor. Must have the same type as ref. A tensor of
+// values to add to ref.
+//
+// Returns the created operation.
+func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceScatterNdAdd",
+ Input: []tf.Input{
+ ref, indices, updates,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Updates the tree ensemble by either adding a layer to the last tree being grown
//
// or by starting a new tree.
@@ -10850,68 +11837,31 @@ func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
-// ResourceScatterNdAddAttr is an optional argument to ResourceScatterNdAdd.
-type ResourceScatterNdAddAttr func(optionalAttr)
+// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal.
+type StatelessRandomNormalAttr func(optionalAttr)
-// ResourceScatterNdAddUseLocking sets the optional use_locking attribute to value.
+// StatelessRandomNormalDtype sets the optional dtype attribute to value.
//
-// value: An optional bool. Defaults to True. If True, the assignment will
-// be protected by a lock; otherwise the behavior is undefined,
-// but may exhibit less contention.
-// If not specified, defaults to true
-func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr {
+// value: The type of the output.
+// If not specified, defaults to DT_FLOAT
+func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr {
return func(m optionalAttr) {
- m["use_locking"] = value
+ m["dtype"] = value
}
}
-// Adds sparse `updates` to individual values or slices within a given
-//
-// variable according to `indices`.
-//
-// `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
-//
-// `indices` must be integer tensor, containing indices into `ref`.
-// It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
-//
-// The innermost dimension of `indices` (with length `K`) corresponds to
-// indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
-// dimension of `ref`.
-//
-// `updates` is `Tensor` of rank `Q-1+P-K` with shape:
-//
-// ```
-// [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
-// ```
-//
-// For example, say we want to update 4 scattered elements to a rank-1 tensor to
-// 8 elements. In Python, that update would look like this:
-//
-// ```python
-// ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8], use_resource=True)
-// indices = tf.constant([[4], [3], [1] ,[7]])
-// updates = tf.constant([9, 10, 11, 12])
-// update = tf.scatter_nd_add(ref, indices, updates)
-// with tf.Session() as sess:
-// print sess.run(update)
-// ```
-//
-// The resulting update to ref would look like this:
+// Outputs deterministic pseudorandom values from a normal distribution.
//
-// [1, 12, 3, 14, 14, 6, 7, 20]
+// The generated values will have mean 0 and standard deviation 1.
//
-// See `tf.scatter_nd` for more details about how to make updates to
-// slices.
+// The outputs are a deterministic function of `shape` and `seed`.
//
// Arguments:
-// ref: A resource handle. Must be from a VarHandleOp.
-// indices: A Tensor. Must be one of the following types: int32, int64.
-// A tensor of indices into ref.
-// updates: A Tensor. Must have the same type as ref. A tensor of
-// values to add to ref.
+// shape: The shape of the output tensor.
+// seed: 2 seeds (shape [2]).
//
-// Returns the created operation.
-func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, updates tf.Output, optional ...ResourceScatterNdAddAttr) (o *tf.Operation) {
+// Returns Random values with specified shape.
+func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -10920,57 +11870,93 @@ func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, update
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ResourceScatterNdAdd",
+ Type: "StatelessRandomNormal",
Input: []tf.Input{
- ref, indices, updates,
+ shape, seed,
},
Attrs: attrs,
}
- return scope.AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Mutually reduces multiple tensors of identical type and shape.
-func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
+// Creates a sequence of numbers.
+//
+// This operation creates a sequence of numbers that begins at `start` and
+// extends by increments of `delta` up to but not including `limit`.
+//
+// For example:
+//
+// ```
+// # 'start' is 3
+// # 'limit' is 18
+// # 'delta' is 3
+// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+// ```
+//
+// Arguments:
+// start: 0-D (scalar). First entry in the sequence.
+// limit: 0-D (scalar). Upper limit of sequence, exclusive.
+// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`.
+//
+// Returns 1-D.
+func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
opspec := tf.OpSpec{
- Type: "CollectiveReduce",
+ Type: "Range",
Input: []tf.Input{
- input,
+ start, limit, delta,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal.
-type StatelessRandomNormalAttr func(optionalAttr)
+// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum.
+type ResourceApplyMomentumAttr func(optionalAttr)
-// StatelessRandomNormalDtype sets the optional dtype attribute to value.
+// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value.
//
-// value: The type of the output.
-// If not specified, defaults to DT_FLOAT
-func StatelessRandomNormalDtype(value tf.DataType) StatelessRandomNormalAttr {
+// value: If `True`, updating of the var and accum tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr {
return func(m optionalAttr) {
- m["dtype"] = value
+ m["use_locking"] = value
}
}
-// Outputs deterministic pseudorandom values from a normal distribution.
+// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value.
//
-// The generated values will have mean 0 and standard deviation 1.
+// value: If `True`, the tensor passed to compute grad will be
+// var - lr * momentum * accum, so in the end, the var you get is actually
+// var - lr * momentum * accum.
+// If not specified, defaults to false
+func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr {
+ return func(m optionalAttr) {
+ m["use_nesterov"] = value
+ }
+}
+
+// Update '*var' according to the momentum scheme. Set use_nesterov = True if you
//
-// The outputs are a deterministic function of `shape` and `seed`.
+// want to use Nesterov momentum.
+//
+// accum = accum * momentum + grad
+// var -= lr * accum
//
// Arguments:
-// shape: The shape of the output tensor.
-// seed: 2 seeds (shape [2]).
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// grad: The gradient.
+// momentum: Momentum. Must be a scalar.
//
-// Returns Random values with specified shape.
-func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, optional ...StatelessRandomNormalAttr) (output tf.Output) {
+// Returns the created operation.
+func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) {
if scope.Err() != nil {
return
}
@@ -10979,12 +11965,54 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option
a(attrs)
}
opspec := tf.OpSpec{
- Type: "StatelessRandomNormal",
+ Type: "ResourceApplyMomentum",
Input: []tf.Input{
- shape, seed,
+ var_, accum, lr, grad, momentum,
},
Attrs: attrs,
}
+ return scope.AddOperation(opspec)
+}
+
+// Exits the current frame to its parent frame.
+//
+// Exit makes its input `data` available to the parent frame.
+//
+// Arguments:
+// data: The tensor to be made available to the parent frame.
+//
+// Returns The same tensor as `data`.
+func Exit(scope *Scope, data tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Exit",
+ Input: []tf.Input{
+ data,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Produce a string tensor that encodes the state of a Reader.
+//
+// Not all Readers support being serialized, so this can produce an
+// Unimplemented error.
+//
+// Arguments:
+// reader_handle: Handle to a Reader.
+func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReaderSerializeStateV2",
+ Input: []tf.Input{
+ reader_handle,
+ },
+ }
op := scope.AddOperation(opspec)
return op.Output(0)
}
@@ -11122,68 +12150,6 @@ func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (o
return op.Output(0)
}
-// StringSplitV2Attr is an optional argument to StringSplitV2.
-type StringSplitV2Attr func(optionalAttr)
-
-// StringSplitV2Maxsplit sets the optional maxsplit attribute to value.
-//
-// value: An `int`. If `maxsplit > 0`, limit of the split of the result.
-// If not specified, defaults to -1
-func StringSplitV2Maxsplit(value int64) StringSplitV2Attr {
- return func(m optionalAttr) {
- m["maxsplit"] = value
- }
-}
-
-// Split elements of `source` based on `sep` into a `SparseTensor`.
-//
-// Let N be the size of source (typically N will be the batch size). Split each
-// element of `source` based on `sep` and return a `SparseTensor`
-// containing the split tokens. Empty tokens are ignored.
-//
-// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
-// then the output will be
-// ```
-// st.indices = [0, 0;
-// 0, 1;
-// 1, 0;
-// 1, 1;
-// 1, 2]
-// st.shape = [2, 3]
-// st.values = ['hello', 'world', 'a', 'b', 'c']
-// ```
-//
-// If `sep` is given, consecutive delimiters are not grouped together and are
-// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
-// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
-// string, consecutive whitespace are regarded as a single separator, and the
-// result will contain no empty strings at the startor end if the string has
-// leading or trailing whitespace.
-//
-// Note that the above mentioned behavior matches python's str.split.
-//
-// Arguments:
-// input: `1-D` string `Tensor`, the strings to split.
-// sep: `0-D` string `Tensor`, the delimiter character.
-func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringSplitV2",
- Input: []tf.Input{
- input, sep,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// MaxPoolAttr is an optional argument to MaxPool.
type MaxPoolAttr func(optionalAttr)
@@ -11664,6 +12630,51 @@ func Conj(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// ProdAttr is an optional argument to Prod.
+type ProdAttr func(optionalAttr)
+
+// ProdKeepDims sets the optional keep_dims attribute to value.
+//
+// value: If true, retain reduced dimensions with length 1.
+// If not specified, defaults to false
+func ProdKeepDims(value bool) ProdAttr {
+ return func(m optionalAttr) {
+ m["keep_dims"] = value
+ }
+}
+
+// Computes the product of elements across dimensions of a tensor.
+//
+// Reduces `input` along the dimensions given in `axis`. Unless
+// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+// `axis`. If `keep_dims` is true, the reduced dimensions are
+// retained with length 1.
+//
+// Arguments:
+// input: The tensor to reduce.
+// axis: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
+//
+// Returns The reduced tensor.
+func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Prod",
+ Input: []tf.Input{
+ input, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResizeBilinearAttr is an optional argument to ResizeBilinear.
type ResizeBilinearAttr func(optionalAttr)
@@ -11708,21 +12719,6 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...
return op.Output(0)
}
-// Computes softsign: `features / (abs(features) + 1)`.
-func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Softsign",
- Input: []tf.Input{
- features,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a TensorList which, when stacked, has the value of `tensor`.
//
// Each tensor in the result list corresponds to one row of the input tensor.
@@ -11743,81 +12739,6 @@ func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Outpu
return op.Output(0)
}
-// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping.
-type GenerateVocabRemappingAttr func(optionalAttr)
-
-// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value.
-//
-// value: Number of entries in the old vocab file to consider. If -1,
-// use the entire old vocabulary.
-// If not specified, defaults to -1
-//
-// REQUIRES: value >= -1
-func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr {
- return func(m optionalAttr) {
- m["old_vocab_size"] = value
- }
-}
-
-// Given a path to new and old vocabulary files, returns a remapping Tensor of
-//
-// length `num_new_vocab`, where `remapping[i]` contains the row number in the old
-// vocabulary that corresponds to row `i` in the new vocabulary (starting at line
-// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
-// in the new vocabulary is not in the old vocabulary. The old vocabulary is
-// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
-// default value of -1.
-//
-// `num_vocab_offset` enables
-// use in the partitioned variable case, and should generally be set through
-// examining partitioning info. The format of the files should be a text file,
-// with each line containing a single entity within the vocabulary.
-//
-// For example, with `new_vocab_file` a text file containing each of the following
-// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
-// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
-// `[0, -1, 2]`.
-//
-// The op also returns a count of how many entries in the new vocabulary
-// were present in the old vocabulary, which is used to calculate the number of
-// values to initialize in a weight matrix remapping
-//
-// This functionality can be used to remap both row vocabularies (typically,
-// features) and column vocabularies (typically, classes) from TensorFlow
-// checkpoints. Note that the partitioning logic relies on contiguous vocabularies
-// corresponding to div-partitioned variables. Moreover, the underlying remapping
-// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
-// use the corresponding index_table_from_file() as the FeatureColumn framework
-// does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
-//
-// Arguments:
-// new_vocab_file: Path to the new vocab file.
-// old_vocab_file: Path to the old vocab file.
-// new_vocab_offset: How many entries into the new vocab file to start reading.
-// num_new_vocab: Number of entries in the new vocab file to remap.
-//
-// Returns A Tensor of length num_new_vocab where the element at index i
-// is equal to the old ID that maps to the new ID i. This element is -1 for any
-// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab.
-func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "GenerateVocabRemapping",
- Input: []tf.Input{
- new_vocab_file, old_vocab_file,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Assigns sparse updates to the variable referenced by `resource`.
//
// This operation computes
@@ -12024,65 +12945,6 @@ func StageClear(scope *Scope, dtypes []tf.DataType, optional ...StageClearAttr)
return scope.AddOperation(opspec)
}
-// ComputeAccidentalHitsAttr is an optional argument to ComputeAccidentalHits.
-type ComputeAccidentalHitsAttr func(optionalAttr)
-
-// ComputeAccidentalHitsSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func ComputeAccidentalHitsSeed(value int64) ComputeAccidentalHitsAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// ComputeAccidentalHitsSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func ComputeAccidentalHitsSeed2(value int64) ComputeAccidentalHitsAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Computes the ids of the positions in sampled_candidates that match true_labels.
-//
-// When doing log-odds NCE, the result of this op should be passed through a
-// SparseToDense op, then added to the logits of the sampled candidates. This has
-// the effect of 'removing' the sampled labels that match the true labels by
-// making the classifier sure that they are sampled labels.
-//
-// Arguments:
-// true_classes: The true_classes output of UnpackSparseLabels.
-// sampled_candidates: The sampled_candidates output of CandidateSampler.
-// num_true: Number of true labels per context.
-//
-// Returns A vector of indices corresponding to rows of true_candidates.A vector of IDs of positions in sampled_candidates that match a true_label
-// for the row with the corresponding index in indices.A vector of the same length as indices and ids, in which each element
-// is -FLOAT_MAX.
-func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candidates tf.Output, num_true int64, optional ...ComputeAccidentalHitsAttr) (indices tf.Output, ids tf.Output, weights tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ComputeAccidentalHits",
- Input: []tf.Input{
- true_classes, sampled_candidates,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// QuantizedRelu6Attr is an optional argument to QuantizedRelu6.
type QuantizedRelu6Attr func(optionalAttr)
@@ -12211,6 +13073,17 @@ func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...Fix
return op.Output(0)
}
+// StringLengthAttr is an optional argument to StringLength.
+type StringLengthAttr func(optionalAttr)
+
+// StringLengthUnit sets the optional unit attribute to value.
+// If not specified, defaults to "BYTE"
+func StringLengthUnit(value string) StringLengthAttr {
+ return func(m optionalAttr) {
+ m["unit"] = value
+ }
+}
+
// String lengths of `input`.
//
// Computes the length of each string given in the input tensor.
@@ -12220,15 +13093,20 @@ func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...Fix
//
// Returns Integer tensor that has the same shape as `input`. The output contains the
// element-wise string lengths of `input`.
-func StringLength(scope *Scope, input tf.Output) (output tf.Output) {
+func StringLength(scope *Scope, input tf.Output, optional ...StringLengthAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
Type: "StringLength",
Input: []tf.Input{
input,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -12863,6 +13741,27 @@ func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAtt
return op.Output(0)
}
+// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+//
+// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
+// layer.
+func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesGetEnsembleStates",
+ Input: []tf.Input{
+ tree_ensemble_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign.
type ResourceApplyPowerSignAttr func(optionalAttr)
@@ -13645,78 +14544,6 @@ func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output,
return op.Output(0)
}
-// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler.
-type AllCandidateSamplerAttr func(optionalAttr)
-
-// AllCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a learned unigram distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to produce.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AllCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Adds two `SparseTensor` objects to produce another `SparseTensor`.
//
// The input `SparseTensor` objects' indices are assumed ordered in standard
@@ -14527,6 +15354,78 @@ func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (o
return op.Output(0)
}
+// Returns the last element of the input list as well as a list with all but that element.
+//
+// Fails if the list is empty.
+//
+// input_handle: the input list
+// tensor: the withdrawn last element of the list
+// element_dtype: the type of elements in the list
+// element_shape: the shape of the output tensor
+func TensorListPopBack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType) (output_handle tf.Output, tensor tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"element_dtype": element_dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorListPopBack",
+ Input: []tf.Input{
+ input_handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad.
+type MaxPoolGradGradAttr func(optionalAttr)
+
+// MaxPoolGradGradDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, in_channels, in_height, in_width].
+// If not specified, defaults to "NHWC"
+func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Computes second-order gradients of the maxpooling function.
+//
+// Arguments:
+// orig_input: The original input tensor.
+// orig_output: The original output tensor.
+// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
+// ksize: The size of the window for each dimension of the input tensor.
+// strides: The stride of the sliding window for each dimension of the
+// input tensor.
+// padding: The type of padding algorithm to use.
+//
+// Returns Gradients of gradients w.r.t. the input to `max_pool`.
+func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MaxPoolGradGrad",
+ Input: []tf.Input{
+ orig_input, orig_output, grad,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3.
type TensorArrayGatherV3Attr func(optionalAttr)
@@ -14573,33 +15472,6 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow
return op.Output(0)
}
-// This op consumes a lock created by `MutexLock`.
-//
-// This op exists to consume a tensor created by `MutexLock` (other than
-// direct control dependencies). It should be the only that consumes the tensor,
-// and will raise an error if it is not. Its only purpose is to keep the
-// mutex lock tensor alive until it is consumed by this op.
-//
-// **NOTE**: This operation must run on the same device as its input. This may
-// be enforced via the `colocate_with` mechanism.
-//
-// Arguments:
-// mutex_lock: A tensor returned by `MutexLock`.
-//
-// Returns the created operation.
-func ConsumeMutexLock(scope *Scope, mutex_lock tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ConsumeMutexLock",
- Input: []tf.Input{
- mutex_lock,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Returns x / y element-wise for integer types.
//
// Truncation designates that negative numbers will round fractional quantities
@@ -15670,79 +16542,6 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra
return op.Output(0)
}
-// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
-type LogUniformCandidateSamplerAttr func(optionalAttr)
-
-// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a log-uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "LogUniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
//
// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
@@ -16028,109 +16827,6 @@ func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) {
return op.Output(0)
}
-// ProdAttr is an optional argument to Prod.
-type ProdAttr func(optionalAttr)
-
-// ProdKeepDims sets the optional keep_dims attribute to value.
-//
-// value: If true, retain reduced dimensions with length 1.
-// If not specified, defaults to false
-func ProdKeepDims(value bool) ProdAttr {
- return func(m optionalAttr) {
- m["keep_dims"] = value
- }
-}
-
-// Computes the product of elements across dimensions of a tensor.
-//
-// Reduces `input` along the dimensions given in `axis`. Unless
-// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-// `axis`. If `keep_dims` is true, the reduced dimensions are
-// retained with length 1.
-//
-// Arguments:
-// input: The tensor to reduce.
-// axis: The dimensions to reduce. Must be in the range
-// `[-rank(input), rank(input))`.
-//
-// Returns The reduced tensor.
-func Prod(scope *Scope, input tf.Output, axis tf.Output, optional ...ProdAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Prod",
- Input: []tf.Input{
- input, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// FusedResizeAndPadConv2DAttr is an optional argument to FusedResizeAndPadConv2D.
-type FusedResizeAndPadConv2DAttr func(optionalAttr)
-
-// FusedResizeAndPadConv2DResizeAlignCorners sets the optional resize_align_corners attribute to value.
-//
-// value: If true, the centers of the 4 corner pixels of the input and output tensors are
-// aligned, preserving the values at the corner pixels. Defaults to false.
-// If not specified, defaults to false
-func FusedResizeAndPadConv2DResizeAlignCorners(value bool) FusedResizeAndPadConv2DAttr {
- return func(m optionalAttr) {
- m["resize_align_corners"] = value
- }
-}
-
-// Performs a resize and padding as a preprocess during a convolution.
-//
-// It's often possible to do spatial transformations more efficiently as part of
-// the packing stage of a convolution, so this op allows for an optimized
-// implementation where these stages are fused together. This prevents the need to
-// write out the intermediate results as whole tensors, reducing memory pressure,
-// and we can get some latency gains by merging the transformation calculations.
-// The data_format attribute for Conv2D isn't supported by this op, and defaults to
-// 'NHWC' order.
-// Internally this op uses a single per-graph scratch buffer, which means that it
-// will block if multiple versions are being run in parallel. This is because this
-// operator is primarily an optimization to minimize memory usage.
-//
-// Arguments:
-// input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
-// size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
-// new size for the images.
-// paddings: A two-column matrix specifying the padding sizes. The number of
-// rows must be the same as the rank of `input`.
-// filter: 4-D with shape
-// `[filter_height, filter_width, in_channels, out_channels]`.
-//
-// strides: 1-D of length 4. The stride of the sliding window for each dimension
-// of `input`. Must be in the same order as the dimension specified with format.
-// padding: The type of padding algorithm to use.
-func FusedResizeAndPadConv2D(scope *Scope, input tf.Output, size tf.Output, paddings tf.Output, filter tf.Output, mode string, strides []int64, padding string, optional ...FusedResizeAndPadConv2DAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"mode": mode, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FusedResizeAndPadConv2D",
- Input: []tf.Input{
- input, size, paddings, filter,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns a list of tensors with the same shapes and contents as the input
//
// tensors.
@@ -17600,175 +18296,6 @@ func FractionalAvgPoolGrad(scope *Scope, orig_input_tensor_shape tf.Output, out_
return op.Output(0)
}
-// BoostedTreesEnsembleResourceHandleOpAttr is an optional argument to BoostedTreesEnsembleResourceHandleOp.
-type BoostedTreesEnsembleResourceHandleOpAttr func(optionalAttr)
-
-// BoostedTreesEnsembleResourceHandleOpContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func BoostedTreesEnsembleResourceHandleOpContainer(value string) BoostedTreesEnsembleResourceHandleOpAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// BoostedTreesEnsembleResourceHandleOpSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func BoostedTreesEnsembleResourceHandleOpSharedName(value string) BoostedTreesEnsembleResourceHandleOpAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// Creates a handle to a BoostedTreesEnsembleResource
-func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTreesEnsembleResourceHandleOpAttr) (resource tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesEnsembleResourceHandleOp",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum.
-type ResourceApplyMomentumAttr func(optionalAttr)
-
-// ResourceApplyMomentumUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var and accum tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyMomentumUseLocking(value bool) ResourceApplyMomentumAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// ResourceApplyMomentumUseNesterov sets the optional use_nesterov attribute to value.
-//
-// value: If `True`, the tensor passed to compute grad will be
-// var - lr * momentum * accum, so in the end, the var you get is actually
-// var - lr * momentum * accum.
-// If not specified, defaults to false
-func ResourceApplyMomentumUseNesterov(value bool) ResourceApplyMomentumAttr {
- return func(m optionalAttr) {
- m["use_nesterov"] = value
- }
-}
-
-// Update '*var' according to the momentum scheme. Set use_nesterov = True if you
-//
-// want to use Nesterov momentum.
-//
-// accum = accum * momentum + grad
-// var -= lr * accum
-//
-// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// grad: The gradient.
-// momentum: Momentum. Must be a scalar.
-//
-// Returns the created operation.
-func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, momentum tf.Output, optional ...ResourceApplyMomentumAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyMomentum",
- Input: []tf.Input{
- var_, accum, lr, grad, momentum,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad.
-type MaxPoolGradGradAttr func(optionalAttr)
-
-// MaxPoolGradGradDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, in_channels, in_height, in_width].
-// If not specified, defaults to "NHWC"
-func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Computes second-order gradients of the maxpooling function.
-//
-// Arguments:
-// orig_input: The original input tensor.
-// orig_output: The original output tensor.
-// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
-// ksize: The size of the window for each dimension of the input tensor.
-// strides: The stride of the sliding window for each dimension of the
-// input tensor.
-// padding: The type of padding algorithm to use.
-//
-// Returns Gradients of gradients w.r.t. the input to `max_pool`.
-func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MaxPoolGradGrad",
- Input: []tf.Input{
- orig_input, orig_output, grad,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns the last element of the input list as well as a list with all but that element.
-//
-// Fails if the list is empty.
-//
-// input_handle: the input list
-// tensor: the withdrawn last element of the list
-// element_dtype: the type of elements in the list
-// element_shape: the shape of the output tensor
-func TensorListPopBack(scope *Scope, input_handle tf.Output, element_dtype tf.DataType) (output_handle tf.Output, tensor tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"element_dtype": element_dtype}
- opspec := tf.OpSpec{
- Type: "TensorListPopBack",
- Input: []tf.Input{
- input_handle,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Returns element-wise integer closest to x.
//
// If the result is midway between two representable values,
@@ -19059,31 +19586,6 @@ func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output
return op.Output(0)
}
-// Read an element from the TensorArray into output `value`.
-//
-// Arguments:
-// handle: The handle to a TensorArray.
-//
-// flow_in: A float scalar that enforces proper chaining of operations.
-// dtype: The type of the elem that is returned.
-//
-// Returns The tensor that is read from the TensorArray.
-func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- opspec := tf.OpSpec{
- Type: "TensorArrayReadV3",
- Input: []tf.Input{
- handle, index, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// QuantizeV2Attr is an optional argument to QuantizeV2.
type QuantizeV2Attr func(optionalAttr)
@@ -20481,6 +20983,201 @@ func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (ou
return op.Output(0)
}
+// EnterAttr is an optional argument to Enter.
+type EnterAttr func(optionalAttr)
+
+// EnterIsConstant sets the optional is_constant attribute to value.
+//
+// value: If true, the output is constant within the child frame.
+// If not specified, defaults to false
+func EnterIsConstant(value bool) EnterAttr {
+ return func(m optionalAttr) {
+ m["is_constant"] = value
+ }
+}
+
+// EnterParallelIterations sets the optional parallel_iterations attribute to value.
+//
+// value: The number of iterations allowed to run in parallel.
+// If not specified, defaults to 10
+func EnterParallelIterations(value int64) EnterAttr {
+ return func(m optionalAttr) {
+ m["parallel_iterations"] = value
+ }
+}
+
+// Creates or finds a child frame, and makes `data` available to the child frame.
+//
+// This op is used together with `Exit` to create loops in the graph.
+// The unique `frame_name` is used by the `Executor` to identify frames. If
+// `is_constant` is true, `output` is a constant in the child frame; otherwise
+// it may be changed in the child frame. At most `parallel_iterations` iterations
+// are run in parallel in the child frame.
+//
+// Arguments:
+// data: The tensor to be made available to the child frame.
+// frame_name: The name of the child frame.
+//
+// Returns The same tensor as `data`.
+func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"frame_name": frame_name}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Enter",
+ Input: []tf.Input{
+ data,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Add all input tensors element wise.
+//
+// Arguments:
+// inputs: Must all be the same size and shape.
+func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AddN",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// TryRpcAttr is an optional argument to TryRpc.
+type TryRpcAttr func(optionalAttr)
+
+// TryRpcProtocol sets the optional protocol attribute to value.
+//
+// value: RPC protocol to use. Empty string means use the default protocol.
+// Options include 'grpc'.
+// If not specified, defaults to ""
+func TryRpcProtocol(value string) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["protocol"] = value
+ }
+}
+
+// TryRpcFailFast sets the optional fail_fast attribute to value.
+//
+// value: `boolean`. If `true` (default), then failures to connect
+// (i.e., the server does not immediately respond) cause an RPC failure.
+// If not specified, defaults to true
+func TryRpcFailFast(value bool) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["fail_fast"] = value
+ }
+}
+
+// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value.
+//
+// value: `int`. If `0` (default), then the kernel will run the RPC
+// request and only time out if the RPC deadline passes or the session times out.
+// If this value is greater than `0`, then the op will raise an exception if
+// the RPC takes longer than `timeout_in_ms`.
+// If not specified, defaults to 0
+func TryRpcTimeoutInMs(value int64) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["timeout_in_ms"] = value
+ }
+}
+
+// Perform batches of RPC requests.
+//
+// This op asynchronously performs either a single RPC request, or a batch
+// of requests. RPC requests are defined by three main parameters:
+//
+// - `address` (the host+port or BNS address of the request)
+// - `method` (the method name for the request)
+// - `request` (the serialized proto string, or vector of strings,
+// of the RPC request argument).
+//
+// For example, if you have an RPC service running on port localhost:2345,
+// and its interface is configured with the following proto declaration:
+//
+// ```
+// service MyService {
+// rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
+// }
+// };
+// ```
+//
+// then call this op with arguments:
+//
+// ```
+// address = "localhost:2345"
+// method = "MyService/MyMethod"
+// ```
+//
+// The `request` tensor is a string tensor representing serialized `MyRequestProto`
+// strings; and the output string tensor `response` will have the same shape
+// and contain (upon successful completion) corresponding serialized
+// `MyResponseProto` strings.
+//
+// For example, to send a single, empty, `MyRequestProto`, call
+// this op with `request = ""`. To send 5 **parallel** empty requests,
+// call this op with `request = ["", "", "", "", ""]`.
+//
+// More generally, one can create a batch of `MyRequestProto` serialized protos
+// from regular batched tensors using the `encode_proto` op, and convert
+// the response `MyResponseProto` serialized protos to batched tensors
+// using the `decode_proto` op.
+//
+// **NOTE** Working with serialized proto strings is faster than instantiating
+// actual proto objects in memory, so no performance degradation is expected
+// compared to writing custom kernels for this workflow.
+//
+// Unlike the standard `Rpc` op, if the connection fails or the remote worker
+// returns an error status, this op does **not** reraise the exception.
+// Instead, the `status_code` and `status_message` entry for the corresponding RPC
+// call is set with the error returned from the RPC call. The `response` tensor
+// will contain valid response values for those minibatch entries whose RPCs did
+// not fail; the rest of the entries will have empty strings.
+//
+// Arguments:
+// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `method` and `request`.
+// method: `0-D` or `1-D`. The method address on the RPC server.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `address` and `request`.
+// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `address` and `method`.
+//
+// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages
+// returned from the RPC calls.
+func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TryRpc",
+ Input: []tf.Input{
+ address, method, request,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// Delete the tensor specified by its handle in the session.
//
// Arguments:
@@ -20958,76 +21655,6 @@ func RangeDataset(scope *Scope, start tf.Output, stop tf.Output, step tf.Output,
return op.Output(0)
}
-// DepthwiseConv2dNativeBackpropInputAttr is an optional argument to DepthwiseConv2dNativeBackpropInput.
-type DepthwiseConv2dNativeBackpropInputAttr func(optionalAttr)
-
-// DepthwiseConv2dNativeBackpropInputDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, height, width, channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, channels, height, width].
-// If not specified, defaults to "NHWC"
-func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dNativeBackpropInputAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// DepthwiseConv2dNativeBackpropInputDilations sets the optional dilations attribute to value.
-//
-// value: 1-D tensor of length 4. The dilation factor for each dimension of
-// `input`. If set to k > 1, there will be k-1 skipped cells between each filter
-// element on that dimension. The dimension order is determined by the value of
-// `data_format`, see above for details. Dilations in the batch and depth
-// dimensions must be 1.
-// If not specified, defaults to <i:1 i:1 i:1 i:1 >
-func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
- return func(m optionalAttr) {
- m["dilations"] = value
- }
-}
-
-// Computes the gradients of depthwise convolution with respect to the input.
-//
-// Arguments:
-// input_sizes: An integer vector representing the shape of `input`, based
-// on `data_format`. For example, if `data_format` is 'NHWC' then
-// `input` is a 4-D `[batch, height, width, channels]` tensor.
-// filter: 4-D with shape
-// `[filter_height, filter_width, in_channels, depthwise_multiplier]`.
-// out_backprop: 4-D with shape based on `data_format`.
-// For example, if `data_format` is 'NHWC' then
-// out_backprop shape is `[batch, out_height, out_width, out_channels]`.
-// Gradients w.r.t. the output of the convolution.
-// strides: The stride of the sliding window for each dimension of the input
-// of the convolution.
-// padding: The type of padding algorithm to use.
-//
-// Returns 4-D with shape according to `data_format`. For example, if
-// `data_format` is 'NHWC', output shape is `[batch, in_height,
-// in_width, in_channels]`. Gradient w.r.t. the input of the
-// convolution.
-func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, filter tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...DepthwiseConv2dNativeBackpropInputAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DepthwiseConv2dNativeBackpropInput",
- Input: []tf.Input{
- input_sizes, filter, out_backprop,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Stops gradient computation.
//
// When executed in a graph, this op outputs its input tensor as-is.
@@ -21297,23 +21924,44 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
-// Forwards the input to the output.
+// Computes the sum along segments of a tensor.
//
-// This operator represents the loop termination condition used by the
-// "pivot" switches of a loop.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// Computes a tensor such that
+// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
+// need not be sorted and need not cover all values in the full
+// range of valid values.
+//
+// If the sum is empty for a given segment ID `i`, `output[i] = 0`.
+// If the given segment ID `i` is negative, the value is dropped and will not be
+// added to the sum of the segment.
+//
+// `num_segments` should equal the number of distinct segment IDs.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentSum.png" alt>
+// </div>
//
// Arguments:
-// input: A boolean scalar, representing the branch predicate of the Switch op.
//
-// Returns The same tensor as `input`.
-func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.
+//
+//
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+func UnsortedSegmentSum(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "LoopCond",
+ Type: "UnsortedSegmentSum",
Input: []tf.Input{
- input,
+ data, segment_ids, num_segments,
},
}
op := scope.AddOperation(opspec)
@@ -21947,40 +22595,6 @@ func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (ou
return op.Output(0)
}
-// Creates a sequence of numbers.
-//
-// This operation creates a sequence of numbers that begins at `start` and
-// extends by increments of `delta` up to but not including `limit`.
-//
-// For example:
-//
-// ```
-// # 'start' is 3
-// # 'limit' is 18
-// # 'delta' is 3
-// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
-// ```
-//
-// Arguments:
-// start: 0-D (scalar). First entry in the sequence.
-// limit: 0-D (scalar). Upper limit of sequence, exclusive.
-// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`.
-//
-// Returns 1-D.
-func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Range",
- Input: []tf.Input{
- start, limit, delta,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// DestroyResourceOpAttr is an optional argument to DestroyResourceOp.
type DestroyResourceOpAttr func(optionalAttr)
@@ -23045,156 +23659,6 @@ func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, def
return op.Output(0)
}
-// Bucketizes 'input' based on 'boundaries'.
-//
-// For example, if the inputs are
-// boundaries = [0, 10, 100]
-// input = [[-5, 10000]
-// [150, 10]
-// [5, 100]]
-//
-// then the output will be
-// output = [[0, 3]
-// [3, 2]
-// [1, 3]]
-//
-// Arguments:
-// input: Any shape of Tensor contains with int or float type.
-// boundaries: A sorted list of floats gives the boundary of the buckets.
-//
-// Returns Same shape with 'input', each value of input replaced with bucket index.
-//
-// @compatibility(numpy)
-// Equivalent to np.digitize.
-// @end_compatibility
-func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"boundaries": boundaries}
- opspec := tf.OpSpec{
- Type: "Bucketize",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Calculates gains for each feature and returns the best possible split information for the feature.
-//
-// The split information is the best threshold (bucket id), gains and left/right node contributions per node for each feature.
-//
-// It is possible that not all nodes can be split on each feature. Hence, the list of possible nodes can differ between the features. Therefore, we return `node_ids_list` for each feature, containing the list of nodes that this feature can be used to split.
-//
-// In this manner, the output is the best split per features and per node, so that it needs to be combined later to produce the best split for each node (among all possible features).
-//
-// The length of output lists are all of the same length, `num_features`.
-// The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature.
-//
-// Arguments:
-// node_id_range: A Rank 1 tensor (shape=[2]) to specify the range [first, last) of node ids to process within `stats_summary_list`. The nodes are iterated between the two nodes specified by the tensor, as like `for node_id in range(node_id_range[0], node_id_range[1])` (Note that the last index node_id_range[1] is exclusive).
-// stats_summary_list: A list of Rank 3 tensor (#shape=[max_splits, bucket, 2]) for accumulated stats summary (gradient/hessian) per node per buckets for each feature. The first dimension of the tensor is the maximum number of splits, and thus not all elements of it will be used, but only the indexes specified by node_ids will be used.
-// l1: l1 regularization factor on leaf weights, per instance based.
-// l2: l2 regularization factor on leaf weights, per instance based.
-// tree_complexity: adjustment to the gain, per leaf based.
-// min_node_weight: mininum avg of hessians in a node before required for the node to be considered for splitting.
-// max_splits: the number of nodes that can be split in the whole tree. Used as a dimension of output tensors.
-//
-// Returns An output list of Rank 1 tensors indicating possible split node ids for each feature. The length of the list is num_features, but each tensor has different size as each feature provides different possible nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the best gains for each feature to split for certain nodes. See above for details like shapes and sizes.An output list of Rank 1 tensors indicating the bucket id to compare with (as a threshold) for split in each node. See above for details like shapes and sizes.A list of Rank 2 tensors indicating the contribution of the left nodes when branching from parent nodes (given by the tensor element in the output node_ids_list) to the left direction by the given threshold for each feature. This value will be used to make the left node value by adding to the parent node value. Second dimension size is 1 for 1-dimensional logits, but would be larger for multi-class problems. See above for details like shapes and sizes.A list of Rank 2 tensors, with the same shape/conditions as left_node_contribs_list, but just that the value is for the right node.
-func BoostedTreesCalculateBestGainsPerFeature(scope *Scope, node_id_range tf.Output, stats_summary_list []tf.Output, l1 tf.Output, l2 tf.Output, tree_complexity tf.Output, min_node_weight tf.Output, max_splits int64) (node_ids_list []tf.Output, gains_list []tf.Output, thresholds_list []tf.Output, left_node_contribs_list []tf.Output, right_node_contribs_list []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"max_splits": max_splits}
- opspec := tf.OpSpec{
- Type: "BoostedTreesCalculateBestGainsPerFeature",
- Input: []tf.Input{
- node_id_range, tf.OutputList(stats_summary_list), l1, l2, tree_complexity, min_node_weight,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if node_ids_list, idx, err = makeOutputList(op, idx, "node_ids_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- if gains_list, idx, err = makeOutputList(op, idx, "gains_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- if thresholds_list, idx, err = makeOutputList(op, idx, "thresholds_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- if left_node_contribs_list, idx, err = makeOutputList(op, idx, "left_node_contribs_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- if right_node_contribs_list, idx, err = makeOutputList(op, idx, "right_node_contribs_list"); err != nil {
- scope.UpdateErr("BoostedTreesCalculateBestGainsPerFeature", err)
- return
- }
- return node_ids_list, gains_list, thresholds_list, left_node_contribs_list, right_node_contribs_list
-}
-
-// EncodePngAttr is an optional argument to EncodePng.
-type EncodePngAttr func(optionalAttr)
-
-// EncodePngCompression sets the optional compression attribute to value.
-//
-// value: Compression level.
-// If not specified, defaults to -1
-func EncodePngCompression(value int64) EncodePngAttr {
- return func(m optionalAttr) {
- m["compression"] = value
- }
-}
-
-// PNG-encode an image.
-//
-// `image` is a 3-D uint8 or uint16 Tensor of shape `[height, width, channels]`
-// where `channels` is:
-//
-// * 1: for grayscale.
-// * 2: for grayscale + alpha.
-// * 3: for RGB.
-// * 4: for RGBA.
-//
-// The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
-// default or a value from 0 to 9. 9 is the highest compression level, generating
-// the smallest output, but is slower.
-//
-// Arguments:
-// image: 3-D with shape `[height, width, channels]`.
-//
-// Returns 0-D. PNG-encoded image.
-func EncodePng(scope *Scope, image tf.Output, optional ...EncodePngAttr) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodePng",
- Input: []tf.Input{
- image,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Updates the table to associates keys with values.
//
// The tensor `keys` must be of the same type as the keys of the table.
@@ -23988,6 +24452,31 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr
return op.Output(0)
}
+// Read an element from the TensorArray into output `value`.
+//
+// Arguments:
+// handle: The handle to a TensorArray.
+//
+// flow_in: A float scalar that enforces proper chaining of operations.
+// dtype: The type of the elem that is returned.
+//
+// Returns The tensor that is read from the TensorArray.
+func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayReadV3",
+ Input: []tf.Input{
+ handle, index, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient for the tanh of `x` wrt its input.
//
// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`
@@ -26662,6 +27151,260 @@ func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) {
return op.Output(0)
}
+// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler.
+type LearnedUnigramCandidateSamplerAttr func(optionalAttr)
+
+// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a learned unigram distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "LearnedUnigramCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// SerializeSparseAttr is an optional argument to SerializeSparse.
+type SerializeSparseAttr func(optionalAttr)
+
+// SerializeSparseOutType sets the optional out_type attribute to value.
+//
+// value: The `dtype` to use for serialization; the supported types are `string`
+// (default) and `variant`.
+// If not specified, defaults to DT_STRING
+func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Serialize a `SparseTensor` into a `[3]` `Tensor` object.
+//
+// Arguments:
+// sparse_indices: 2-D. The `indices` of the `SparseTensor`.
+// sparse_values: 1-D. The `values` of the `SparseTensor`.
+// sparse_shape: 1-D. The `shape` of the `SparseTensor`.
+func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SerializeSparse",
+ Input: []tf.Input{
+ sparse_indices, sparse_values, sparse_shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2.
+type RandomShuffleQueueV2Attr func(optionalAttr)
+
+// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value.
+//
+// value: The shape of each component in a value. The length of this attr must
+// be either 0 or the same as the length of component_types. If the length of
+// this attr is 0, the shapes of queue elements are not constrained, and
+// only one element may be dequeued at a time.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["shapes"] = value
+ }
+}
+
+// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value.
+//
+// value: The upper bound on the number of elements in this queue.
+// Negative numbers mean no limit.
+// If not specified, defaults to -1
+func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["capacity"] = value
+ }
+}
+
+// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value.
+//
+// value: Dequeue will block unless there would be this
+// many elements after the dequeue or the queue is closed. This
+// ensures a minimum level of mixing of elements.
+// If not specified, defaults to 0
+func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["min_after_dequeue"] = value
+ }
+}
+
+// RandomShuffleQueueV2Seed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 is set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, a random seed is used.
+// If not specified, defaults to 0
+func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// RandomShuffleQueueV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this queue is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this queue will be shared under the given name
+// across multiple sessions.
+// If not specified, defaults to ""
+func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A queue that randomizes the order of elements.
+//
+// Arguments:
+// component_types: The type of each component in a value.
+//
+// Returns The handle to the queue.
+func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"component_types": component_types}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomShuffleQueueV2",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Draw bounding boxes on a batch of images.
+//
+// Outputs a copy of `images` but draws on top of the pixels zero or more bounding
+// boxes specified by the locations in `boxes`. The coordinates of the each
+// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The
+// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
+// height of the underlying image.
+//
+// For example, if an image is 100 x 200 pixels (height x width) and the bounding
+// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of
+// the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates).
+//
+// Parts of the bounding box may fall outside the image.
+//
+// Arguments:
+// images: 4-D with shape `[batch, height, width, depth]`. A batch of images.
+// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
+// boxes.
+//
+// Returns 4-D with the same shape as `images`. The batch of input images with
+// bounding boxes drawn on the images.
+func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DrawBoundingBoxes",
+ Input: []tf.Input{
+ images, boxes,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Gets the next output from the given iterator.
//
// This operation is a synchronous version IteratorGetNext. It should only be used
@@ -27420,178 +28163,6 @@ func FakeParam(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Outpu
return op.Output(0)
}
-// EncodeProtoAttr is an optional argument to EncodeProto.
-type EncodeProtoAttr func(optionalAttr)
-
-// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value.
-// If not specified, defaults to "local://"
-func EncodeProtoDescriptorSource(value string) EncodeProtoAttr {
- return func(m optionalAttr) {
- m["descriptor_source"] = value
- }
-}
-
-// The op serializes protobuf messages provided in the input tensors.
-//
-// The types of the tensors in `values` must match the schema for the
-// fields specified in `field_names`. All the tensors in `values` must
-// have a common shape prefix, *batch_shape*.
-//
-// The `sizes` tensor specifies repeat counts for each field. The repeat
-// count (last dimension) of a each tensor in `values` must be greater
-// than or equal to corresponding repeat count in `sizes`.
-//
-// A `message_type` name must be provided to give context for the field
-// names. The actual message descriptor can be looked up either in the
-// linked-in descriptor pool or a filename provided by the caller using
-// the `descriptor_source` attribute.
-//
-// The `descriptor_source` attribute selects a source of protocol
-// descriptors to consult when looking up `message_type`. This may be a
-// filename containing a serialized `FileDescriptorSet` message,
-// or the special value `local://`, in which case only descriptors linked
-// into the code will be searched; the filename can be on any filesystem
-// accessible to TensorFlow.
-//
-// You can build a `descriptor_source` file using the `--descriptor_set_out`
-// and `--include_imports` options to the protocol compiler `protoc`.
-//
-// The `local://` database only covers descriptors linked into the
-// code via C++ libraries, not Python imports. You can link in a proto descriptor
-// by creating a cc_library target with alwayslink=1.
-//
-// There are a few special cases in the value mapping:
-//
-// Submessage and group fields must be pre-serialized as TensorFlow strings.
-//
-// TensorFlow lacks support for unsigned int64s, so they must be
-// represented as `tf.int64` with the same twos-complement bit pattern
-// (the obvious way).
-//
-// Unsigned int32 values can be represented exactly with `tf.int64`, or
-// with sign wrapping if the input is of type `tf.int32`.
-//
-// Arguments:
-// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`.
-// values: List of tensors containing values for the corresponding field.
-// field_names: List of strings containing proto field names.
-// message_type: Name of the proto message type to decode.
-//
-// Returns Tensor of serialized protos with shape `batch_shape`.
-func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodeProto",
- Input: []tf.Input{
- sizes, tf.OutputList(values),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Creates a TensorArray for storing the gradients of values in the given handle.
-//
-// If the given TensorArray gradient already exists, returns a reference to it.
-//
-// Locks the size of the original TensorArray by disabling its dynamic size flag.
-//
-// **A note about the input flow_in:**
-//
-// The handle flow_in forces the execution of the gradient lookup to occur
-// only after certain other operations have occurred. For example, when
-// the forward TensorArray is dynamically sized, writes to this TensorArray
-// may resize the object. The gradient TensorArray is statically sized based
-// on the size of the forward TensorArray when this operation executes.
-// Furthermore, the size of the forward TensorArray is frozen by this call.
-// As a result, the flow is used to ensure that the call to generate the gradient
-// TensorArray only happens after all writes are executed.
-//
-// In the case of dynamically sized TensorArrays, gradient computation should
-// only be performed on read operations that have themselves been chained via
-// flow to occur only after all writes have executed. That way the final size
-// of the forward TensorArray is known when this operation is called.
-//
-// **A note about the source attribute:**
-//
-// TensorArray gradient calls use an accumulator TensorArray object. If
-// multiple gradients are calculated and run in the same session, the multiple
-// gradient nodes may accidentally flow through the same accumulator TensorArray.
-// This double counts and generally breaks the TensorArray gradient flow.
-//
-// The solution is to identify which gradient call this particular
-// TensorArray gradient is being called in. This is performed by identifying
-// a unique string (e.g. "gradients", "gradients_1", ...) from the input
-// gradient Tensor's name. This string is used as a suffix when creating
-// the TensorArray gradient object here (the attribute `source`).
-//
-// The attribute `source` is added as a suffix to the forward TensorArray's
-// name when performing the creation / lookup, so that each separate gradient
-// calculation gets its own TensorArray accumulator.
-//
-// Arguments:
-// handle: The handle to the forward TensorArray.
-// flow_in: A float scalar that enforces proper chaining of operations.
-// source: The gradient source string, used to decide which gradient TensorArray
-// to return.
-func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"source": source}
- opspec := tf.OpSpec{
- Type: "TensorArrayGradV3",
- Input: []tf.Input{
- handle, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
-// Creates a dataset that splits a SparseTensor into elements row-wise.
-func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseTensorSliceDataset",
- Input: []tf.Input{
- indices, values, dense_shape,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns x / y element-wise for real types.
-//
-// If `x` and `y` are reals, this will return the floating-point division.
-//
-// *NOTE*: `Div` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "RealDiv",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adds v into specified rows of x.
//
// Computes y = x; y[i, :] += v; return y.
@@ -27887,6 +28458,255 @@ func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...Sta
return op.Output(0)
}
+// StringSplitV2Attr is an optional argument to StringSplitV2.
+type StringSplitV2Attr func(optionalAttr)
+
+// StringSplitV2Maxsplit sets the optional maxsplit attribute to value.
+//
+// value: An `int`. If `maxsplit > 0`, limit of the split of the result.
+// If not specified, defaults to -1
+func StringSplitV2Maxsplit(value int64) StringSplitV2Attr {
+ return func(m optionalAttr) {
+ m["maxsplit"] = value
+ }
+}
+
+// Split elements of `source` based on `sep` into a `SparseTensor`.
+//
+// Let N be the size of source (typically N will be the batch size). Split each
+// element of `source` based on `sep` and return a `SparseTensor`
+// containing the split tokens. Empty tokens are ignored.
+//
+// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
+// then the output will be
+// ```
+// st.indices = [0, 0;
+// 0, 1;
+// 1, 0;
+// 1, 1;
+// 1, 2]
+// st.shape = [2, 3]
+// st.values = ['hello', 'world', 'a', 'b', 'c']
+// ```
+//
+// If `sep` is given, consecutive delimiters are not grouped together and are
+// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
+// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
+// string, consecutive whitespace are regarded as a single separator, and the
+// result will contain no empty strings at the startor end if the string has
+// leading or trailing whitespace.
+//
+// Note that the above mentioned behavior matches python's str.split.
+//
+// Arguments:
+// input: `1-D` string `Tensor`, the strings to split.
+// sep: `0-D` string `Tensor`, the delimiter character.
+func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringSplitV2",
+ Input: []tf.Input{
+ input, sep,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Computes softsign: `features / (abs(features) + 1)`.
+func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Softsign",
+ Input: []tf.Input{
+ features,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// EncodeProtoAttr is an optional argument to EncodeProto.
+type EncodeProtoAttr func(optionalAttr)
+
+// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value.
+// If not specified, defaults to "local://"
+func EncodeProtoDescriptorSource(value string) EncodeProtoAttr {
+ return func(m optionalAttr) {
+ m["descriptor_source"] = value
+ }
+}
+
+// The op serializes protobuf messages provided in the input tensors.
+//
+// The types of the tensors in `values` must match the schema for the
+// fields specified in `field_names`. All the tensors in `values` must
+// have a common shape prefix, *batch_shape*.
+//
+// The `sizes` tensor specifies repeat counts for each field. The repeat
+// count (last dimension) of a each tensor in `values` must be greater
+// than or equal to corresponding repeat count in `sizes`.
+//
+// A `message_type` name must be provided to give context for the field
+// names. The actual message descriptor can be looked up either in the
+// linked-in descriptor pool or a filename provided by the caller using
+// the `descriptor_source` attribute.
+//
+// The `descriptor_source` attribute selects a source of protocol
+// descriptors to consult when looking up `message_type`. This may be a
+// filename containing a serialized `FileDescriptorSet` message,
+// or the special value `local://`, in which case only descriptors linked
+// into the code will be searched; the filename can be on any filesystem
+// accessible to TensorFlow.
+//
+// You can build a `descriptor_source` file using the `--descriptor_set_out`
+// and `--include_imports` options to the protocol compiler `protoc`.
+//
+// The `local://` database only covers descriptors linked into the
+// code via C++ libraries, not Python imports. You can link in a proto descriptor
+// by creating a cc_library target with alwayslink=1.
+//
+// There are a few special cases in the value mapping:
+//
+// Submessage and group fields must be pre-serialized as TensorFlow strings.
+//
+// TensorFlow lacks support for unsigned int64s, so they must be
+// represented as `tf.int64` with the same twos-complement bit pattern
+// (the obvious way).
+//
+// Unsigned int32 values can be represented exactly with `tf.int64`, or
+// with sign wrapping if the input is of type `tf.int32`.
+//
+// Arguments:
+// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+// values: List of tensors containing values for the corresponding field.
+// field_names: List of strings containing proto field names.
+// message_type: Name of the proto message type to decode.
+//
+// Returns Tensor of serialized protos with shape `batch_shape`.
+func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeProto",
+ Input: []tf.Input{
+ sizes, tf.OutputList(values),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a TensorArray for storing the gradients of values in the given handle.
+//
+// If the given TensorArray gradient already exists, returns a reference to it.
+//
+// Locks the size of the original TensorArray by disabling its dynamic size flag.
+//
+// **A note about the input flow_in:**
+//
+// The handle flow_in forces the execution of the gradient lookup to occur
+// only after certain other operations have occurred. For example, when
+// the forward TensorArray is dynamically sized, writes to this TensorArray
+// may resize the object. The gradient TensorArray is statically sized based
+// on the size of the forward TensorArray when this operation executes.
+// Furthermore, the size of the forward TensorArray is frozen by this call.
+// As a result, the flow is used to ensure that the call to generate the gradient
+// TensorArray only happens after all writes are executed.
+//
+// In the case of dynamically sized TensorArrays, gradient computation should
+// only be performed on read operations that have themselves been chained via
+// flow to occur only after all writes have executed. That way the final size
+// of the forward TensorArray is known when this operation is called.
+//
+// **A note about the source attribute:**
+//
+// TensorArray gradient calls use an accumulator TensorArray object. If
+// multiple gradients are calculated and run in the same session, the multiple
+// gradient nodes may accidentally flow through the same accumulator TensorArray.
+// This double counts and generally breaks the TensorArray gradient flow.
+//
+// The solution is to identify which gradient call this particular
+// TensorArray gradient is being called in. This is performed by identifying
+// a unique string (e.g. "gradients", "gradients_1", ...) from the input
+// gradient Tensor's name. This string is used as a suffix when creating
+// the TensorArray gradient object here (the attribute `source`).
+//
+// The attribute `source` is added as a suffix to the forward TensorArray's
+// name when performing the creation / lookup, so that each separate gradient
+// calculation gets its own TensorArray accumulator.
+//
+// Arguments:
+// handle: The handle to the forward TensorArray.
+// flow_in: A float scalar that enforces proper chaining of operations.
+// source: The gradient source string, used to decide which gradient TensorArray
+// to return.
+func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"source": source}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayGradV3",
+ Input: []tf.Input{
+ handle, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Creates a dataset that splits a SparseTensor into elements row-wise.
+func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseTensorSliceDataset",
+ Input: []tf.Input{
+ indices, values, dense_shape,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns x / y element-wise for real types.
+//
+// If `x` and `y` are reals, this will return the floating-point division.
+//
+// *NOTE*: `Div` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "RealDiv",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Creates a dataset that concatenates `input_dataset` with `another_dataset`.
func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
@@ -30813,260 +31633,6 @@ func TensorArraySplitV3(scope *Scope, handle tf.Output, value tf.Output, lengths
return op.Output(0)
}
-// SerializeSparseAttr is an optional argument to SerializeSparse.
-type SerializeSparseAttr func(optionalAttr)
-
-// SerializeSparseOutType sets the optional out_type attribute to value.
-//
-// value: The `dtype` to use for serialization; the supported types are `string`
-// (default) and `variant`.
-// If not specified, defaults to DT_STRING
-func SerializeSparseOutType(value tf.DataType) SerializeSparseAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Serialize a `SparseTensor` into a `[3]` `Tensor` object.
-//
-// Arguments:
-// sparse_indices: 2-D. The `indices` of the `SparseTensor`.
-// sparse_values: 1-D. The `values` of the `SparseTensor`.
-// sparse_shape: 1-D. The `shape` of the `SparseTensor`.
-func SerializeSparse(scope *Scope, sparse_indices tf.Output, sparse_values tf.Output, sparse_shape tf.Output, optional ...SerializeSparseAttr) (serialized_sparse tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "SerializeSparse",
- Input: []tf.Input{
- sparse_indices, sparse_values, sparse_shape,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2.
-type RandomShuffleQueueV2Attr func(optionalAttr)
-
-// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value.
-//
-// value: The shape of each component in a value. The length of this attr must
-// be either 0 or the same as the length of component_types. If the length of
-// this attr is 0, the shapes of queue elements are not constrained, and
-// only one element may be dequeued at a time.
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["shapes"] = value
- }
-}
-
-// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value.
-//
-// value: The upper bound on the number of elements in this queue.
-// Negative numbers mean no limit.
-// If not specified, defaults to -1
-func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["capacity"] = value
- }
-}
-
-// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value.
-//
-// value: Dequeue will block unless there would be this
-// many elements after the dequeue or the queue is closed. This
-// ensures a minimum level of mixing of elements.
-// If not specified, defaults to 0
-func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["min_after_dequeue"] = value
- }
-}
-
-// RandomShuffleQueueV2Seed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 is set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, a random seed is used.
-// If not specified, defaults to 0
-func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// RandomShuffleQueueV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this queue is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this queue will be shared under the given name
-// across multiple sessions.
-// If not specified, defaults to ""
-func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A queue that randomizes the order of elements.
-//
-// Arguments:
-// component_types: The type of each component in a value.
-//
-// Returns The handle to the queue.
-func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"component_types": component_types}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomShuffleQueueV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Draw bounding boxes on a batch of images.
-//
-// Outputs a copy of `images` but draws on top of the pixels zero or more bounding
-// boxes specified by the locations in `boxes`. The coordinates of the each
-// bounding box in `boxes` are encoded as `[y_min, x_min, y_max, x_max]`. The
-// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
-// height of the underlying image.
-//
-// For example, if an image is 100 x 200 pixels (height x width) and the bounding
-// box is `[0.1, 0.2, 0.5, 0.9]`, the upper-left and bottom-right coordinates of
-// the bounding box will be `(40, 10)` to `(180, 50)` (in (x,y) coordinates).
-//
-// Parts of the bounding box may fall outside the image.
-//
-// Arguments:
-// images: 4-D with shape `[batch, height, width, depth]`. A batch of images.
-// boxes: 3-D with shape `[batch, num_bounding_boxes, 4]` containing bounding
-// boxes.
-//
-// Returns 4-D with the same shape as `images`. The batch of input images with
-// bounding boxes drawn on the images.
-func DrawBoundingBoxes(scope *Scope, images tf.Output, boxes tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DrawBoundingBoxes",
- Input: []tf.Input{
- images, boxes,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// LearnedUnigramCandidateSamplerAttr is an optional argument to LearnedUnigramCandidateSampler.
-type LearnedUnigramCandidateSamplerAttr func(optionalAttr)
-
-// LearnedUnigramCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func LearnedUnigramCandidateSamplerSeed(value int64) LearnedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// LearnedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func LearnedUnigramCandidateSamplerSeed2(value int64) LearnedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a learned unigram distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func LearnedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LearnedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "LearnedUnigramCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Computes gradients for the scaled exponential linear (Selu) operation.
//
// Arguments:
@@ -32425,79 +32991,6 @@ func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.
return weights, biases
}
-// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler.
-type UniformCandidateSamplerAttr func(optionalAttr)
-
-// UniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "UniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// CTCLossAttr is an optional argument to CTCLoss.
type CTCLossAttr func(optionalAttr)
@@ -32648,480 +33141,3 @@ func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Outpu
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Add all input tensors element wise.
-//
-// Arguments:
-// inputs: Must all be the same size and shape.
-func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AddN",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// TryRpcAttr is an optional argument to TryRpc.
-type TryRpcAttr func(optionalAttr)
-
-// TryRpcProtocol sets the optional protocol attribute to value.
-//
-// value: RPC protocol to use. Empty string means use the default protocol.
-// Options include 'grpc'.
-// If not specified, defaults to ""
-func TryRpcProtocol(value string) TryRpcAttr {
- return func(m optionalAttr) {
- m["protocol"] = value
- }
-}
-
-// TryRpcFailFast sets the optional fail_fast attribute to value.
-//
-// value: `boolean`. If `true` (default), then failures to connect
-// (i.e., the server does not immediately respond) cause an RPC failure.
-// If not specified, defaults to true
-func TryRpcFailFast(value bool) TryRpcAttr {
- return func(m optionalAttr) {
- m["fail_fast"] = value
- }
-}
-
-// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value.
-//
-// value: `int`. If `0` (default), then the kernel will run the RPC
-// request and only time out if the RPC deadline passes or the session times out.
-// If this value is greater than `0`, then the op will raise an exception if
-// the RPC takes longer than `timeout_in_ms`.
-// If not specified, defaults to 0
-func TryRpcTimeoutInMs(value int64) TryRpcAttr {
- return func(m optionalAttr) {
- m["timeout_in_ms"] = value
- }
-}
-
-// Perform batches of RPC requests.
-//
-// This op asynchronously performs either a single RPC request, or a batch
-// of requests. RPC requests are defined by three main parameters:
-//
-// - `address` (the host+port or BNS address of the request)
-// - `method` (the method name for the request)
-// - `request` (the serialized proto string, or vector of strings,
-// of the RPC request argument).
-//
-// For example, if you have an RPC service running on port localhost:2345,
-// and its interface is configured with the following proto declaration:
-//
-// ```
-// service MyService {
-// rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
-// }
-// };
-// ```
-//
-// then call this op with arguments:
-//
-// ```
-// address = "localhost:2345"
-// method = "MyService/MyMethod"
-// ```
-//
-// The `request` tensor is a string tensor representing serialized `MyRequestProto`
-// strings; and the output string tensor `response` will have the same shape
-// and contain (upon successful completion) corresponding serialized
-// `MyResponseProto` strings.
-//
-// For example, to send a single, empty, `MyRequestProto`, call
-// this op with `request = ""`. To send 5 **parallel** empty requests,
-// call this op with `request = ["", "", "", "", ""]`.
-//
-// More generally, one can create a batch of `MyRequestProto` serialized protos
-// from regular batched tensors using the `encode_proto` op, and convert
-// the response `MyResponseProto` serialized protos to batched tensors
-// using the `decode_proto` op.
-//
-// **NOTE** Working with serialized proto strings is faster than instantiating
-// actual proto objects in memory, so no performance degradation is expected
-// compared to writing custom kernels for this workflow.
-//
-// Unlike the standard `Rpc` op, if the connection fails or the remote worker
-// returns an error status, this op does **not** reraise the exception.
-// Instead, the `status_code` and `status_message` entry for the corresponding RPC
-// call is set with the error returned from the RPC call. The `response` tensor
-// will contain valid response values for those minibatch entries whose RPCs did
-// not fail; the rest of the entries will have empty strings.
-//
-// Arguments:
-// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `method` and `request`.
-// method: `0-D` or `1-D`. The method address on the RPC server.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `address` and `request`.
-// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `address` and `method`.
-//
-// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages
-// returned from the RPC calls.
-func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TryRpc",
- Input: []tf.Input{
- address, method, request,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// EnterAttr is an optional argument to Enter.
-type EnterAttr func(optionalAttr)
-
-// EnterIsConstant sets the optional is_constant attribute to value.
-//
-// value: If true, the output is constant within the child frame.
-// If not specified, defaults to false
-func EnterIsConstant(value bool) EnterAttr {
- return func(m optionalAttr) {
- m["is_constant"] = value
- }
-}
-
-// EnterParallelIterations sets the optional parallel_iterations attribute to value.
-//
-// value: The number of iterations allowed to run in parallel.
-// If not specified, defaults to 10
-func EnterParallelIterations(value int64) EnterAttr {
- return func(m optionalAttr) {
- m["parallel_iterations"] = value
- }
-}
-
-// Creates or finds a child frame, and makes `data` available to the child frame.
-//
-// This op is used together with `Exit` to create loops in the graph.
-// The unique `frame_name` is used by the `Executor` to identify frames. If
-// `is_constant` is true, `output` is a constant in the child frame; otherwise
-// it may be changed in the child frame. At most `parallel_iterations` iterations
-// are run in parallel in the child frame.
-//
-// Arguments:
-// data: The tensor to be made available to the child frame.
-// frame_name: The name of the child frame.
-//
-// Returns The same tensor as `data`.
-func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"frame_name": frame_name}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Enter",
- Input: []tf.Input{
- data,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Produce a string tensor that encodes the state of a Reader.
-//
-// Not all Readers support being serialized, so this can produce an
-// Unimplemented error.
-//
-// Arguments:
-// reader_handle: Handle to a Reader.
-func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReaderSerializeStateV2",
- Input: []tf.Input{
- reader_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Exits the current frame to its parent frame.
-//
-// Exit makes its input `data` available to the parent frame.
-//
-// Arguments:
-// data: The tensor to be made available to the parent frame.
-//
-// Returns The same tensor as `data`.
-func Exit(scope *Scope, data tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Exit",
- Input: []tf.Input{
- data,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a copy of the input tensor.
-func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Snapshot",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a tensor of zeros with the same shape and type as x.
-//
-// Arguments:
-// x: a tensor of type T.
-//
-// Returns a tensor of the same shape and type as x but filled with zeros.
-func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ZerosLike",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// AbortAttr is an optional argument to Abort.
-type AbortAttr func(optionalAttr)
-
-// AbortErrorMsg sets the optional error_msg attribute to value.
-//
-// value: A string which is the message associated with the exception.
-// If not specified, defaults to ""
-func AbortErrorMsg(value string) AbortAttr {
- return func(m optionalAttr) {
- m["error_msg"] = value
- }
-}
-
-// AbortExitWithoutError sets the optional exit_without_error attribute to value.
-// If not specified, defaults to false
-func AbortExitWithoutError(value bool) AbortAttr {
- return func(m optionalAttr) {
- m["exit_without_error"] = value
- }
-}
-
-// Raise a exception to abort the process when called.
-//
-// If exit_without_error is true, the process will exit normally,
-// otherwise it will exit with a SIGABORT signal.
-//
-// Returns nothing but an exception.
-//
-// Returns the created operation.
-func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Abort",
-
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// FixedUnigramCandidateSamplerAttr is an optional argument to FixedUnigramCandidateSampler.
-type FixedUnigramCandidateSamplerAttr func(optionalAttr)
-
-// FixedUnigramCandidateSamplerVocabFile sets the optional vocab_file attribute to value.
-//
-// value: Each valid line in this file (which should have a CSV-like format)
-// corresponds to a valid word ID. IDs are in sequential order, starting from
-// num_reserved_ids. The last entry in each line is expected to be a value
-// corresponding to the count or relative probability. Exactly one of vocab_file
-// and unigrams needs to be passed to this op.
-// If not specified, defaults to ""
-func FixedUnigramCandidateSamplerVocabFile(value string) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["vocab_file"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerDistortion sets the optional distortion attribute to value.
-//
-// value: The distortion is used to skew the unigram probability distribution.
-// Each weight is first raised to the distortion's power before adding to the
-// internal unigram distribution. As a result, distortion = 1.0 gives regular
-// unigram sampling (as defined by the vocab file), and distortion = 0.0 gives
-// a uniform distribution.
-// If not specified, defaults to 1
-func FixedUnigramCandidateSamplerDistortion(value float32) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["distortion"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerNumReservedIds sets the optional num_reserved_ids attribute to value.
-//
-// value: Optionally some reserved IDs can be added in the range [0,
-// ..., num_reserved_ids) by the users. One use case is that a special unknown
-// word token is used as ID 0. These IDs will have a sampling probability of 0.
-// If not specified, defaults to 0
-func FixedUnigramCandidateSamplerNumReservedIds(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["num_reserved_ids"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerNumShards sets the optional num_shards attribute to value.
-//
-// value: A sampler can be used to sample from a subset of the original range
-// in order to speed up the whole computation through parallelism. This parameter
-// (together with 'shard') indicates the number of partitions that are being
-// used in the overall computation.
-// If not specified, defaults to 1
-//
-// REQUIRES: value >= 1
-func FixedUnigramCandidateSamplerNumShards(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["num_shards"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerShard sets the optional shard attribute to value.
-//
-// value: A sampler can be used to sample from a subset of the original range
-// in order to speed up the whole computation through parallelism. This parameter
-// (together with 'num_shards') indicates the particular partition number of a
-// sampler op, when partitioning is being used.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func FixedUnigramCandidateSamplerShard(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["shard"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerUnigrams sets the optional unigrams attribute to value.
-//
-// value: A list of unigram counts or probabilities, one per ID in sequential
-// order. Exactly one of vocab_file and unigrams should be passed to this op.
-// If not specified, defaults to <>
-func FixedUnigramCandidateSamplerUnigrams(value []float32) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["unigrams"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func FixedUnigramCandidateSamplerSeed(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// FixedUnigramCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func FixedUnigramCandidateSamplerSeed2(value int64) FixedUnigramCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a learned unigram distribution.
-//
-// A unigram sampler could use a fixed unigram distribution read from a
-// file or passed in as an in-memory array instead of building up the distribution
-// from data on the fly. There is also an option to skew the distribution by
-// applying a distortion power to the weights.
-//
-// The vocabulary file should be in CSV-like format, with the last field
-// being the weight associated with the word.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...FixedUnigramCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FixedUnigramCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index 9fc6969c20..6b3e305e5d 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc2</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 68712082e1..f130515934 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc2</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index f031173c99..67ecc2d597 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc2</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 2cac27990e..8ba859da01 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc2</version>
+ <version>1.11.0</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 8a93091276..dcd654d713 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc2</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
index 014bd8d212..45214f834c 100644
--- a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
+++ b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
@@ -6,7 +6,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>spark-tensorflow-connector_2.11</artifactId>
<packaging>jar</packaging>
- <version>1.11.0-rc2</version>
+ <version>1.11.0</version>
<name>spark-tensorflow-connector</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
diff --git a/tensorflow/java/maven/tensorflow-hadoop/pom.xml b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
index d07c5fcd98..a8669ee72b 100644
--- a/tensorflow/java/maven/tensorflow-hadoop/pom.xml
+++ b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
@@ -5,7 +5,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-hadoop</artifactId>
<packaging>jar</packaging>
- <version>1.11.0-rc2</version>
+ <version>1.11.0</version>
<name>tensorflow-hadoop</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index af0c68a4ed..67d628ba11 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.11.0-rc2</version>
+ <version>1.11.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 79f14466e6..9275ad767e 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -333,6 +333,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/memory",
],
)
@@ -1638,6 +1639,15 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "experimental_dataset_ops_gen",
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow:__subpackages__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "image_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],
)
@@ -2007,6 +2017,7 @@ py_library(
":array_ops",
":cond_v2_impl",
":constant_op",
+ ":control_flow_ops",
":control_flow_util",
":framework_ops",
":function_def_to_graph",
diff --git a/tensorflow/python/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py
index b8b268d8ce..583c978395 100644
--- a/tensorflow/python/autograph/converters/builtin_functions.py
+++ b/tensorflow/python/autograph/converters/builtin_functions.py
@@ -48,8 +48,13 @@ class BuiltinFunctionTransformer(converter.Base):
node = self.generic_visit(node)
if anno.hasanno(node.func, 'live_val'):
live_val = anno.getanno(node.func, 'live_val')
- if live_val in py_builtins.SUPPORTED_BUILTINS:
- node = self._convert_builtin(live_val, node.args, as_expression=True)
+ try:
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
+ except TypeError:
+ # Not everything in Python is hashable. If it isn't then it's definitely
+ # not a supported built-in.
+ return node
return node
def visit_Print(self, node):
diff --git a/tensorflow/python/autograph/converters/builtin_functions_test.py b/tensorflow/python/autograph/converters/builtin_functions_test.py
index c87c304cdb..2ed14c14e7 100644
--- a/tensorflow/python/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/python/autograph/converters/builtin_functions_test.py
@@ -36,7 +36,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return len(a)
with self.converted(test_fn, builtin_functions, {'len': len}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
ops = result.test_fn(p)
self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
@@ -50,7 +50,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
with self.assertPrints('a\n'):
sess.run(result.test_fn('a'))
@@ -63,12 +63,22 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a, b, c)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
with self.assertPrints('a 1 [2, 3]\n'):
sess.run(
result.test_fn(
constant_op.constant('a'), constant_op.constant(1), [2, 3]))
+ def test_conversion_robust_to_unhashable_callables(self):
+
+ def test_fn():
+ return foo() # pylint:disable=undefined-variable
+
+ with self.converted(test_fn, builtin_functions, {'foo': {
+ 'a': 'b'
+ }.keys}) as result:
+ self.assertListEqual(list(result.test_fn()), ['a'])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
index 62da045d6a..496c99e3b5 100644
--- a/tensorflow/python/autograph/converters/return_statements.py
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -212,6 +212,7 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
def __init__(self):
self.cant_return = False
+ self.function_level = 0
super(DetectReturnInUnsupportedControlFlow, self).__init__()
def visit_While(self, node):
@@ -229,6 +230,12 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
+ def visit_FunctionDef(self, node):
+ if not self.function_level:
+ self.function_level += 1
+ self.generic_visit(node)
+ self.function_level -= 1
+
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
@@ -242,6 +249,7 @@ class DetectReturnInConditional(gast.NodeVisitor):
def __init__(self):
self.cant_return = False
+ self.function_level = 0
super(DetectReturnInConditional, self).__init__()
def visit_If(self, node):
@@ -249,6 +257,12 @@ class DetectReturnInConditional(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
+ def visit_FunctionDef(self, node):
+ if not self.function_level:
+ self.function_level += 1
+ self.generic_visit(node)
+ self.function_level -= 1
+
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
diff --git a/tensorflow/python/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py
index 01dd03da0b..762fbc6f60 100644
--- a/tensorflow/python/autograph/converters/return_statements_test.py
+++ b/tensorflow/python/autograph/converters/return_statements_test.py
@@ -151,6 +151,18 @@ class SingleReturnTest(converter_testing.TestCase):
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
+ def test_nested_functions_in_control_flow(self):
+
+ def test_fn(x):
+
+ if x:
+ def inner_fn(y):
+ return y
+ inner_fn(x)
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
def test_loop(self):
def test_fn(x):
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 1bf0515745..1af8fca599 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -123,6 +123,8 @@ class ReplaceTransformer(gast.NodeTransformer):
self._check_inner_children_have_context(e)
for e in node.values:
self._check_inner_children_have_context(e)
+ elif isinstance(node, gast.Index):
+ self._check_inner_children_have_context(node.value)
elif isinstance(node, gast.Subscript):
self._check_inner_children_have_context(node.value)
self._check_inner_children_have_context(node.slice)
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 078d9a149b..3032241846 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -158,6 +158,18 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+ def test_replace_index(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
+
def test_replace_call_keyword(self):
template = """
def test_fn():
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc
index b2300df0b6..4d361612b7 100644
--- a/tensorflow/python/client/session_ref.cc
+++ b/tensorflow/python/client/session_ref.cc
@@ -93,23 +93,35 @@ class SessionLogger {
public:
SessionLogger() {
std::string log_name = getenv("TF_REPLAY_LOG_FILE");
+ LOG(INFO) << "Constructing new session logger for " << log_name;
TF_CHECK_OK(
Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
Env::Default()->DeleteFile(log_name).IgnoreError();
- TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
+ TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
}
- Status RecordCreateSession(Session* session) {
- LOG(INFO) << "Capturing devices for session.";
+ ~SessionLogger() {
+ log_writer_->Close().IgnoreError();
+ log_writer_.release();
+ log_file_->Close().IgnoreError();
+ }
+
+ Status RecordNewSession(Session* session) {
+ LOG(INFO) << "New session discovered. Capturing devices...";
ReplayOp op;
NewReplaySession* req = op.mutable_new_replay_session();
std::vector<DeviceAttributes> devices;
- TF_CHECK_OK(session->ListDevices(&devices));
- for (const DeviceAttributes& dev : devices) {
- *req->mutable_devices()->add_local_device() = dev;
+ Status status = session->ListDevices(&devices);
+ if (status.ok()) {
+ LOG(INFO) << "Found: " << devices.size() << " devices.";
+ for (const DeviceAttributes& dev : devices) {
+ *req->mutable_devices()->add_local_device() = dev;
+ }
+ } else {
+ LOG(WARNING) << "Failed to list devices on session. Continuing.";
}
req->set_session_handle(SessionToHandle(session));
@@ -226,7 +238,6 @@ class SessionLogger {
// N.B. RunOptions is not stored (it has no entry in CloseRequest)
Status RecordClose(Session* session, const RunOptions& run_options) {
- mutex_lock l(log_mutex_);
ReplayOp op;
CloseSessionRequest* req = op.mutable_close_session();
req->set_session_handle(SessionToHandle(session));
@@ -241,7 +252,6 @@ class SessionLogger {
Status RecordListDevices(Session* session,
std::vector<DeviceAttributes>* response) {
- mutex_lock l(log_mutex_);
ReplayOp op;
ListDevicesRequest* req = op.mutable_list_devices();
ListDevicesResponse* resp = op.mutable_list_devices_response();
@@ -258,7 +268,6 @@ class SessionLogger {
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) {
- mutex_lock l(log_mutex_);
ReplayOp op;
PartialRunSetupRequest* req = op.mutable_partial_run_setup();
req->set_session_handle(SessionToHandle(session));
@@ -362,18 +371,19 @@ class SessionLogger {
private:
Status Flush(const ReplayOp& op) {
+ mutex_lock l(log_mutex_);
+
string buf;
op.SerializeToString(&buf);
TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
- // Flushing the RecordWriter _does not_ flush the underlying file.
- TF_RETURN_IF_ERROR(log_writer_->Flush());
- return log_file_->Flush();
+ // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
+ return log_file_->Sync();
}
- mutex log_mutex_;
- std::unique_ptr<io::RecordWriter> log_writer_;
std::unique_ptr<WritableFile> log_file_;
+ std::unique_ptr<io::RecordWriter> log_writer_;
+ mutex log_mutex_;
};
static SessionLogger* global_session_logger() {
@@ -384,7 +394,7 @@ static SessionLogger* global_session_logger() {
SessionRef::SessionRef(Session* session) : session_(session) {
if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
logger_ = global_session_logger();
- logger_->RecordCreateSession(this->session_.get()).IgnoreError();
+ logger_->RecordNewSession(this->session_.get()).IgnoreError();
} else {
logger_ = nullptr;
}
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index f576435136..347833ce8f 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -120,11 +120,17 @@ class SessionTest(test_util.TensorFlowTestCase):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
- devices = sess.list_devices()
- self.assertEqual(2, len(devices))
- for device in devices:
- self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
- device.name).device_type)
+ num_cpu_devices = 0
+ num_gpu_devices = 0
+ for device in sess.list_devices():
+ device_type = framework_device_lib.DeviceSpec.from_string(
+ device.name).device_type
+ if device_type == 'CPU':
+ num_cpu_devices += 1
+ elif device_type == 'GPU':
+ num_gpu_devices += 1
+ self.assertEqual(2, num_cpu_devices)
+ self.assertEqual(0, num_gpu_devices)
def testPerSessionThreads(self):
with session.Session(
@@ -1022,7 +1028,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with session.Session():
a = constant_op.constant(1.0, shape=[1, 2])
b = constant_op.constant(2.0, shape=[1, 2], name='b')
- v = variables.Variable(a, a.dtype)
+ v = variables.VariableV1(a, a.dtype)
assign_a_to_v = state_ops.assign(v, a)
assign_a_to_v.eval()
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 45f40cd183..bea5aa990f 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 24)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 10, 1)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 28ee3ebaa6..cadfe7f9e0 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -15,6 +15,7 @@ tf_py_test(
size = "small",
srcs = ["batch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -31,10 +32,44 @@ tf_py_test(
)
tf_py_test(
+ name = "cache_dataset_op_test",
+ size = "small",
+ srcs = ["cache_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+tf_py_test(
+ name = "concatenate_dataset_op_test",
+ size = "small",
+ srcs = ["concatenate_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_py_test(
name = "dataset_constructor_op_test",
size = "small",
srcs = ["dataset_constructor_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -63,6 +98,7 @@ tf_py_test(
size = "medium",
srcs = ["dataset_from_generator_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -78,6 +114,7 @@ tf_py_test(
size = "small",
srcs = ["dataset_ops_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python/data/ops:dataset_ops",
@@ -89,6 +126,7 @@ tf_py_test(
size = "small",
srcs = ["filter_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -106,6 +144,7 @@ tf_py_test(
size = "small",
srcs = ["flat_map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -123,6 +162,7 @@ tf_py_test(
size = "small",
srcs = ["list_files_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -133,10 +173,25 @@ tf_py_test(
)
tf_py_test(
+ name = "inputs_test",
+ size = "small",
+ srcs = ["inputs_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+tf_py_test(
name = "interleave_dataset_op_test",
size = "small",
srcs = ["interleave_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -151,11 +206,80 @@ tf_py_test(
],
)
+cuda_py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:util",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:training",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ ],
+ grpc_enabled = True,
+)
+
+tf_py_test(
+ name = "iterator_ops_cluster_test",
+ size = "small",
+ srcs = ["iterator_ops_cluster_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:lookup_ops",
+ ],
+ grpc_enabled = True,
+ tags = [
+ "no_oss", # Test flaky due to port collisions.
+ "no_windows",
+ ],
+)
+
tf_py_test(
name = "map_dataset_op_test",
size = "small",
srcs = ["map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -177,11 +301,54 @@ tf_py_test(
],
)
+cuda_py_test(
+ name = "multi_device_iterator_test",
+ size = "small",
+ srcs = ["multi_device_iterator_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_windows_gpu",
+ ],
+)
+
+cuda_py_test(
+ name = "optional_ops_test",
+ size = "small",
+ srcs = ["optional_ops_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
tf_py_test(
name = "prefetch_dataset_op_test",
size = "small",
srcs = ["prefetch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -197,6 +364,7 @@ tf_py_test(
size = "small",
srcs = ["range_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dataset_ops_gen",
@@ -218,6 +386,7 @@ tf_py_test(
size = "small",
srcs = ["reader_dataset_ops_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -236,32 +405,35 @@ tf_py_test(
)
tf_py_test(
- name = "sequence_dataset_op_test",
+ name = "reduce_dataset_op_test",
size = "small",
- srcs = ["sequence_dataset_op_test.py"],
+ srcs = ["reduce_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
- name = "shuffle_dataset_op_test",
+ name = "sequence_dataset_op_test",
size = "small",
- srcs = ["shuffle_dataset_op_test.py"],
+ srcs = ["sequence_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
],
)
@@ -270,6 +442,7 @@ tf_py_test(
size = "small",
srcs = ["shard_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
@@ -277,171 +450,62 @@ tf_py_test(
)
tf_py_test(
- name = "cache_dataset_op_test",
+ name = "shuffle_dataset_op_test",
size = "small",
- srcs = ["cache_dataset_op_test.py"],
+ srcs = ["shuffle_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)
-tf_py_test(
- name = "zip_dataset_op_test",
- size = "small",
- srcs = ["zip_dataset_op_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
- name = "concatenate_dataset_op_test",
- size = "small",
- srcs = ["concatenate_dataset_op_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
+py_library(
+ name = "test_base",
+ srcs = ["test_base.py"],
+ deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/util:nest",
],
)
-cuda_py_test(
- name = "iterator_ops_test",
+tf_py_test(
+ name = "window_dataset_op_test",
size = "small",
- srcs = ["iterator_ops_test.py"],
+ srcs = ["window_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/training/checkpointable:util",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:training",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
- ],
- grpc_enabled = True,
-)
-
-tf_py_test(
- name = "iterator_ops_cluster_test",
- size = "small",
- srcs = ["iterator_ops_cluster_test.py"],
- additional_deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:session",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:lookup_ops",
- ],
- grpc_enabled = True,
- tags = [
- "no_oss", # Test flaky due to port collisions.
- "no_windows",
- ],
-)
-
-cuda_py_test(
- name = "optional_ops_test",
- size = "small",
- srcs = ["optional_ops_test.py"],
- additional_deps = [
- "@absl_py//absl/testing:parameterized",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:optional_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:tensor_shape",
- ],
-)
-
-cuda_py_test(
- name = "multi_device_iterator_test",
- size = "small",
- srcs = ["multi_device_iterator_test.py"],
- additional_deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:multi_device_iterator_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- ],
- tags = [
- "no_windows_gpu",
],
)
tf_py_test(
- name = "window_dataset_op_test",
+ name = "zip_dataset_op_test",
size = "small",
- srcs = ["window_dataset_op_test.py"],
+ srcs = ["zip_dataset_op_test.py"],
additional_deps = [
- "@absl_py//absl/testing:parameterized",
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
],
)
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index c48708a2b9..9cb4daf284 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -37,7 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('even', 28, 14, False),
@@ -115,11 +116,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testBatchSparse(self):
def _sparse(i):
@@ -227,7 +223,7 @@ def _random_seq_lens(count):
return np.random.randint(20, size=(count,)).astype(np.int32)
-class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
+class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('default_padding', _random_seq_lens(32), 4, [-1], False),
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index d5f5b2fe05..63625fac03 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -23,6 +23,7 @@ import tempfile
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -34,7 +35,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class FileCacheDatasetTest(test.TestCase):
+class FileCacheDatasetTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
@@ -200,7 +201,7 @@ class FileCacheDatasetTest(test.TestCase):
self.assertAllEqual(elements, elements_itr2)
-class MemoryCacheDatasetTest(test.TestCase):
+class MemoryCacheDatasetTest(test_base.DatasetTestBase):
def testCacheDatasetPassthrough(self):
with ops.device("cpu:0"):
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index 5dfb84f28e..83af31f380 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class ConcatenateDatasetTest(test.TestCase):
+class ConcatenateDatasetTest(test_base.DatasetTestBase):
def testConcatenateDataset(self):
input_components = (
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
index e43564a2eb..bc6b36285a 100644
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -36,7 +37,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testFromTensors(self):
"""Test a dataset that represents a single tuple of tensors."""
@@ -58,11 +59,6 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testFromTensorsSparse(self):
"""Test a dataset that represents a single tuple of tensors."""
components = (sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index cd0c1ddf1e..cb8cb9a77d 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -22,6 +22,7 @@ import threading
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
output_types=None):
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index 239aa85175..f115f9d9c7 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -19,11 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
-class DatasetOpsTest(test.TestCase):
+class DatasetOpsTest(test_base.DatasetTestBase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 19944d389f..6b7afafa5d 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -22,6 +22,7 @@ import time
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class FilterDatasetTest(test.TestCase):
+class FilterDatasetTest(test_base.DatasetTestBase):
def testFilterDataset(self):
components = (
@@ -129,11 +130,6 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _map_fn(i):
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
index 1123cbff62..68038f9cfc 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
-class FlatMapDatasetTest(test.TestCase):
+class FlatMapDatasetTest(test_base.DatasetTestBase):
# pylint: disable=g-long-lambda
def testFlatMapDataset(self):
diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py
new file mode 100644
index 0000000000..d089b49bcc
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/inputs_test.py
@@ -0,0 +1,149 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class InputsTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def make_apply_fn(dataset):
+
+ def apply_fn(dataset):
+
+ def _apply_fn(dataset):
+ return dataset.cache()
+
+ return dataset.apply(_apply_fn)
+
+ return apply_fn
+
+ @staticmethod
+ def make_gen():
+
+ def gen():
+ yield 42
+
+ return gen
+
+ @staticmethod
+ def make_interleave_fn(dataset, num_parallel_calls=None):
+
+ def interleave_fn(dataset):
+ return dataset.interleave(
+ lambda x: dataset_ops.Dataset.range(0),
+ cycle_length=2,
+ num_parallel_calls=num_parallel_calls)
+
+ return interleave_fn
+
+ @parameterized.named_parameters(
+ ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
+ ("FromGenerator",
+ dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
+ 1),
+ ("FromSparseTensorSlices",
+ dataset_ops.Dataset.from_sparse_tensor_slices(
+ sparse_tensor.SparseTensor(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])))),
+ ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
+ ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
+ ("Range", dataset_ops.Dataset.range(10)),
+ ("TextLine", readers.TextLineDataset("")),
+ ("TFRecord", readers.TFRecordDataset(""), 1),
+ )
+ def testDatasetSourceInputs(self, dataset, num_inputs=0):
+ self.assertEqual(num_inputs, len(dataset._inputs()))
+
+ @parameterized.named_parameters(
+ ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
+ ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
+ ("Filter", lambda x: x.filter(lambda x: True),
+ dataset_ops.Dataset.range(0)),
+ ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
+ ("PaddedBatch", lambda x: x.padded_batch(10, []),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelInterleave",
+ make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
+ dataset_ops.Dataset.range(0)),
+ ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
+ ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
+ ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
+ ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
+ ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
+ )
+ def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
+ self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
+
+ @parameterized.named_parameters(
+ ("Concatenate", lambda x, y: x.concatenate(y),
+ dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
+ def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
+ self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
+
+ @parameterized.named_parameters(
+ ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
+ ("ZipNest", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0),
+ (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
+ ("ZipTuple", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
+ def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
+ self.assertEqual(
+ nest.flatten(input_datasets),
+ dataset_fn(input_datasets)._inputs())
+
+ def testCollectInputs(self):
+ ds1 = dataset_ops.Dataset.range(0)
+ ds2 = ds1.concatenate(ds1)
+ ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
+
+ inputs = []
+ queue = [ds3]
+ while queue:
+ ds = queue[0]
+ queue = queue[1:]
+ queue.extend(ds._inputs())
+ inputs.append(ds)
+
+ self.assertEqual(5, inputs.count(ds1))
+ self.assertEqual(2, inputs.count(ds2))
+ self.assertEqual(1, inputs.count(ds3))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index e7e51df65e..92bb67b6ff 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
+class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index c4b338a58f..8eb13815d4 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -22,6 +22,7 @@ from os import path
import shutil
import tempfile
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class ListFilesDatasetOpTest(test.TestCase):
+class ListFilesDatasetOpTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index ae04995436..230ae3f3fd 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -47,7 +48,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-class MapDatasetTest(test.TestCase, parameterized.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
@@ -574,11 +575,6 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _sparse(i):
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
index 056664b83b..1cf6dd1bea 100644
--- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.framework import dtypes
@@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MultiDeviceIteratorTest(test.TestCase):
+class MultiDeviceIteratorTest(test_base.DatasetTestBase):
def testNoGetNext(self):
dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index 706a65fe55..604e3ad88e 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
@@ -35,7 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class OptionalTest(test.TestCase, parameterized.TestCase):
+class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFromValue(self):
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index cc97bac609..76e2697b29 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
+class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((-1), (0), (5))
def testBufferSize(self, buffer_size):
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
index 51e90785e7..b7e2a5f615 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
@@ -34,7 +35,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def tearDown(self):
# Remove all checkpoint files.
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index aa3636364d..aef2dd1d9c 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -21,6 +21,7 @@ import gzip
import os
import zlib
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
@@ -46,7 +47,7 @@ except ImportError:
psutil_import_succeeded = False
-class TextLineDatasetTest(test.TestCase):
+class TextLineDatasetTest(test_base.DatasetTestBase):
def _lineText(self, f, l):
return compat.as_bytes("%d: %d" % (f, l))
@@ -199,7 +200,7 @@ class TextLineDatasetTest(test.TestCase):
self.assertNotIn(filename, [open_file.path for open_file in open_files])
-class FixedLengthRecordReaderTest(test.TestCase):
+class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
@@ -621,7 +622,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
sess.run(get_next_op)
-class TFRecordDatasetTest(test.TestCase):
+class TFRecordDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordDatasetTest, self).setUp()
diff --git a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
new file mode 100644
index 0000000000..11e07300b9
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
@@ -0,0 +1,124 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testSum(self):
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), lambda x, y: x + y)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i) // 2, sess.run(result))
+
+ def testSumTuple(self):
+
+ def reduce_fn(state, value):
+ v1, v2 = value
+ return state + v1 + v2
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ ds = dataset_ops.Dataset.zip((ds, ds))
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i), sess.run(result))
+
+ def testSumAndCount(self):
+
+ def reduce_fn(state, value):
+ s, c = state
+ return s + value, c + 1
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
+ with self.cached_session() as sess:
+ s, c = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, s)
+ self.assertEqual(i, c)
+
+ def testSquareUsingPlaceholder(self):
+ delta = array_ops.placeholder(dtype=dtypes.int64)
+
+ def reduce_fn(state, _):
+ return state + delta
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ square = sess.run(result, feed_dict={delta: i})
+ self.assertEqual(i * i, square)
+
+ def testSparse(self):
+
+ def reduce_fn(_, value):
+ return value
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
+ result = ds.reduce(make_sparse_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
+
+ def testNested(self):
+
+ def reduce_fn(state, value):
+ state["dense"] += value["dense"]
+ state["sparse"] = value["sparse"]
+ return state
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def map_fn(i):
+ return {"dense": math_ops.cast(i, dtype=dtypes.int64),
+ "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))}
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
+ result = ds.reduce(map_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ result = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, result["dense"])
+ self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 37e2333560..e86356dee7 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SequenceDatasetTest(test.TestCase):
+class SequenceDatasetTest(test_base.DatasetTestBase):
def testRepeatTensorDataset(self):
"""Test a dataset that repeats its input multiple times."""
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index 137f6341ce..b9f3c79da5 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class ShardDatasetOpTest(test.TestCase):
+class ShardDatasetOpTest(test_base.DatasetTestBase):
def testSimpleCase(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index f294840706..347af18576 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -21,6 +21,7 @@ import collections
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ShuffleDatasetTest(test.TestCase):
+class ShuffleDatasetTest(test_base.DatasetTestBase):
def testShuffleDataset(self):
components = (
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
new file mode 100644
index 0000000000..b730e10949
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -0,0 +1,109 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for tf.data functionality."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class DatasetTestBase(test.TestCase):
+ """Base class for dataset tests."""
+
+ def assertSparseValuesEqual(self, a, b):
+ """Asserts that two SparseTensors/SparseTensorValues are equal."""
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def getNext(self, dataset):
+ """Returns a callable that returns the next element of the dataset.
+
+ Example use:
+ ```python
+ # In both graph and eager modes
+ dataset = ...
+ nxt = self.getNext(dataset)
+ result = self.evaluate(nxt())
+ ```
+
+ Args:
+ dataset: A dataset whose next element is returned
+
+ Returns:
+ A callable that returns the next element of `dataset`
+ """
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ nxt = it.get_next()
+ return lambda: nxt
+
+ def assertDatasetsEqual(self, dataset1, dataset2):
+ """Checks that datasets are equal. Supports both graph and eager mode."""
+ self.assertEqual(dataset1.output_types, dataset2.output_types)
+ self.assertEqual(dataset1.output_classes, dataset2.output_classes)
+
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except errors.OutOfRangeError:
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+
+ op1 = nest.flatten(op1)
+ op2 = nest.flatten(op2)
+ assert len(op1) == len(op2)
+ for i in range(len(op1)):
+ if isinstance(
+ op1[i],
+ (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ self.assertSparseValuesEqual(op1[i], op2[i])
+ else:
+ self.assertAllEqual(op1[i], op2[i])
+
+ def assertDatasetsRaiseSameError(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
+ """Checks that datasets raise the same error on the first get_next call."""
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ try:
+ self.evaluate(next1())
+ raise ValueError(
+ 'Expected dataset to raise an error of type %s, but it did not.' %
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
+ # Check that the first segment of the error messages are the same.
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
+ self.evaluate(next2())
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
index fd4348426d..9d06781094 100644
--- a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -150,11 +151,6 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
stride_t: stride
})
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testWindowSparse(self):
def _sparse(i):
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
index 3106effbd3..9d76387a34 100644
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ZipDatasetTest(test.TestCase):
+class ZipDatasetTest(test_base.DatasetTestBase):
def testZipDataset(self):
component_placeholders = [
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 7c20c049f5..6bba72a8e9 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -80,6 +80,12 @@ class Dataset(object):
"""
raise NotImplementedError("Dataset._as_variant_tensor")
+ @abc.abstractmethod
+ def _inputs(self):
+ """Returns a list of the input datasets of the dataset."""
+
+ raise NotImplementedError("Dataset._inputs")
+
def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -1007,8 +1013,8 @@ class Dataset(object):
return ParallelMapDataset(self, map_func, num_parallel_calls)
def flat_map(self, map_func):
- """Maps `map_func` across this dataset and flattens the result.
-
+ """Maps `map_func` across this dataset and flattens the result.
+
Use `flat_map` if you want to make sure that the order of your dataset
stays the same. For example, to flatten a dataset of batches into a
dataset of their elements:
@@ -1017,15 +1023,15 @@ class Dataset(object):
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset. '[...]' represents a tensor.
a = {[1,2,3,4,5], [6,7,8,9], [10]}
-
- a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
+
+ a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
{[1,2,3,4,5,6,7,8,9,10]}
```
-
- `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
- `flat_map` produces the same output as
+
+ `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
+ `flat_map` produces the same output as
`tf.data.Dataset.interleave(cycle_length=1)`
-
+
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@@ -1157,6 +1163,7 @@ class Dataset(object):
dataset = transformation_func(self)
if not isinstance(dataset, Dataset):
raise TypeError("`transformation_func` must return a Dataset.")
+ dataset._input_datasets = [self] # pylint: disable=protected-access
return dataset
def window(self, size, shift=None, stride=1, drop_remainder=False):
@@ -1198,8 +1205,146 @@ class Dataset(object):
shift = size
return WindowDataset(self, size, shift, stride, drop_remainder)
+ def reduce(self, initial_state, reduce_func):
+ """Reduces the input dataset to a single element.
+
+ The transformation calls `reduce_func` successively on every element of
+ the input dataset until the dataset is exhausted, aggregating information in
+ its internal state. The `initial_state` argument is used for the initial
+ state and the final state is returned as the result.
+
+ For example:
+ - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)`
+ produces `5`
+ - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)`
+ produces `10`
+
+ Args:
+ initial_state: A nested structure of tensors, representing the initial
+ state of the transformation.
+ reduce_func: A function that maps `(old_state, input_element)` to
+ `new_state`. It must take two arguments and return a nested structure
+ of tensors. The structure of `new_state` must match the structure of
+ `initial_state`.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects, corresponding to the final
+ state of the transformation.
+
+ """
+
+ with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ initial_state = nest.pack_sequence_as(initial_state, [
+ sparse_tensor_lib.SparseTensor.from_value(t)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(initial_state))
+ ])
+
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state.
+ state_classes = sparse.get_classes(initial_state)
+ state_shapes = nest.pack_sequence_as(
+ initial_state, [t.get_shape() for t in nest.flatten(initial_state)])
+ state_types = nest.pack_sequence_as(
+ initial_state, [t.dtype for t in nest.flatten(initial_state)])
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = StructuredFunctionWrapper(
+ reduce_func,
+ "reduce()",
+ input_classes=(state_classes, self.output_classes),
+ input_shapes=(state_shapes, self.output_shapes),
+ input_types=(state_types, self.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ output_classes = wrapped_func.output_classes
+ for new_state_class, state_class in zip(
+ nest.flatten(output_classes), nest.flatten(state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." % (state_classes,
+ wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(output_types), nest.flatten(state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." % (state_types,
+ wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
+ output_shapes = wrapped_func.output_shapes
+ flat_state_shapes = nest.flatten(state_shapes)
+ flat_new_state_shapes = nest.flatten(output_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ state_shapes = nest.pack_sequence_as(state_shapes,
+ weakened_state_shapes)
+
+ reduce_func = wrapped_func.function
+ reduce_func.add_to_graph(ops.get_default_graph())
+
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(
+ output_types,
+ gen_dataset_ops.reduce_dataset(
+ self._as_variant_tensor(), # pylint: disable=protected-access
+ nest.flatten(sparse.serialize_sparse_tensors(initial_state)),
+ reduce_func.captured_inputs,
+ f=reduce_func,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)))),
+ output_types,
+ output_shapes,
+ output_classes)
+
+
+class DatasetSource(Dataset):
+ """Abstract class representing a dataset with no inputs."""
+
+ def _inputs(self):
+ return []
+
+
+class UnaryDataset(Dataset):
+ """Abstract class representing a dataset with one input."""
+
+ def __init__(self, input_dataset):
+ super(UnaryDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ def _inputs(self):
+ return [self._input_dataset]
-class TensorDataset(Dataset):
+
+class TensorDataset(DatasetSource):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
def __init__(self, tensors):
@@ -1239,7 +1384,7 @@ class TensorDataset(Dataset):
return self._output_types
-class TensorSliceDataset(Dataset):
+class TensorSliceDataset(DatasetSource):
"""A `Dataset` of slices from a nested structure of tensors."""
def __init__(self, tensors):
@@ -1283,7 +1428,7 @@ class TensorSliceDataset(Dataset):
return self._output_types
-class SparseTensorSliceDataset(Dataset):
+class SparseTensorSliceDataset(DatasetSource):
"""A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
def __init__(self, sparse_tensor):
@@ -1384,6 +1529,9 @@ class _VariantDataset(Dataset):
def _as_variant_tensor(self):
return self._dataset_variant
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return self._structure.output_classes
@@ -1624,7 +1772,7 @@ def flat_structure(dataset):
}
-class _GeneratorDataset(Dataset):
+class _GeneratorDataset(DatasetSource):
"""A `Dataset` that generates elements by invoking a function."""
def __init__(self, init_args, init_func, next_func, finalize_func):
@@ -1725,6 +1873,9 @@ class ZipDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return nest.flatten(self._datasets)
+
@property
def output_classes(self):
return nest.pack_sequence_as(
@@ -1760,6 +1911,7 @@ class ConcatenateDataset(Dataset):
raise TypeError(
"Two datasets to concatenate have different classes %s and %s" %
(input_dataset.output_classes, dataset_to_concatenate.output_classes))
+ self._input_datasets = [input_dataset, dataset_to_concatenate]
def _as_variant_tensor(self):
# pylint: disable=protected-access
@@ -1769,6 +1921,9 @@ class ConcatenateDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._input_dataset, self._dataset_to_concatenate]
+
@property
def output_classes(self):
return self._input_dataset.output_classes
@@ -1787,12 +1942,12 @@ class ConcatenateDataset(Dataset):
return self._input_dataset.output_types
-class RepeatDataset(Dataset):
+class RepeatDataset(UnaryDataset):
"""A `Dataset` that repeats its input several times."""
def __init__(self, input_dataset, count):
"""See `Dataset.repeat()` for details."""
- super(RepeatDataset, self).__init__()
+ super(RepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if count is None:
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
@@ -1819,7 +1974,7 @@ class RepeatDataset(Dataset):
return self._input_dataset.output_types
-class RangeDataset(Dataset):
+class RangeDataset(DatasetSource):
"""A `Dataset` of a step separated range of values."""
def __init__(self, *args):
@@ -1867,12 +2022,12 @@ class RangeDataset(Dataset):
return dtypes.int64
-class CacheDataset(Dataset):
+class CacheDataset(UnaryDataset):
"""A `Dataset` that caches elements of its input."""
def __init__(self, input_dataset, filename):
"""See `Dataset.cache()` for details."""
- super(CacheDataset, self).__init__()
+ super(CacheDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._filename = ops.convert_to_tensor(
filename, dtype=dtypes.string, name="filename")
@@ -1896,7 +2051,7 @@ class CacheDataset(Dataset):
return self._input_dataset.output_types
-class ShuffleDataset(Dataset):
+class ShuffleDataset(UnaryDataset):
"""A `Dataset` that randomly shuffles the elements of its input."""
def __init__(self,
@@ -1924,7 +2079,7 @@ class ShuffleDataset(Dataset):
Raises:
ValueError: if invalid arguments are provided.
"""
- super(ShuffleDataset, self).__init__()
+ super(ShuffleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
@@ -1956,12 +2111,12 @@ class ShuffleDataset(Dataset):
return self._input_dataset.output_types
-class TakeDataset(Dataset):
+class TakeDataset(UnaryDataset):
"""A `Dataset` containing the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.take()` for details."""
- super(TakeDataset, self).__init__()
+ super(TakeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -1984,12 +2139,12 @@ class TakeDataset(Dataset):
return self._input_dataset.output_types
-class SkipDataset(Dataset):
+class SkipDataset(UnaryDataset):
"""A `Dataset` skipping the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.skip()` for details."""
- super(SkipDataset, self).__init__()
+ super(SkipDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -2012,12 +2167,12 @@ class SkipDataset(Dataset):
return self._input_dataset.output_types
-class BatchDataset(Dataset):
+class BatchDataset(UnaryDataset):
"""A `Dataset` that batches contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, drop_remainder):
"""See `Dataset.batch()` for details."""
- super(BatchDataset, self).__init__()
+ super(BatchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
@@ -2166,13 +2321,13 @@ def _default_padding(input_dataset):
return nest.map_structure(make_zero, input_dataset.output_types)
-class PaddedBatchDataset(Dataset):
+class PaddedBatchDataset(UnaryDataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
drop_remainder):
"""See `Dataset.batch()` for details."""
- super(PaddedBatchDataset, self).__init__()
+ super(PaddedBatchDataset, self).__init__(input_dataset)
if sparse.any_sparse(input_dataset.output_classes):
# TODO(b/63669786): support batching of sparse tensors
raise TypeError(
@@ -2272,12 +2427,12 @@ def _warn_if_collections(transformation_name):
% transformation_name)
-class MapDataset(Dataset):
+class MapDataset(UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(MapDataset, self).__init__()
+ super(MapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
@@ -2338,12 +2493,12 @@ class ParallelMapDataset(MapDataset):
# pylint: enable=protected-access
-class FlatMapDataset(Dataset):
+class FlatMapDataset(UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func):
"""See `Dataset.flat_map()` for details."""
- super(FlatMapDataset, self).__init__()
+ super(FlatMapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
@@ -2434,12 +2589,12 @@ class ParallelInterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
-class FilterDataset(Dataset):
+class FilterDataset(UnaryDataset):
"""A `Dataset` that filters its input according to a predicate function."""
def __init__(self, input_dataset, predicate):
"""See `Dataset.filter()` for details."""
- super(FilterDataset, self).__init__()
+ super(FilterDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
predicate, "Dataset.filter()", input_dataset)
@@ -2469,12 +2624,12 @@ class FilterDataset(Dataset):
return self._input_dataset.output_types
-class PrefetchDataset(Dataset):
+class PrefetchDataset(UnaryDataset):
"""A `Dataset` that asynchronously prefetches its input."""
def __init__(self, input_dataset, buffer_size):
"""See `Dataset.prefetch()` for details."""
- super(PrefetchDataset, self).__init__()
+ super(PrefetchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if buffer_size is None:
buffer_size = -1 # This is the sentinel for auto-tuning.
@@ -2500,12 +2655,12 @@ class PrefetchDataset(Dataset):
return self._input_dataset.output_types
-class WindowDataset(Dataset):
+class WindowDataset(UnaryDataset):
"""A dataset that creates window datasets from the input elements."""
def __init__(self, input_dataset, size, shift, stride, drop_remainder):
"""See `window_dataset()` for more details."""
- super(WindowDataset, self).__init__()
+ super(WindowDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py
index c914a43956..b7d3aac206 100644
--- a/tensorflow/python/data/ops/multi_device_iterator_ops.py
+++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py
@@ -116,6 +116,10 @@ class _PerDeviceGenerator(dataset_ops.Dataset):
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
+ def _inputs(self):
+ # TODO(b/116506223): Determine which datasets should be used as inputs here.
+ return []
+
@property
def output_types(self):
return self._output_types
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index 066e09969c..b0f26631f9 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -61,6 +61,9 @@ class TextLineDataset(dataset_ops.Dataset):
return gen_dataset_ops.text_line_dataset(
self._filenames, self._compression_type, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
@@ -105,6 +108,9 @@ class _TFRecordDataset(dataset_ops.Dataset):
return gen_dataset_ops.tf_record_dataset(
self._filenames, self._compression_type, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
@@ -224,6 +230,9 @@ class TFRecordDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
return self._impl._as_variant_tensor() # pylint: disable=protected-access
+ def _inputs(self):
+ return self._impl._inputs() # pylint: disable=protected-access
+
@property
def output_classes(self):
return self._impl.output_classes
@@ -278,6 +287,9 @@ class FixedLengthRecordDataset(dataset_ops.Dataset):
self._filenames, self._header_bytes, self._record_bytes,
self._footer_bytes, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 849d165bfa..e84482d2b2 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -18,6 +18,7 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
py_library(
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 4630bda590..f197a9e4dc 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -599,11 +599,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
v_name = "simple_mul_add/v"
u_init = constant_op.constant(u_init_val, shape=[2, 2], name="u_init")
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
cls._u_line_number = line_number_above()
v_init = constant_op.constant(v_init_val, shape=[2, 1], name="v_init")
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
cls._v_line_number = line_number_above()
w = math_ops.matmul(u, v, name="simple_mul_add/matmul")
@@ -612,7 +612,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
x = math_ops.add(w, w, name="simple_mul_add/add")
cls._x_line_number = line_number_above()
- a = variables.Variable([1, 3, 3, 7], name="a")
+ a = variables.VariableV1([1, 3, 3, 7], name="a")
u.initializer.run()
v.initializer.run()
@@ -1371,7 +1371,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates u.
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" simple_mul_add/u/Assign",
@@ -1388,7 +1388,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates v.
index = self._findSourceLine(out, self._v_line_number)
self.assertEqual(
- ["L%d v = variables.Variable(v_init, name=v_name)" %
+ ["L%d v = variables.VariableV1(v_init, name=v_name)" %
self._v_line_number,
" simple_mul_add/v"],
out.lines[index : index + 2])
@@ -1425,7 +1425,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates u.
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u/read:0",
" simple_mul_add/u:0"],
@@ -1447,7 +1447,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" simple_mul_add/u/Assign",
@@ -1470,7 +1470,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" (... Omitted 2 of 3 op(s) ...) +5"],
@@ -1580,7 +1580,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
"""List an input tree containing tensors from non-:0 output slot."""
with session.Session(config=no_rewrite_session_config()) as sess:
- x = variables.Variable([1, 3, 3, 7], name="x")
+ x = variables.VariableV1([1, 3, 3, 7], name="x")
_, idx = array_ops.unique(x, name="x_unique")
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
sess.run(x.initializer)
@@ -1684,7 +1684,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
with session.Session(config=no_rewrite_session_config()) as sess:
x_init_val = np.array([5.0, 3.0])
x_init = constant_op.constant(x_init_val, shape=[2])
- x = variables.Variable(x_init, name="control_deps/x")
+ x = variables.VariableV1(x_init, name="control_deps/x")
y = math_ops.add(x, x, name="control_deps/y")
y = control_flow_ops.with_dependencies(
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index ee8cabca0d..7b8a42c253 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -132,8 +132,8 @@ def _parse_updated(lines):
class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.a = variables.Variable(10.0, name="a")
- self.b = variables.Variable(20.0, name="b")
+ self.a = variables.VariableV1(10.0, name="a")
+ self.b = variables.VariableV1(20.0, name="b")
self.c = math_ops.add(self.a, self.b, name="c") # Should be 30.0.
self.d = math_ops.subtract(self.a, self.c, name="d") # Should be -20.0.
diff --git a/tensorflow/python/debug/lib/debug_utils_test.py b/tensorflow/python/debug/lib/debug_utils_test.py
index 5b1875e092..23ab98444c 100644
--- a/tensorflow/python/debug/lib/debug_utils_test.py
+++ b/tensorflow/python/debug/lib/debug_utils_test.py
@@ -46,8 +46,8 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
cls._b_init = constant_op.constant(
cls._b_init_val, shape=[2, 1], name="b_init")
- cls._a = variables.Variable(cls._a_init, name="a1")
- cls._b = variables.Variable(cls._b_init, name="b")
+ cls._a = variables.VariableV1(cls._a_init, name="a1")
+ cls._b = variables.VariableV1(cls._b_init, name="b")
cls._c = constant_op.constant(cls._c_val, shape=[2, 1], name="c")
# Matrix product of a and b.
diff --git a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
index 46a7be5808..74498c8ea3 100644
--- a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
@@ -118,8 +118,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
"""
with ops.Graph().as_default() as graph:
with ops.device("/job:worker/task:0/cpu:0"):
- self.a = variables.Variable(10.0, name="a")
- self.b = variables.Variable(100.0, name="b")
+ self.a = variables.VariableV1(10.0, name="a")
+ self.b = variables.VariableV1(100.0, name="b")
self.inc_a = state_ops.assign_add(self.a, 2.0, name="inc_a")
self.dec_b = state_ops.assign_add(self.b, -5.0, name="dec_b")
self.p = math_ops.multiply(self.inc_a, self.dec_b, name="p")
diff --git a/tensorflow/python/debug/lib/grpc_large_data_test.py b/tensorflow/python/debug/lib/grpc_large_data_test.py
index 5bc477a9ba..ccc21bcf94 100644
--- a/tensorflow/python/debug/lib/grpc_large_data_test.py
+++ b/tensorflow/python/debug/lib/grpc_large_data_test.py
@@ -61,7 +61,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
with self.test_session(
use_gpu=True,
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- u = variables.Variable(42.0, name="original_u")
+ u = variables.VariableV1(42.0, name="original_u")
for _ in xrange(50 * 1000):
u = array_ops.identity(u)
sess.run(variables.global_variables_initializer())
@@ -94,7 +94,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
u_init = constant_op.constant(
u_init_val_array, dtype=dtypes.float32, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds # Unused by this watch_fn.
@@ -117,7 +117,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
u_init = constant_op.constant(
u_init_val, dtype=dtypes.string, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -146,7 +146,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
u_init = constant_op.constant(
u_init_val_array, dtype=dtypes.string, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -167,7 +167,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
config=session_debug_testlib.no_rewrite_session_config()) as sess:
u_init = constant_op.constant(
[], dtype=dtypes.float32, shape=[0], name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -189,7 +189,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
config=session_debug_testlib.no_rewrite_session_config()) as sess:
u_init = constant_op.constant(
[], dtype=dtypes.string, shape=[0], name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py
index ba0f15b4e2..1874160dd6 100644
--- a/tensorflow/python/debug/lib/session_debug_file_test.py
+++ b/tensorflow/python/debug/lib/session_debug_file_test.py
@@ -58,9 +58,9 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
v_name = "diff_Watch/v"
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant(v_init_val, shape=[2, 1])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
w = math_ops.matmul(u, v, name="diff_Watch/matmul")
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index 91f21cb1f3..bfc9a3a382 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -148,8 +148,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
sess, "localhost:%d" % self._server_port, watch_fn="foo")
def testGrpcDebugWrapperSessionWithoutWatchFnWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -175,8 +175,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
del feeds, fetch_keys
return ["DebugIdentity", "DebugNumericSummary"], r".*/read", None
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -209,8 +209,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
op_type_regex_whitelist=None,
tolerate_debug_op_creation_failures=True)
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -241,8 +241,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
def testTensorBoardDebugHookWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -286,8 +286,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
self._server.query_source_file_line(__file__, 1)
def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -381,8 +381,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_1")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_1")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -451,8 +451,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_1")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_1")
# These two nodes have names that match those in the
# toggle_watch_on_core_metadata argument used when calling
# start_server_on_separate_thread().
@@ -491,7 +491,7 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v = variables.Variable(50.0, name="v")
+ v = variables.VariableV1(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
@@ -534,8 +534,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -592,8 +592,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -665,8 +665,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -699,7 +699,7 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
with session.Session() as sess:
- v = variables.Variable(50.0, name="v")
+ v = variables.VariableV1(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
@@ -743,7 +743,7 @@ class DelayedDebugServerTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess:
a_init = constant_op.constant(42.0, name="a_init")
- a = variables.Variable(a_init, name="a")
+ a = variables.VariableV1(a_init, name="a")
def watch_fn(fetches, feeds):
del fetches, feeds
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index 070d9c4cd7..25ef91b575 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -70,7 +70,7 @@ class _RNNCellForTest(rnn_cell_impl.RNNCell):
def __init__(self, input_output_size, state_size):
self._input_output_size = input_output_size
self._state_size = state_size
- self._w = variables.Variable(1.0, dtype=dtypes.float32, name="w")
+ self._w = variables.VariableV1(1.0, dtype=dtypes.float32, name="w")
@property
def output_size(self):
@@ -182,9 +182,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "w"
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant(v_init_val, shape=[2, 1])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
w = math_ops.matmul(u, v, name=w_name)
@@ -221,8 +221,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testCopyNodesHaveCorrectDebugOpsAndURLsAttributeValues(self):
with session.Session() as sess:
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess.run(variables.global_variables_initializer())
@@ -324,8 +324,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
str1_name = "str1"
str2_name = "str2"
- str1 = variables.Variable(str1_init, name=str1_name)
- str2 = variables.Variable(str2_init, name=str2_name)
+ str1 = variables.VariableV1(str1_init, name=str1_name)
+ str2 = variables.VariableV1(str2_init, name=str2_name)
# Concatenate str1 and str2
str_concat = math_ops.add(str1, str2, name="str_concat")
@@ -387,9 +387,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
s_name = "%s/s" % op_namespace
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
s_init = constant_op.constant(s_init_val)
- s = variables.Variable(s_init, name=s_name)
+ s = variables.VariableV1(s_init, name=s_name)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_urls = self._debug_urls()
@@ -439,7 +439,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
u_init_val = np.array(11.0)
u_init = constant_op.constant(u_init_val)
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
# "v" is the increment.
v_name = "testDumpToFileWhileLoop/v"
@@ -447,7 +447,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
v_init_val = np.array(2.0)
v_init = constant_op.constant(v_init_val)
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
u.initializer.run()
v.initializer.run()
@@ -605,8 +605,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugCondWatchingWholeGraphWorks(self):
with session.Session() as sess:
- x = variables.Variable(10.0, name="x")
- y = variables.Variable(20.0, name="y")
+ x = variables.VariableV1(10.0, name="x")
+ y = variables.VariableV1(20.0, name="y")
cond = control_flow_ops.cond(
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
@@ -628,9 +628,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
z_name = "testFindNodesWithBadTensorValues/z"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant([2.0, 1.0])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
# Expected output: [0.0, 3.0]
w = math_ops.subtract(u, v, name=w_name)
@@ -679,9 +679,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
z_name = "testFindInfOrNanWithOpNameExclusion/z"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant([2.0, 1.0])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
# Expected output: [0.0, 3.0]
w = math_ops.subtract(u, v, name=w_name)
@@ -725,7 +725,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "testDumpGraphStructureLookup/w"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v = math_ops.add(u, u, name=v_name)
w = math_ops.add(v, v, name=w_name)
@@ -859,9 +859,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testGraphPathFindingOnControlEdgesWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- v1 = variables.Variable(1.0, name="v1")
- v2 = variables.Variable(2.0, name="v2")
- v3 = variables.Variable(3.0, name="v3")
+ v1 = variables.VariableV1(1.0, name="v1")
+ v2 = variables.VariableV1(2.0, name="v2")
+ v3 = variables.VariableV1(3.0, name="v3")
a = math_ops.add(v1, v2, name="a")
with ops.control_dependencies([a]):
c = math_ops.subtract(v3, v3, name="c")
@@ -875,8 +875,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testGraphPathFindingReverseRefEdgeWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- v = variables.Variable(10.0, name="v")
- delta = variables.Variable(1.0, name="delta")
+ v = variables.VariableV1(10.0, name="v")
+ delta = variables.VariableV1(1.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(variables.global_variables_initializer())
@@ -894,7 +894,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "testDumpCausalityCheck/w"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v = math_ops.add(u, u, name=v_name)
w = math_ops.add(v, v, name=w_name)
@@ -980,7 +980,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "oneOfTwoSlots/w"
y_name = "oneOfTwoSlots/y"
- x = variables.Variable([1, 3, 3, 7], dtype=dtypes.int32, name=x_name)
+ x = variables.VariableV1([1, 3, 3, 7], dtype=dtypes.int32, name=x_name)
sess.run(x.initializer)
unique_x, indices, _ = array_ops.unique_with_counts(x, name=u_name)
@@ -1039,9 +1039,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
with session.Session(config=no_rewrite_session_config()) as sess:
u_init = constant_op.constant(10.0)
- u = variables.Variable(u_init, name="gdo/u")
+ u = variables.VariableV1(u_init, name="gdo/u")
v_init = constant_op.constant(20.0)
- v = variables.Variable(v_init, name="gdo/v")
+ v = variables.VariableV1(v_init, name="gdo/v")
w = math_ops.multiply(u, v, name="gdo/w")
# gdo stands for GradientDescentOptimizer.
@@ -1085,7 +1085,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
with session.Session() as sess:
x_init = constant_op.constant([2, 2, 3, 5, 5])
- x = variables.Variable(x_init, name="unconnected/x")
+ x = variables.VariableV1(x_init, name="unconnected/x")
# The UniqueOp (tf.unique) has two output slots. Use only slot 0 in the
# graph. Let the debugger watch the unused slot 1.
@@ -1225,14 +1225,14 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryOnInitializedTensorGivesCorrectResult(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(
+ a = variables.VariableV1(
[
np.nan, np.nan, 0.0, 0.0, 0.0, -1.0, -3.0, 3.0, 7.0, -np.inf,
-np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.nan, np.nan
],
dtype=np.float32,
name="numeric_summary/a")
- b = variables.Variable(
+ b = variables.VariableV1(
[0.0] * 18, dtype=np.float32, name="numeric_summary/b")
c = math_ops.add(a, b, name="numeric_summary/c")
@@ -1249,7 +1249,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryOnUninitializedTensorGivesCorrectResult(self):
with session.Session() as sess:
- a = variables.Variable(
+ a = variables.VariableV1(
[42], dtype=np.float32, name="numeric_summary_uninit/a")
_, dump = self._debug_run_and_get_dump(
@@ -1275,9 +1275,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryFailureIsToleratedWhenOrdered(self):
with session.Session() as sess:
- a = variables.Variable("1", name="a")
- b = variables.Variable("3", name="b")
- c = variables.Variable("2", name="c")
+ a = variables.VariableV1("1", name="a")
+ b = variables.VariableV1("3", name="b")
+ c = variables.VariableV1("2", name="c")
d = math_ops.add(a, b, name="d")
e = math_ops.add(d, c, name="e")
@@ -1313,9 +1313,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryInvalidAttributesStringAreCaught(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(10.0, name="a")
- b = variables.Variable(0.0, name="b")
- c = variables.Variable(0.0, name="c")
+ a = variables.VariableV1(10.0, name="a")
+ b = variables.VariableV1(0.0, name="b")
+ c = variables.VariableV1(0.0, name="c")
x = math_ops.divide(a, b, name="x")
y = math_ops.multiply(x, c, name="y")
@@ -1361,9 +1361,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryMuteOnHealthyMutesOnlyHealthyTensorDumps(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(10.0, name="a")
- b = variables.Variable(0.0, name="b")
- c = variables.Variable(0.0, name="c")
+ a = variables.VariableV1(10.0, name="a")
+ b = variables.VariableV1(0.0, name="b")
+ c = variables.VariableV1(0.0, name="c")
x = math_ops.divide(a, b, name="x")
y = math_ops.multiply(x, c, name="y")
@@ -1396,8 +1396,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryMuteOnHealthyAndCustomBoundsWork(self):
with session.Session() as sess:
- a = variables.Variable([10.0, 10.0], name="a")
- b = variables.Variable([10.0, 2.0], name="b")
+ a = variables.VariableV1([10.0, 10.0], name="a")
+ b = variables.VariableV1([10.0, 2.0], name="b")
x = math_ops.add(a, b, name="x") # [20.0, 12.0]
y = math_ops.divide(x, b, name="y") # [2.0, 6.0]
@@ -1436,9 +1436,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testLookUpNodePythonTracebackWorks(self):
with session.Session() as sess:
u_init = constant_op.constant(10.0)
- u = variables.Variable(u_init, name="traceback/u")
+ u = variables.VariableV1(u_init, name="traceback/u")
v_init = constant_op.constant(20.0)
- v = variables.Variable(v_init, name="traceback/v")
+ v = variables.VariableV1(v_init, name="traceback/v")
w = math_ops.multiply(u, v, name="traceback/w")
@@ -1487,7 +1487,7 @@ class DebugConcurrentRunCallsTest(test_util.TensorFlowTestCase):
self.skipTest("No testing concurrent runs on a single GPU.")
with session.Session() as sess:
- v = variables.Variable(30.0, name="v")
+ v = variables.VariableV1(30.0, name="v")
constants = []
for i in xrange(self._num_concurrent_runs):
constants.append(constant_op.constant(1.0, name="c%d" % i))
diff --git a/tensorflow/python/debug/lib/stepper_test.py b/tensorflow/python/debug/lib/stepper_test.py
index 9a3d0efabf..3839c67198 100644
--- a/tensorflow/python/debug/lib/stepper_test.py
+++ b/tensorflow/python/debug/lib/stepper_test.py
@@ -36,8 +36,8 @@ from tensorflow.python.training import gradient_descent
class StepperTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.a = variables.Variable(2.0, name="a")
- self.b = variables.Variable(3.0, name="b")
+ self.a = variables.VariableV1(2.0, name="a")
+ self.b = variables.VariableV1(3.0, name="b")
self.c = math_ops.multiply(self.a, self.b, name="c") # Should be 6.0.
self.d = math_ops.multiply(self.a, self.a, name="d") # Should be 4.0.
@@ -49,7 +49,7 @@ class StepperTest(test_util.TensorFlowTestCase):
# The there nodes x, y and z form a graph with "cross-links" in. I.e., x
# and y are both direct inputs to z, but x is also a direct input to y.
- self.x = variables.Variable(2.0, name="x") # Should be 2.0
+ self.x = variables.VariableV1(2.0, name="x") # Should be 2.0
self.y = math_ops.negative(self.x, name="y") # Should be -2.0.
self.z = math_ops.multiply(self.x, self.y, name="z") # Should be -4.0.
@@ -580,7 +580,7 @@ class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
class StepperAssignAddTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.v = variables.Variable(10.0, name="v")
+ self.v = variables.VariableV1(10.0, name="v")
self.p = math_ops.add(self.v, self.v, name="p")
self.q = math_ops.multiply(self.p, self.p, name="q")
self.delta = constant_op.constant(2.0, name="delta")
@@ -711,9 +711,9 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
Construct a backward graph using the GradientDescentOptimizer.
"""
- self.a = variables.Variable(1.0, name="a")
- self.b = variables.Variable(2.0, name="b")
- self.c = variables.Variable(4.0, name="c")
+ self.a = variables.VariableV1(1.0, name="a")
+ self.b = variables.VariableV1(2.0, name="b")
+ self.c = variables.VariableV1(4.0, name="c")
self.d = math_ops.multiply(self.a, self.b, name="d")
self.e = math_ops.multiply(self.b, self.c, name="e")
self.f = math_ops.multiply(self.d, self.e, name="f")
diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
index 254201c393..11011a5c13 100644
--- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
@@ -46,7 +46,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def setUp(self):
self.session_root = tempfile.mkdtemp()
- self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v")
+ self.v = variables.VariableV1(10.0, dtype=dtypes.float32, name="v")
self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index 05c9eaa4d2..149a7497df 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -132,8 +132,8 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def setUp(self):
self._tmp_dir = tempfile.mktemp()
- self.v = variables.Variable(10.0, name="v")
- self.w = variables.Variable(21.0, name="w")
+ self.v = variables.VariableV1(10.0, name="v")
+ self.w = variables.VariableV1(21.0, name="w")
self.delta = constant_op.constant(1.0, name="delta")
self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
@@ -358,7 +358,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testDebuggingMakeCallableTensorRunnerWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
- v = variables.Variable(42)
+ v = variables.VariableV1(42)
tensor_runner = wrapped_sess.make_callable(v)
self.sess.run(v.initializer)
@@ -382,7 +382,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testDebuggingMakeCallableOperationRunnerWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
- v = variables.Variable(10.0)
+ v = variables.VariableV1(10.0)
inc_v = state_ops.assign_add(v, 1.0)
op_runner = wrapped_sess.make_callable(inc_v.op)
self.sess.run(v.initializer)
@@ -403,7 +403,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self):
- variable_1 = variables.Variable(
+ variable_1 = variables.VariableV1(
10.5, dtype=dtypes.float32, name="variable_1")
a = math_ops.add(variable_1, variable_1, "callable_a")
math_ops.add(a, a, "callable_b")
@@ -480,7 +480,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertItemsEqual(["callable_a", "callable_b"], node_names)
def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
- variable_1 = variables.Variable(
+ variable_1 = variables.VariableV1(
10.5, dtype=dtypes.float32, name="variable_1")
a = math_ops.add(variable_1, variable_1, "callable_a")
math_ops.add(a, a, "callable_b")
@@ -528,7 +528,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testRuntimeErrorBeforeGraphExecutionIsRaised(self):
# Use an impossible device name to cause an error before graph execution.
with ops.device("/device:GPU:1337"):
- w = variables.Variable([1.0] * 10, name="w")
+ w = variables.VariableV1([1.0] * 10, name="w")
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"]], self.sess, dump_root=self._tmp_dir)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index bd3562f1ff..b9b77d4a5b 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -126,7 +126,7 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
- session_config: an optional @{tf.ConfigProto} object.
+ session_config: an optional `tf.ConfigProto` object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
@@ -685,7 +685,7 @@ def run_distribute_coordinator(worker_fn,
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
- session_config: an optional @{tf.ConfigProto} object which will be passed
+ session_config: an optional `tf.ConfigProto` object which will be passed
to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index 8daa34c885..0289689134 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -62,7 +62,7 @@ def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
# and "ps". More details can be found at the documentation of
- # @{tf.estimator.RunConfig.global_id_in_cluster}.
+ # `tf.estimator.RunConfig.global_id_in_cluster`.
task_type_ordered_list = []
if chief_task_type in cluster_spec.jobs:
task_type_ordered_list = [chief_task_type]
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 78f3198011..deac29111f 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -619,7 +619,7 @@ pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):
"""If x is ResourceVariable, return its handle, else x."""
- if isinstance(x, resource_variable_ops.ResourceVariable):
+ if resource_variable_ops.is_resource_variable(x):
x = x.handle
return x
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index b28befeb62..dd3e1a3723 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1328,8 +1328,25 @@ def register(func, *args, **kwargs):
"Got type: %s" % type(func))
concrete_func = func.get_concrete_function(*args, **kwargs)
graph = ops.get_default_graph()
- concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access
- # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+
+ # There are two situations for the actual call of a defun:
+ # 1. If none of the input args are resource variables or watch by any tape,
+ # it will run the _inference_function of concrete_func for forward pass, and
+ # the gradient will be generated by standard mechanism.
+ # 2. Otherwise, defun will create two functions, one for forward pass, and the
+ # backward pass will be created via tape.
+ # When registering the function, we put both cases into graph.
+ # pylint: disable=protected-access
+ concrete_func._inference_function.add_to_graph(graph)
+
+ if concrete_func._backward_graph_function is None:
+ concrete_func._construct_backprop_function()
+ forward_function = concrete_func._forward_function
+ backward_function = concrete_func._backward_graph_function._inference_function
+ forward_function.add_to_graph(graph)
+ backward_function.add_to_graph(graph)
+ # pylint: enable=protected-access
+
return concrete_func
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 59faf967c5..34a2648e26 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1669,12 +1669,23 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 2)
+ self.assertEqual(len(graph._functions), 6)
+ # two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
- pre_register_matmul_func_name = functions[0].definition.signature.name
- self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*')
- pre_register_add_func_name = functions[1].definition.signature.name
- self.assertRegexpMatches(pre_register_add_func_name, '.*add.*')
+ captured_function_names = [
+ f.definition.signature.name for f in functions
+ ]
+ expected_func_name_regex = [
+ '.*inference.*matmul.*',
+ '.*forward.*matmul.*',
+ '.*inference.*backward.*matmul.*',
+ '.*inference.*add.*',
+ '.*forward.*add.*',
+ '.*inference.*backward.*add.*',
+ ]
+ for i in range(len(functions)):
+ self.assertRegexpMatches(captured_function_names[i],
+ expected_func_name_regex[i])
sq = defun_matmul(t, t)
double = add(t, t)
@@ -1682,12 +1693,11 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
# Make sure the pre registered function is used, and no other function
# is added.
- self.assertEqual(len(graph._functions), 2)
+ self.assertEqual(len(graph._functions), 6)
functions = list(graph._functions.values())
- called_func_name = functions[0].definition.signature.name
- self.assertEqual(pre_register_matmul_func_name, called_func_name)
- called_func_name = functions[1].definition.signature.name
- self.assertEqual(pre_register_add_func_name, called_func_name)
+ for i in range(len(functions)):
+ self.assertEquals(captured_function_names[i],
+ functions[i].definition.signature.name)
def testRegisterFunctionWithInputSignature(self):
def matmul(x, y):
@@ -1705,7 +1715,7 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 1)
+ self.assertEqual(len(graph._functions), 3)
# Test input param shape mismatch
t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@@ -1728,7 +1738,7 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# Only one function is registered since the input param are in same type
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 1)
+ self.assertEqual(len(graph._functions), 3)
def testCallingFunctionWithDifferentVariables(self):
@@ -1767,7 +1777,8 @@ class FunctionTest(test.TestCase):
'be Tensors;.*'):
graph_function('Not a Tensor.')
- def testSwapImplementationWithGrapplerPlugin(self):
+ # TODO(scottzhu): Revive the test once the grappler plugin is updated.
+ def disabled_testSwapImplementationWithGrapplerPlugin(self):
rewrites = rewriter_config_pb2.RewriterConfig()
# function_optimizer has to be turn off, otherwise it will delete the
# registered function if it does not get called.
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 7f2349954d..1c4c5951df 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -281,6 +281,7 @@ py_library(
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
@@ -303,6 +304,7 @@ py_test(
":pandas_io",
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
@@ -342,6 +344,7 @@ py_test(
":pandas_io",
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 756d32d03f..0278990cfc 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -21,6 +21,9 @@ import abc
import collections
import functools
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import boosted_trees_utils
@@ -40,6 +43,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.array_ops import identity as tf_identity
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -193,6 +197,43 @@ def _calculate_num_features(sorted_feature_columns):
return num_features
+def _generate_feature_name_mapping(sorted_feature_columns):
+ """Return a list of feature name for feature ids.
+
+ Args:
+ sorted_feature_columns: a list/set of tf.feature_column sorted by name.
+
+ Returns:
+ feature_name_mapping: a list of feature names indexed by the feature ids.
+
+ Raises:
+ ValueError: when unsupported features/columns are tried.
+ """
+ names = []
+ for column in sorted_feature_columns:
+ if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access
+ categorical_column = column.categorical_column
+ if isinstance(categorical_column,
+ feature_column_lib._VocabularyListCategoricalColumn): # pylint:disable=protected-access
+ for value in categorical_column.vocabulary_list:
+ names.append('{}:{}'.format(column.name, value))
+ elif isinstance(categorical_column,
+ feature_column_lib._BucketizedColumn): # pylint:disable=protected-access
+ boundaries = [-np.inf] + list(categorical_column.boundaries) + [np.inf]
+ for pair in zip(boundaries[:-1], boundaries[1:]):
+ names.append('{}:{}'.format(column.name, pair))
+ else:
+ for num in range(categorical_column._num_buckets): # pylint:disable=protected-access
+ names.append('{}:{}'.format(column.name, num))
+ elif isinstance(column, feature_column_lib._BucketizedColumn):
+ names.append(column.name)
+ else:
+ raise ValueError(
+ 'For now, only bucketized_column and indicator_column is supported '
+ 'but got: {}'.format(column))
+ return names
+
+
def _cache_transformed_features(features, sorted_feature_columns, batch_size):
"""Transform features and cache, then returns (cached_features, cache_op)."""
num_features = _calculate_num_features(sorted_feature_columns)
@@ -966,6 +1007,60 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
+def _compute_feature_importances_per_tree(tree, num_features):
+ """Computes the importance of each feature in the tree."""
+ importances = np.zeros(num_features)
+
+ for node in tree.nodes:
+ node_type = node.WhichOneof('node')
+ if node_type == 'bucketized_split':
+ feature_id = node.bucketized_split.feature_id
+ importances[feature_id] += node.metadata.gain
+ elif node_type == 'leaf':
+ assert node.metadata.gain == 0
+ else:
+ raise ValueError('Unexpected split type %s', node_type)
+
+ return importances
+
+
+def _compute_feature_importances(tree_ensemble, num_features, normalize):
+ """Computes gain-based feature importances.
+
+ The higher the value, the more important the feature.
+
+ Args:
+ tree_ensemble: a trained tree ensemble, instance of proto
+ boosted_trees.TreeEnsemble.
+ num_features: The total number of feature ids.
+ normalize: If True, normalize the feature importances.
+
+ Returns:
+ sorted_feature_idx: A list of feature_id which is sorted
+ by its feature importance.
+ feature_importances: A list of corresponding feature importances.
+
+ Raises:
+ AssertionError: When normalize = True, if feature importances
+ contain negative value, or if normalization is not possible
+ (e.g. ensemble is empty or trees contain only a root node).
+ """
+ tree_importances = [_compute_feature_importances_per_tree(tree, num_features)
+ for tree in tree_ensemble.trees]
+ tree_importances = np.array(tree_importances)
+ tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1)
+ feature_importances = np.sum(tree_importances * tree_weights, axis=0)
+ if normalize:
+ assert np.all(feature_importances >= 0), ('feature_importances '
+ 'must be non-negative.')
+ normalizer = np.sum(feature_importances)
+ assert normalizer > 0, 'Trees are all empty or contain only a root node.'
+ feature_importances /= normalizer
+
+ sorted_feature_idx = np.argsort(feature_importances)[::-1]
+ return sorted_feature_idx, feature_importances[sorted_feature_idx]
+
+
def _bt_explanations_fn(features,
head,
sorted_feature_columns,
@@ -1053,9 +1148,41 @@ class _BoostedTreesBase(estimator.Estimator):
feature_columns, key=lambda tc: tc.name)
self._head = head
self._n_features = _calculate_num_features(self._sorted_feature_columns)
+ self._names_for_feature_id = np.array(
+ _generate_feature_name_mapping(self._sorted_feature_columns))
self._center_bias = center_bias
self._is_classification = is_classification
+ def experimental_feature_importances(self, normalize=False):
+ """Computes gain-based feature importances.
+
+ The higher the value, the more important the corresponding feature.
+
+ Args:
+ normalize: If True, normalize the feature importances.
+
+ Returns:
+ sorted_feature_names: 1-D array of feature name which is sorted
+ by its feature importance.
+ feature_importances: 1-D array of the corresponding feature importance.
+
+ Raises:
+ ValueError: When attempting to normalize on an empty ensemble
+ or an ensemble of trees which have no splits. Or when attempting
+ to normalize and feature importances have negative values.
+ """
+ reader = checkpoint_utils.load_checkpoint(self._model_dir)
+ serialized = reader.get_tensor('boosted_trees:0_serialized')
+ if not serialized:
+ raise ValueError('Found empty serialized string for TreeEnsemble.'
+ 'You should only call this method after training.')
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+
+ sorted_feature_id, importances = _compute_feature_importances(
+ ensemble_proto, self._n_features, normalize)
+ return self._names_for_feature_id[sorted_feature_id], importances
+
def experimental_predict_with_explanations(self,
input_fn,
predict_keys=None,
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index d4cb3e27d0..23687a738b 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -17,9 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
+from google.protobuf import text_format
import numpy as np
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator import run_config
@@ -31,10 +35,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import boosted_trees_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
NUM_FEATURES = 3
@@ -564,6 +570,535 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+ def testFeatureImportancesWithTrainedEnsemble(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ # It will stop after 5 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.833933, 0.606342, 0.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.579010, 0.420990, 0.0], importances)
+
+ def testFeatureImportancesOnEmptyEnsemble(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ class BailOutWithoutTraining(session_run_hook.SessionRunHook):
+
+ def before_run(self, run_context):
+ raise StopIteration('to bail out.')
+
+ # The step-0 checkpoint will have only an empty ensemble.
+ est.train(input_fn,
+ steps=100, # must stop at 0 anyway.
+ hooks=[BailOutWithoutTraining()])
+
+ with self.assertRaisesRegexp(ValueError, 'empty serialized string'):
+ est.experimental_feature_importances(normalize=False)
+
+ with self.assertRaisesRegexp(ValueError, 'empty serialized string'):
+ est.experimental_feature_importances(normalize=True)
+
+ def _create_fake_checkpoint_with_tree_ensemble_proto(self,
+ est,
+ tree_ensemble_text):
+ with ops.Graph().as_default():
+ with ops.name_scope('boosted_trees') as name:
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+ tree_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(tree_ensemble_text, tree_ensemble_proto)
+ stamp_token, _ = tree_ensemble.serialize()
+ restore_op = tree_ensemble.deserialize(
+ stamp_token, tree_ensemble_proto.SerializeToString())
+
+ with session.Session() as sess:
+ resources.initialize_resources(resources.shared_resources()).run()
+ restore_op.run()
+ saver = saver_lib.Saver()
+ save_path = os.path.join(est.model_dir, 'model.ckpt')
+ saver.save(sess, save_path)
+
+ def testFeatureImportancesOnNonEmptyEnsemble(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 3.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 7
+ right_id: 8
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 1.0 * [3 + 1, 2, 2] + 1.0 * [1, 1, 0]
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
+ def testFeatureImportancesWithTreeWeights(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=3,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 12.5
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 0.4
+ tree_weights: 0.6
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 0.4 * [12.5, 0, 5] + 0.6 * [0, 5, 0] + 1.0 * [0, 0, 0]
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
+ def testFeatureImportancesWithAllEmptyTree(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ # Reverse order because feature importances are sorted by np.argsort(f)[::-1]
+ feature_names_expected = ['f_2_bucketized',
+ 'f_1_bucketized',
+ 'f_0_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, 0.0], importances)
+
+ with self.assertRaisesRegexp(AssertionError,
+ 'all empty or contain only a root node'):
+ est.experimental_feature_importances(normalize=True)
+
+ def testNegativeFeatureImportances(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ # In order to generate a negative feature importances,
+ # We assign an invalid value -1 to tree_weights here.
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: -1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ # Github #21509 (nataliaponomareva):
+ # The gains stored in the splits can be negative
+ # if people are using complexity regularization.
+ feature_names_expected = ['f_2_bucketized',
+ 'f_0_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, -5.0], importances)
+
+ with self.assertRaisesRegexp(AssertionError, 'non-negative'):
+ est.experimental_feature_importances(normalize=True)
+
+ def testFeatureImportancesNamesForCategoricalColumn(self):
+ categorical = feature_column.categorical_column_with_vocabulary_list(
+ key='categorical', vocabulary_list=('bad', 'good', 'ok'))
+ feature_indicator = feature_column.indicator_column(categorical)
+ bucketized_col = feature_column.bucketized_column(
+ feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ bucketized_indicator = feature_column.indicator_column(bucketized_col)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[feature_indicator,
+ bucketized_col,
+ bucketized_indicator],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 5
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -2.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 4.34
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['categorical_indicator:ok',
+ 'continuous_bucketized_indicator:(-2.0, 0.5)',
+ 'continuous_bucketized_indicator:(-inf, -2.0)',
+ 'categorical_indicator:bad',
+ # Reverse order because feature importances
+ # are sorted by np.argsort(f)[::-1]
+ 'continuous_bucketized_indicator:(12.0, inf)',
+ 'continuous_bucketized_indicator:(0.5, 12.0)',
+ 'continuous_bucketized',
+ 'categorical_indicator:good']
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 1.0 * [5, 0, 2, 0, 0, 0, 0, 0] + 1.0 * [0, 2, 0, 1, 0, 0, 0, 0]
+ self.assertAllClose([5.0, 2.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.2, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0], importances)
+
+ def testFeatureImportancesNamesForUnsupportedColumn(self):
+ numeric_col = feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'only bucketized_column and indicator_column'):
+ _ = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[numeric_col],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
+
def testTreeComplexityIsSetCorrectly(self):
input_fn = _make_train_input_fn(is_classification=True)
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 1c0c4581c0..a6c2aaa7d9 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -24,7 +24,10 @@ from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.engine import training
from tensorflow.python.layers import core as core_layers
from tensorflow.python.layers import normalization
from tensorflow.python.ops import init_ops
@@ -45,8 +48,14 @@ def _add_hidden_layer_summary(value, tag):
summary.histogram('%s/activation' % tag, value)
-def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
- dropout, input_layer_partitioner, batch_norm):
+def _dnn_logit_fn_builder(units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager=None):
"""Function builder for a dnn logit_fn.
Args:
@@ -60,6 +69,8 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
coordinate.
input_layer_partitioner: Partitioner for input layer.
batch_norm: Whether to use batch normalization after each hidden layer.
+ shared_state_manager: A SharedEmbeddingStateManager object to hold the
+ shared state for SharedEmbeddingColumn's.
Returns:
A logit_fn (see below).
@@ -85,50 +96,129 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
A `Tensor` representing the logits, or a list of `Tensor`'s representing
multiple logits in the MultiHead case.
"""
- is_training = mode == model_fn.ModeKeys.TRAIN
- with variable_scope.variable_scope(
- 'input_from_feature_columns',
- values=tuple(six.itervalues(features)),
- partitioner=input_layer_partitioner):
- net = feature_column_lib.input_layer(
- features=features, feature_columns=feature_columns)
+ dnn_model = _DNNModel(
+ units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager,
+ name='dnn')
+ return dnn_model(features, mode)
+
+ return dnn_logit_fn
+
+
+def _get_previous_name_scope():
+ current_name_scope = ops.get_name_scope()
+ return current_name_scope.rsplit('/', 1)[0] + '/'
+
+
+class _DNNModel(training.Model):
+ """A DNN Model."""
+
+ def __init__(self,
+ units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager,
+ name=None,
+ **kwargs):
+ super(_DNNModel, self).__init__(name=name, **kwargs)
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ self._input_layer = feature_column_v2.FeatureLayer(
+ feature_columns=feature_columns,
+ name='input_layer',
+ shared_state_manager=shared_state_manager)
+ else:
+ self._input_layer = feature_column.InputLayer(
+ feature_columns=feature_columns,
+ name='input_layer',
+ create_scope_now=False)
+
+ self._add_layer(self._input_layer, 'input_layer')
+
+ self._dropout = dropout
+ self._batch_norm = batch_norm
+
+ self._hidden_layers = []
+ self._dropout_layers = []
+ self._batch_norm_layers = []
+ self._hidden_layer_scope_names = []
for layer_id, num_hidden_units in enumerate(hidden_units):
with variable_scope.variable_scope(
- 'hiddenlayer_%d' % layer_id, values=(net,)) as hidden_layer_scope:
- net = core_layers.dense(
- net,
+ 'hiddenlayer_%d' % layer_id) as hidden_layer_scope:
+ hidden_layer = core_layers.Dense(
units=num_hidden_units,
activation=activation_fn,
kernel_initializer=init_ops.glorot_uniform_initializer(),
- name=hidden_layer_scope)
- if dropout is not None and is_training:
- net = core_layers.dropout(net, rate=dropout, training=True)
- if batch_norm:
- # TODO(hjm): In future, if this becomes popular, we can enable
- # customization of the batch normalization params by accepting a
- # list of `BatchNormalization` instances as `batch_norm`.
- net = normalization.batch_normalization(
- net,
+ name=hidden_layer_scope,
+ _scope=hidden_layer_scope)
+ self._add_layer(hidden_layer, hidden_layer_scope.name)
+ self._hidden_layer_scope_names.append(hidden_layer_scope.name)
+ self._hidden_layers.append(hidden_layer)
+ if self._dropout is not None:
+ dropout_layer = core_layers.Dropout(rate=self._dropout)
+ self._add_layer(dropout_layer, dropout_layer.name)
+ self._dropout_layers.append(dropout_layer)
+ if self._batch_norm:
+ batch_norm_layer = normalization.BatchNormalization(
# The default momentum 0.99 actually crashes on certain
# problem, so here we use 0.999, which is the default of
# tf.contrib.layers.batch_norm.
momentum=0.999,
- training=is_training,
- name='batchnorm_%d' % layer_id)
- _add_hidden_layer_summary(net, hidden_layer_scope.name)
-
- with variable_scope.variable_scope('logits', values=(net,)) as logits_scope:
- logits = core_layers.dense(
- net,
+ trainable=True,
+ name='batchnorm_%d' % layer_id,
+ _scope='batchnorm_%d' % layer_id)
+ self._add_layer(batch_norm_layer, batch_norm_layer.name)
+ self._batch_norm_layers.append(batch_norm_layer)
+
+ with variable_scope.variable_scope('logits') as logits_scope:
+ self._logits_layer = core_layers.Dense(
units=units,
activation=None,
kernel_initializer=init_ops.glorot_uniform_initializer(),
- name=logits_scope)
- _add_hidden_layer_summary(logits, logits_scope.name)
-
- return logits
+ name=logits_scope,
+ _scope=logits_scope)
+ self._add_layer(self._logits_layer, logits_scope.name)
+ self._logits_scope_name = logits_scope.name
+ self._input_layer_partitioner = input_layer_partitioner
- return dnn_logit_fn
+ def call(self, features, mode):
+ is_training = mode == model_fn.ModeKeys.TRAIN
+ # The Keras training.Model adds a name_scope with the name of the model
+ # which modifies the constructed graph. Hence we add another name_scope
+ # here which is the one before the training.Model one was applied.
+ # TODO(rohanj): Remove this in TF 2.0 (b/116728605)
+ with ops.name_scope(name=_get_previous_name_scope()):
+ # TODO(rohanj): Remove dependence on variable scope for partitioning.
+ with variable_scope.variable_scope(
+ 'input_from_feature_columns',
+ partitioner=self._input_layer_partitioner):
+ net = self._input_layer(features)
+ for i in range(len(self._hidden_layers)):
+ net = self._hidden_layers[i](net)
+ if self._dropout is not None and is_training:
+ net = self._dropout_layers[i](net, training=True)
+ if self._batch_norm:
+ net = self._batch_norm_layers[i](net, training=is_training)
+ _add_hidden_layer_summary(net, self._hidden_layer_scope_names[i])
+
+ logits = self._logits_layer(net)
+ _add_hidden_layer_summary(logits, self._logits_scope_name)
+ return logits
+
+ def _add_layer(self, layer, layer_name):
+ # "Magic" required for keras.Model classes to track all the variables in
+ # a list of layers.Layer objects.
+ # TODO(ashankar): Figure out API so user code doesn't have to do this.
+ setattr(self, layer_name, layer)
def _dnn_model_fn(features,
@@ -143,7 +233,8 @@ def _dnn_model_fn(features,
input_layer_partitioner=None,
config=None,
use_tpu=False,
- batch_norm=False):
+ batch_norm=False,
+ shared_state_manager=None):
"""Deep Neural Net model_fn.
Args:
@@ -167,6 +258,8 @@ def _dnn_model_fn(features,
use_tpu: Whether to make a DNN model able to run on TPU. Will make function
return a `_TPUEstimatorSpec` instance and disable variable partitioning.
batch_norm: Whether to use batch normalization after each hidden layer.
+ shared_state_manager: A SharedEmbeddingStateManager object to hold the
+ shared state for SharedEmbeddingColumn's.
Returns:
An `EstimatorSpec` instance.
@@ -202,7 +295,8 @@ def _dnn_model_fn(features,
activation_fn=activation_fn,
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
logits = logit_fn(features=features, mode=mode)
if use_tpu:
@@ -370,6 +464,10 @@ class DNNClassifier(estimator.Estimator):
"""
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
n_classes, weight_column, label_vocabulary, loss_reduction)
+
+ shared_state_manager = feature_column_v2.maybe_create_shared_state_manager(
+ feature_columns)
+
def _model_fn(features, labels, mode, config):
"""Call the defined shared _dnn_model_fn."""
return _dnn_model_fn(
@@ -384,7 +482,8 @@ class DNNClassifier(estimator.Estimator):
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
super(DNNClassifier, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
@@ -532,6 +631,10 @@ class DNNRegressor(estimator.Estimator):
batch_norm: Whether to use batch normalization after each hidden layer.
"""
+ shared_state_manager = None
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+
def _model_fn(features, labels, mode, config):
"""Call the defined shared _dnn_model_fn."""
return _dnn_model_fn(
@@ -539,7 +642,8 @@ class DNNRegressor(estimator.Estimator):
labels=labels,
mode=mode,
head=head_lib._regression_head( # pylint: disable=protected-access
- label_dimension=label_dimension, weight_column=weight_column,
+ label_dimension=label_dimension,
+ weight_column=weight_column,
loss_reduction=loss_reduction),
hidden_units=hidden_units,
feature_columns=tuple(feature_columns or []),
@@ -548,7 +652,8 @@ class DNNRegressor(estimator.Estimator):
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
super(DNNRegressor, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 9799cf9e98..f712244c8d 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -27,6 +27,7 @@ from tensorflow.python.estimator.canned import dnn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
@@ -142,6 +143,9 @@ def _dnn_linear_combined_model_fn(features,
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
+ shared_state_manager = feature_column_v2.maybe_create_shared_state_manager(
+ list(linear_feature_columns) + list(dnn_feature_columns))
+
# Build DNN Logits.
dnn_parent_scope = 'dnn'
@@ -169,8 +173,9 @@ def _dnn_linear_combined_model_fn(features,
feature_columns=dnn_feature_columns,
activation_fn=dnn_activation_fn,
dropout=dnn_dropout,
+ batch_norm=batch_norm,
input_layer_partitioner=input_layer_partitioner,
- batch_norm=batch_norm)
+ shared_state_manager=shared_state_manager)
dnn_logits = dnn_logit_fn(features=features, mode=mode)
linear_parent_scope = 'linear'
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
index d16318659b..ae968e717a 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
@@ -35,6 +36,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
@@ -119,7 +121,16 @@ class LinearOnlyRegressorPartitionerTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorPartitionerV2Test(
+ linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorEvaluationTest(
@@ -128,7 +139,16 @@ class LinearOnlyRegressorEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorEvaluationV2Test(
+ linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorPredictTest(
@@ -137,7 +157,16 @@ class LinearOnlyRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorPredictV2Test(
+ linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorIntegrationTest(
@@ -146,7 +175,16 @@ class LinearOnlyRegressorIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorIntegrationV2Test(
+ linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorTrainingTest(
@@ -155,7 +193,16 @@ class LinearOnlyRegressorTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorTrainingV2Test(
+ linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
def _linear_classifier_fn(feature_columns,
@@ -185,7 +232,18 @@ class LinearOnlyClassifierTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierTrainingV2Test(
+ linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierClassesEvaluationTest(
@@ -194,7 +252,18 @@ class LinearOnlyClassifierClassesEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierClassesEvaluationV2Test(
+ linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierPredictTest(
@@ -203,7 +272,18 @@ class LinearOnlyClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierPredictV2Test(
+ linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierIntegrationTest(
@@ -212,9 +292,21 @@ class LinearOnlyClassifierIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierIntegrationV2Test(
+ linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
def setUp(self):
@@ -225,13 +317,15 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- label_dimension, batch_size):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size,
+ fc_impl):
linear_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
dnn_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
feature_columns = linear_feature_columns + dnn_feature_columns
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=linear_feature_columns,
@@ -257,14 +351,14 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
@@ -293,9 +387,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -326,9 +421,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
label_dimension = 2
batch_size = 10
@@ -376,7 +472,8 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
# A function to mimic dnn-classifier init reuse same tests.
@@ -407,7 +504,16 @@ class DNNOnlyClassifierEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierEvaluateV2Test(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNOnlyClassifierPredictTest(
@@ -416,7 +522,16 @@ class DNNOnlyClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierPredictV2Test(
+ dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNOnlyClassifierTrainTest(
@@ -425,7 +540,16 @@ class DNNOnlyClassifierTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
# A function to mimic dnn-regressor init reuse same tests.
@@ -454,7 +578,16 @@ class DNNOnlyRegressorEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNOnlyRegressorEvaluateV2Test(
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNOnlyRegressorPredictTest(
@@ -463,7 +596,16 @@ class DNNOnlyRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNOnlyRegressorPredictV2Test(
+ dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNOnlyRegressorTrainTest(
@@ -472,9 +614,19 @@ class DNNOnlyRegressorTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+class DNNOnlyRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
+
+
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def setUp(self):
@@ -488,13 +640,14 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- n_classes, batch_size):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size, fc_impl):
linear_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
dnn_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
feature_columns = linear_feature_columns + dnn_feature_columns
est = dnn_linear_combined.DNNLinearCombinedClassifier(
linear_feature_columns=linear_feature_columns,
@@ -520,14 +673,14 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
n_classes = 3
input_dimension = 2
@@ -559,9 +712,10 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -593,9 +747,10 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 3
@@ -647,9 +802,11 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedTests(test.TestCase):
def setUp(self):
@@ -681,9 +838,9 @@ class DNNLinearCombinedTests(test.TestCase):
return optimizer_mock
- def test_train_op_calls_both_dnn_and_linear(self):
+ def test_train_op_calls_both_dnn_and_linear(self, fc_impl):
opt = gradient_descent.GradientDescentOptimizer(1.)
- x_column = feature_column.numeric_column('x')
+ x_column = fc_impl.numeric_column('x')
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[0.], [1.]])},
y=np.array([[0.], [1.]]),
@@ -708,7 +865,7 @@ class DNNLinearCombinedTests(test.TestCase):
checkpoint_utils.load_variable(
self._model_dir, 'dnn_called'))
- def test_dnn_and_linear_logits_are_added(self):
+ def test_dnn_and_linear_logits_are_added(self, fc_impl):
with ops.Graph().as_default():
variables_lib.Variable([[1.0]], name='linear/linear_model/x/weights')
variables_lib.Variable([2.0], name='linear/linear_model/bias_weights')
@@ -719,7 +876,7 @@ class DNNLinearCombinedTests(test.TestCase):
variables_lib.Variable(1, name='global_step', dtype=dtypes.int64)
linear_testing_utils.save_variables_to_ckpt(self._model_dir)
- x_column = feature_column.numeric_column('x')
+ x_column = fc_impl.numeric_column('x')
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=[x_column],
dnn_hidden_units=[1],
@@ -737,6 +894,7 @@ class DNNLinearCombinedTests(test.TestCase):
next(est.predict(input_fn=input_fn)))
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedWarmStartingTest(test.TestCase):
def setUp(self):
@@ -758,11 +916,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._ckpt_and_vocab_dir)
- def test_classifier_basic_warm_starting(self):
+ def test_classifier_basic_warm_starting(self, fc_impl):
"""Tests correctness of DNNLinearCombinedClassifier default warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -798,11 +956,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
dnn_lc_classifier.get_variable_value(variable_name),
warm_started_dnn_lc_classifier.get_variable_value(variable_name))
- def test_regressor_basic_warm_starting(self):
+ def test_regressor_basic_warm_starting(self, fc_impl):
"""Tests correctness of DNNLinearCombinedRegressor default warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -836,11 +994,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
dnn_lc_regressor.get_variable_value(variable_name),
warm_started_dnn_lc_regressor.get_variable_value(variable_name))
- def test_warm_starting_selective_variables(self):
+ def test_warm_starting_selective_variables(self, fc_impl):
"""Tests selecting variables to warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
diff --git a/tensorflow/python/estimator/canned/dnn_test.py b/tensorflow/python/estimator/canned/dnn_test.py
index fc90b7c35e..756696cea0 100644
--- a/tensorflow/python/estimator/canned/dnn_test.py
+++ b/tensorflow/python/estimator/canned/dnn_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
@@ -33,6 +34,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import data_flow_ops
@@ -62,15 +64,32 @@ class DNNModelFnTest(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNModelFnTest.__init__(self, dnn._dnn_model_fn)
+ dnn_testing_utils.BaseDNNModelFnTest.__init__(
+ self, dnn._dnn_model_fn, fc_impl=feature_column)
+
+
+class DNNModelFnV2Test(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNModelFnTest.__init__(
+ self, dnn._dnn_model_fn, fc_impl=feature_column_v2)
class DNNLogitFnTest(dnn_testing_utils.BaseDNNLogitFnTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNLogitFnTest.__init__(self,
- dnn._dnn_logit_fn_builder)
+ dnn_testing_utils.BaseDNNLogitFnTest.__init__(
+ self, dnn._dnn_logit_fn_builder, fc_impl=feature_column)
+
+
+class DNNLogitFnV2Test(dnn_testing_utils.BaseDNNLogitFnTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNLogitFnTest.__init__(
+ self, dnn._dnn_logit_fn_builder, fc_impl=feature_column_v2)
class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
@@ -78,8 +97,17 @@ class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn,
- _dnn_regressor_fn)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
+ self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNWarmStartingV2Test(dnn_testing_utils.BaseDNNWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
+ self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNClassifierEvaluateTest(
@@ -88,7 +116,16 @@ class DNNClassifierEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierEvaluateV2Test(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNClassifierPredictTest(
@@ -97,7 +134,16 @@ class DNNClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierPredictV2Test(dnn_testing_utils.BaseDNNClassifierPredictTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNClassifierTrainTest(
@@ -106,7 +152,16 @@ class DNNClassifierTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
def _dnn_regressor_fn(*args, **kwargs):
@@ -119,7 +174,16 @@ class DNNRegressorEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorEvaluateV2Test(dnn_testing_utils.BaseDNNRegressorEvaluateTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNRegressorPredictTest(
@@ -128,7 +192,16 @@ class DNNRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorPredictV2Test(dnn_testing_utils.BaseDNNRegressorPredictTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNRegressorTrainTest(
@@ -137,7 +210,16 @@ class DNNRegressorTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
def _queue_parsed_features(feature_map):
@@ -156,7 +238,8 @@ def _queue_parsed_features(feature_map):
return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
-class DNNRegressorIntegrationTest(test.TestCase):
+@parameterized.parameters((feature_column,), (feature_column_v2,))
+class DNNRegressorIntegrationTest(test.TestCase, parameterized.TestCase):
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -166,11 +249,11 @@ class DNNRegressorIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- label_dimension, batch_size):
- feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size,
+ fc_impl):
+ feature_columns = [fc_impl.numeric_column('x', shape=(input_dimension,))]
+
est = dnn.DNNRegressor(
hidden_units=(2, 2),
feature_columns=feature_columns,
@@ -194,14 +277,14 @@ class DNNRegressorIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
@@ -230,9 +313,10 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -263,9 +347,10 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
label_dimension = 2
batch_size = 10
@@ -313,9 +398,11 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNClassifierIntegrationTest(test.TestCase):
def setUp(self):
@@ -329,11 +416,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- n_classes, batch_size):
- feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size, fc_impl):
+ feature_columns = [fc_impl.numeric_column('x', shape=(input_dimension,))]
+
est = dnn.DNNClassifier(
hidden_units=(2, 2),
feature_columns=feature_columns,
@@ -357,14 +443,14 @@ class DNNClassifierIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
n_classes = 3
input_dimension = 2
@@ -396,9 +482,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -430,9 +517,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 3
@@ -484,7 +572,8 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index 11f1e93630..cd66d0a3bd 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -104,6 +104,7 @@ def create_checkpoint(weights_and_biases,
weights_and_biases: Iterable of tuples of weight and bias values.
global_step: Initial global step to save in checkpoint.
model_dir: Directory into which checkpoint is saved.
+ batch_norm_vars: Variables used for batch normalization.
"""
weights, biases = zip(*weights_and_biases)
if batch_norm_vars:
@@ -244,8 +245,9 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
class BaseDNNModelFnTest(object):
"""Tests that _dnn_model_fn passes expected logits to mock head."""
- def __init__(self, dnn_model_fn):
+ def __init__(self, dnn_model_fn, fc_impl=feature_column):
self._dnn_model_fn = dnn_model_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -272,7 +274,7 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
optimizer=mock_optimizer(self, hidden_units))
@@ -462,8 +464,8 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age'),
- feature_column.numeric_column('height')
+ self._fc_impl.numeric_column('age'),
+ self._fc_impl.numeric_column('height')
],
optimizer=mock_optimizer(self, hidden_units))
with monitored_session.MonitoredTrainingSession(
@@ -499,7 +501,7 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
optimizer=mock_optimizer(self, hidden_units))
@@ -508,8 +510,9 @@ class BaseDNNModelFnTest(object):
class BaseDNNLogitFnTest(object):
"""Tests correctness of logits calculated from _dnn_logit_fn_builder."""
- def __init__(self, dnn_logit_fn_builder):
+ def __init__(self, dnn_logit_fn_builder, fc_impl=feature_column):
self._dnn_logit_fn_builder = dnn_logit_fn_builder
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -541,7 +544,7 @@ class BaseDNNLogitFnTest(object):
units=logits_dimension,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
activation_fn=nn.relu,
@@ -786,8 +789,8 @@ class BaseDNNLogitFnTest(object):
units=logits_dimension,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age'),
- feature_column.numeric_column('height')
+ self._fc_impl.numeric_column('age'),
+ self._fc_impl.numeric_column('height')
],
activation_fn=nn.relu,
dropout=None,
@@ -806,9 +809,13 @@ class BaseDNNLogitFnTest(object):
class BaseDNNWarmStartingTest(object):
- def __init__(self, _dnn_classifier_fn, _dnn_regressor_fn):
+ def __init__(self,
+ _dnn_classifier_fn,
+ _dnn_regressor_fn,
+ fc_impl=feature_column):
self._dnn_classifier_fn = _dnn_classifier_fn
self._dnn_regressor_fn = _dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
# Create a directory to save our old checkpoint and vocabularies to.
@@ -843,8 +850,8 @@ class BaseDNNWarmStartingTest(object):
def test_classifier_basic_warm_starting(self):
"""Tests correctness of DNNClassifier default warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -875,8 +882,8 @@ class BaseDNNWarmStartingTest(object):
def test_regressor_basic_warm_starting(self):
"""Tests correctness of DNNRegressor default warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -905,8 +912,8 @@ class BaseDNNWarmStartingTest(object):
def test_warm_starting_selective_variables(self):
"""Tests selecting variables to warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -958,8 +965,8 @@ class BaseDNNWarmStartingTest(object):
vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')
with open(vocab_file, 'w') as f:
f.write('\n'.join(vocab_list))
- occupation = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_file(
+ occupation = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=vocab_file,
vocabulary_size=len(vocab_list)),
@@ -985,8 +992,8 @@ class BaseDNNWarmStartingTest(object):
'new_occupation_vocab')
with open(new_vocab_file, 'w') as f:
f.write('\n'.join(new_vocab_list))
- new_occupation = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_file(
+ new_occupation = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=new_vocab_file,
vocabulary_size=len(new_vocab_list)),
@@ -1051,8 +1058,8 @@ class BaseDNNWarmStartingTest(object):
def test_warm_starting_with_naming_change(self):
"""Tests warm-starting with a Tensor name remapping."""
- locality = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ locality = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'locality', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -1068,8 +1075,8 @@ class BaseDNNWarmStartingTest(object):
# Create a second DNNClassifier, warm-started from the first. Use a
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
# accumulator values that change).
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
warm_started_dnn_classifier = self._dnn_classifier_fn(
@@ -1101,8 +1108,9 @@ class BaseDNNWarmStartingTest(object):
class BaseDNNClassifierEvaluateTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1121,7 +1129,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
# batch_size = 2, one false label, and one true.
@@ -1161,7 +1169,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
n_classes=n_classes,
model_dir=self._model_dir)
def _input_fn():
@@ -1192,7 +1200,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
# batch_size = 2, one false label, and one true.
@@ -1218,7 +1226,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
n_classes=n_classes,
weight_column='w',
model_dir=self._model_dir)
@@ -1238,8 +1246,9 @@ class BaseDNNClassifierEvaluateTest(object):
class BaseDNNRegressorEvaluateTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1259,7 +1268,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
return {'age': [[10.]]}, [[1.]]
@@ -1289,7 +1298,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
label_dimension=label_dimension,
model_dir=self._model_dir)
def _input_fn():
@@ -1320,7 +1329,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
label_dimension=label_dimension,
weight_column='w',
model_dir=self._model_dir)
@@ -1339,8 +1348,9 @@ class BaseDNNRegressorEvaluateTest(object):
class BaseDNNClassifierPredictTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1361,7 +1371,7 @@ class BaseDNNClassifierPredictTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
label_vocabulary=label_vocabulary,
- feature_columns=(feature_column.numeric_column('x'),),
+ feature_columns=(self._fc_impl.numeric_column('x'),),
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)
@@ -1405,7 +1415,7 @@ class BaseDNNClassifierPredictTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x', shape=(2,)),),
+ feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),
label_vocabulary=label_vocabulary,
n_classes=3,
model_dir=self._model_dir)
@@ -1453,8 +1463,9 @@ class BaseDNNClassifierPredictTest(object):
class BaseDNNRegressorPredictTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1475,7 +1486,7 @@ class BaseDNNRegressorPredictTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x'),),
+ feature_columns=(self._fc_impl.numeric_column('x'),),
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)
@@ -1497,7 +1508,7 @@ class BaseDNNRegressorPredictTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x', shape=(2,)),),
+ feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),
label_dimension=3,
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
@@ -1594,8 +1605,9 @@ def _assert_simple_summary(testcase, expected_values, actual_summary):
class BaseDNNClassifierTrainTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1609,7 +1621,7 @@ class BaseDNNClassifierTrainTest(object):
hidden_units = (2, 2)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, then validate final checkpoint.
@@ -1625,7 +1637,7 @@ class BaseDNNClassifierTrainTest(object):
n_classes = 3
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1643,7 +1655,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1682,7 +1694,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1728,7 +1740,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1759,7 +1771,7 @@ class BaseDNNClassifierTrainTest(object):
dnn_classifier = self._dnn_classifier_fn(
n_classes=n_classes,
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1793,8 +1805,9 @@ class BaseDNNClassifierTrainTest(object):
class BaseDNNRegressorTrainTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1808,7 +1821,7 @@ class BaseDNNRegressorTrainTest(object):
hidden_units = (2, 2)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, then validate final checkpoint.
@@ -1824,7 +1837,7 @@ class BaseDNNRegressorTrainTest(object):
opt = mock_optimizer(self, hidden_units=hidden_units)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1864,7 +1877,7 @@ class BaseDNNRegressorTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1917,7 +1930,8 @@ class BaseDNNRegressorTrainTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age', shape=[input_dimension])],
+ self._fc_impl.numeric_column('age', shape=[input_dimension])
+ ],
label_dimension=label_dimension,
optimizer=opt,
model_dir=self._model_dir)
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 115dd18518..8b96284bd3 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -25,14 +25,18 @@ import six
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variable_ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import estimator_export
@@ -46,23 +50,42 @@ def _get_default_optimizer(feature_columns):
return ftrl.FtrlOptimizer(learning_rate=learning_rate)
-def _compute_fraction_of_zero(cols_to_vars):
- """Given a linear cols_to_vars dict, compute the fraction of zero weights.
+def _get_expanded_variable_list(var_list):
+ """Given a list of variables, expands them if they are partitioned.
Args:
- cols_to_vars: A dictionary mapping FeatureColumns to lists of tf.Variables
- like one returned from feature_column_lib.linear_model.
+ var_list: A list of variables.
+
+ Returns:
+ A list of variables where each partitioned variable is expanded to its
+ components.
+ """
+ returned_list = []
+ for variable in var_list:
+ if (isinstance(variable, variable_ops.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ returned_list.append(variable) # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ returned_list.extend(list(variable))
+ return returned_list
+
+
+# TODO(rohanj): Consider making this a public utility method.
+def _compute_fraction_of_zero(variables):
+ """Given a linear variables list, compute the fraction of zero weights.
+
+ Args:
+ variables: A list or list of list of variables
Returns:
The fraction of zeros (sparsity) in the linear model.
"""
all_weight_vars = []
- for var_or_var_list in cols_to_vars.values():
+ for var_or_var_list in variables:
+ var_list = nest.flatten(var_or_var_list)
# Skip empty-lists associated with columns that created no Variables.
- if var_or_var_list:
- all_weight_vars += [
- array_ops.reshape(var, [-1]) for var in var_or_var_list
- ]
+ if var_list:
+ all_weight_vars += [array_ops.reshape(var, [-1]) for var in var_list]
return nn.zero_fraction(array_ops.concat(all_weight_vars, axis=0))
@@ -92,14 +115,36 @@ def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
Returns:
A `Tensor` representing the logits.
"""
- cols_to_vars = {}
- logits = feature_column_lib.linear_model(
- features=features,
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- cols_to_vars=cols_to_vars)
- bias = cols_to_vars.pop('bias')
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+ linear_model = feature_column_v2.LinearModel(
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ shared_state_manager=shared_state_manager)
+ logits = linear_model(features)
+ bias = linear_model.bias_variable
+
+ # We'd like to get all the non-bias variables associated with this
+ # LinearModel. This includes the shared embedding variables as well.
+ variables = linear_model.variables
+ variables.remove(bias)
+ variables.extend(shared_state_manager.variables)
+
+ # Expand (potential) Partitioned variables
+ bias = _get_expanded_variable_list([bias])
+ variables = _get_expanded_variable_list(variables)
+ else:
+ linear_model = feature_column._LinearModel( # pylint: disable=protected-access
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ name='linear_model')
+ logits = linear_model(features)
+ cols_to_vars = linear_model.cols_to_vars()
+ bias = cols_to_vars.pop('bias')
+ variables = cols_to_vars.values()
+
if units > 1:
summary.histogram('bias', bias)
else:
@@ -107,7 +152,7 @@ def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
# so we should provide a scalar summary.
summary.scalar('bias', bias[0][0])
summary.scalar('fraction_of_zero_weights',
- _compute_fraction_of_zero(cols_to_vars))
+ _compute_fraction_of_zero(variables))
return logits
return linear_logit_fn
diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py
index 59a230417d..3e6da5de22 100644
--- a/tensorflow/python/estimator/canned/linear_test.py
+++ b/tensorflow/python/estimator/canned/linear_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import linear_testing_utils
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.platform import test
@@ -40,7 +42,16 @@ class LinearRegressorPartitionerTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorPartitionerV2Test(
+ linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorEvaluationTest(
@@ -49,7 +60,16 @@ class LinearRegressorEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorEvaluationV2Test(
+ linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorPredictTest(
@@ -58,7 +78,16 @@ class LinearRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorPredictV2Test(
+ linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorIntegrationTest(
@@ -67,7 +96,16 @@ class LinearRegressorIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorIntegrationV2Test(
+ linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorTrainingTest(
@@ -76,19 +114,37 @@ class LinearRegressorTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
-# Tests for Linear Classifier.
+class LinearRegressorTrainingV2Test(
+ linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
+
+# Tests for Linear Classifier.
class LinearClassifierTrainingTest(
linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierTrainingV2Test(
+ linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierEvaluationTest(
@@ -97,7 +153,18 @@ class LinearClassifierEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierEvaluationV2Test(
+ linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierPredictTest(
@@ -106,7 +173,18 @@ class LinearClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierPredictV2Test(
+ linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierIntegrationTest(
@@ -115,7 +193,18 @@ class LinearClassifierIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierIntegrationV2Test(
+ linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
# Tests for Linear logit_fn.
@@ -124,7 +213,17 @@ class LinearLogitFnTest(linear_testing_utils.BaseLinearLogitFnTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- linear_testing_utils.BaseLinearLogitFnTest.__init__(self)
+ linear_testing_utils.BaseLinearLogitFnTest.__init__(
+ self, fc_lib=feature_column)
+
+
+class LinearLogitFnV2Test(linear_testing_utils.BaseLinearLogitFnTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearLogitFnTest.__init__(
+ self, fc_lib=feature_column_v2)
# Tests for warm-starting with Linear logit_fn.
@@ -134,7 +233,22 @@ class LinearWarmStartingTest(linear_testing_utils.BaseLinearWarmStartingTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearWarmStartingTest.__init__(
- self, _linear_classifier_fn, _linear_regressor_fn)
+ self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column)
+
+
+class LinearWarmStartingV2Test(linear_testing_utils.BaseLinearWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearWarmStartingTest.__init__(
+ self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column_v2)
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 65cdd50061..827352a70b 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -37,7 +37,8 @@ from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -152,8 +153,9 @@ class CheckPartitionerVarHook(session_run_hook.SessionRunHook):
class BaseLinearRegressorPartitionerTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -173,7 +175,7 @@ class BaseLinearRegressorPartitionerTest(object):
return [partitions, 1] if shape[0] == x_dim else [1]
regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.categorical_column_with_hash_bucket(
+ feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(
'language', hash_bucket_size=x_dim),),
partitioner=_partitioner,
model_dir=self._model_dir)
@@ -209,9 +211,8 @@ class BaseLinearRegressorPartitionerTest(object):
'_get_replica_device_setter',
return_value=lambda _: '/cpu:0'):
linear_regressor = self._linear_regressor_fn(
- feature_columns=(
- feature_column_lib.categorical_column_with_hash_bucket(
- 'language', hash_bucket_size=x_dim),),
+ feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(
+ 'language', hash_bucket_size=x_dim),),
config=FakeRunConfig(),
model_dir=self._model_dir)
@@ -232,8 +233,9 @@ class BaseLinearRegressorPartitionerTest(object):
# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.
class BaseLinearRegressorEvaluationTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -252,7 +254,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(
input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=1)
@@ -276,7 +278,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(
input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1)
@@ -308,7 +310,7 @@ class BaseLinearRegressorEvaluationTest(object):
return features, labels
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='weights',
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(input_fn=_input_fn, steps=1)
@@ -336,8 +338,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column(
- 'age', shape=(x_dim,)),),
+ feature_columns=(self._fc_lib.numeric_column('age', shape=(x_dim,)),),
label_dimension=label_dim,
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
@@ -374,8 +375,8 @@ class BaseLinearRegressorEvaluationTest(object):
batch_size = 2
feature_columns = [
- feature_column_lib.numeric_column('age'),
- feature_column_lib.numeric_column('height')
+ self._fc_lib.numeric_column('age'),
+ self._fc_lib.numeric_column('height')
]
input_fn = numpy_io.numpy_input_fn(
x={'age': np.array([20, 40]),
@@ -402,8 +403,9 @@ class BaseLinearRegressorEvaluationTest(object):
class BaseLinearRegressorPredictTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -422,7 +424,7 @@ class BaseLinearRegressorPredictTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('x'),),
+ feature_columns=(self._fc_lib.numeric_column('x'),),
model_dir=self._model_dir)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -441,7 +443,7 @@ class BaseLinearRegressorPredictTest(object):
batch_size = 2
label_dimension = 3
x_dim = 4
- feature_columns = (feature_column_lib.numeric_column('x', shape=(x_dim,)),)
+ feature_columns = (self._fc_lib.numeric_column('x', shape=(x_dim,)),)
with ops.Graph().as_default():
variables_lib.Variable( # shape=[x_dim, label_dimension]
[[1., 2., 3.], [2., 3., 4.], [3., 4., 5.], [4., 5., 6.]],
@@ -479,8 +481,8 @@ class BaseLinearRegressorPredictTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('x0'),
- feature_column_lib.numeric_column('x1')),
+ feature_columns=(self._fc_lib.numeric_column('x0'),
+ self._fc_lib.numeric_column('x1')),
model_dir=self._model_dir)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -515,9 +517,8 @@ class BaseLinearRegressorPredictTest(object):
dense_shape=[2, 2]),
})
- feature_columns = (
- feature_column_lib.categorical_column_with_vocabulary_list(
- 'language', vocabulary_list=['a', 'b', 'c']),)
+ feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
# Check prediction for each sparse_combiner.
# With sparse_combiner = 'sum', we have
@@ -561,8 +562,9 @@ class BaseLinearRegressorPredictTest(object):
class BaseLinearRegressorIntegrationTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -575,7 +577,7 @@ class BaseLinearRegressorIntegrationTest(object):
def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
input_dimension, label_dimension, prediction_length):
feature_columns = [
- feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ self._fc_lib.numeric_column('x', shape=(input_dimension,))
]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
@@ -597,7 +599,7 @@ class BaseLinearRegressorIntegrationTest(object):
self.assertAllEqual((prediction_length, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ feature_spec = self._fc_lib.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
@@ -729,8 +731,9 @@ class BaseLinearRegressorIntegrationTest(object):
class BaseLinearRegressorTrainingTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -808,7 +811,7 @@ class BaseLinearRegressorTrainingTest(object):
label = 5.
age = 17
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, and validate final checkpoint.
@@ -820,7 +823,7 @@ class BaseLinearRegressorTrainingTest(object):
def testTrainWithOneDimLabel(self):
label_dimension = 1
batch_size = 20
- feature_columns = [feature_column_lib.numeric_column('age', shape=(1,))]
+ feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
label_dimension=label_dimension,
@@ -840,7 +843,7 @@ class BaseLinearRegressorTrainingTest(object):
def testTrainWithOneDimWeight(self):
label_dimension = 1
batch_size = 20
- feature_columns = [feature_column_lib.numeric_column('age', shape=(1,))]
+ feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
label_dimension=label_dimension,
@@ -867,7 +870,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = (logits - label)^2 = (0 - 5.)^2 = 25.
mock_optimizer = self._mock_optimizer(expected_loss=25.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -900,7 +903,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = (logits - label)^2 = (175 - 5)^2 = 28900
mock_optimizer = self._mock_optimizer(expected_loss=28900.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -935,7 +938,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004
mock_optimizer = self._mock_optimizer(expected_loss=52004.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -954,8 +957,9 @@ class BaseLinearRegressorTrainingTest(object):
class BaseLinearClassifierTrainingTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1031,7 +1035,7 @@ class BaseLinearClassifierTrainingTest(object):
label = 0
age = 17
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1051,7 +1055,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
data_rank_1 = np.array([0, 1])
@@ -1078,7 +1082,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
data_rank_1 = np.array([0, 1])
@@ -1103,7 +1107,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='w',
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1129,7 +1133,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='w',
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1166,7 +1170,7 @@ class BaseLinearClassifierTrainingTest(object):
expected_loss=-1 * math.log(1.0/n_classes))
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1229,7 +1233,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1277,7 +1281,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=1.1132617)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1341,7 +1345,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1368,8 +1372,9 @@ class BaseLinearClassifierTrainingTest(object):
class BaseLinearClassifierEvaluationTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1398,7 +1403,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
eval_metrics = est.evaluate(
@@ -1464,7 +1469,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
eval_metrics = est.evaluate(
@@ -1540,7 +1545,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
weight_column='w',
model_dir=self._model_dir)
@@ -1605,8 +1610,9 @@ class BaseLinearClassifierEvaluationTest(object):
class BaseLinearClassifierPredictTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1634,7 +1640,7 @@ class BaseLinearClassifierPredictTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
label_vocabulary=label_vocabulary,
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1730,9 +1736,8 @@ class BaseLinearClassifierPredictTest(object):
dense_shape=[2, 2]),
})
- feature_columns = (
- feature_column_lib.categorical_column_with_vocabulary_list(
- 'language', vocabulary_list=['a', 'b', 'c']),)
+ feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
# Check prediction for each sparse_combiner.
# With sparse_combiner = 'sum', we have
@@ -1776,8 +1781,9 @@ class BaseLinearClassifierPredictTest(object):
class BaseLinearClassifierIntegrationTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1789,7 +1795,7 @@ class BaseLinearClassifierIntegrationTest(object):
def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,
predict_input_fn, input_dimension, prediction_length):
feature_columns = [
- feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ self._fc_lib.numeric_column('x', shape=(input_dimension,))
]
est = self._linear_classifier_fn(
feature_columns=feature_columns,
@@ -1811,7 +1817,7 @@ class BaseLinearClassifierIntegrationTest(object):
self.assertAllEqual((prediction_length, 1), predictions.shape)
# EXPORT
- feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ feature_spec = self._fc_lib.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
@@ -1961,9 +1967,12 @@ class BaseLinearClassifierIntegrationTest(object):
class BaseLinearLogitFnTest(object):
+ def __init__(self, fc_lib=feature_column):
+ self._fc_lib = fc_lib
+
def test_basic_logit_correctness(self):
"""linear_logit_fn simply wraps feature_column_lib.linear_model."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
with ops.Graph().as_default():
logit_fn = linear._linear_logit_fn_builder(units=2, feature_columns=[age])
logits = logit_fn(features={'age': [[23.], [31.]]})
@@ -1983,12 +1992,14 @@ class BaseLinearLogitFnTest(object):
def test_compute_fraction_of_zero(self):
"""Tests the calculation of sparsity."""
- age = feature_column_lib.numeric_column('age')
- occupation = feature_column_lib.categorical_column_with_hash_bucket(
+ if self._fc_lib != feature_column:
+ return
+ age = feature_column.numeric_column('age')
+ occupation = feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=5)
with ops.Graph().as_default():
cols_to_vars = {}
- feature_column_lib.linear_model(
+ feature_column.linear_model(
features={
'age': [[23.], [31.]],
'occupation': [['doctor'], ['engineer']]
@@ -1997,7 +2008,42 @@ class BaseLinearLogitFnTest(object):
units=3,
cols_to_vars=cols_to_vars)
cols_to_vars.pop('bias')
- fraction_zero = linear._compute_fraction_of_zero(cols_to_vars)
+ fraction_zero = linear._compute_fraction_of_zero(cols_to_vars.values())
+ age_var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ 'linear_model/age')[0]
+ with tf_session.Session() as sess:
+ sess.run([variables_lib.global_variables_initializer()])
+ # Upon initialization, all variables will be zero.
+ self.assertAllClose(1, fraction_zero.eval())
+
+ sess.run(age_var.assign([[2.0, 0.0, -1.0]]))
+ # 1 of the 3 age weights are zero, and all of the 15 (5 hash buckets
+ # x 3-dim output) are zero.
+ self.assertAllClose(16. / 18., fraction_zero.eval())
+
+ def test_compute_fraction_of_zero_v2(self):
+ """Tests the calculation of sparsity."""
+ if self._fc_lib != feature_column_v2:
+ return
+
+ age = feature_column_v2.numeric_column('age')
+ occupation = feature_column_v2.categorical_column_with_hash_bucket(
+ 'occupation', hash_bucket_size=5)
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+ with ops.Graph().as_default():
+ model = feature_column_v2.LinearModel(
+ feature_columns=[age, occupation],
+ units=3,
+ shared_state_manager=shared_state_manager)
+ features = {
+ 'age': [[23.], [31.]],
+ 'occupation': [['doctor'], ['engineer']]
+ }
+ model(features)
+ variables = model.variables
+ variables.remove(model.bias_variable)
+ variables.extend(shared_state_manager.variables)
+ fraction_zero = linear._compute_fraction_of_zero(variables)
age_var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
'linear_model/age')[0]
with tf_session.Session() as sess:
@@ -2013,9 +2059,13 @@ class BaseLinearLogitFnTest(object):
class BaseLinearWarmStartingTest(object):
- def __init__(self, _linear_classifier_fn, _linear_regressor_fn):
+ def __init__(self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column):
self._linear_classifier_fn = _linear_classifier_fn
self._linear_regressor_fn = _linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
# Create a directory to save our old checkpoint and vocabularies to.
@@ -2039,7 +2089,7 @@ class BaseLinearWarmStartingTest(object):
def test_classifier_basic_warm_starting(self):
"""Tests correctness of LinearClassifier default warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2066,7 +2116,7 @@ class BaseLinearWarmStartingTest(object):
def test_regressor_basic_warm_starting(self):
"""Tests correctness of LinearRegressor default warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearRegressor and train to save a checkpoint.
linear_regressor = self._linear_regressor_fn(
@@ -2091,7 +2141,7 @@ class BaseLinearWarmStartingTest(object):
def test_warm_starting_selective_variables(self):
"""Tests selecting variables to warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2128,7 +2178,7 @@ class BaseLinearWarmStartingTest(object):
vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')
with open(vocab_file, 'w') as f:
f.write('\n'.join(vocab_list))
- occupation = feature_column_lib.categorical_column_with_vocabulary_file(
+ occupation = self._fc_lib.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=vocab_file,
vocabulary_size=len(vocab_list))
@@ -2152,7 +2202,7 @@ class BaseLinearWarmStartingTest(object):
'new_occupation_vocab')
with open(new_vocab_file, 'w') as f:
f.write('\n'.join(new_vocab_list))
- new_occupation = feature_column_lib.categorical_column_with_vocabulary_file(
+ new_occupation = self._fc_lib.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=new_vocab_file,
vocabulary_size=len(new_vocab_list))
@@ -2205,7 +2255,7 @@ class BaseLinearWarmStartingTest(object):
def test_warm_starting_with_naming_change(self):
"""Tests warm-starting with a Tensor name remapping."""
- age_in_years = feature_column_lib.numeric_column('age_in_years')
+ age_in_years = self._fc_lib.numeric_column('age_in_years')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2219,7 +2269,7 @@ class BaseLinearWarmStartingTest(object):
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
# accumulator values that change).
warm_started_linear_classifier = self._linear_classifier_fn(
- feature_columns=[feature_column_lib.numeric_column('age')],
+ feature_columns=[self._fc_lib.numeric_column('age')],
n_classes=4,
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
# The 'age' variable correspond to the 'age_in_years' variable in the
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index eec64ad452..e6d82f0db7 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -144,7 +144,7 @@ class Estimator(object):
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same (for multi-head models).
- If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will
+ If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
be passed. If the `model_fn`'s signature does not accept
`mode`, the `model_fn` must still be able to handle
`labels=None`.
@@ -468,17 +468,41 @@ class Estimator(object):
with ops.Graph().as_default():
if self._eval_distribution:
+ # We want to create the iterations variable outside the distribution
+ # scope as that is just stored on the host and mainly used to drive
+ # the loop and doesn't need to be a Mirrored/Device variable.
+ training.get_or_create_steps_per_run_variable()
with self._eval_distribution.scope():
return _evaluate()
else:
return _evaluate()
def _convert_eval_steps_to_hooks(self, steps):
+ """Create hooks to run correct number of steps in evaluation.
+
+ Args:
+ steps: number of steps to run during evaluation.
+
+ Raises:
+ ValueError: if steps is less than or equal to zero.
+
+ Returns:
+ List of hooks to be passed to the estimator.
+ """
if steps is None:
return []
if steps <= 0:
raise ValueError('Must specify steps > 0, given: {}'.format(steps))
+
+ # The hooks are declared as private in evaluation.py discourage the use
+ # by other libraries or open source users. This should be the only usage
+ # of the estimator evaluation hooks.
+ if self._eval_distribution:
+ steps_per_run = getattr(self._eval_distribution, 'steps_per_run', 1)
+ if steps_per_run > 1:
+ return [evaluation._MultiStepStopAfterNEvalsHook( # pylint: disable=protected-access
+ num_evals=steps, steps_per_run=steps_per_run)]
return [evaluation._StopAfterNEvalsHook(num_evals=steps)] # pylint: disable=protected-access
def predict(self,
@@ -783,9 +807,9 @@ class Estimator(object):
those features and labels, and restores the given checkpoint
(or, lacking that, the most recent checkpoint) into the graph.
Only one of the modes is used for saving variables to the `SavedModel`
- (order of preference: @{tf.estimator.ModeKeys#TRAIN$TRAIN},
- @{tf.estimator.ModeKeys#EVAL$EVAL}, then
- @{tf.estimator.ModeKeys#PREDICT$PREDICT}), such that up to three
+ (order of preference: `tf.estimator.ModeKeys.TRAIN`,
+ `tf.estimator.ModeKeys.EVAL`, then
+ `tf.estimator.ModeKeys.PREDICT`), such that up to three
`tf.MetaGraphDefs` are saved with a single set of variables in a single
`SavedModel` directory.
@@ -1081,7 +1105,7 @@ class Estimator(object):
"""Creates the global step tensor in graph.
The global step tensor must be an integer type with name 'global_step' and
- be added to the collection @{tf.GraphKeys#GLOBAL_STEP$GLOBAL_STEP}.
+ be added to the collection `tf.GraphKeys.GLOBAL_STEP`.
Args:
graph: The graph in which to create the global step tensor.
@@ -1394,6 +1418,36 @@ class Estimator(object):
# It is expected to have one CheckpointSaverHook. If multiple, we pick
# up the first one to add listener.
saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
+
+ # Add summary hooks to worker 0 if we are running with a master, to ensure
+ # that summaries are written at correct intervals even with long-running
+ # evaluations.
+ save_summary_steps = self._config.save_summary_steps
+ log_step_count_steps = self._config.log_step_count_steps
+ if (self._config.cluster_spec and self._config.cluster_spec.jobs and
+ (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
+ # Update config values to prevent the default hooks from being created on
+ # the master or other workers.
+ save_summary_steps = 0
+ log_step_count_steps = None
+
+ if (self._config.task_type == run_config.TaskType.WORKER and
+ self._config.task_id == 0):
+ if (self._config.save_summary_steps and
+ self._config.save_summary_steps > 0):
+ worker_hooks.append(
+ training.SummarySaverHook(
+ save_steps=self._config.save_summary_steps,
+ output_dir=self._config.model_dir,
+ scaffold=estimator_spec.scaffold))
+
+ if (self._config.log_step_count_steps and
+ self._config.log_step_count_steps > 0):
+ worker_hooks.append(
+ training.StepCounterHook(
+ every_n_steps=self._config.log_step_count_steps,
+ output_dir=self._config.model_dir))
+
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
@@ -1403,9 +1457,9 @@ class Estimator(object):
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=self._config.save_summary_steps,
+ save_summaries_steps=save_summary_steps,
config=self._session_config,
- log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
+ log_step_count_steps=log_step_count_steps) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
@@ -1474,6 +1528,7 @@ class Estimator(object):
self._eval_distribution.__class__.__name__ == 'TPUStrategy')
if is_tpu_strategy:
+ steps_per_run_variable = training.get_or_create_steps_per_run_variable()
def step_fn(ctx, features, labels=None):
"""Runs one step of the eval computation and captures outputs."""
estimator_spec = self._eval_distribution.call_for_each_tower(
@@ -1490,7 +1545,7 @@ class Estimator(object):
# TODO(priyag): Fix eval step hook to account for steps_per_run.
ctx = self._eval_distribution.run_steps_on_dataset(
- step_fn, iterator, iterations=self._eval_distribution.steps_per_run)
+ step_fn, iterator, iterations=steps_per_run_variable)
update_op = ctx.run_op
eval_dict = ctx.non_tensor_outputs['eval_dict']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 1ed5e30b0e..246dfb1a4b 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import functools
import glob
+import json
import os
import tempfile
@@ -969,6 +970,99 @@ class EstimatorTrainTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
est.train(dummy_input_fn, steps=1)
+ def test_master_distributed_hooks(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.MASTER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_0(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_nonzero(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 1
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
@@ -1017,7 +1111,7 @@ class EstimatorGetVariablesTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='one')
+ variables.VariableV1(1., name='one')
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
@@ -1033,8 +1127,8 @@ class EstimatorGetVariablesTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='one')
- variables.Variable(3., name='three')
+ variables.VariableV1(1., name='one')
+ variables.VariableV1(3., name='three')
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
@@ -1178,13 +1272,13 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn(features, labels, mode, params):
del features, labels, params
mean = metrics_module.Mean()
- mean.update_state(variables.Variable(2.) + 1)
+ mean.update_state(variables.VariableV1(2.) + 1)
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(1.),
eval_metric_ops={
'mean1': mean,
- 'mean2': metrics_lib.mean(variables.Variable(2.) + 1)
+ 'mean2': metrics_lib.mean(variables.VariableV1(2.) + 1)
})
est = estimator.Estimator(model_fn=_model_fn)
@@ -1332,7 +1426,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn_with_incremental_loss(features, labels, mode):
_, _ = features, labels
- local_weight = variables.Variable(
+ local_weight = variables.VariableV1(
0., name='local_weight', collections=[ops.GraphKeys.LOCAL_VARIABLES])
# Loss will be 2, 4, 6, ...
loss = 2 * state_ops.assign_add(local_weight, 1.)
@@ -1385,7 +1479,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _get_model_fn(val=1):
def _model_fn(features, labels, mode):
del features, labels # unused
- variables.Variable(val, name='weight')
+ variables.VariableV1(val, name='weight')
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
@@ -1409,7 +1503,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -1603,7 +1697,7 @@ class EstimatorPredictTest(test.TestCase):
def test_no_checkpoint_uses_init(self):
def _model_fn(features, labels, mode, params, config):
del features, labels, params, config
- x = variables.Variable([[3.]], name='x')
+ x = variables.VariableV1([[3.]], name='x')
return model_fn_lib.EstimatorSpec(mode, predictions=math_ops.add(x, 1.))
est = estimator.Estimator(model_fn=_model_fn)
# Expected prediction value is 1 + the value of the Variable that is newly
@@ -1614,7 +1708,7 @@ class EstimatorPredictTest(test.TestCase):
def _make_model_fn(x):
def _variable_creating_and_export_model_fn(features, labels, mode):
_, _ = features, labels
- x_var = variables.Variable([[x]], name='x')
+ x_var = variables.VariableV1([[x]], name='x')
return model_fn_lib.EstimatorSpec(
mode,
predictions=math_ops.add(x_var, 1.),
@@ -1936,7 +2030,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- v = variables.Variable([[16.]], name='weight')
+ v = variables.VariableV1([[16.]], name='weight')
prediction = v * 2
return model_fn_lib.EstimatorSpec(
mode,
@@ -1953,7 +2047,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- v = variables.Variable([[16.]], name='weight')
+ v = variables.VariableV1([[16.]], name='weight')
prediction = v * 2
return model_fn_lib.EstimatorSpec(
mode,
@@ -1974,7 +2068,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -2029,7 +2123,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn_for_export_tests(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
classes = constant_op.constant(['wumpus'])
update_global_step = state_ops.assign_add(training.get_global_step(), 1)
@@ -2052,11 +2146,11 @@ def _x_y_input_fn():
def _model_fn_with_x_y(features, labels, mode):
_ = labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
classes = constant_op.constant(['wumpus'])
if mode == model_fn_lib.ModeKeys.PREDICT:
- variables.Variable(36., name='name_collision')
+ variables.VariableV1(36., name='name_collision')
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
@@ -2076,8 +2170,8 @@ def _model_fn_with_x_y(features, labels, mode):
metrics_lib.mean(
features['x'] - features['y'], name='{}mean'.format(prefix))
}
- variables.Variable(1., name='later_var')
- variables.Variable(3., name='name_collision')
+ variables.VariableV1(1., name='later_var')
+ variables.VariableV1(3., name='name_collision')
return model_fn_lib.EstimatorSpec(
mode,
predictions=multiplied,
@@ -2411,9 +2505,9 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_with_predict_only_vars(features, labels, mode):
_, _ = features, labels
if mode == model_fn_lib.ModeKeys.PREDICT:
- variables.Variable(1., name='only_in_predict')
+ variables.VariableV1(1., name='only_in_predict')
else:
- variables.Variable(1., name='otherwise')
+ variables.VariableV1(1., name='otherwise')
prediction = constant_op.constant(1.)
return model_fn_lib.EstimatorSpec(
@@ -2684,7 +2778,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
scores = constant_op.constant([3.])
return model_fn_lib.EstimatorSpec(
@@ -2717,7 +2811,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
if mode == model_fn_lib.ModeKeys.PREDICT:
@@ -2762,8 +2856,8 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- my_int = variables.Variable(1, name='my_int',
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ my_int = variables.VariableV1(1, name='my_int',
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
_ = training.get_or_create_steps_per_run_variable()
scores = constant_op.constant([3.])
with ops.control_dependencies([
@@ -2808,8 +2902,8 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- my_int = variables.Variable(1, name='my_int',
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ my_int = variables.VariableV1(1, name='my_int',
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
scores = constant_op.constant([3.])
with ops.control_dependencies([
variables.local_variables_initializer(),
@@ -3038,7 +3132,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
@@ -3081,7 +3175,7 @@ class EstimatorHookOrderingTest(test.TestCase):
"""A graph that generates NaN's for testing."""
del features, labels
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, name='global_step')
inc_global_step = state_ops.assign_add(global_step, 1)
nan_const = constant_op.constant(np.nan, dtype=dtypes.float32)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 6b2765be82..7546771ed3 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import os
import re
+import six
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
@@ -31,6 +32,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
from tensorflow.python.ops import check_ops
@@ -214,25 +216,40 @@ def _convert_keras_metrics_to_estimator(model):
if not getattr(model, 'metrics', None):
return None
- # TODO(psv/fchollet): support stateful metrics
eval_metric_ops = {}
+
+ def get_metric_name(metric):
+ if isinstance(metric, metrics.Metric):
+ return metric.name
+ if callable(metric):
+ return metric.__name__
+ assert isinstance(metric, six.string_types)
+ return metric
+
# When each metric maps to an output
if isinstance(model.metrics, dict):
for i, output_name in enumerate(model.metrics.keys()):
- metric_name = model.metrics[output_name]
- if callable(metric_name):
- metric_name = metric_name.__name__
+ # `metric` is the user given metric value in `compile`. This can be
+ # metric name (`acc`), metric function (binary_accuracy) or a metric
+ # object (BinaryAccuracy()).
+ metric = model.metrics[output_name]
+ metric_name = get_metric_name(metric)
# When some outputs use the same metric
if list(model.metrics.values()).count(metric_name) > 1:
metric_name += '_' + output_name
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i - len(model.metrics)])
+ if isinstance(metric, metrics.Metric):
+ eval_metric_ops[metric_name] = metric
+ else:
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i - len(model.metrics)])
else:
- for i, metric_name in enumerate(model.metrics):
- if callable(metric_name):
- metric_name = metric_name.__name__
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i])
+ for i, metric in enumerate(model.metrics):
+ metric_name = get_metric_name(metric)
+ if isinstance(metric, metrics.Metric):
+ eval_metric_ops[metric_name] = metric
+ else:
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i])
return eval_metric_ops
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 3758243d7b..288f9b8906 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -257,7 +257,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -281,7 +281,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
my_hook = MyHook()
with self.cached_session():
@@ -306,7 +306,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
my_hook = MyHook()
with self.cached_session():
keras_model.fit(x_train, y_train, epochs=1)
@@ -328,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -351,7 +351,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -370,7 +370,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
# Create state
@@ -662,7 +662,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
tf_config = json.dumps({
'cluster': {
@@ -687,7 +687,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
@@ -706,7 +706,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -736,7 +736,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
@@ -751,7 +751,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
@@ -765,7 +765,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
keras_model.train_on_batch(
np.random.random((10,) + _INPUT_SIZE),
@@ -776,7 +776,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=SGD(lr=0.0001, momentum=0.9),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
@@ -786,7 +786,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=optimizer,
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session() as sess:
keras_model_fn = keras_lib._create_keras_model_fn(keras_model)
global_step = training_util.create_global_step()
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index 31e4778e72..fb110c4b7b 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import os
import time
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
@@ -144,14 +143,11 @@ class StrategyInitFinalizeHook(training.SessionRunHook):
self._finalize_fn = finalize_fn
def begin(self):
+ # We only create the init ops, but don't run it. We rely on SessionManager
+ # to run it for us.
self._init_ops = self._initialization_fn()
self._finalize_ops = self._finalize_fn()
- def after_create_session(self, session, coord):
- logging.info('Initialize system')
- session.run(self._init_ops,
- options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
-
def end(self, session):
logging.info('Finalize system.')
session.run(self._finalize_ops)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 9984379e9d..618e70f3a5 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -170,7 +170,8 @@ def _internal_input_layer(features,
trainable=True,
cols_to_vars=None,
scope=None,
- cols_to_output_tensors=None):
+ cols_to_output_tensors=None,
+ from_template=False):
"""See input_layer. `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
@@ -186,10 +187,7 @@ def _internal_input_layer(features,
if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
- # a non-None `scope` can allow for variable reuse, when, e.g., this function
- # is wrapped by a `make_template`.
- with variable_scope.variable_scope(
- scope, default_name='input_layer', values=features.values()):
+ def _get_logits(): # pylint: disable=missing-docstring
builder = _LazyBuilder(features)
output_tensors = []
ordered_columns = []
@@ -217,6 +215,16 @@ def _internal_input_layer(features,
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
+ # If we're constructing from the `make_template`, that by default adds a
+ # variable scope with the name of the layer. In that case, we dont want to
+ # add another `variable_scope` as that would break checkpoints.
+ if from_template:
+ return _get_logits()
+ else:
+ with variable_scope.variable_scope(
+ scope, default_name='input_layer', values=features.values()):
+ return _get_logits()
+
@tf_export('feature_column.input_layer')
def input_layer(features,
@@ -301,17 +309,18 @@ class InputLayer(object):
feature_columns,
weight_collections=None,
trainable=True,
- cols_to_vars=None):
+ cols_to_vars=None,
+ name='feature_column_input_layer',
+ create_scope_now=True):
"""See `input_layer`."""
self._feature_columns = feature_columns
self._weight_collections = weight_collections
self._trainable = trainable
self._cols_to_vars = cols_to_vars
+ self._name = name
self._input_layer_template = template.make_template(
- 'feature_column_input_layer',
- _internal_input_layer,
- create_scope_now_=True)
+ self._name, _internal_input_layer, create_scope_now_=create_scope_now)
self._scope = self._input_layer_template.variable_scope
def __call__(self, features):
@@ -321,7 +330,11 @@ class InputLayer(object):
weight_collections=self._weight_collections,
trainable=self._trainable,
cols_to_vars=None,
- scope=self._scope)
+ from_template=True)
+
+ @property
+ def name(self):
+ return self._name
@property
def non_trainable_variables(self):
@@ -2305,7 +2318,7 @@ class _LazyBuilder(object):
# Input_tensor must have rank 1.
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return sparse_ops.sparse_reshape(
- input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ input_tensor, [array_ops.shape(input_tensor)[0], 1])
else:
return array_ops.expand_dims(input_tensor, -1)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index abb79efa68..1ae510250c 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -169,6 +169,18 @@ class LazyColumnTest(test.TestCase):
TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
builder.get(NotAFeatureColumn())
+ def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
+ # empty 1-D sparse tensor:
+ builder = _LazyBuilder(features={'a': sparse_tensor.SparseTensor(
+ indices=np.reshape(np.array([], dtype=np.int64), (0, 1)),
+ dense_shape=[0],
+ values=np.array([]))})
+ with self.cached_session():
+ spv = builder.get('a').eval()
+ self.assertAllEqual(np.array([0, 1], dtype=np.int64), spv.dense_shape)
+ self.assertAllEqual(
+ np.reshape(np.array([], dtype=np.int64), (0, 2)), spv.indices)
+
class NumericColumnTest(test.TestCase):
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 57f7af7635..b79373c475 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -136,14 +136,11 @@ import six
from tensorflow.python.eager import context
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import Layer
-from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -153,7 +150,6 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
@@ -245,28 +241,19 @@ class StateManager(object):
raise NotImplementedError('StateManager.get_resource')
-class _InputLayerStateManager(StateManager):
- """Manages the state of InputLayer."""
+class _StateManagerImpl(StateManager):
+ """Manages the state of FeatureLayer and LinearModel."""
- def __init__(self, layer, feature_columns, trainable):
- """Creates an _InputLayerStateManager object.
+ def __init__(self, layer, trainable):
+ """Creates an _StateManagerImpl object.
Args:
layer: The input layer this state manager is associated with.
- feature_columns: List of feature columns for the input layer
trainable: Whether by default, variables created are trainable or not.
"""
self._trainable = trainable
self._layer = layer
- self._cols_to_vars_map = {}
- self._cols_to_names_map = {}
- for column in sorted(feature_columns, key=lambda x: x.name):
- self._cols_to_vars_map[column] = {}
- base_name = column.name
- if isinstance(column, SharedEmbeddingColumn):
- base_name = column.shared_collection_name
- with variable_scope.variable_scope(base_name) as vs:
- self._cols_to_names_map[column] = _strip_leading_slashes(vs.name)
+ self._cols_to_vars_map = collections.defaultdict(lambda: {})
def create_variable(self,
feature_column,
@@ -277,19 +264,20 @@ class _InputLayerStateManager(StateManager):
initializer=None):
if name in self._cols_to_vars_map[feature_column]:
raise ValueError('Variable already exists.')
- with variable_scope.variable_scope(self._cols_to_names_map[feature_column]):
- var = self._layer.add_variable(
- name=name,
- shape=shape,
- dtype=dtype,
- initializer=initializer,
- trainable=self._trainable and trainable,
- # TODO(rohanj): Get rid of this hack once we have a mechanism for
- # specifying a default partitioner for an entire layer. In that case,
- # the default getter for Layers should work.
- getter=variable_scope.get_variable)
- self._cols_to_vars_map[feature_column][name] = var
- return var
+
+ var = self._layer.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ use_resource=True,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
def get_variable(self, feature_column, name):
if name in self._cols_to_vars_map[feature_column]:
@@ -313,12 +301,15 @@ class FeatureLayer(Layer):
keywords_embedded = embedding_column(
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
columns = [price, keywords_embedded, ...]
- features = tf.parse_example(..., features=make_parse_example_spec(columns))
feature_layer = FeatureLayer(columns)
+
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
dense_tensor = feature_layer(features)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
- prediction = tf.layers.dense(dense_tensor, 1)."""
+ prediction = tf.layers.dense(dense_tensor, 1).
+ ```
+ """
def __init__(self,
feature_columns,
@@ -375,8 +366,7 @@ class FeatureLayer(Layer):
super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs)
self._feature_columns = _normalize_feature_columns(feature_columns)
- self._state_manager = _InputLayerStateManager(self, self._feature_columns,
- self.trainable)
+ self._state_manager = _StateManagerImpl(self, self.trainable)
self._shared_state_manager = shared_state_manager
for column in sorted(self._feature_columns, key=lambda x: x.name):
if not isinstance(column, DenseColumn):
@@ -394,8 +384,9 @@ class FeatureLayer(Layer):
if isinstance(column, SharedEmbeddingColumn):
column.create_state(self._shared_state_manager)
else:
- with variable_scope.variable_scope(None, default_name=self.name):
- column.create_state(self._state_manager)
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
+ column.create_state(self._state_manager)
super(FeatureLayer, self).build(None)
def call(self, features, cols_to_output_tensors=None):
@@ -424,19 +415,20 @@ class FeatureLayer(Layer):
output_tensors = []
ordered_columns = []
for column in sorted(self._feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
- if isinstance(column, SharedEmbeddingColumn):
- tensor = column.get_dense_tensor(transformation_cache,
- self._shared_state_manager)
- else:
- tensor = column.get_dense_tensor(transformation_cache,
- self._state_manager)
- num_elements = column.variable_shape.num_elements()
- batch_size = array_ops.shape(tensor)[0]
- tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- output_tensors.append(tensor)
- if cols_to_output_tensors is not None:
- cols_to_output_tensors[column] = tensor
+ with ops.name_scope(column.name):
+ ordered_columns.append(column)
+ if isinstance(column, SharedEmbeddingColumn):
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._shared_state_manager)
+ else:
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ output_tensors.append(tensor)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
@@ -448,20 +440,18 @@ class FeatureLayer(Layer):
return (input_shape[0], total_elements)
-def linear_model(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- """Returns a linear prediction `Tensor` based on given `feature_columns`.
+def _strip_leading_slashes(name):
+ return name.rsplit('/', 1)[-1]
+
+
+class LinearModel(Layer):
+ """Produces a linear prediction `Tensor` based on given `feature_columns`.
- This function generates a weighted sum based on output dimension `units`.
+ This layer generates a weighted sum based on output dimension `units`.
Weighted sum refers to logits in classification problems. It refers to the
prediction itself for linear regression problems.
- Note on supported columns: `linear_model` treats categorical columns as
+ Note on supported columns: `LinearModel` treats categorical columns as
`indicator_column`s. To be specific, assume the input as `SparseTensor` looks
like:
@@ -486,308 +476,195 @@ def linear_model(features,
keywords = categorical_column_with_hash_bucket("keywords", 10K)
keywords_price = crossed_column('keywords', price_buckets, ...)
columns = [price_buckets, keywords, keywords_price ...]
+ linear_model = LinearModel(columns)
+
features = tf.parse_example(..., features=make_parse_example_spec(columns))
- prediction = linear_model(features, columns)
+ prediction = linear_model(features)
```
-
- Args:
- features: A mapping from key to tensors. `_FeatureColumn`s look up via these
- keys. For example `numeric_column('price')` will look at 'price' key in
- this dict. Values are `Tensor` or `SparseTensor` depending on
- corresponding `_FeatureColumn`.
- feature_columns: An iterable containing the FeatureColumns to use as inputs
- to your model. All items should be instances of classes derived from
- `_FeatureColumn`s.
- units: An integer, dimensionality of the output space. Default value is 1.
- sparse_combiner: A string specifying how to reduce if a categorical column
- is multivalent. Except `numeric_column`, almost all columns passed to
- `linear_model` are considered as categorical columns. It combines each
- categorical column independently. Currently "mean", "sqrtn" and "sum" are
- supported, with "sum" the default for linear model. "sqrtn" often achieves
- good accuracy, in particular with bag-of-words columns.
- * "sum": do not normalize features in the column
- * "mean": do l1 normalization on features in the column
- * "sqrtn": do l2 normalization on features in the column
- For example, for two features represented as the categorical columns:
-
- ```python
- # Feature 1
-
- shape = [2, 2]
- {
- [0, 0]: "a"
- [0, 1]: "b"
- [1, 0]: "c"
- }
-
- # Feature 2
-
- shape = [2, 3]
- {
- [0, 0]: "d"
- [1, 0]: "e"
- [1, 1]: "f"
- [1, 2]: "g"
- }
- ```
- with `sparse_combiner` as "mean", the linear model outputs conceptly are:
- ```
- y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
- y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
- ```
- where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
- assigned to the presence of `x` in the input features.
- weight_collections: A list of collection names to which the Variable will be
- added. Note that, variables will also be added to collections
- `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
- trainable: If `True` also add the variable to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- cols_to_vars: If not `None`, must be a dictionary that will be filled with a
- mapping from `_FeatureColumn` to associated list of `Variable`s. For
- example, after the call, we might have cols_to_vars = {
- _NumericColumn(
- key='numeric_feature1', shape=(1,):
- [<tf.Variable 'linear_model/price2/weights:0' shape=(1, 1)>],
- 'bias': [<tf.Variable 'linear_model/bias_weights:0' shape=(1,)>],
- _NumericColumn(
- key='numeric_feature2', shape=(2,)):
- [<tf.Variable 'linear_model/price1/weights:0' shape=(2, 1)>]}
- If a column creates no variables, its value will be an empty list. Note
- that cols_to_vars will also contain a string key 'bias' that maps to a
- list of Variables.
-
- Returns:
- A `Tensor` which represents predictions/logits of a linear model. Its shape
- is (batch_size, units) and its dtype is `float32`.
-
- Raises:
- ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
- nor `_CategoricalColumn`.
- """
- with variable_scope.variable_scope(None, 'linear_model') as vs:
- model_name = _strip_leading_slashes(vs.name)
- linear_model_layer = _LinearModel(
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
- name=model_name)
- retval = linear_model_layer(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(linear_model_layer.cols_to_vars())
- return retval
-
-
-def _add_to_collections(var, weight_collections):
- """Adds a var to the list of weight_collections provided.
-
- Handles the case for partitioned and non-partitioned variables.
-
- Args:
- var: A variable or Partitioned Variable.
- weight_collections: List of collections to add variable to.
- """
- for weight_collection in weight_collections:
- # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
- if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
- continue
- # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
- # so that we don't have to do this check.
- if isinstance(var, variables.PartitionedVariable):
- for constituent_var in list(var):
- ops.add_to_collection(weight_collection, constituent_var)
- else:
- ops.add_to_collection(weight_collection, var)
-
-
-class _FCLinearWrapper(base.Layer):
- """Wraps a _FeatureColumn in a layer for use in a linear model.
-
- See `linear_model` above.
"""
def __init__(self,
- feature_column,
+ feature_columns,
units=1,
sparse_combiner='sum',
- weight_collections=None,
trainable=True,
name=None,
+ shared_state_manager=None,
**kwargs):
- super(_FCLinearWrapper, self).__init__(
- trainable=trainable, name=name, **kwargs)
- self._feature_column = feature_column
- self._units = units
- self._sparse_combiner = sparse_combiner
- self._weight_collections = weight_collections
+ """Constructs a LinearModel.
- def build(self, _):
- if isinstance(self._feature_column, fc_old._CategoricalColumn): # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- else:
- num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=[num_elements, self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(weight, self._weight_collections)
- self._weight_var = weight
- self.built = True
-
- def call(self, builder):
- weighted_sum = fc_old._create_weighted_sum( # pylint: disable=protected-access
- column=self._feature_column,
- builder=builder,
- units=self._units,
- sparse_combiner=self._sparse_combiner,
- weight_collections=self._weight_collections,
- trainable=self.trainable,
- weight_var=self._weight_var)
- return weighted_sum
+ Args:
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `_FeatureColumn`s.
+ units: An integer, dimensionality of the output space. Default value is 1.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum"
+ are supported, with "sum" the default for linear model. "sqrtn" often
+ achieves good accuracy, in particular with bag-of-words columns.
+ * "sum": do not normalize features in the column
+ * "mean": do l1 normalization on features in the column
+ * "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name to give to the Linear Model. All variables and ops created will
+ be scoped by this name.
+ shared_state_manager: SharedEmbeddingStateManager that manages the state
+ of SharedEmbeddingColumns. For more info, look at `FeatureLayer`.
+ **kwargs: Keyword arguments to construct a layer.
+ Raises:
+ ValueError: if an item in `feature_columns` is neither a `DenseColumn`
+ nor `CategoricalColumn`.
+ """
+ super(LinearModel, self).__init__(name=name, trainable=trainable, **kwargs)
-class _BiasLayer(base.Layer):
- """A layer for the bias term.
- """
+ self._feature_columns = _normalize_feature_columns(feature_columns)
+ self._feature_columns = sorted(self._feature_columns, key=lambda x: x.name)
+ for column in self._feature_columns:
+ if not isinstance(column, (DenseColumn, CategoricalColumn)):
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ 'DenseColumn or CategoricalColumn. Given: {}'.format(column))
- def __init__(self,
- units=1,
- trainable=True,
- weight_collections=None,
- name=None,
- **kwargs):
- super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
self._units = units
- self._weight_collections = weight_collections
-
- def build(self, _):
- self._bias_variable = self.add_variable(
- 'bias_weights',
- shape=[self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(self._bias_variable, self._weight_collections)
- self.built = True
-
- def call(self, _):
- return self._bias_variable
+ self._sparse_combiner = sparse_combiner
+ self._state_manager = _StateManagerImpl(self, self.trainable)
+ self._shared_state_manager = shared_state_manager
+ self._bias_variable = None
-def _get_expanded_variable_list(var_list):
- returned_list = []
- for variable in var_list:
- if (isinstance(variable, variables.Variable) or
- resource_variable_ops.is_resource_variable(variable)):
- returned_list.append(variable) # Single variable case.
- else: # Must be a PartitionedVariable, so convert into a list.
- returned_list.extend(list(variable))
- return returned_list
+ def build(self, _):
+ # Create state for shared embedding columns.
+ for column in self._feature_columns:
+ if isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._shared_state_manager)
+ # We need variable scopes for now because we want the variable partitioning
+ # information to percolate down. We also use _pure_variable_scope's here
+ # since we want to open up a name_scope in the `call` method while creating
+ # the ops.
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ for column in self._feature_columns:
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
+ # Create the state for each feature column
+ if not isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._state_manager)
+
+ # Create a weight variable for each column.
+ if isinstance(column, CategoricalColumn):
+ first_dim = column.num_buckets
+ else:
+ first_dim = column.variable_shape.num_elements()
+ self._state_manager.create_variable(
+ column,
+ name='weights',
+ dtype=dtypes.float32,
+ shape=(first_dim, self._units),
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+
+ # Create a bias variable.
+ self._bias_variable = self.add_variable(
+ name='bias_weights',
+ dtype=dtypes.float32,
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable,
+ use_resource=True,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
-def _strip_leading_slashes(name):
- return name.rsplit('/', 1)[-1]
+ super(LinearModel, self).build(None)
+ def call(self, features):
+ """Returns a `Tensor` the represents the predictions of a linear model.
-class _LinearModel(training.Model):
- """Creates a linear model using feature columns.
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values are `Tensor` or `SparseTensor` depending on
+ corresponding `_FeatureColumn`.
- See `linear_model` for details.
- """
+ Returns:
+ A `Tensor` which represents predictions/logits of a linear model. Its
+ shape is (batch_size, units) and its dtype is `float32`.
- def __init__(self,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- name=None,
- **kwargs):
- super(_LinearModel, self).__init__(name=name, **kwargs)
- self._feature_columns = fc_old._normalize_feature_columns( # pylint: disable=protected-access
- feature_columns)
- self._weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
-
- column_layers = {}
- for column in sorted(self._feature_columns, key=lambda x: x.name):
- with variable_scope.variable_scope(
- None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
- # Having the fully expressed variable scope name ends up doubly
- # expressing the outer scope (scope with which this method was called)
- # in the name of the variable that would get created.
- column_name = _strip_leading_slashes(vs.name)
- column_layer = _FCLinearWrapper(column, units, sparse_combiner,
- self._weight_collections, trainable,
- column_name, **kwargs)
- column_layers[column_name] = column_layer
- self._column_layers = self._add_layers(column_layers)
- self._bias_layer = _BiasLayer(
- units=units,
- trainable=trainable,
- weight_collections=self._weight_collections,
- name='bias_layer',
- **kwargs)
- self._cols_to_vars = {}
-
- def cols_to_vars(self):
- """Returns a dict mapping _FeatureColumns to variables.
-
- See `linear_model` for more information.
- This is not populated till `call` is called i.e. layer is built.
+ Raises:
+ ValueError: If features are not a dictionary.
"""
- return self._cols_to_vars
-
- def call(self, features):
- with variable_scope.variable_scope(self.name):
- for column in self._feature_columns:
- if not isinstance(
- column,
- (
- fc_old._DenseColumn, # pylint: disable=protected-access
- fc_old._CategoricalColumn)): # pylint: disable=protected-access
- raise ValueError(
- 'Items of feature_columns must be either a '
- '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
+ with ops.name_scope(self.name):
+ transformation_cache = FeatureTransformationCache(features)
weighted_sums = []
- ordered_columns = []
- builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
- for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
- column = layer._feature_column # pylint: disable=protected-access
- ordered_columns.append(column)
- weighted_sum = layer(builder)
- weighted_sums.append(weighted_sum)
- self._cols_to_vars[column] = ops.get_collection(
- ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name)
-
- _verify_static_batch_size_equality(weighted_sums, ordered_columns)
+ for column in self._feature_columns:
+ with ops.name_scope(column.name):
+ # All the weights used in the linear model are owned by the state
+ # manager associated with this Linear Model.
+ weight_var = self._state_manager.get_variable(column, 'weights')
+
+ # The embedding weights for the SharedEmbeddingColumn are owned by
+ # the shared_state_manager and so we need to pass that in while
+ # creating the weighted sum. For all other columns, the state is owned
+ # by the Linear Model's state manager.
+ if isinstance(column, SharedEmbeddingColumn):
+ state_manager = self._shared_state_manager
+ else:
+ state_manager = self._state_manager
+ weighted_sum = _create_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ sparse_combiner=self._sparse_combiner,
+ weight_var=weight_var)
+ weighted_sums.append(weighted_sum)
+
+ _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
predictions_no_bias = math_ops.add_n(
weighted_sums, name='weighted_sum_no_bias')
predictions = nn_ops.bias_add(
- predictions_no_bias,
- self._bias_layer( # pylint: disable=not-callable
- builder,
- scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
- name='weighted_sum')
- bias = self._bias_layer.variables[0]
- self._cols_to_vars['bias'] = _get_expanded_variable_list([bias])
- return predictions
-
- def _add_layers(self, layers):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of layers.Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for name, layer in layers.items():
- setattr(self, 'layer-%s' % name, layer)
- return layers
+ predictions_no_bias, self._bias_variable, name='weighted_sum')
+ return predictions
+
+ @property
+ def bias_variable(self):
+ return self._bias_variable
def _transform_features(features, feature_columns, state_manager):
@@ -2045,58 +1922,40 @@ class DenseColumn(FeatureColumn):
pass
-def _create_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def is_feature_column_v2(feature_columns):
+ """Returns True if all feature columns are V2."""
+ for feature_column in feature_columns:
+ if not isinstance(feature_column, FeatureColumn):
+ return False
+ return True
+
+
+def _create_weighted_sum(column, transformation_cache, state_manager,
+ sparse_combiner, weight_var):
"""Creates a weighted sum for a dense/categorical column for linear_model."""
if isinstance(column, CategoricalColumn):
return _create_categorical_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
else:
return _create_dense_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
-def _create_dense_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_dense_column_weighted_sum(column, transformation_cache,
+ state_manager, weight_var):
"""Create a weighted sum of a dense column for linear_model."""
tensor = column.get_dense_tensor(transformation_cache, state_manager)
num_elements = column.variable_shape.num_elements()
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=[num_elements, units],
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
- return math_ops.matmul(tensor, weight, name='weighted_sum')
+ return math_ops.matmul(tensor, weight_var, name='weighted_sum')
class CategoricalColumn(FeatureColumn):
@@ -2137,14 +1996,8 @@ class CategoricalColumn(FeatureColumn):
pass
-def _create_categorical_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_categorical_column_weighted_sum(
+ column, transformation_cache, state_manager, sparse_combiner, weight_var):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""Create a weighted sum of a categorical column for linear_model.
@@ -2183,17 +2036,8 @@ def _create_categorical_column_weighted_sum(column,
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=(column.num_buckets, units),
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
return _safe_embedding_lookup_sparse(
- weight,
+ weight_var,
id_tensor,
sparse_weights=weight_tensor,
combiner=sparse_combiner,
@@ -2333,7 +2177,7 @@ class FeatureTransformationCache(object):
# Input_tensor must have rank 1.
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return sparse_ops.sparse_reshape(
- input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ input_tensor, [array_ops.shape(input_tensor)[0], 1])
else:
return array_ops.expand_dims(input_tensor, -1)
@@ -2769,6 +2613,7 @@ class SharedEmbeddingStateManager(Layer):
dtype=dtype,
trainable=self.trainable and trainable,
initializer=initializer,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -2782,6 +2627,12 @@ class SharedEmbeddingStateManager(Layer):
return self._var_dict[name]
+def maybe_create_shared_state_manager(feature_columns):
+ if is_feature_column_v2(feature_columns):
+ return SharedEmbeddingStateManager()
+ return None
+
+
class SharedEmbeddingColumn(
DenseColumn, SequenceDenseColumn,
collections.namedtuple(
@@ -2822,6 +2673,10 @@ class SharedEmbeddingColumn(
def create_state(self, state_manager):
"""Creates the shared embedding lookup variable."""
+ if not isinstance(state_manager, SharedEmbeddingStateManager):
+ raise ValueError('Expected state_manager to be of type '
+ 'SharedEmbeddingStateManager. Obtained type: {}'.format(
+ type(state_manager)))
embedding_shape = (self.categorical_column.num_buckets, self.dimension)
state_manager.create_variable(
name=self.shared_collection_name,
@@ -3433,11 +3288,10 @@ def _safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- if not isinstance(embedding_weights[0],
- resource_variable_ops.ResourceVariable):
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ # TODO(rohanj): Look into removing this convert_to_tensor call.
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 58168e0f9e..d3787146ed 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -31,9 +31,7 @@ from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
-from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
@@ -48,7 +46,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@@ -177,6 +174,22 @@ class LazyColumnTest(test.TestCase):
TypeError, '"key" must be either a "str" or "FeatureColumn".'):
transformation_cache.get(NotAFeatureColumn(), None)
+ def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
+ # empty 1-D sparse tensor:
+ transformation_cache = FeatureTransformationCache(
+ features={
+ 'a':
+ sparse_tensor.SparseTensor(
+ indices=np.reshape(np.array([], dtype=np.int64), (0, 1)),
+ dense_shape=[0],
+ values=np.array([]))
+ })
+ with self.cached_session():
+ spv = transformation_cache.get('a', None).eval()
+ self.assertAllEqual(np.array([0, 1], dtype=np.int64), spv.dense_shape)
+ self.assertAllEqual(
+ np.reshape(np.array([], dtype=np.int64), (0, 2)), spv.indices)
+
class NumericColumnTest(test.TestCase):
@@ -344,26 +357,12 @@ class NumericColumnTest(test.TestCase):
self.assertEqual(a.default_value, ((3., 2.),))
def test_linear_model(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[10.], [50.]], predictions.eval())
-
- def test_keras_linear_model(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.]], price_var.eval())
@@ -548,13 +547,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_one_input_value(self):
"""Tests linear_model() for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight variable per bucket, all initialized to zero.
@@ -573,13 +572,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_two_input_values(self):
"""Tests linear_model() for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight per bucket per input column, all initialized to zero.
@@ -600,62 +599,6 @@ class BucketizedColumnTest(test.TestCase):
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
- def test_keras_linear_model_one_input_value(self):
- """Tests _LinearModel for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight variable per bucket, all initialized to zero.
- self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 1st bucket, whose weight is 20.
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 4th bucket, whose weight is 50.
- self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
-
- def test_keras_linear_model_two_input_values(self):
- """Tests _LinearModel for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight per bucket per input column, all initialized to zero.
- self.assertAllClose(
- [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
- [60.], [70.], [80.], [90.], [100.]]))
- # 1st example:
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 6th bucket, whose weight is 70.
- # 2nd example:
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 9th bucket, whose weight is 100.
- self.assertAllClose([[80.], [140.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[81.], [141.]], predictions.eval())
-
class HashedCategoricalColumnTest(test.TestCase):
@@ -836,39 +779,18 @@ class HashedCategoricalColumnTest(test.TestCase):
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 3: wire_var[3] = 4
- # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
- self.assertAllClose(((4.,), (6.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -1087,93 +1009,12 @@ class CrossedColumnTest(test.TestCase):
Uses data from test_get_sparse_tesnsors_simple.
"""
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'a': constant_op.constant(((-1., .5), (.5, 1.))),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
- with _initialized_session() as sess:
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(
- ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
- # Expected ids after cross = (1, 0, 1, 3, 4, 2)
- self.assertAllClose(((3.,), (14.,)), predictions.eval())
- sess.run(bias.assign((.1,)))
- self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
-
- def test_linear_model_with_weights(self):
-
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
- """Produces sparse IDs and sparse weights."""
-
- @property
- def name(self):
- return 'test_column'
-
- @property
- def _parse_example_spec(self):
- return {
- self.name: parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
- dtypes.float32),
- }
-
- @property
- def _num_buckets(self):
- return 5
-
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
-
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
- """Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
- id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
-
- t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError,
- 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- fc.linear_model({
- t.name: sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[0, 1, 2],
- dense_shape=(2, 2)),
- '{}_weights'.format(t.name): sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[1., 10., 2.],
- dense_shape=(2, 2)),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
-
- def test_keras_linear_model(self):
- """Tests _LinearModel.
-
- Uses data from test_get_sparse_tesnsors_simple.
- """
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ predictions = model({
'a':
constant_op.constant(((-1., .5), (.5, 1.))),
'c':
@@ -1181,13 +1022,12 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
+ })
+ crossed_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
- crossed_var.eval())
+ self.assertAllClose(
+ ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
self.assertAllClose(((0.,), (0.,)), predictions.eval())
sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
# Expected ids after cross = (1, 0, 1, 3, 4, 2)
@@ -1195,9 +1035,9 @@ class CrossedColumnTest(test.TestCase):
sess.run(bias.assign((.1,)))
self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
- def test_keras_linear_model_with_weights(self):
+ def test_linear_model_with_weights(self):
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ class _TestColumnWithWeights(fc.CategoricalColumn):
"""Produces sparse IDs and sparse weights."""
@property
@@ -1205,38 +1045,36 @@ class CrossedColumnTest(test.TestCase):
return 'test_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {
- self.name:
- parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name):
- parsing_ops.VarLenFeature(dtypes.float32),
- }
+ self.name: parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
+ dtypes.float32),
+ }
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 5
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
+ def transform_feature(self, transformation_cache, state_manager):
+ return (transformation_cache.get(self.name, state_manager),
+ transformation_cache.get('{}_weights'.format(self.name),
+ state_manager))
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
"""Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
+ ids_and_weights = transformation_cache.get(self, state_manager)
+ return fc.CategoricalColumn.IdWeightPair(
id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ model({
t.name:
sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (1, 1)),
@@ -1252,37 +1090,7 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
-
-
-def get_linear_model_bias(name='linear_model'):
- with variable_scope.variable_scope(name, reuse=True):
- return variable_scope.get_variable('bias_weights')
-
-
-def get_linear_model_column_var(column, name='linear_model'):
- return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
- name + '/' + column.name)[0]
-
-
-def get_keras_linear_model_predictions(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- keras_linear_model = _LinearModel(
- feature_columns,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- name='linear_model')
- retval = keras_linear_model(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(keras_linear_model.cols_to_vars())
- return retval
+ })
class LinearModelTest(test.TestCase):
@@ -1290,56 +1098,50 @@ class LinearModelTest(test.TestCase):
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- fc.linear_model(features={}, feature_columns=[])
+ fc.LinearModel(feature_columns=[])
def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
+ with self.assertRaisesRegexp(ValueError, 'must be a FeatureColumn'):
+ fc.LinearModel(feature_columns='NotSupported')
def test_should_be_dense_or_categorical_column(self):
- class NotSupportedColumn(fc_old._FeatureColumn):
+ class NotSupportedColumn(fc.FeatureColumn):
@property
def name(self):
return 'NotSupportedColumn'
- def _transform_feature(self, cache):
+ def transform_feature(self, transformation_cache, state_manager):
pass
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
pass
with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- fc.linear_model(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+ ValueError, 'must be either a DenseColumn or CategoricalColumn'):
+ fc.LinearModel(feature_columns=[NotSupportedColumn()])
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
+ fc.LinearModel(feature_columns={'a': fc.numeric_column('a')})
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
+ fc.LinearModel(
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
def test_dense_bias(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
sess.run(price_var.assign([[10.]]))
@@ -1347,16 +1149,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[15.], [55.]], predictions.eval())
def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
@@ -1365,18 +1167,17 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([wire_cast, price])
+ predictions = model(features)
+ price_var, wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
@@ -1386,38 +1187,36 @@ class LinearModelTest(test.TestCase):
def test_dense_and_sparse_column(self):
"""When the column is both dense and sparse, uses sparse tensors."""
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+ class _DenseAndSparseColumn(fc.DenseColumn, fc.CategoricalColumn):
@property
def name(self):
return 'dense_and_sparse_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
+ def transform_feature(self, transformation_cache, state_manager):
+ return transformation_cache.get(self.name, state_manager)
@property
- def _variable_shape(self):
+ def variable_shape(self):
raise ValueError('Should not use this method.')
- def _get_dense_tensor(self, inputs, weight_collections=None,
- trainable=None):
+ def get_dense_tensor(self, transformation_cache, state_manager):
raise ValueError('Should not use this method.')
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 4
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
sp_tensor = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=[2, 0, 3],
dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+ return fc.CategoricalColumn.IdWeightPair(sp_tensor, None)
dense_and_sparse_column = _DenseAndSparseColumn()
with ops.Graph().as_default():
@@ -1426,10 +1225,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {dense_and_sparse_column.name: sp_tensor}
- predictions = fc.linear_model(features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
+ model = fc.LinearModel([dense_and_sparse_column])
+ predictions = model(features)
+ dense_and_sparse_column_var, bias = model.variables
with _initialized_session() as sess:
sess.run(dense_and_sparse_column_var.assign(
[[10.], [100.], [1000.], [10000.]]))
@@ -1437,12 +1235,12 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((1, 3)), price_var.eval())
@@ -1452,16 +1250,16 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], units=3)
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
@@ -1474,18 +1272,19 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price])
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose([[0.], [0.]], price_var.eval())
sess.run(price_var.assign([[10.], [100.]]))
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = array_ops.sparse_placeholder(dtypes.string)
wire_value = sparse_tensor.SparseTensorValue(
@@ -1493,8 +1292,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
dense_shape=[2, 2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
self.assertAllClose(
@@ -1506,25 +1306,24 @@ class LinearModelTest(test.TestCase):
predictions.eval(feed_dict={wire_tensor: wire_value}))
def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], sparse_combiner='mean')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [5010.]], predictions.eval())
def test_sparse_combiner_with_negative_weights(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- wire_cast_weights = fc_old.weighted_categorical_column(wire_cast, 'weights')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc.weighted_categorical_column(wire_cast, 'weights')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
@@ -1535,22 +1334,21 @@ class LinearModelTest(test.TestCase):
'wire_cast': wire_tensor,
'weights': constant_op.constant([[1., 1., -1.0]])
}
- predictions = fc.linear_model(
- features, [wire_cast_weights], sparse_combiner='sum')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast_weights], sparse_combiner='sum')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [-9985.]], predictions.eval())
def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((2, 3)), price_var.eval())
@@ -1560,21 +1358,22 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
+ price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price_var.eval())
@@ -1583,17 +1382,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- predictions = fc.linear_model(features, [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
+ price1_var, price2_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price1_var.eval())
@@ -1604,115 +1402,55 @@ class LinearModelTest(test.TestCase):
sess.run(bias.assign([7.]))
self.assertAllClose([[3217.], [4657.]], predictions.eval())
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- fc.linear_model(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ model(features)
+ price_var, bias = model.variables
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertIn(bias, trainable_vars)
self.assertIn(price_var, trainable_vars)
def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast])
+ model = fc.LinearModel([wire_cast])
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ wire_cast_var, bias = model.variables
self.assertIn(bias, trainable_vars)
self.assertIn(wire_cast_var, trainable_vars)
def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], trainable=False)
+ model = fc.LinearModel([price], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast], trainable=False)
+ model = fc.LinearModel([wire_cast], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1720,15 +1458,15 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([price_a, wire_cast, price_b])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1736,17 +1474,45 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([wire_cast, price_b, price_a])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
+ def test_variable_names(self):
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+
+ with ops.Graph().as_default():
+ model = fc.LinearModel(all_cols)
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ model(features)
+ variable_names = [var.name for var in model.variables]
+ self.assertItemsEqual([
+ 'linear_model/dense_feature_bucketized/weights:0',
+ 'linear_model/price1/weights:0',
+ 'linear_model/sparse_feature_embedding/embedding_weights:0',
+ 'linear_model/sparse_feature_embedding/weights:0',
+ 'linear_model/bias_weights:0',
+ ], variable_names)
+
def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1.], [5.], [7.]], # batchsize = 3
@@ -1755,12 +1521,13 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ model(features)
def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
@@ -1770,17 +1537,19 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2, price3])
+ model = fc.LinearModel([price1, price2, price3])
+ model(features)
def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'must have the same size and shape'):
@@ -1788,14 +1557,15 @@ class LinearModelTest(test.TestCase):
predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
sess.run(
predictions,
@@ -1805,14 +1575,14 @@ class LinearModelTest(test.TestCase):
})
def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
input_fn = numpy_io.numpy_input_fn(
@@ -1823,15 +1593,14 @@ class LinearModelTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
# self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1843,14 +1612,14 @@ class LinearModelTest(test.TestCase):
coord.join(threads)
def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
# Provides 1-dim tensor and dense tensor.
@@ -1864,11 +1633,10 @@ class LinearModelTest(test.TestCase):
self.assertEqual(1, features['price'].shape.ndims)
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1877,16 +1645,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
# Provides 1-dim tensor and dense tensor.
@@ -1905,10 +1673,9 @@ class LinearModelTest(test.TestCase):
dense_shape=(2,))
country_data = np.array(['US', 'CA'])
- net = fc.linear_model(features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ model = fc.LinearModel([price_buckets, body_style, country])
+ net = model(features)
+ body_style_var, _, price_buckets_var, bias = model.variables
with _initialized_session() as sess:
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1924,7 +1691,7 @@ class LinearModelTest(test.TestCase):
}))
def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
features = {
'price': constant_op.constant(0),
}
@@ -1932,29 +1699,31 @@ class LinearModelTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ net = model(features)
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
sess.run(net, feed_dict={features['price']: np.array(1)})
def test_multiple_linear_models(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features1 = {'price': [[1.], [5.]]}
features2 = {'price': [[2.], [10.]]}
- predictions1 = fc.linear_model(features1, [price])
- predictions2 = fc.linear_model(features2, [price])
- bias1 = get_linear_model_bias(name='linear_model')
- bias2 = get_linear_model_bias(name='linear_model_1')
- price_var1 = get_linear_model_column_var(price, name='linear_model')
- price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ model1 = fc.LinearModel([price])
+ model2 = fc.LinearModel([price])
+ predictions1 = model1(features1)
+ predictions2 = model2(features2)
+ price_var1, bias1 = model1.variables
+ price_var2, bias2 = model2.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias1.eval())
sess.run(price_var1.assign([[10.]]))
@@ -1966,664 +1735,6 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[25.], [105.]], predictions2.eval())
-class _LinearModelTest(test.TestCase):
-
- def test_raises_if_empty_feature_columns(self):
- with self.assertRaisesRegexp(ValueError,
- 'feature_columns must not be empty'):
- get_keras_linear_model_predictions(features={}, feature_columns=[])
-
- def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns='NotSupported')
-
- def test_should_be_dense_or_categorical_column(self):
-
- class NotSupportedColumn(fc_old._FeatureColumn):
-
- @property
- def name(self):
- return 'NotSupportedColumn'
-
- def _transform_feature(self, cache):
- pass
-
- @property
- def _parse_example_spec(self):
- pass
-
- with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
-
- def test_does_not_support_dict_columns(self):
- with self.assertRaisesRegexp(
- ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
-
- def test_raises_if_duplicate_name(self):
- with self.assertRaisesRegexp(
- ValueError, 'Duplicate feature column name found for columns'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
-
- def test_dense_bias(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- sess.run(price_var.assign([[10.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[15.], [55.]], predictions.eval())
-
- def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features,
- [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[1015.], [10065.]], predictions.eval())
-
- def test_dense_and_sparse_column(self):
- """When the column is both dense and sparse, uses sparse tensors."""
-
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
-
- @property
- def name(self):
- return 'dense_and_sparse_column'
-
- @property
- def _parse_example_spec(self):
- return {self.name: parsing_ops.VarLenFeature(self.dtype)}
-
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
-
- @property
- def _variable_shape(self):
- raise ValueError('Should not use this method.')
-
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None):
- raise ValueError('Should not use this method.')
-
- @property
- def _num_buckets(self):
- return 4
-
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
- sp_tensor = sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 0], [1, 1]],
- values=[2, 0, 3],
- dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
-
- dense_and_sparse_column = _DenseAndSparseColumn()
- with ops.Graph().as_default():
- sp_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {dense_and_sparse_column.name: sp_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
- with _initialized_session() as sess:
- sess.run(
- dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
- [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((1, 3)), price_var.eval())
- sess.run(price_var.assign([[10., 100., 1000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
- predictions.eval())
-
- def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
- sess.run(
- wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
- [1000., 1100.,
- 1200.], [10000., 11000., 12000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
- predictions.eval())
-
- def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([[0.], [0.]], price_var.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = array_ops.sparse_placeholder(dtypes.string)
- wire_value = sparse_tensor.SparseTensorValue(
- values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
- indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
- dense_shape=[2, 2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
- self.assertAllClose(
- np.zeros((2, 1)),
- predictions.eval(feed_dict={wire_tensor: wire_value}))
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- self.assertAllClose(
- [[1010.], [11000.]],
- predictions.eval(feed_dict={wire_tensor: wire_value}))
-
- def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [5010.]], predictions.eval())
-
- def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((2, 3)), price_var.eval())
- sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
- sess.run(bias.assign([2., 3., 4.]))
- self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
- predictions.eval())
-
- def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- with self.assertRaisesRegexp(
- Exception,
- r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- get_keras_linear_model_predictions(features, [price])
-
- def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
- with ops.Graph().as_default():
- features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price1_var.eval())
- self.assertAllClose([[0.]], price2_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price1_var.assign([[10.], [100.]]))
- sess.run(price2_var.assign([[1000.]]))
- sess.run(bias.assign([7.]))
- self.assertAllClose([[3217.], [4657.]], predictions.eval())
-
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(
- features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
- def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertIn(bias, trainable_vars)
- self.assertIn(price_var, trainable_vars)
-
- def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast])
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, trainable_vars)
- self.assertIn(wire_cast_var, trainable_vars)
-
- def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': [[1.], [5.], [7.]], # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2])
-
- def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]], # batchsize = 2
- 'price3': [[3.], [4.], [5.]] # batchsize = 3
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2, price3])
-
- def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- with self.assertRaisesRegexp(errors.OpError,
- 'must have the same size and shape'):
- sess.run(
- predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
-
- def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- sess.run(
- predictions,
- feed_dict={
- features['price1']: [[1.], [5.]],
- features['price2']: [[1.], [5.]],
- })
-
- def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- input_fn = numpy_io.numpy_input_fn(
- x={
- 'price': np.array([-1., 2., 13., 104.]),
- 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
- },
- batch_size=2,
- shuffle=False)
- features = input_fn()
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- # self.assertEqual(1 + 3 + 5, net.shape[1])
- with _initialized_session() as sess:
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
-
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
-
- coord.request_stop()
- coord.join(threads)
-
- def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price':
- constant_op.constant([
- -1.,
- 12.,
- ]),
- 'body-style':
- sparse_tensor.SparseTensor(
- indices=((0,), (1,)),
- values=('sedan', 'hardtop'),
- dense_shape=(2,)),
- }
- self.assertEqual(1, features['price'].shape.ndims)
- self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
-
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
-
- def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
- 'country', vocabulary_list=['US', 'JP', 'CA'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- 'body-style': array_ops.sparse_placeholder(dtypes.string),
- 'country': array_ops.placeholder(dtypes.string),
- }
- self.assertIsNone(features['price'].shape.ndims)
- self.assertIsNone(features['body-style'].get_shape().ndims)
-
- price_data = np.array([-1., 12.])
- body_style_data = sparse_tensor.SparseTensorValue(
- indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
- country_data = np.array(['US', 'CA'])
-
- net = get_keras_linear_model_predictions(
- features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
- with _initialized_session() as sess:
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
- sess.run(
- net,
- feed_dict={
- features['price']: price_data,
- features['body-style']: body_style_data,
- features['country']: country_data
- }))
-
- def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
- features = {
- 'price': constant_op.constant(0),
- }
- self.assertEqual(0, features['price'].shape.ndims)
-
- # Static rank 0 should fail
- with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- get_keras_linear_model_predictions(features, [price])
-
- # Dynamic rank 0 should fail
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- }
- net = get_keras_linear_model_predictions(features, [price])
- self.assertEqual(1, net.shape[1])
- with _initialized_session() as sess:
- with self.assertRaisesOpError('Feature .* cannot have rank 0'):
- sess.run(net, feed_dict={features['price']: np.array(1)})
-
-
class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@@ -3723,47 +2834,22 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
- key='wire',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size,
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
+ wire_column = fc.categorical_column_with_vocabulary_file(
key='wire',
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size,
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4124,45 +3210,21 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'),
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
+ wire_column = fc.categorical_column_with_vocabulary_list(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'),
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4382,39 +3444,18 @@ class IdentityCategoricalColumnTest(test.TestCase):
}))
def test_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
- self.assertEqual(3, column.num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] = 1
- # weight_var[2] + weight_var[1] = 3+2 = 5
- self.assertAllClose(((1.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
self.assertEqual(3, column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -4640,27 +3681,8 @@ class IndicatorColumnTest(test.TestCase):
self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
def test_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
- with ops.Graph().as_default():
- features = {
- 'animal':
- sparse_tensor.SparseTensor(
- indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
- }
-
- predictions = fc.linear_model(features, [animal])
- weight_var = get_linear_model_column_var(animal)
- with _initialized_session():
- # All should be zero-initialized.
- self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
- self.assertAllClose([[0.]], predictions.eval())
- weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
- self.assertAllClose([[2. + 3.]], predictions.eval())
-
- def test_keras_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
with ops.Graph().as_default():
features = {
'animal':
@@ -4668,8 +3690,9 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- predictions = get_keras_linear_model_predictions(features, [animal])
- weight_var = get_linear_model_column_var(animal)
+ model = fc.LinearModel([animal])
+ predictions = model(features)
+ weight_var, _ = model.variables
with _initialized_session():
# All should be zero-initialized.
self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
@@ -5121,17 +4144,16 @@ class EmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
- categorical_column.name: sparse_input
- }, (embedding_column,))
+ model = fc.LinearModel((embedding_column,))
+ predictions = model({categorical_column.name: sparse_input})
expected_var_names = (
'linear_model/bias_weights:0',
'linear_model/aaa_embedding/weights:0',
@@ -5173,82 +4195,6 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 4
- vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(batch_size, 5))
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
- categorical_column,
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column.name: sparse_input
- }, (embedding_column,))
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_embedding/weights:0',
- 'linear_model/aaa_embedding/embedding_weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_embedding/embedding_weights:0']
- linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # example 2, ids [], embedding[2] = [0, 0]
- # example 3, ids [1], embedding[3] = [3, 5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
- self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
-
def test_feature_layer(self):
# Inputs.
vocabulary_size = 3
@@ -5749,27 +4695,31 @@ class SharedEmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
+ model = fc.LinearModel(
+ (embedding_column_a, embedding_column_b),
+ shared_state_manager=fc.SharedEmbeddingStateManager())
+ predictions = model({
categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
+ categorical_column_b.name: input_b
+ })
+
# Linear weights do not follow the column name. But this is a rare use
# case, and fixing it would add too much complexity to the code.
expected_var_names = (
'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ 'linear_model/aaa_shared_embedding/weights:0',
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0',
+ 'linear_model/bbb_shared_embedding/weights:0',
)
self.assertItemsEqual(
expected_var_names,
@@ -5781,102 +4731,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
bias = trainable_vars['linear_model/bias_weights:0']
embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
- linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
- linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights_a.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
- linear_weights_b.assign(((3.,), (5.,))).eval()
- # example 0, ids [0], embedding[0] = [1, 2]
- # example 1, ids [], embedding[1] = 0, 0]
- # sum(embeddings * linear_weights)
- # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
- self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
-
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 2
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
- # Linear weights do not follow the column name. But this is a rare use
- # case, and fixing it would add too much complexity to the code.
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0']
linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ 'linear_model/aaa_shared_embedding/weights:0']
linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ 'linear_model/bbb_shared_embedding/weights:0']
with _initialized_session():
# Predictions with all zero weights.
self.assertAllClose(np.zeros((1,)), bias.eval())
@@ -6275,13 +5134,14 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2)),
weight_tensor.eval())
- def test_keras_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6292,9 +5152,8 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(.5, 1., .1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -6305,15 +5164,16 @@ class WeightedCategoricalColumnTest(test.TestCase):
# = 3*1 + 2*.1 = 3+.2 = 3.2
self.assertAllClose(((.5,), (3.2,)), predictions.eval())
- def test_keras_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
r'Dimensions.*are not compatible'):
- get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6324,122 +5184,23 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (0, 1), (1, 0), (1, 1)),
values=(.5, 11., 1., .1),
dense_shape=(2, 2))
- }, (column,))
+ })
- def test_keras_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
- # Disabling the constant folding optimizer here since it changes the
- # error message differently on CPU and GPU.
- config = config_pb2.ConfigProto()
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- with _initialized_session(config):
- with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
- predictions.eval()
-
- def test_keras_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_dense_values(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,), sparse_combiner='mean')
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2)),
- 'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(.5, 1., .1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError, r'Dimensions.*are not compatible'):
- fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (0, 1), (1, 0), (1, 1)),
- values=(.5, 11., 1., .1),
- dense_shape=(2, 2))
- }, (column,))
-
- def test_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
+ 'values': ((.5,), (1.,))
+ })
# Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
@@ -6450,20 +5211,21 @@ class WeightedCategoricalColumnTest(test.TestCase):
predictions.eval()
def test_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
+ model = fc.LinearModel((column,))
+ predictions = model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index f740e5cfaa..87f567db0e 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -113,7 +113,7 @@ class FunctionTest(test.TestCase):
return a
with ops.Graph().as_default():
- var = variables.Variable([18.0])
+ var = variables.VariableV1([18.0])
call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access
self.assertEqual("MyIdentity", call.op.name)
for cfg in _OptimizerOptions():
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index 2dafb94ba7..563a177dd0 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -104,13 +104,13 @@ class DeviceFunctionsTest(test.TestCase):
def testNestedDeviceFunctions(self):
with ops.Graph().as_default():
- var_0 = variables.Variable(0)
+ var_0 = variables.VariableV1(0)
with ops.device(test_device_func_pin_variable_to_cpu):
- var_1 = variables.Variable(1)
+ var_1 = variables.VariableV1(1)
with ops.device(lambda op: "/device:GPU:0"):
- var_2 = variables.Variable(2)
+ var_2 = variables.VariableV1(2)
with ops.device("/device:GPU:0"): # Implicit merging device function.
- var_3 = variables.Variable(3)
+ var_3 = variables.VariableV1(3)
self.assertDeviceEqual(var_0.device, None)
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index 1d594e4078..cab426844d 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -212,8 +212,8 @@ class SubscribeTest(test_util.TensorFlowTestCase):
def testSubscribeVariable(self):
"""Confirm that variables can be subscribed."""
- v1 = variables.Variable(0.0)
- v2 = variables.Variable(4.0)
+ v1 = variables.VariableV1(0.0)
+ v2 = variables.VariableV1(4.0)
add = math_ops.add(v1, v2)
assign_v1 = v1.assign(3.0)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index cd0b03be43..6673bc5561 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,8 +24,8 @@ from collections import OrderedDict
import contextlib
import gc
import itertools
-import os
import math
+import os
import random
import re
import tempfile
@@ -402,11 +402,14 @@ def with_c_shapes(cls):
return cls
-def enable_cond_v2(fn):
- """Decorator for enabling CondV2 on a test.
+def enable_control_flow_v2(fn):
+ """Decorator for enabling CondV2 and WhileV2 on a test.
- Note this enables using CondV2 after running the test class's setup/teardown
- methods.
+ Note this enables using CondV2 and WhileV2 after running the test class's
+ setup/teardown methods.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
Args:
fn: the function to be wrapped
@@ -416,21 +419,56 @@ def enable_cond_v2(fn):
"""
def wrapper(*args, **kwargs):
- prev_value = control_flow_ops.ENABLE_COND_V2
+ enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
+ enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
control_flow_ops.ENABLE_COND_V2 = True
+ control_flow_ops.ENABLE_WHILE_V2 = True
try:
fn(*args, **kwargs)
finally:
- control_flow_ops.ENABLE_COND_V2 = prev_value
+ control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
+ control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
return wrapper
-def with_cond_v2(cls):
- """Adds methods that call original methods but with CondV2 enabled.
+def with_control_flow_v2(cls):
+ """Adds methods that call original methods with WhileV2 and CondV2 enabled.
- Note this enables CondV2 in new methods after running the test class's
- setup method.
+ Note this enables CondV2 and WhileV2 in new methods after running the test
+ class's setup method.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
+
+ If a test function has _disable_control_flow_v2 attr set to True (using the
+ @disable_control_flow_v2 decorator), the v2 function is not generated for it.
+
+ Example:
+
+ @test_util.with_control_flow_v2
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ @test_util.disable_control_flow_v2("b/xyzabc")
+ def testDisabledForV2(self):
+ ...
+
+ Generated class:
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ def testEnabledForV2WithControlFlowV2(self):
+ // Enable V2 flags.
+ testEnabledForV2(self)
+ // Restore V2 flags.
+
+ def testDisabledForV2(self):
+ ...
Args:
cls: class to decorate
@@ -438,15 +476,33 @@ def with_cond_v2(cls):
Returns:
cls with new test methods added
"""
- if control_flow_ops.ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
return cls
for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith("test"):
- setattr(cls, name + "WithCondV2", enable_cond_v2(value))
+ if (callable(value) and name.startswith("test") and
+ not getattr(value, "_disable_control_flow_v2", False)):
+ setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
return cls
+def disable_control_flow_v2(unused_msg):
+ """Decorator for a function in a with_control_flow_v2 enabled test class.
+
+ Blocks the function from being run with v2 control flow ops.
+
+ Args:
+ unused_msg: Reason for disabling.
+
+ Returns:
+ The wrapped function with _disable_control_flow_v2 attr set to True.
+ """
+ def wrapper(func):
+ func._disable_control_flow_v2 = True
+ return func
+ return wrapper
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py
index c40de9da0a..d3d96c646c 100644
--- a/tensorflow/python/grappler/item_test.py
+++ b/tensorflow/python/grappler/item_test.py
@@ -110,7 +110,7 @@ class ItemTest(test.TestCase):
def testColocationContraints(self):
with ops.Graph().as_default() as g:
c = constant_op.constant([10])
- v = variables.Variable([3], dtype=dtypes.int32)
+ v = variables.VariableV1([3], dtype=dtypes.int32)
i = gen_array_ops.ref_identity(v)
a = state_ops.assign(i, c)
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index b658edff2d..03b42f6453 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -39,8 +39,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
def testNoSwapping(self):
"""Make sure the graph is preserved when there is nothing to swap."""
- a = variables.Variable(10, name='a')
- b = variables.Variable(20, name='b')
+ a = variables.VariableV1(10, name='a')
+ b = variables.VariableV1(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
@@ -60,8 +60,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
def testSimpleSwap(self):
"""Check that the swap annotations are followed."""
- a = variables.Variable(10, name='a')
- b = variables.Variable(20, name='b')
+ a = variables.VariableV1(10, name='a')
+ b = variables.VariableV1(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
@@ -244,7 +244,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
init_op_name=init_op_name,
train_op_name=train_op_name,
loss_op_name=loss_op_name)
- self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
+ self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-2)
def _annotated_graph(self):
graph = ops.Graph()
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 5a9afe7257..eca0f67982 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -57,7 +57,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
def testKeepNodes(self):
g = ops.Graph()
with g.as_default():
- a1 = variables.Variable(
+ a1 = variables.VariableV1(
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index ac011a2940..4a72c4b3f3 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -7,7 +7,6 @@ exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
-load("@pip_deps//:requirements.bzl", "requirement")
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
@@ -63,7 +62,6 @@ py_library(
":backend",
":engine",
":layers",
- requirement("keras_applications"),
"//tensorflow/python/saved_model",
"//tensorflow/python:training",
],
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index a46f9edb1e..584facc859 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -695,10 +695,8 @@ def track_tf_optimizer(tf_optimizer):
if context.executing_eagerly():
return
graph = ops.get_default_graph()
- if graph not in _GRAPH_TF_OPTIMIZERS:
- _GRAPH_TF_OPTIMIZERS[graph] = set()
- _GRAPH_TF_OPTIMIZERS[graph].add(tf_optimizer)
-
+ optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
+ optimizers.add(tf_optimizer)
def track_variable(v):
"""Tracks the given variable for initialization."""
@@ -1513,12 +1511,8 @@ def batch_dot(x, y, axes=None):
out = math_ops.reduce_sum(
math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else:
- if axes is not None:
- adj_x = None if axes[0] == ndim(x) - 1 else True
- adj_y = True if axes[1] == ndim(y) - 1 else None
- else:
- adj_x = None
- adj_y = None
+ adj_x = None if axes[0] == ndim(x) - 1 else True
+ adj_y = True if axes[1] == ndim(y) - 1 else None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if diff:
if x_ndim > y_ndim:
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index b6fae19823..467bc4cdc4 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
+from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
@@ -1222,6 +1223,45 @@ class KerasCallbacksTest(test.TestCase):
callbacks=cbks,
epochs=1)
+ def test_fit_generator_with_callback(self):
+
+ class TestCallback(keras.callbacks.Callback):
+ def set_model(self, model):
+ # Check the model operations for the optimizer operations that
+ # the _make_train_function adds under a named scope for the
+ # optimizer. This ensurs the full model is populated before the
+ # set_model callback is called.
+ optimizer_name_scope = 'training/' + model.optimizer.__class__.__name__
+ graph_def = ops.get_default_graph().as_graph_def()
+ for node in graph_def.node:
+ if node.name.startswith(optimizer_name_scope):
+ return
+ raise RuntimeError('The optimizer operations are not present in the '
+ 'model graph when the Callback.set_model function '
+ 'is called')
+ np.random.seed(1337)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.cached_session():
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=10, input_dim=100)
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=1,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=[TestCallback()],
+ verbose=0)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index cb19a412a2..a75ce30d31 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections as collections_lib
import enum # pylint: disable=g-bad-import-order
+import functools
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import numpy as np
@@ -160,9 +161,13 @@ class Layer(checkpointable.CheckpointableBase):
self._trainable_weights = []
self._non_trainable_weights = []
self._updates = []
- # When executing eagerly, _losses is a list of zero-argument lambdas which
- # return tensors. When using graph execution, _losses is a list of ops.
+ # A list of zero-argument lambdas which return Tensors, used for variable
+ # regularizers.
+ self._callable_losses = []
+ # A list of Tensors containing activity regularizers and losses manually
+ # added through `add_loss`. Empty when executing eagerly.
self._losses = []
+ self._in_call = False # Flag for error checking in add_loss
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
@@ -359,20 +364,20 @@ class Layer(checkpointable.CheckpointableBase):
def losses(self):
"""Losses which are associated with this `Layer`.
- Note that when executing eagerly, getting this property evaluates
- regularizers. When using graph execution, variable regularization ops have
- already been created and are simply returned here.
+ Variable regularization tensors are created when this property is accessed,
+ so it is eager safe: accessing `losses` under a `tf.GradientTape` will
+ propagate gradients back to the corresponding variables.
Returns:
A list of tensors.
"""
- if context.executing_eagerly():
- # _losses may only contain variable regularization losses when executing
- # eagerly, and they have been saved as lambdas to be executed when
- # requested.
- return [regularizer() for regularizer in self._losses]
- else:
- return self._losses
+ collected_losses = []
+ collected_losses.extend(self._losses)
+ for regularizer in self._callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ collected_losses.append(loss_tensor)
+ return collected_losses
@doc_controls.for_subclass_implementers
def add_loss(self, losses, inputs=None):
@@ -393,7 +398,9 @@ class Layer(checkpointable.CheckpointableBase):
from `Layer.call()`).
Arguments:
- losses: Loss tensor, or list/tuple of tensors.
+ losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
+ may also be zero-argument callables which create a loss tensor. Only
+ callable losses are supported when executing eagerly.
inputs: If anything other than None is passed, it signals the losses
are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available.
@@ -403,29 +410,45 @@ class Layer(checkpointable.CheckpointableBase):
(e.g. weight regularization losses).
Raises:
- RuntimeError: If called in Eager mode.
+ RuntimeError: If called in Eager mode with a `Tensor` rather than a
+ callable, or if `inputs` is not None.
"""
- if context.executing_eagerly():
- # TODO(fchollet): it should be possible (and highly desirable) to support
- # `add_loss` in eager mode. This allows great convenience and flexibility
- # in defining custom losses on the fly (e.g. in VAEs).
- # Simply appending the loss value to `self._losses`
- # is the correct behavior.
- # The only caveat is that we need to force the user to only call
- # `add_loss` from inside a model or Layer's `call` method
- # (otherwise the loss computation cannot be backproped through).
- raise RuntimeError('Layer.add_loss not supported in Eager mode.')
-
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly:
+ if inputs is not None:
+ raise RuntimeError(
+ 'Activity regularization (via the "inputs" argument to '
+ 'Layer.add_loss) is not supported when executing eagerly. Consider '
+ 'returning activity regularization losses from a Model\'s call() '
+ 'method.')
+ if getattr(self, '_in_call', False):
+ # TODO(psv): Support activity regularization and a way to reset losses.
+ raise RuntimeError(
+ 'Adding losses inside a Layer\'s call() method is not currently '
+ 'supported when executing eagerly. Please file a feature request '
+ 'if you need this limitation lifted.')
losses = generic_utils.to_list(losses)
- losses = [ops.convert_to_tensor(loss, dtype=backend.floatx())
- if not tensor_util.is_tensor(loss) else loss for loss in losses]
- self._losses += losses
- if inputs is None:
- for loss in losses:
- loss._unconditional_loss = True # pylint: disable=protected-access
- else:
- for loss in losses:
- loss._unconditional_loss = False # pylint: disable=protected-access
+
+ def _tag_unconditional(loss):
+ if callable(loss):
+ loss = loss()
+ if loss is None:
+ return None # Will be filtered out when computing the .losses property
+ if not tensor_util.is_tensor(loss):
+ loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
+ loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
+ return loss
+
+ for loss in losses:
+ if callable(loss):
+ self._callable_losses.append(
+ functools.partial(_tag_unconditional, loss))
+ else:
+ if executing_eagerly:
+ raise RuntimeError(
+ 'Layer.add_loss only supported for zero-argument lambdas when '
+ 'executing eagerly.')
+ self._losses.append(_tag_unconditional(loss))
def get_losses_for(self, inputs):
"""Retrieves losses relevant to a specific set of inputs.
@@ -599,56 +622,20 @@ class Layer(checkpointable.CheckpointableBase):
return variable
def _handle_weight_regularization(self, name, variable, regularizer):
- # `init_graph` should point to the graph in which variable initialization
- # will occur; it should be None if and only if initialization will take
- # place in the eager context.
- init_graph = None
- if not context.executing_eagerly():
- default_graph = ops.get_default_graph()
- if default_graph.building_function:
- with ops.init_scope():
- # Retrieve the variables from the graph into which variables
- # will be lifted; if initialization ops will be lifted into
- # the eager context, then there is nothing to retrieve, since variable
- # collections are not supported when eager execution is enabled.
- if not context.executing_eagerly():
- init_graph = ops.get_default_graph()
- else:
- # Initialization ops will not be lifted out of the default graph.
- init_graph = default_graph
-
- if init_graph is not None: # pylint: disable=protected-access
- # The variable was created and initialized in a graph.
- if regularizer:
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
- elif regularizer: # initialization took place in an eager context
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request'
- 'if this is important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested.
- # This makes losses responsive to variable updates when executing
- # eagerly.
- #
- # TODO(akshayka): Do the same for graphs as well, so that losses
- # collected in a while_loop can be run outside its control flow
- # context and so that losses won't be swallowed up by graph functions
- # (i.e., `.losses()` should always create regularizers).
- self._losses.append(lambda: regularizer(variable))
+ """Create lambdas which compute regularization losses."""
+
+ def _loss_for_variable(v):
+ """Creates a regularization loss `Tensor` for variable `v`."""
+ with ops.colocate_with(v):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ return regularization
+
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ self.add_loss(functools.partial(_loss_for_variable, v))
+ else:
+ self.add_loss(functools.partial(_loss_for_variable, variable))
def _handle_activity_regularization(self, inputs, outputs):
# Apply activity regularization.
@@ -766,7 +753,9 @@ class Layer(checkpointable.CheckpointableBase):
self._assert_input_compatibility(inputs)
if not in_deferred_mode:
+ self._in_call = True
outputs = self.call(inputs, *args, **kwargs)
+ self._in_call = False
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None (layer: ' +
@@ -1972,7 +1961,9 @@ def make_variable(name,
if use_resource is None:
use_resource = True
- v = tf_variables.Variable(
+ # TODO(apassos,rohanj) figure out how to remove collections from here so we
+ # can remove the V1.
+ v = tf_variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index ade8a4b32d..5091cac836 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -647,12 +647,6 @@ class Model(Network):
skip_target_indices=skip_target_indices,
sample_weights=self.sample_weights)
- # If using distribution strategy and stateful_metrics, raise an error
- # since we currently don't support stateful metrics.
- if self._distribution_strategy is not None and self.stateful_metric_names:
- raise NotImplementedError('Stateful metrics are not supported with '
- 'DistributionStrategy.')
-
# Prepare gradient updates and state updates.
self.total_loss = total_loss
@@ -857,7 +851,8 @@ class Model(Network):
# able to clone a Dataset on multiple workers we can remove this lambda.
result = self._distribution_strategy.distribute_dataset(lambda: x)
iterator = result.make_initializable_iterator()
- K.get_session().run(iterator.initializer)
+ with self._distribution_strategy.scope():
+ K.get_session().run(iterator.initializer)
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 8b434ca444..a6470458d2 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -26,6 +26,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
@@ -111,96 +112,99 @@ def fit_loop(
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- # Create a train function that is composed of all the parameters above.
- distributed_train_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_train_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [1]
- else:
- ins = dataset_inputs + dataset_targets
+ # Create a train function that is composed of all the parameters above.
+ distributed_train_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_train_function',
+ **all_session_args)
+
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [1]
+ else:
+ ins = dataset_inputs + dataset_targets
- do_validation = False
- if validation_steps:
- do_validation = True
+ do_validation = False
+ if validation_steps:
+ do_validation = True
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- callbacks = cbks.configure_callbacks(
- callbacks,
- model,
- do_validation=do_validation,
- val_inputs=None,
- val_targets=None,
- epochs=epochs,
- steps_per_epoch=steps_per_epoch,
- verbose=verbose)
- out_labels = model.metrics_names or []
- callbacks.on_train_begin()
-
- assert steps_per_epoch is not None
-
- for epoch in range(initial_epoch, epochs):
- callbacks.on_epoch_begin(epoch)
- epoch_logs = {}
- for step_index in range(steps_per_epoch):
- batch_logs = {'batch': step_index, 'size': 1}
- callbacks.on_batch_begin(step_index, batch_logs)
- try:
- outs = distributed_train_function(ins)
- except errors.OutOfRangeError:
- logging.warning('Your dataset iterator ran out of data; '
- 'interrupting training. Make sure that your dataset '
- 'can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' %
- steps_per_epoch * epochs)
- break
-
- if not isinstance(outs, list):
- outs = [outs]
-
- outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, out_labels, outs)
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- callbacks.on_batch_end(step_index, batch_logs)
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=do_validation,
+ val_inputs=None,
+ val_targets=None,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
+ out_labels = model.metrics_names or []
+ callbacks.on_train_begin()
+
+ assert steps_per_epoch is not None
+
+ for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ outs = distributed_train_function(ins)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ outs = _aggregate_metrics_across_towers(current_strategy.num_towers,
+ out_labels,
+ model.stateful_metric_names,
+ outs)
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_iterator,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
break
- if do_validation:
- val_outs = test_loop(
- model,
- val_iterator,
- steps=validation_steps,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
+ callbacks.on_train_end()
- callbacks.on_epoch_end(epoch, epoch_logs)
- if callbacks.model.stop_training:
- break
- callbacks.on_train_end()
-
- # Copy the weights back from the replicated model to the original model.
- with current_strategy.scope():
+ # Copy the weights back from the replicated model to the original model.
updated_weights = current_strategy.unwrap(
model._grouped_model)[0].get_weights()
model.set_weights(updated_weights)
- return model.history
+ return model.history
def _experimental_fit_loop(
@@ -422,54 +426,65 @@ def test_loop(model, iterator, verbose=0, steps=None):
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- distributed_test_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_test_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [0]
- else:
- ins = dataset_inputs + dataset_targets
+ distributed_test_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_test_function',
+ **all_session_args)
- outs = []
- if verbose == 1:
- progbar = Progbar(target=steps)
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [0]
+ else:
+ ins = dataset_inputs + dataset_targets
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ stateful_metric_indices = [
+ i for i, name in enumerate(model.metrics_names)
+ if str(name) in model.stateful_metric_names
+ ]
+
+ outs = []
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- assert steps is not None
- for step in range(steps):
- batch_outs = distributed_test_function(ins)
- batch_outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, model.metrics_names, batch_outs)
- if isinstance(batch_outs, list):
- if step == 0:
- outs = [0.] * len(batch_outs)
- for i, batch_out in enumerate(batch_outs):
- outs[i] += batch_out
- else:
- if step == 0:
- outs.append(0.)
- outs[0] += batch_outs
- if verbose >= 1:
- progbar.update(step + 1)
- for i in range(len(outs)):
- outs[i] /= steps
+ assert steps is not None
+ for step in range(steps):
+ batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, model.metrics_names,
+ model.stateful_metric_names, batch_outs)
+ if isinstance(batch_outs, list):
+ if step == 0:
+ outs = [0.] * len(batch_outs)
+ for i, batch_out in enumerate(batch_outs):
+ if i in stateful_metric_indices:
+ outs[i] = batch_out
+ else:
+ outs[i] += batch_out
+ else:
+ if step == 0:
+ outs.append(0.)
+ outs[0] += batch_outs
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ if i not in stateful_metric_indices:
+ outs[i] /= steps
- if len(outs) == 1:
- return outs[0]
- return outs
+ if len(outs) == 1:
+ return outs[0]
+ return outs
def _experimental_test_loop(model, iterator, verbose=0, steps=None):
@@ -630,51 +645,50 @@ def predict_loop(model, iterator, verbose=0, steps=None):
dataset_inputs = distributed_training_utils.flatten_perdevice_values(
current_strategy, inputs)
- distributed_predict_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_predict_function',
- **all_session_args)
+ distributed_predict_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_predict_function',
+ **all_session_args)
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + [0]
- else:
- ins = dataset_inputs
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + [0]
+ else:
+ ins = dataset_inputs
- if verbose == 1:
- progbar = Progbar(target=steps)
+ if verbose == 1:
+ progbar = Progbar(target=steps)
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- if steps is not None:
- # Since we do not know how many samples we will see, we cannot pre-allocate
- # the returned Numpy arrays. Instead, we store one array per batch seen
- # and concatenate them upon returning.
- unconcatenated_outs = []
- for step in range(steps):
- batch_outs = distributed_predict_function(ins)
- if not isinstance(batch_outs, list):
- batch_outs = [batch_outs]
- if step == 0:
- for _ in batch_outs:
- unconcatenated_outs.append([])
- # TODO(anjalisridhar): Should combine the outputs from multiple towers
- # correctly here.
- for i, batch_out in enumerate(batch_outs):
- unconcatenated_outs[i].append(batch_out)
- if verbose >= 1:
- progbar.update(step + 1)
- if len(unconcatenated_outs) == 1:
- return np.concatenate(unconcatenated_outs[0], axis=0)
- return [
- np.concatenate(unconcatenated_outs[i], axis=0)
- for i in range(len(unconcatenated_outs))
- ]
+ if steps is not None:
+ # Since we do not know how many samples we will see, we cannot
+ # pre-allocate the returned Numpy arrays. Instead, we store one array per
+ # batch seen and concatenate them upon returning.
+ unconcatenated_outs = []
+ for step in range(steps):
+ batch_outs = distributed_predict_function(ins)
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if step == 0:
+ for _ in batch_outs:
+ unconcatenated_outs.append([])
+ # TODO(anjalisridhar): Should combine the outputs from multiple towers
+ # correctly here.
+ for i, batch_out in enumerate(batch_outs):
+ unconcatenated_outs[i].append(batch_out)
+ if verbose >= 1:
+ progbar.update(step + 1)
+ if len(unconcatenated_outs) == 1:
+ return np.concatenate(unconcatenated_outs[0], axis=0)
+ return [
+ np.concatenate(unconcatenated_outs[i], axis=0)
+ for i in range(len(unconcatenated_outs))
+ ]
def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
@@ -816,10 +830,10 @@ def _clone_and_build_model(model, inputs=None, targets=None):
cloned_model.compile(
optimizer,
model.loss,
- metrics=model.metrics,
+ metrics=metrics_module.clone_metrics(model.metrics),
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics,
+ weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics),
target_tensors=targets)
return cloned_model
@@ -834,8 +848,9 @@ def clone_model_on_towers(
model._make_callback_model()
-def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
- """Aggregate metrics values across all towers.
+def _aggregate_metrics_across_towers(num_devices, out_labels,
+ stateful_metric_names, outs):
+ """Aggregates stateless metrics values across towers.
When using `MirroredStrategy`, the number of towers is equal to the
number of devices over which training is distributed. This may not always be
@@ -844,6 +859,7 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
Args:
num_devices: Number of devices over which the model is being distributed.
out_labels: The list of metric names passed to `compile`.
+ stateful_metric_names: List of stateful metric names on the model.
outs: The output from all the towers.
Returns:
@@ -858,10 +874,16 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
# Each label in `out_labels` corresponds to one set of metrics. The
# number of metric values corresponds to the number of devices. We
# currently take the mean of the values.
- for _ in out_labels[1:]:
- m = np.mean(outs[current_index:current_index + num_devices])
- merged_output.append(m)
- current_index += num_devices
+ for metric_name in out_labels[1:]:
+ if metric_name in stateful_metric_names:
+ # For stateful metrics, we get one aggregated result value.
+ merged_output.append(outs[current_index])
+ current_index += 1
+ else:
+ m = np.mean(outs[current_index:current_index + num_devices])
+ merged_output.append(m)
+ current_index += num_devices
+
return merged_output
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index db7ccb181f..1f5176c4d7 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -192,6 +192,20 @@ class CorrectnessTest(test.TestCase):
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
+ def test_no_loss_in_call(self):
+
+ class HasLoss(keras.layers.Layer):
+
+ def call(self, x):
+ self.add_loss(x)
+ return x
+
+ layer = HasLoss()
+ with self.assertRaises(RuntimeError):
+ layer(1.)
+
+ with ops.Graph().as_default():
+ layer(1.)
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index 413c1f4fba..2e074699da 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
@@ -48,6 +49,10 @@ def fit_generator(model,
epoch = initial_epoch
do_validation = bool(validation_data)
+ if not context.executing_eagerly():
+ model._make_train_function()
+ if do_validation:
+ model._make_test_function()
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
@@ -233,6 +238,9 @@ def evaluate_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.evaluate_generator`."""
+ if not context.executing_eagerly():
+ model._make_test_function()
+
if hasattr(model, 'metrics'):
for m in model.stateful_metric_functions:
m.reset_states()
@@ -342,6 +350,9 @@ def predict_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.predict_generator`."""
+ if not context.executing_eagerly():
+ model._make_test_function()
+
steps_done = 0
wait_time = 0.01
all_outs = []
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 30be4131a4..54ad74c08b 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
@@ -2427,6 +2428,17 @@ class TestTrainingWithMetrics(test.TestCase):
scores = model.train_on_batch(x, y, sample_weight=w)
self.assertArrayNear(scores, [0.2, 0.8, 0.8], 0.1)
+ def test_losses_in_defun(self):
+ with context.eager_mode():
+ layer = keras.layers.Dense(1, kernel_regularizer='l1')
+ layer(array_ops.ones([1, 10]))
+
+ @function.defun
+ def get_losses():
+ return layer.losses
+
+ self.assertAllEqual(self.evaluate(layer.losses),
+ self.evaluate(get_losses()))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 4032202986..efa21955e6 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -671,22 +671,34 @@ class Lambda(Layer):
if mask is not None:
self.supports_masking = True
self.mask = mask
- if output_shape is None:
- self._output_shape = None
- elif isinstance(output_shape, (tuple, list)):
- self._output_shape = tuple(output_shape)
- else:
- if not callable(output_shape):
- raise TypeError('In Lambda, `output_shape` '
- 'must be a list, a tuple, or a function.')
- self._output_shape = output_shape
+ if (output_shape is not None and not isinstance(output_shape,
+ (tuple, list)) and
+ not callable(output_shape)):
+ raise TypeError('In Lambda, `output_shape` '
+ 'must be a list, a tuple, or a function.')
+ # Convert a list representing a single shape into a tuple.
+ if (isinstance(output_shape, list) and isinstance(output_shape[0],
+ (int, type(None)))):
+ output_shape = tuple(output_shape)
+ self._output_shape = output_shape
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self._output_shape is None:
if context.executing_eagerly():
- raise NotImplementedError
- x = K.placeholder(shape=input_shape)
+ # Make use of existing autocomputation for Eager mode but provide
+ # Lambda-specific error message.
+ try:
+ return super(Lambda, self).compute_output_shape(input_shape)
+ except NotImplementedError:
+ raise NotImplementedError('We could not automatically infer '
+ 'the static shape of the Lambda\'s output.'
+ ' Please specify the `output_shape` for'
+ ' this Lambda.')
+ if isinstance(input_shape, list):
+ x = [K.placeholder(shape=shape) for shape in input_shape]
+ else:
+ x = K.placeholder(shape=input_shape)
x = self.call(x)
if isinstance(x, list):
return [tensor_shape.TensorShape(K.int_shape(x_elem)) for x_elem in x]
@@ -697,16 +709,27 @@ class Lambda(Layer):
num_samples = input_shape[0][0]
else:
num_samples = input_shape[0] if input_shape else None
- return tensor_shape.TensorShape((num_samples,) +
- tuple(self._output_shape))
+ # List here represents multiple outputs.
+ if isinstance(self._output_shape, list):
+ return [
+ tensor_shape.TensorShape((num_samples,) + tuple(single_shape))
+ for single_shape in self._output_shape
+ ]
+ return tensor_shape.TensorShape((num_samples,) + self._output_shape)
else:
shape = self._output_shape(input_shape)
if not isinstance(shape, (list, tuple)):
raise ValueError(
'`output_shape` function must return a tuple or a list of tuples.')
+ # List here can represent multiple outputs or single output.
if isinstance(shape, list):
- if isinstance(shape[0], int) or shape[0] is None:
+ # Convert list representing single output into a tuple.
+ if isinstance(shape[0], (int, type(None))):
shape = tuple(shape)
+ else:
+ return [
+ tensor_shape.TensorShape(single_shape) for single_shape in shape
+ ]
return tensor_shape.TensorShape(shape)
def call(self, inputs, mask=None):
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index 1df1d575b1..f0fea1f65c 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -252,6 +252,51 @@ class CoreLayersTest(test.TestCase):
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual('lambda', l.get_config()['output_shape_type'])
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_autocalculate_multiple_inputs(self):
+
+ def lambda_fn(x):
+ return math_ops.matmul(x[0], x[1])
+
+ l = keras.layers.Lambda(lambda_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual((10, 20), output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_list_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=[(10,), (20,)])
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_tuple_with_none(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=(None, 10))
+ output_shape = l.compute_output_shape((5, 10, 20))
+ # Dimension(None) != Dimension(None), so check
+ # str representations for equality.
+ self.assertAllEqual(('5', '?', '10'), tuple([str(s) for s in output_shape]))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_function_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ def output_shape_fn(input_shape):
+ return input_shape
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=output_shape_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
def test_lambda_config_serialization(self):
with self.cached_session():
# test serialization with output_shape and output_shape_type
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index e64241e5cf..f4e8419eb0 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -71,6 +71,22 @@ def check_is_tensor_or_operation(x, name):
name, x))
+def clone_metric(metric):
+ """Returns a clone of the metric if stateful, otherwise returns it as is."""
+ if isinstance(metric, Metric):
+ return metric.__class__.from_config(metric.get_config())
+ return metric
+
+
+def clone_metrics(metrics):
+ """Clones the given metric list/dict."""
+ if metrics is None:
+ return None
+ if isinstance(metrics, dict):
+ return {key: clone_metric(value) for key, value in metrics.items()}
+ return [clone_metric(metric) for metric in metrics]
+
+
def update_state_wrapper(update_state_fn):
"""Decorator to wrap metric `update_state()` with `add_update()`.
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 41c5e3cccf..b04b4df257 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import sequential
@@ -290,7 +291,9 @@ def _in_place_subclassed_model_reset(model):
if isinstance(value, Layer):
attributes_cache[name] = value
assert value in model._layers
- elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
+ elif isinstance(
+ value, (list, tuple)) and name not in ('layers', '_layers',
+ 'stateful_metric_functions'):
# Handle case: list/tuple of layers (also tracked by the Network API).
if value and all(isinstance(val, Layer) for val in value):
raise ValueError('We do not support the use of list-of-layers '
@@ -466,10 +469,10 @@ def clone_and_build_model(
clone.compile(
optimizer,
model.loss,
- metrics=model.metrics,
+ metrics=metrics_module.clone_metrics(model.metrics),
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics,
+ weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics),
target_tensors=target_tensors)
return clone
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 8d7493462e..9664f09fff 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -18,10 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gc
+import weakref
+
import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
@@ -156,6 +160,19 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
+ def test_optimizer_garbage_collection(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
+ keras.backend.track_tf_optimizer(optimizer)
+ optimizer_weak = weakref.ref(optimizer)
+ graph_weak = weakref.ref(graph)
+ del graph, optimizer
+ gc.collect()
+ # Check that the weak references are dead now.
+ self.assertIs(graph_weak(), None)
+ self.assertIs(optimizer_weak(), None)
+
@test_util.run_in_graph_and_eager_modes
def test_tfoptimizer_iterations(self):
with self.cached_session():
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 5183e4d30c..9490746fd9 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1097,6 +1097,18 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "unicode_script_op_test",
+ size = "small",
+ srcs = ["unicode_script_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
cuda_py_test(
name = "topk_op_test",
size = "small",
@@ -1468,7 +1480,7 @@ cuda_py_test(
name = "control_flow_ops_py_test",
# TODO(b/70473603): change this back to "small" once the C API is
# permanently enabled
- size = "medium",
+ size = "large",
srcs = ["control_flow_ops_py_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -1500,6 +1512,7 @@ cuda_py_test(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python:while_v2",
],
)
@@ -2346,7 +2359,7 @@ cuda_py_test(
cuda_py_test(
name = "transpose_op_test",
- size = "large",
+ size = "medium",
srcs = ["transpose_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2354,10 +2367,11 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
],
- shard_count = 2,
+ shard_count = 4,
tags = [
"no_gpu",
"no_oss",
+ "optonly", # times out
],
)
@@ -2476,6 +2490,7 @@ cuda_py_test(
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
],
+ shard_count = 2,
tags = [
"optonly", # flaky timeouts unless optimized
],
@@ -2496,7 +2511,7 @@ cuda_py_test(
cuda_py_test(
name = "conv_ops_test",
- size = "large",
+ size = "medium",
srcs = ["conv_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2515,6 +2530,9 @@ cuda_py_test(
"//tensorflow/python:variables",
],
shard_count = 4,
+ tags = [
+ "optonly", # times out
+ ],
)
cuda_py_test(
@@ -2574,7 +2592,7 @@ cuda_py_test(
cuda_py_test(
name = "fft_ops_test",
- size = "large",
+ size = "medium",
srcs = ["fft_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2584,7 +2602,8 @@ cuda_py_test(
"//tensorflow/python:spectral_ops",
"//tensorflow/python:spectral_ops_test_util",
],
- shard_count = 3,
+ shard_count = 4,
+ tags = ["optonly"],
)
cuda_py_test(
@@ -2649,7 +2668,7 @@ cuda_py_test(
cuda_py_test(
name = "scatter_ops_test",
- size = "large", # NOTE: This is not run by default.
+ size = "medium", # NOTE: This is not run by default.
srcs = ["scatter_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2658,11 +2677,13 @@ cuda_py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",
],
+ shard_count = 2,
+ tags = ["optonly"],
)
cuda_py_test(
name = "slice_op_test",
- size = "large",
+ size = "medium",
srcs = ["slice_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -3245,8 +3266,7 @@ tf_py_test(
tags = ["no_gpu"], # TODO(b/111656070)
)
-# TODO(b/116053459): Replace with cuda_py_test.
-tf_py_test(
+cuda_py_test(
name = "while_v2_test",
size = "medium",
srcs = ["while_v2_test.py"],
@@ -3266,5 +3286,4 @@ tf_py_test(
"//tensorflow/python:while_v2",
],
grpc_enabled = True,
- tags = ["no_gpu"], # TODO(b/116053459)
)
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 2fe85839d0..c5547b19be 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1001,14 +1001,14 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
errors.FailedPreconditionError,
"Attempting to use uninitialized value Variable"):
with self.cached_session() as sess:
- v = variables.Variable([1, 2])
+ v = variables.VariableV1([1, 2])
sess.run(v[:].assign([1, 2]))
def testTypeError(self):
init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
- v = variables.Variable(init_val)
+ v = variables.VariableV1(init_val)
with self.assertRaises(TypeError):
v[:].assign(too_small_val)
with self.assertRaises(TypeError):
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index 467e33ec87..7cdc67f83f 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -445,6 +445,78 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# change= 0.1(1.14+7.0-7.0)
self.assertAllClose([[1], [0.114]], logits_updates)
+ def testCategoricalSplits(self):
+ """Tests the training prediction work for categorical splits."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ categorical_split {
+ feature_id: 1
+ value: 2
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ categorical_split {
+ feature_id: 0
+ value: 13
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ is_finalized: true
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [13, 1, 3]
+ feature_1_values = [2, 2, 1]
+
+ # No previous cached values.
+ cached_tree_ids = [0, 0, 0]
+ cached_node_ids = [0, 0, 0]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ self.assertAllClose([0, 0, 0], new_tree_ids)
+ self.assertAllClose([3, 4, 2], new_node_ids)
+ self.assertAllClose([[5.], [6.], [7.]], logits_updates)
+
def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self):
"""Tests that prediction based on previous node in the tree works."""
with self.cached_session() as session:
@@ -924,6 +996,68 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
logits = session.run(predict_op)
self.assertAllClose(expected_logits, logits)
+ def testCategoricalSplits(self):
+ """Tests the predictions work for categorical splits."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ categorical_split {
+ feature_id: 1
+ value: 2
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ categorical_split {
+ feature_id: 0
+ value: 13
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [13, 1, 3]
+ feature_1_values = [2, 2, 1]
+
+ expected_logits = [[5.], [6.], [7.]]
+
+ # Prediction should work fine.
+ predict_op = boosted_trees_ops.predict(
+ tree_ensemble_handle,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits = session.run(predict_op)
+ self.assertAllClose(expected_logits, logits)
+
class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
"""Tests feature contribs ops for model understanding."""
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index fc4d2a3809..d91a848e01 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -63,6 +62,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops import while_v2 # pylint: disable=unused-import
# pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad
# pylint: enable=unused-import
@@ -125,12 +125,12 @@ def isum(s, maximum_iterations=None):
return r_s
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
v = control_flow_ops._Identity(v)
op = state_ops.assign(v, 9)
@@ -142,7 +142,7 @@ class ControlFlowTest(test.TestCase):
def testRefEnter(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
nine = constant_op.constant(9)
@@ -155,7 +155,7 @@ class ControlFlowTest(test.TestCase):
def testRefSwitch(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
p = constant_op.constant(True)
v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p) # pylint: disable=protected-access
@@ -332,10 +332,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesOpError("has inputs from different frames"):
res.eval(feed_dict={data: 1.0})
+ @test_util.disable_control_flow_v2("b/113294340")
def testCondBool(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296297")
-
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
fn2 = lambda: math_ops.subtract(values, 1)
@@ -366,6 +364,7 @@ class ControlFlowTest(test.TestCase):
"has been marked as not fetchable"):
sess.run(t, feed_dict={x: 3})
+ @test_util.disable_control_flow_v2("Not relevant")
def testFeedable(self):
with self.cached_session() as sess:
c = constant_op.constant(2)
@@ -383,10 +382,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "may not be fed"):
sess.run(r, feed_dict={t: 3})
+ @test_util.disable_control_flow_v2("b/113296180 (IndexedSlices)")
def testCondIndexedSlices(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296180")
-
with self.cached_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
@@ -401,10 +398,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, val)
self.assertAllEqual(0, ind)
+ @test_util.disable_control_flow_v2("b/113296161 (SparseTensors)")
def testCondSparseTensor(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296161 (SparseTensors)")
-
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
@@ -435,10 +430,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
+ @test_util.disable_control_flow_v2("b/113293074")
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113293074")
-
with self.cached_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
@@ -510,10 +503,8 @@ class ControlFlowTest(test.TestCase):
result = r.eval()
self.assertAllEqual(12, result)
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testCond_4(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
@@ -587,10 +578,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
self.assertAllEqual([2.0], r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/79881896")
-
with self.cached_session():
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -629,10 +618,9 @@ class ControlFlowTest(test.TestCase):
merged_op = control_flow_ops.merge([assign_v, orig_v])
self.assertAllEqual([1.0], sess.run(merged_op.output))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondSwitchIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the recv identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
pred = constant_op.constant(True)
@@ -646,10 +634,9 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondRecvIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
with ops.device(test.gpu_device_name()):
@@ -665,10 +652,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
-
graph = ops.Graph()
with graph.as_default():
x = constant_op.constant(10.0, name="x")
@@ -694,10 +679,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
+ @test_util.disable_control_flow_v2(
+ "b/110550782 (gradient w.r.t external variable)")
def testCondGrad_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
-
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
@@ -729,10 +713,8 @@ class ControlFlowTest(test.TestCase):
result = gradients_impl.gradients(z, x)[0]
self.assertEqual(1.0, result.eval())
+ @test_util.disable_control_flow_v2("b/113327884")
def testCondGrad_Gather(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113327884")
-
with self.cached_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -756,6 +738,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(dense_gv, [0.0, 2.0])
# Microbenchmark: 256,000 iterations/s.
+ @test_util.disable_control_flow_v2("b/116630618 (Times out)")
def testWhile_1(self):
with self.cached_session():
n = constant_op.constant(0)
@@ -764,6 +747,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependencies(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -779,6 +763,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(result.eval(), 2)
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependenciesNoInput(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -794,9 +779,10 @@ class ControlFlowTest(test.TestCase):
result.eval()
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefs_1(self):
with self.cached_session() as sess:
- x = variables.Variable(0)._ref() # pylint: disable=protected-access
+ x = variables.VariableV1(0)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 100)
@@ -824,18 +810,22 @@ class ControlFlowTest(test.TestCase):
r = isum(s)
self.assertAllEqual(45, r.eval())
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testWhileWithMaximumIterations(self):
with self.cached_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithMaximumIterationsAndSingleArgument(self):
with self.cached_session():
r = control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nested), b/115920078 (gradients)")
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -861,6 +851,7 @@ class ControlFlowTest(test.TestCase):
# Should execute without issue.
self.assertEqual(3, self.evaluate(loop_execute))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while_loop)")
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -904,10 +895,8 @@ class ControlFlowTest(test.TestCase):
r"context '.*' \(currently defined in '.*'\)"):
_ = gradients_impl.gradients(loop_with_maxiter, v)
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
v = constant_op.constant(1.0)
def create_while_loop():
@@ -939,6 +928,8 @@ class ControlFlowTest(test.TestCase):
r"while loop context '' \(currently defined in 'cond/.+'\)"):
_ = gradients_impl.gradients(loop, v)
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nesting), b/115776323 (max_iters)")
def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
v = constant_op.constant(1.0)
@@ -1048,6 +1039,7 @@ class ControlFlowTest(test.TestCase):
result = r[3].eval()
self.assertAllEqual(42, result)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhile_5(self):
with self.cached_session():
@@ -1072,6 +1064,7 @@ class ControlFlowTest(test.TestCase):
result = r[2].eval()
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+ @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
def testBufferForwarding(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1122,6 +1115,7 @@ class ControlFlowTest(test.TestCase):
self._testWhile_Gpu_1(use_gpu=False)
self._testWhile_Gpu_1(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShape(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1139,6 +1133,7 @@ class ControlFlowTest(test.TestCase):
r = r[1] * array_ops.ones([8, 8])
self.assertAllEqual(np.ones((8, 8)), r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Scalar(self):
with self.cached_session():
n = 0
@@ -1147,6 +1142,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Vector(self):
with self.cached_session():
n = np.array([0]) # Note, [0] would not work here; that is a list
@@ -1155,6 +1151,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual([10000], r.eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileShapeInference(self):
with self.cached_session():
i = constant_op.constant(0)
@@ -1169,7 +1166,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(
c, b, [i, m],
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
- self.assertTrue(r[1].get_shape()[0].value is None)
+ self.assertIsNone(r[1].get_shape()[0].value)
self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2))
with self.assertRaisesRegexp(
@@ -1180,6 +1177,7 @@ class ControlFlowTest(test.TestCase):
r"tf.while_loop to specify a less-specific shape."):
r = control_flow_ops.while_loop(c, b, [i, m])
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileShapeInferenceSparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -1211,6 +1209,7 @@ class ControlFlowTest(test.TestCase):
c, b, [i, x],
[i.get_shape(), tensor_shape.TensorShape([5])])
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileShapeInferenceIndexedSlices(self):
with self.cached_session():
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
@@ -1265,6 +1264,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertEqual(225, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_1(self):
self._testNestedWhile_1(use_gpu=False)
self._testNestedWhile_1(use_gpu=True)
@@ -1297,6 +1297,7 @@ class ControlFlowTest(test.TestCase):
outer_c, outer_b, [s0], parallel_iterations=1)
self.assertEqual(1048576.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_2(self):
self._testNestedWhile_2(use_gpu=False)
self._testNestedWhile_2(use_gpu=True)
@@ -1350,6 +1351,7 @@ class ControlFlowTest(test.TestCase):
lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/79881896 (control_deps)")
def testWhileWithControl_5(self):
with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
@@ -1364,9 +1366,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
with self.cached_session() as sess:
@@ -1380,10 +1379,8 @@ class ControlFlowTest(test.TestCase):
(constant_op.constant(5),))
self.assertEqual(0, sess.run(loop))
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondWithControl_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
@@ -1405,9 +1402,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(4, r.eval())
self.assertAllClose(65536.0, v.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondExitControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
v = variables.Variable(1)
@@ -1432,8 +1428,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1445,8 +1439,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1458,9 +1450,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
with self.test_session(use_gpu=use_gpu) as sess:
p = array_ops.placeholder(dtypes.bool)
n = constant_op.constant(0.0)
@@ -1477,18 +1466,17 @@ class ControlFlowTest(test.TestCase):
lambda: control_flow_ops.while_loop(c, b, [n]),
lambda: math_ops.multiply(n, 2.0))
r1 = gradients_impl.gradients(r, [n])
- self.assertEqual(10, sess.run(r, {p: True}))
+ self.assertEqual(10., sess.run(r, {p: True}))
self.assertEqual([1.0], sess.run(r1, {p: True}))
self.assertEqual(0.0, sess.run(r, {p: False}))
self.assertEqual([2.0], sess.run(r1, {p: False}))
+ @test_util.disable_control_flow_v2("b/116743589")
def testCondWhile_3(self):
self._testCondWhile_3(use_gpu=False)
self._testCondWhile_3(use_gpu=True)
def testWhileCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
i = ops.convert_to_tensor(0, name="i")
@@ -1505,8 +1493,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1516,8 +1502,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1532,6 +1516,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
# NOTE: It is ok to have parallel_iterations > 1
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_1(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1554,6 +1539,7 @@ class ControlFlowTest(test.TestCase):
result = select.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_2(self):
with self.cached_session():
select1 = variables.Variable([3.0, 4.0, 5.0])
@@ -1580,6 +1566,7 @@ class ControlFlowTest(test.TestCase):
result2 = select2.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_3(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1601,7 +1588,7 @@ class ControlFlowTest(test.TestCase):
result = r[1].eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
- # b/24814703
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_4(self):
with self.cached_session():
var_a = variables.Variable(0, name="a")
@@ -1629,7 +1616,7 @@ class ControlFlowTest(test.TestCase):
lpa.eval() # Run the loop
self.assertEqual(10, var_b.eval())
- # b/24736492
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_5(self):
with self.cached_session():
# Create some variables.
@@ -1659,7 +1646,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, var_a.eval())
self.assertEqual(10, var_b.eval())
- # b/24814668
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_6(self):
with self.cached_session():
# Create some variables.
@@ -1689,6 +1676,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(55, var_b.eval())
self.assertEqual(10, var_a.eval())
+ @test_util.disable_control_flow_v2("b/116742472 (resource accumulator)")
def testWhileQueue_1(self):
with self.cached_session():
q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
@@ -1707,6 +1695,7 @@ class ControlFlowTest(test.TestCase):
for i in xrange(10):
self.assertEqual([i], q.dequeue().eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileStack_1(self):
with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
@@ -1775,6 +1764,7 @@ class ControlFlowTest(test.TestCase):
with self.session(graph=graph) as sess:
self.assertAllClose(1024.0, sess.run(r))
+ @test_util.disable_control_flow_v2("b/116351701 (colocation)")
def testWhileGrad_ColocateGradients(self):
self._testWhileGrad_ColocateGradients(colocate=False)
self._testWhileGrad_ColocateGradients(colocate=True)
@@ -1790,6 +1780,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Shape(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[None])
@@ -1861,8 +1852,6 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
@@ -1885,10 +1874,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhileCondWhileGrad(self):
self._testNestedWhileCondWhileGrad(use_gpu=False)
self._testNestedWhileCondWhileGrad(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116823782")
def testWhileGrad_Variable(self):
with self.cached_session():
a = variables.Variable(3.0)
@@ -1902,8 +1893,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1919,6 +1908,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116340060")
def testGradInWhileWrtInitialLoopVal(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
@@ -1936,6 +1926,7 @@ class ControlFlowTest(test.TestCase):
"loop invariants or wrt the input parameters to the loop body."):
control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testWhileGradInWhile(self):
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1952,9 +1943,8 @@ class ControlFlowTest(test.TestCase):
[tensor_shape.unknown_shape()])
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testCondGradInNestedWhiles(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
_, x = control_flow_ops.while_loop(
@@ -1972,6 +1962,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
+ @test_util.disable_control_flow_v2("b/116255781 (flat_args)")
def testWhile_NestedInput(self):
with self.cached_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
@@ -1999,6 +1990,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
sess.run(r_flattened))
+ @test_util.disable_control_flow_v2("b/116255781(flat_args)")
def testWhile_NestedBadArityFails(self):
with self.cached_session():
named = collections.namedtuple("named", ("a", "b"))
@@ -2057,6 +2049,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients([rx], x)
self.assertAllClose(1024.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)")
def testWhileGrad_NoGradient(self):
with self.cached_session():
v = constant_op.constant(2.0, name="v")
@@ -2067,6 +2060,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)
self.assertAllClose(1.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGrad_NoDependency(self):
with self.cached_session() as sess:
variable = variables.Variable(array_ops.ones([2, 3]))
@@ -2180,10 +2174,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(8.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_Simple(self):
self._testNestedWhileGrad_Simple(use_gpu=False)
self._testNestedWhileGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_SerialInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2207,6 +2203,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(256.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_ParallelInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2230,6 +2227,8 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2(
+ "Nested loops and TensorArrays not supported")
def testNestedWhileGrad_ParallelIterations(self):
# Make sure the stack pushes and pops of an inner loop are executed in
# the sequential order of the iterations of its outer loop.
@@ -2268,13 +2267,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_Simple(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_UnknownShape(self):
with self.cached_session() as sess:
v = array_ops.placeholder(dtypes.float32)
@@ -2292,6 +2290,7 @@ class ControlFlowTest(test.TestCase):
r = sess.run(r, feed_dict={v: 2.0})
self.assertAllClose(1024.0, r)
+ @test_util.disable_control_flow_v2("b/116283162 (shape_invariants)")
def testWhileGrad_Concat(self):
with self.cached_session() as sess:
x = variable_scope.get_variable("x", initializer=[[1., 2.]])
@@ -2315,9 +2314,10 @@ class ControlFlowTest(test.TestCase):
sess.run(op)
self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefsWithGradients_1(self):
with self.cached_session() as sess:
- x = variables.Variable(0.)._ref() # pylint: disable=protected-access
+ x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 10)
@@ -2329,7 +2329,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
- grad_ys = [variables.Variable(73)._ref()] # pylint: disable=protected-access
+ grad_ys = [variables.VariableV1(73)._ref()] # pylint: disable=protected-access
grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)
variables.global_variables_initializer().run()
@@ -2343,6 +2343,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, value_x)
self.assertEqual(73, value_x_grad)
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileGrad_IndexedSlices(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2364,6 +2365,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileGrad_SparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2386,6 +2388,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testCallGradInLoop(self):
with self.cached_session() as sess:
i0 = constant_op.constant(0)
@@ -2405,6 +2408,8 @@ class ControlFlowTest(test.TestCase):
c, b, [i0, constant_op.constant(0.0)])
self.assertAllClose(600.0, sess.run(output_grad)[1])
+ @test_util.disable_control_flow_v2(
+ "b/116255781 (flat_args), b/115660901 (TensorArray)")
def testWhileAndTensorArray(self):
with self.cached_session() as sess:
param = constant_op.constant(2.0)
@@ -2509,6 +2514,7 @@ class ControlFlowTest(test.TestCase):
all_ops = x.graph.get_operations()
self.assertFalse(any([name in op.name for op in all_ops]))
+ @test_util.disable_control_flow_v2("b/116255781 (flat args)")
def testWhileGradGradFail(self):
theta = variables.Variable(initial_value=1.)
@@ -2538,6 +2544,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, y)[0]
self.assertEqual(388.0, r.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath1(self):
q = variables.Variable([7., 8.])
@@ -2555,6 +2562,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([0., 0.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath2(self):
q = variables.Variable([7., 8.])
@@ -2572,6 +2580,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([1., 1.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testIssue16504(self):
c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
w = variables.Variable(
@@ -2595,6 +2604,7 @@ class ControlFlowTest(test.TestCase):
grad, = gradients_impl.gradients(w, c)
self.assertIsNotNone(grad)
+ @test_util.disable_control_flow_v2("b/116270461 (resource)")
def testStopGradMultiFlows(self):
with self.cached_session():
@@ -2653,10 +2663,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCase(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
@@ -2708,10 +2717,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCaseSideEffects(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
@@ -2746,10 +2754,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, r0.eval())
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testOneOpCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
@@ -2779,7 +2785,7 @@ class ControlFlowTest(test.TestCase):
def testWithOpsDependencies(self):
with self.cached_session() as sess:
- v = variables.Variable(0.0)
+ v = variables.VariableV1(0.0)
c = constant_op.constant(10)
# Fetching v directly will result in an uninitialized error
@@ -2802,7 +2808,7 @@ class ControlFlowTest(test.TestCase):
def testWithTensorDependencies(self):
with self.cached_session():
- v = variables.Variable(0.0)
+ v = variables.VariableV1(0.0)
c1 = constant_op.constant(10)
c2 = constant_op.constant(20)
@@ -2828,7 +2834,7 @@ class ControlFlowTest(test.TestCase):
def testWithIndexedSlicesDependencies(self):
with self.cached_session():
- v = variables.Variable(
+ v = variables.VariableV1(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
@@ -2851,18 +2857,18 @@ class ControlFlowTest(test.TestCase):
with ops.Graph().as_default():
# device set on tensor => same device on dep.
with ops.device("/job:ps"):
- vd = variables.Variable([0.0])
+ vd = variables.VariableV1([0.0])
with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
self.assertTrue("/job:ps" in with_vd_dep.device)
# No device set on tensor => no device on dep.
- vnod = variables.Variable([0.0])
+ vnod = variables.VariableV1([0.0])
with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
vnod)
self.assertDeviceEqual(None, with_vnod_dep.device)
# device set on tensor, default device on graph => default device on dep.
- vdef = variables.Variable([0.0], name="vdef")
+ vdef = variables.VariableV1([0.0], name="vdef")
with ops.device("/job:worker/device:GPU:1"):
with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
vdef)
@@ -2872,8 +2878,8 @@ class ControlFlowTest(test.TestCase):
def testGroup(self):
with self.cached_session() as sess:
- v1 = variables.Variable([0.0])
- v2 = variables.Variable([1.0])
+ v1 = variables.VariableV1([0.0])
+ v2 = variables.VariableV1([1.0])
# Group init1 and init2 and run.
init = control_flow_ops.group(v1.initializer, v2.initializer)
@@ -2955,29 +2961,29 @@ class ControlFlowTest(test.TestCase):
p1 = array_ops.placeholder(dtypes.float32)
p2 = array_ops.placeholder(dtypes.float32)
p3 = array_ops.placeholder(dtypes.float32)
- v1 = variables.Variable(p1, validate_shape=False)
- v2 = variables.Variable(p2, validate_shape=False)
- v3 = variables.Variable(p3, validate_shape=False)
+ v1 = variables.VariableV1(p1, validate_shape=False)
+ v2 = variables.VariableV1(p2, validate_shape=False)
+ v3 = variables.VariableV1(p3, validate_shape=False)
self.assertIs(None, v1.get_shape().ndims)
s = control_flow_ops.ref_select(index, [v1, v2, v3])
self.assertIs(None, s.get_shape().ndims)
# All inputs known but different.
- v1 = variables.Variable([[1, 2]])
- v2 = variables.Variable([[2], [1]])
+ v1 = variables.VariableV1([[1, 2]])
+ v2 = variables.VariableV1([[2], [1]])
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertIs(None, s.get_shape().ndims)
# All inputs known and same.
- v1 = variables.Variable([[1, 2]])
- v2 = variables.Variable([[1, 2]])
+ v1 = variables.VariableV1([[1, 2]])
+ v2 = variables.VariableV1([[1, 2]])
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertEqual([1, 2], s.get_shape())
# Possibly the same but not guaranteed.
- v1 = variables.Variable([[1., 2.]])
+ v1 = variables.VariableV1([[1., 2.]])
p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
- v2 = variables.Variable(p2, validate_shape=False)
+ v2 = variables.VariableV1(p2, validate_shape=False)
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertEqual(None, s.get_shape())
@@ -3031,9 +3037,11 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, x)[0]
self.assertEqual(r.eval(), 524288.0)
- self.assertEqual(
- len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
- 1)
+ # while_v2 does not have stacks.
+ if not control_flow_ops.ENABLE_WHILE_V2:
+ self.assertEqual(
+ len([op for op in x.graph.get_operations() if op.type == "StackV2"
+ ]), 1)
class ControlFlowContextCheckTest(test.TestCase):
@@ -3160,11 +3168,11 @@ class TupleTest(test.TestCase):
def testTensors(self):
for v1_first in [True, False]:
with self.cached_session():
- v1 = variables.Variable([1.0])
+ v1 = variables.VariableV1([1.0])
add1 = math_ops.add(
control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
2.0)
- v2 = variables.Variable([10.0])
+ v2 = variables.VariableV1([10.0])
add2 = math_ops.add(
control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access
20.0)
@@ -3190,14 +3198,14 @@ class TupleTest(test.TestCase):
def testIndexedSlices(self):
for v1_first in [True, False]:
with self.cached_session():
- v1 = variables.Variable(
+ v1 = variables.VariableV1(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
np.float32))
v1_at_1 = ops.IndexedSlices(
control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
constant_op.constant([1]))
- v2 = variables.Variable(
+ v2 = variables.VariableV1(
np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
np.float32))
v2_at_1 = ops.IndexedSlices(
@@ -3229,7 +3237,7 @@ class TupleTest(test.TestCase):
def testAcceptTensorsAsControlInputs(self):
with self.cached_session():
- var = variables.Variable(0)
+ var = variables.VariableV1(0)
assign = state_ops.assign(var, 1)
t, = control_flow_ops.tuple(
[constant_op.constant(0)], control_inputs=[assign])
@@ -3393,7 +3401,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class EagerTest(test.TestCase):
def testCond(self):
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
index 06c3271850..120e10314f 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -87,7 +87,7 @@ class AssignOpTest(test.TestCase):
def testAssignNonStrictShapeChecking(self):
with self.cached_session():
data = array_ops.fill([1024, 1024], 0)
- p = variables.Variable([1])
+ p = variables.VariableV1([1])
a = state_ops.assign(p, data, validate_shape=False)
a.op.run()
self.assertAllEqual(p.eval(), data.eval())
@@ -100,14 +100,14 @@ class AssignOpTest(test.TestCase):
def testInitRequiredAssignAdd(self):
with self.cached_session():
- p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
+ p = variables.VariableV1(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
a.op.run()
def testInitRequiredAssignSub(self):
with self.cached_session():
- p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
+ p = variables.VariableV1(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
a.op.run()
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 200da772e5..6d1ead20be 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -191,7 +191,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
"%s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
tf_logging.info("Testing without grouped_conv")
self._VerifyValues(
input_size, filter_size, stride, padding, data_type, use_gpu=True)
@@ -227,7 +227,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._VerifyValues(
input_size,
filter_size,
@@ -434,7 +434,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -465,7 +465,7 @@ class DepthwiseConv2DTest(test.TestCase):
"Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -483,7 +483,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
"%d, padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -504,7 +504,7 @@ class DepthwiseConv2DTest(test.TestCase):
"Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 26d013bccb..37b35ba51a 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -118,7 +118,9 @@ class BernoulliTest(test.TestCase):
self.assertEqual(dist.probs.dtype, dist.stddev().dtype)
self.assertEqual(dist.probs.dtype, dist.entropy().dtype)
self.assertEqual(dist.probs.dtype, dist.prob(0).dtype)
+ self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype)
self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype)
+ self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype)
dist64 = make_bernoulli([], dtypes.int64)
self.assertEqual(dist64.dtype, dtypes.int64)
@@ -181,6 +183,16 @@ class BernoulliTest(test.TestCase):
return
self._testPmf(logits=special.logit(p))
+ @test_util.run_in_graph_and_eager_modes
+ def testPmfWithFloatArgReturnsXEntropy(self):
+ p = [[0.2], [0.4], [0.3], [0.6]]
+ samps = [0, 0.1, 0.8]
+ self.assertAllClose(
+ np.float32(samps) * np.log(np.float32(p)) +
+ (1 - np.float32(samps)) * np.log(1 - np.float32(p)),
+ self.evaluate(
+ bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps)))
+
def testBroadcasting(self):
with self.cached_session():
p = array_ops.placeholder(dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index de73a40b23..6625a88843 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -78,6 +78,14 @@ class NormalTest(test.TestCase):
self.assertEqual(expected, sigma_shape)
@test_util.run_in_graph_and_eager_modes
+ def testSampleLikeArgsGetDistDType(self):
+ dist = normal_lib.Normal(0., 1.)
+ self.assertEqual(dtypes.float32, dist.dtype)
+ for method in ("log_prob", "prob", "log_cdf", "cdf",
+ "log_survival_function", "survival_function", "quantile"):
+ self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype)
+
+ @test_util.run_in_graph_and_eager_modes
def testParamShapes(self):
sample_shape = [10, 3, 4]
self._testParamShapes(sample_shape, sample_shape)
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
index 37f9f716f8..88ea10c22a 100644
--- a/tensorflow/python/kernel_tests/identity_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -61,7 +61,7 @@ class IdentityOpTest(test.TestCase):
def testRefIdentityShape(self):
with self.cached_session():
shape = [2, 3]
- tensor = variables.Variable(
+ tensor = variables.VariableV1(
constant_op.constant(
[[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32))
self.assertEquals(shape, tensor.get_shape())
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index 7261d4bb3b..f1e151ebd8 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -37,8 +37,10 @@ class LinearOperatorCirculantBaseTest(object):
"""Common class for circulant tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
@@ -110,8 +112,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype)
@@ -121,7 +122,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -171,8 +172,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
@@ -182,7 +182,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -217,8 +217,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
@@ -228,7 +227,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -238,7 +237,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
np.testing.assert_allclose(0, imag_matrix.eval(), rtol=0, atol=eps * 3)
def test_simple_positive_real_spectrum_gives_self_adjoint_pos_def_oper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
spectrum = math_ops.cast([6., 4, 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -250,7 +249,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
operator.assert_self_adjoint().run() # Should not fail
def test_defining_operator_using_real_convolution_kernel(self):
- with self.test_session():
+ with self.cached_session():
convolution_kernel = [1., 2., 1.]
spectrum = math_ops.fft(
math_ops.cast(convolution_kernel, dtypes.complex64))
@@ -266,7 +265,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
def test_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
# Make spectrum the FFT of a real convolution kernel h. This ensures that
# spectrum is Hermitian.
h = linear_operator_test_util.random_normal(shape=(3, 4))
@@ -281,7 +280,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def test_convolution_kernel_same_as_first_row_of_to_dense(self):
spectrum = [[3., 2., 1.], [2., 1.5, 1.]]
- with self.test_session():
+ with self.cached_session():
operator = linalg.LinearOperatorCirculant(spectrum)
h = operator.convolution_kernel()
c = operator.to_dense()
@@ -293,27 +292,27 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def test_assert_non_singular_fails_for_singular_operator(self):
spectrum = math_ops.cast([0, 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
spectrum = math_ops.cast([-3j, 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_non_singular().run() # Should not fail
def test_assert_positive_definite_fails_for_non_positive_definite(self):
spectrum = math_ops.cast([6., 4, 2j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Not positive definite"):
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_fail_when_pos_def(self):
spectrum = math_ops.cast([6., 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_positive_definite().run() # Should not fail
def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
@@ -331,8 +330,10 @@ class LinearOperatorCirculant2DBaseTest(object):
"""Common class for 2D circulant tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
@@ -446,8 +447,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
@@ -482,8 +482,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
@@ -493,7 +492,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
return operator, mat
def test_real_hermitian_spectrum_gives_real_symmetric_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = [[1., 2., 2.], [3., 4., 4.], [3., 4., 4.]]
operator = linalg.LinearOperatorCirculant(spectrum)
@@ -510,7 +509,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
self.assertAllClose(matrix, matrix_transpose, atol=0)
def test_real_spectrum_gives_self_adjoint_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal(
shape=(3, 3), dtype=dtypes.float32)
@@ -526,27 +525,27 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
def test_assert_non_singular_fails_for_singular_operator(self):
spectrum = math_ops.cast([[0, 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
spectrum = math_ops.cast([[-3j, 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_non_singular().run() # Should not fail
def test_assert_positive_definite_fails_for_non_positive_definite(self):
spectrum = math_ops.cast([[6., 4], [2j, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Not positive definite"):
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_fail_when_pos_def(self):
spectrum = math_ops.cast([[6., 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_positive_definite().run() # Should not fail
def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
@@ -574,13 +573,15 @@ class LinearOperatorCirculant3DTest(test.TestCase):
"""Simple test of the 3D case. See also the 1D and 2D tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
def test_real_spectrum_gives_self_adjoint_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32)
@@ -597,7 +598,7 @@ class LinearOperatorCirculant3DTest(test.TestCase):
self.assertAllClose(matrix, matrix_h)
def test_defining_operator_using_real_convolution_kernel(self):
- with self.test_session():
+ with self.cached_session():
convolution_kernel = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32)
# Convolution kernel is real ==> spectrum is Hermitian.
@@ -615,7 +616,7 @@ class LinearOperatorCirculant3DTest(test.TestCase):
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
def test_defining_spd_operator_by_taking_real_part(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# S is real and positive.
s = linear_operator_test_util.random_uniform(
shape=(10, 2, 3, 4), dtype=dtypes.float32, minval=1., maxval=2.)
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 0f5607712b..ae413edaec 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -170,6 +170,32 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_pop_back(
l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+ @test_util.run_in_graph_and_eager_modes
+ def testCPUGPUCopyNested(self):
+ if not context.num_gpus():
+ return
+ t = constant_op.constant([1.0, 2.0])
+ child_l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
+ l = list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([], dtype=dtypes.int32),
+ element_dtype=dtypes.variant)
+ l = list_ops.tensor_list_push_back(l, child_l)
+ with context.device("gpu:0"):
+ l_gpu = array_ops.identity(l)
+ _, child_l_gpu = list_ops.tensor_list_pop_back(
+ l_gpu, element_dtype=dtypes.variant)
+ self.assertAllEqual(
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
+ l_cpu = array_ops.identity(l_gpu)
+ _, child_l_cpu = list_ops.tensor_list_pop_back(
+ l_cpu, element_dtype=dtypes.variant)
+ self.assertAllEqual(
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+
def testGraphStack(self):
with self.cached_session():
tl = list_ops.empty_tensor_list(
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index f90545f84c..1365d4b240 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -290,7 +290,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(self.evaluate(read), [[2]])
def testUseResource(self):
- v = variables.Variable(1.0, use_resource=True)
+ v = variables.VariableV1(1.0, use_resource=True)
self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
def testEagerNoUseResource(self):
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index 86e063cb36..4b92309e4d 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -136,7 +136,7 @@ class StatefulScatterNdTest(test.TestCase):
new = ref.copy()
np_scatter(new, indices, updates)
# Scatter via tensorflow
- ref_var = variables.Variable(ref)
+ ref_var = variables.VariableV1(ref)
ref_var.initializer.run()
tf_scatter(ref_var, indices, updates).eval()
@@ -258,7 +258,7 @@ class StatefulScatterNdTest(test.TestCase):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
with self.test_session(use_gpu=False):
- ref = variables.Variable(params)
+ ref = variables.VariableV1(params)
ref.initializer.run()
# Indices all in range, no problem.
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index 1a0fa744ae..527b7daf10 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -178,7 +178,7 @@ class ScatterTest(test.TestCase):
np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
np_scatter(new, indices, updates)
# Scatter via tensorflow
- ref = variables.Variable(old)
+ ref = variables.VariableV1(old)
ref.initializer.run()
tf_scatter(ref, indices, updates).eval()
self.assertAllClose(ref.eval(), new)
@@ -294,7 +294,7 @@ class ScatterTest(test.TestCase):
updates = np.array([-3, -4, -5]).astype(np.float32)
if not test.is_gpu_available():
with self.test_session(use_gpu=False):
- ref = variables.Variable(params)
+ ref = variables.VariableV1(params)
ref.initializer.run()
# Indices all in range, no problem.
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index afe3df6178..636ed4747e 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
@@ -125,9 +124,9 @@ class SoftplusTest(test.TestCase):
def testNoInts(self):
with self.cached_session():
with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "No OpKernel was registered to support Op 'Softplus'"):
- nn_ops.softplus(constant_op.constant(7)).eval()
+ TypeError,
+ "'features' has DataType int32 not in list of allowed values"):
+ nn_ops.softplus(constant_op.constant(42)).eval()
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py
index 05a7c53dee..1b4db9fa46 100644
--- a/tensorflow/python/kernel_tests/softsign_op_test.py
+++ b/tensorflow/python/kernel_tests/softsign_op_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
@@ -69,8 +68,8 @@ class SoftsignTest(test.TestCase):
def testNoInts(self):
with self.cached_session():
with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "No OpKernel was registered to support Op 'Softsign'"):
+ TypeError,
+ "'features' has DataType int32 not in list of allowed values"):
nn_ops.softsign(constant_op.constant(7)).eval()
diff --git a/tensorflow/python/kernel_tests/string_length_op_test.py b/tensorflow/python/kernel_tests/string_length_op_test.py
index 9f013c2c7e..4afe3ad3f4 100644
--- a/tensorflow/python/kernel_tests/string_length_op_test.py
+++ b/tensorflow/python/kernel_tests/string_length_op_test.py
@@ -32,6 +32,33 @@ class StringLengthOpTest(test.TestCase):
values = sess.run(lengths)
self.assertAllEqual(values, [[[1, 2], [3, 4], [5, 6]]])
+ def testUnit(self):
+ unicode_strings = [u"H\xc3llo", u"\U0001f604"]
+ utf8_strings = [s.encode("utf-8") for s in unicode_strings]
+ expected_utf8_byte_lengths = [6, 4]
+ expected_utf8_char_lengths = [5, 1]
+
+ with self.test_session() as sess:
+ utf8_byte_lengths = string_ops.string_length(utf8_strings, unit="BYTE")
+ utf8_char_lengths = string_ops.string_length(
+ utf8_strings, unit="UTF8_CHAR")
+ self.assertAllEqual(
+ sess.run(utf8_byte_lengths), expected_utf8_byte_lengths)
+ self.assertAllEqual(
+ sess.run(utf8_char_lengths), expected_utf8_char_lengths)
+ with self.assertRaisesRegexp(
+ ValueError, "Attr 'unit' of 'StringLength' Op passed string 'XYZ' "
+ 'not in: "BYTE", "UTF8_CHAR"'):
+ string_ops.string_length(utf8_strings, unit="XYZ")
+
+ def testLegacyPositionalName(self):
+ # Code that predates the 'unit' parameter may have used a positional
+ # argument for the 'name' parameter. Check that we don't break such code.
+ strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]]
+ lengths = string_ops.string_length(strings, "some_name")
+ with self.test_session():
+ self.assertAllEqual(lengths.eval(), [[[1, 2], [3, 4], [5, 6]]])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/unicode_script_op_test.py b/tensorflow/python/kernel_tests/unicode_script_op_test.py
new file mode 100644
index 0000000000..927e5459ed
--- /dev/null
+++ b/tensorflow/python/kernel_tests/unicode_script_op_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#===============================================================================
+"""Functional tests for UnicodeScript op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class UnicodeScriptOpTest(test.TestCase):
+
+ def testValidScripts(self):
+ inputs = [
+ ord("a"),
+ 0x0411, # CYRILLIC CAPITAL LETTER BE
+ 0x82b8, # CJK UNIFIED IDEOGRAPH-82B8
+ ord(",")
+ ]
+ with self.cached_session():
+ input_vector = constant_op.constant(inputs, dtypes.int32)
+ outputs = string_ops.unicode_script(input_vector).eval()
+ self.assertAllEqual(
+ outputs,
+ [
+ 25, # USCRIPT_LATIN (LATN)
+ 8, # USCRIPT_CYRILLIC (CYRL)
+ 17, # USCRIPT_HAN (HANI)
+ 0 # USCRIPT_COMMON (ZYYY)
+ ])
+
+ def testInvalidScript(self):
+ inputs = [-100, 0xffffff]
+ with self.cached_session():
+ input_vector = constant_op.constant(inputs, dtypes.int32)
+ outputs = string_ops.unicode_script(input_vector).eval()
+ self.assertAllEqual(outputs, [-1, -1])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 401e1ae102..33f464fb90 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -394,10 +394,10 @@ class VariableScopeTest(test.TestCase):
old = variable_scope._DEFAULT_USE_RESOURCE
try:
variable_scope.enable_resource_variables()
- self.assertTrue(isinstance(variables_lib.Variable(1.0),
+ self.assertTrue(isinstance(variables_lib.VariableV1(1.0),
resource_variable_ops.ResourceVariable))
variable_scope.disable_resource_variables()
- self.assertFalse(isinstance(variables_lib.Variable(1.0),
+ self.assertFalse(isinstance(variables_lib.VariableV1(1.0),
resource_variable_ops.ResourceVariable))
finally:
variable_scope._DEFAULT_USE_RESOURCE = old
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 0b101529fe..c2b86089f4 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -43,14 +43,14 @@ class VariablesTestCase(test.TestCase):
def testInitialization(self):
with self.cached_session():
- var0 = variables.Variable(0.0)
+ var0 = variables.VariableV1(0.0)
self.assertEqual("Variable:0", var0.name)
self.assertEqual("Variable", var0._shared_name)
self.assertEqual([], var0.get_shape())
self.assertEqual([], var0.get_shape())
self.assertEqual([], var0.shape)
- var1 = variables.Variable(1.1)
+ var1 = variables.VariableV1(1.1)
self.assertEqual("Variable_1:0", var1.name)
self.assertEqual("Variable_1", var1._shared_name)
self.assertEqual([], var1.get_shape())
@@ -143,7 +143,7 @@ class VariablesTestCase(test.TestCase):
def testZeroSizeStringAssign(self):
with self.cached_session() as sess:
- array = variables.Variable(
+ array = variables.VariableV1(
initial_value=array_ops.zeros((0,), dtype=dtypes.string),
name="foo",
trainable=False,
@@ -192,7 +192,7 @@ class VariablesTestCase(test.TestCase):
# d get the control dep.
d = constant_op.constant(2.0)
# variables do not.
- var_x = variables.Variable(2.0)
+ var_x = variables.VariableV1(2.0)
self.assertEqual([c.op], d.op.control_inputs)
self.assertEqual([], var_x.initializer.control_inputs)
self.assertEqual([], var_x.value().op.control_inputs)
@@ -280,10 +280,10 @@ class VariablesTestCase(test.TestCase):
def testCollections(self):
with self.cached_session():
- var_x = variables.Variable(2.0)
- var_y = variables.Variable(2.0, trainable=False)
- var_z = variables.Variable(2.0, trainable=True)
- var_t = variables.Variable(
+ var_x = variables.VariableV1(2.0)
+ var_y = variables.VariableV1(2.0, trainable=False)
+ var_z = variables.VariableV1(2.0, trainable=True)
+ var_t = variables.VariableV1(
2.0,
trainable=True,
collections=[
@@ -296,9 +296,9 @@ class VariablesTestCase(test.TestCase):
def testCollectionsWithScope(self):
with self.cached_session():
with ops.name_scope("scope_1"):
- var_x = variables.Variable(2.0)
+ var_x = variables.VariableV1(2.0)
with ops.name_scope("scope_2"):
- var_y = variables.Variable(2.0)
+ var_y = variables.VariableV1(2.0)
self.assertEqual([var_x, var_y], variables.global_variables())
self.assertEqual([var_x], variables.global_variables("scope_1"))
@@ -399,7 +399,7 @@ class VariablesTestCase(test.TestCase):
def testColocation(self):
with ops.device("/job:ps"):
- var = variables.Variable(0, name="v")
+ var = variables.VariableV1(0, name="v")
with ops.device("/job:worker/task:7"):
assign_op = var.assign(1)
self.assertDeviceEqual("/job:ps", assign_op.device)
@@ -522,7 +522,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(np.ones((5, 5), np.float32), var.eval())
def testRepr(self):
- var = variables.Variable(np.zeros((5, 5), np.float32), name="noop")
+ var = variables.VariableV1(np.zeros((5, 5), np.float32), name="noop")
self.assertEqual(
"<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>",
repr(var))
@@ -556,8 +556,8 @@ class IsInitializedTest(test.TestCase):
def testVariableList(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2], name="v")
- w = variables.Variable([3, 4], name="w")
+ v = variables.VariableV1([1, 2], name="v")
+ w = variables.VariableV1([3, 4], name="w")
uninited = variables.report_uninitialized_variables()
self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited))
sess.run(w.initializer)
@@ -593,8 +593,8 @@ class ObsoleteIsInitializedTest(test.TestCase):
def testVariables(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2])
- w = variables.Variable([3, 4])
+ v = variables.VariableV1([1, 2])
+ w = variables.VariableV1([3, 4])
_ = v, w
inited = variables.assert_variables_initialized()
with self.assertRaisesOpError("Attempting to use uninitialized value"):
@@ -604,8 +604,8 @@ class ObsoleteIsInitializedTest(test.TestCase):
def testVariableList(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2])
- w = variables.Variable([3, 4])
+ v = variables.VariableV1([1, 2])
+ w = variables.VariableV1([3, 4])
inited = variables.assert_variables_initialized([v])
with self.assertRaisesOpError("Attempting to use uninitialized value"):
inited.op.run()
@@ -714,7 +714,7 @@ class PartitionedVariableTest(test.TestCase):
dtype=v0.dtype,
variable_list=[v0, v1],
partitions=partitions)
-
+
deltas_a = constant_op.constant([1.0, 2.0])
deltas_b = constant_op.constant([3.0, 4.0])
ones = array_ops.ones([2])
@@ -727,17 +727,18 @@ class PartitionedVariableTest(test.TestCase):
self.assertEqual([1.0], v0.eval())
self.assertEqual([3.0], plus_delta[1].eval())
self.assertEqual([3.0], v1.eval())
-
+
self.assertEqual([-2.0], minus_delta[0].eval())
self.assertEqual([-2.0], v0.eval())
self.assertEqual([-1.0], minus_delta[1].eval())
self.assertEqual([-1.0], v1.eval())
-
+
self.assertEqual([1.0], assign_ones[0].eval())
self.assertEqual([1.0], v0.eval())
self.assertEqual([1.0], assign_ones[1].eval())
self.assertEqual([1.0], v1.eval())
+
class VariableContainerTest(test.TestCase):
def testContainer(self):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 3ba880d7a1..e399ece232 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -131,10 +131,20 @@ class Layer(base_layer.Layer):
def add_loss(self, losses, inputs=None):
previous_losses_length = len(self._losses)
+ previous_callable_losses_length = len(self._callable_losses)
super(Layer, self).add_loss(losses, inputs=inputs)
- # TODO(fchollet): deprecate collection below.
- new_losses = self._losses[previous_losses_length:]
- _add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES)
+ if not context.executing_eagerly():
+ # TODO(fchollet): deprecate collection below.
+ new_losses = self._losses[previous_losses_length:]
+ new_callable_losses = self._callable_losses[
+ previous_callable_losses_length:]
+ for regularizer in new_callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ new_losses.append(loss_tensor)
+ _add_elements_to_collection(
+ new_losses,
+ ops.GraphKeys.REGULARIZATION_LOSSES)
def _name_scope(self):
"""Determines op naming for the Layer."""
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index d61d3b6dba..257fa27156 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -207,7 +207,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -217,7 +218,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DNoBias(self):
height, width = 7, 9
@@ -445,7 +447,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DPointwiseRegularizer(self):
length = 9
@@ -455,7 +458,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DBiasRegularizer(self):
length = 9
@@ -465,7 +469,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DNoBias(self):
length = 9
@@ -682,7 +687,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DPointwiseRegularizer(self):
height, width = 7, 9
@@ -692,7 +698,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -702,7 +709,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DNoBias(self):
height, width = 7, 9
@@ -839,7 +847,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeBiasRegularizer(self):
height, width = 7, 9
@@ -849,7 +858,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeNoBias(self):
height, width = 7, 9
@@ -1017,7 +1027,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeBiasRegularizer(self):
depth, height, width = 5, 7, 9
@@ -1027,7 +1038,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeNoBias(self):
depth, height, width = 5, 7, 9
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 46009a30ac..d26f3f4789 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -197,7 +197,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testKernelRegularizerWithReuse(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
@@ -218,7 +219,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testFunctionalDense(self):
with self.cached_session():
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 87f8bd85a5..9d7d31df22 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,8 +60,17 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
+# The while_v2 module.
+_while_v2 = None
ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+# Note: Setting this to True is not sufficient to switch to the v2 while_loop.
+# Users must also import the while_v2 module to set the _while_v2 module
+# variable above. We do this to avoid a circular dependency:
+# control_flow_ops -> while_v2 -> gradients_impl -> control_flow_ops
+# A ValueError is raised in tf.while_loop if this is set to True and the
+# `_while_v2` module is not set.
+ENABLE_WHILE_V2 = os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -3211,6 +3220,13 @@ def while_loop(cond,
```
"""
+ if ENABLE_WHILE_V2 and not context.executing_eagerly():
+ if not _while_v2:
+ raise ValueError("The while_v2 module is not set. Did you forget to "
+ "import tensorflow.python.ops."
+ "while_v2?")
+ return _while_v2.while_loop(cond, body, loop_vars, name)
+
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 76d980679e..12fd039392 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -25,6 +25,7 @@ import types
import numpy as np
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -127,6 +128,18 @@ def _update_docstring(old_str, append_str):
return old_str + "\n\n" + append_str
+def _convert_to_tensor(value, name=None, preferred_dtype=None):
+ """Converts to tensor avoiding an eager bug that loses float precision."""
+ # TODO(b/116672045): Remove this function.
+ if (context.executing_eagerly() and preferred_dtype is not None and
+ (preferred_dtype.is_integer or preferred_dtype.is_bool)):
+ v = ops.convert_to_tensor(value, name=name)
+ if v.dtype.is_floating:
+ return v
+ return ops.convert_to_tensor(
+ value, name=name, preferred_dtype=preferred_dtype)
+
+
class _DistributionMeta(abc.ABCMeta):
def __new__(mcs, classname, baseclasses, attrs):
@@ -741,7 +754,8 @@ class Distribution(_BaseDistribution):
def _call_log_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_prob(value, **kwargs)
except NotImplementedError as original_exception:
@@ -769,7 +783,8 @@ class Distribution(_BaseDistribution):
def _call_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._prob(value, **kwargs)
except NotImplementedError as original_exception:
@@ -797,7 +812,8 @@ class Distribution(_BaseDistribution):
def _call_log_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_cdf(value, **kwargs)
except NotImplementedError as original_exception:
@@ -835,7 +851,8 @@ class Distribution(_BaseDistribution):
def _call_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._cdf(value, **kwargs)
except NotImplementedError as original_exception:
@@ -870,7 +887,8 @@ class Distribution(_BaseDistribution):
def _call_log_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_survival_function(value, **kwargs)
except NotImplementedError as original_exception:
@@ -909,7 +927,8 @@ class Distribution(_BaseDistribution):
def _call_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._survival_function(value, **kwargs)
except NotImplementedError as original_exception:
@@ -963,7 +982,8 @@ class Distribution(_BaseDistribution):
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
return self._quantile(value, **kwargs)
def quantile(self, value, name="quantile"):
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 60d73a1693..6263041b8d 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -550,11 +550,9 @@ def safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- if not isinstance(embedding_weights[0],
- resource_variable_ops.ResourceVariable):
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 4f6e5dc473..3c9b7a01c7 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -273,7 +273,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
def testVariableRefGradient(self):
with ops.Graph().as_default():
init = constant_op.constant(100.0)
- var = variables.Variable(init)
+ var = variables.VariableV1(init)
gradient = gradients.gradients(var._ref(), var)
self.assertIsNotNone(gradient)
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 78c85db557..76d659f109 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -184,7 +184,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -199,7 +199,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -215,7 +215,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -240,7 +240,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -283,7 +283,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -319,7 +319,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -335,7 +335,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -353,7 +353,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 8e11c4bce1..35278d9680 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -516,6 +516,40 @@ def _Log1pGrad(op, grad):
return grad * math_ops.reciprocal(1 + x)
+@ops.RegisterGradient("Xlogy")
+def _XLogyGrad(op, grad):
+ """Returns gradient of xlogy(x, y) with respect to x and y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ with ops.control_dependencies([grad]):
+ not_zero_x = math_ops.cast(
+ math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
+ partial_x = gen_math_ops.xlogy(not_zero_x, y)
+ partial_y = gen_math_ops.xdivy(x, y)
+ return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
+
+
+@ops.RegisterGradient("Xdivy")
+def _XDivyGrad(op, grad):
+ """Returns gradient of xdivy(x, y) with respect to x and y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ with ops.control_dependencies([grad]):
+ not_zero_x = math_ops.cast(
+ math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
+ partial_x = gen_math_ops.xdivy(not_zero_x, y)
+ partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
+ return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
+
+
@ops.RegisterGradient("Sinh")
def _SinhGrad(op, grad):
"""Returns grad * cosh(x)."""
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index 7110e0958c..9cfb050942 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -256,5 +256,93 @@ class DivNoNanGradientTest(test.TestCase):
self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list()))
+class XlogyTest(test.TestCase):
+
+ def _xlogy_gradients(self, x, y):
+ xlogy_xgrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), x)[0])
+ xlogy_ygrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), y)[0])
+ return xlogy_xgrad, xlogy_ygrad
+
+ def testNonZeroValuesGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ xlogy_expected_xgrad = self.evaluate(math_ops.log(y))
+ xlogy_expected_ygrad = self.evaluate(x / y)
+ self.assertAllClose(xlogy_expected_xgrad, xlogy_xgrad)
+ self.assertAllClose(xlogy_expected_ygrad, xlogy_ygrad)
+
+ def testZeroXGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xlogy_xgrad)
+ self.assertAllClose(zero, xlogy_ygrad)
+
+ def testZeroYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ self.assertAllClose(-np.inf, xlogy_xgrad)
+ self.assertAllClose(np.inf, xlogy_ygrad)
+
+ def testZeroXYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xlogy_xgrad)
+ self.assertAllClose(zero, xlogy_ygrad)
+
+
+class XdivyTest(test.TestCase):
+
+ def _xdivy_gradients(self, x, y):
+ xdivy_xgrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), x)[0])
+ xdivy_ygrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), y)[0])
+ return xdivy_xgrad, xdivy_ygrad
+
+ def testNonZeroValuesGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ xdivy_expected_xgrad = self.evaluate(1 / y)
+ xdivy_expected_ygrad = self.evaluate(-x / y**2)
+ self.assertAllClose(xdivy_expected_xgrad, xdivy_xgrad)
+ self.assertAllClose(xdivy_expected_ygrad, xdivy_ygrad)
+
+ def testZeroXGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xdivy_xgrad)
+ self.assertAllClose(zero, xdivy_ygrad)
+
+ def testZeroYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ self.assertAllClose(np.inf, xdivy_xgrad)
+ self.assertAllClose(-np.inf, xdivy_ygrad)
+
+ def testZeroXYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xdivy_xgrad)
+ self.assertAllClose(zero, xdivy_ygrad)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 1b01d1d37f..f051850d92 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -21,6 +21,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -488,5 +489,75 @@ class DivNoNanTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, np_result)
+class XlogyTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyNoZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy = self.evaluate(math_ops.xlogy(x, y))
+ xtimeslogy = self.evaluate(x * math_ops.log(y))
+ self.assertAllClose(xlogy, xtimeslogy)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyWithZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y))
+ self.assertAllClose(xlogy_tf_np, zeros_np)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyWithZeroBroadcast(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.], [1.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
+ xtimes_logy = self.evaluate(math_ops.log(y[1]))
+ self.assertAllClose(zeros_np, xlogy_tf_np[0])
+ self.assertAllClose(xtimes_logy, xlogy_tf_np[1])
+
+
+class XdivyTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyNoZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy = self.evaluate(math_ops.xdivy(x, y))
+ x_over_y = self.evaluate(x / y)
+ self.assertAllClose(xdivy, x_over_y)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyWithZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y))
+ self.assertAllClose(xdivy_tf_np, zeros_np)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyWithZeroBroadcast(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.], [1.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
+ x_over_y = self.evaluate(1 / y[1])
+ self.assertAllClose(zeros_np, xdivy_tf_np[0])
+ self.assertAllClose(x_over_y, xdivy_tf_np[1])
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/matmul_benchmark.py b/tensorflow/python/ops/matmul_benchmark.py
index 6e5fe74290..138149e63d 100644
--- a/tensorflow/python/ops/matmul_benchmark.py
+++ b/tensorflow/python/ops/matmul_benchmark.py
@@ -49,13 +49,13 @@ def build_graph(device, n, m, k, transpose_a, transpose_b, dtype):
"""
with ops.device('%s' % device):
if not transpose_a:
- x = variables.Variable(random_ops.random_uniform([n, m], dtype=dtype))
+ x = variables.VariableV1(random_ops.random_uniform([n, m], dtype=dtype))
else:
- x = variables.Variable(random_ops.random_uniform([m, n], dtype=dtype))
+ x = variables.VariableV1(random_ops.random_uniform([m, n], dtype=dtype))
if not transpose_b:
- y = variables.Variable(random_ops.random_uniform([m, k], dtype=dtype))
+ y = variables.VariableV1(random_ops.random_uniform([m, k], dtype=dtype))
else:
- y = variables.Variable(random_ops.random_uniform([k, m], dtype=dtype))
+ y = variables.VariableV1(random_ops.random_uniform([k, m], dtype=dtype))
z = math_ops.matmul(x, y, transpose_a=transpose_a, transpose_b=transpose_b)
return control_flow_ops.group(z)
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index e0f6d51881..83cbe64ff2 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -1987,14 +1987,12 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Pow", math_ops.pow)
@RegisterPForWithArgs("RealDiv", math_ops.divide)
@RegisterPForWithArgs("Real", math_ops.real)
-@RegisterPForWithArgs("ReciprocalGrad", math_ops.reciprocal_grad)
@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
@RegisterPForWithArgs("Relu6", nn_ops.relu6)
@RegisterPForWithArgs("Relu", nn_ops.relu)
@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
@RegisterPForWithArgs("Rint", math_ops.rint)
@RegisterPForWithArgs("Round", math_ops.round)
-@RegisterPForWithArgs("RsqrtGrad", math_ops.rsqrt_grad)
@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
@RegisterPForWithArgs("Selu", nn_ops.selu)
@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
@@ -2003,7 +2001,6 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Sin", math_ops.sin)
@RegisterPForWithArgs("Softplus", nn_ops.softplus)
@RegisterPForWithArgs("Softsign", nn_ops.softsign)
-@RegisterPForWithArgs("SqrtGrad", math_ops.sqrt_grad)
@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
@RegisterPForWithArgs("Square", math_ops.square)
@@ -2095,6 +2092,9 @@ def _convert_biasaddgrad(pfor_input):
@RegisterPForWithArgs("SoftplusGrad")
@RegisterPForWithArgs("SoftsignGrad")
@RegisterPForWithArgs("TanhGrad")
+@RegisterPForWithArgs("SqrtGrad")
+@RegisterPForWithArgs("RsqrtGrad")
+@RegisterPForWithArgs("ReciprocalGrad")
def _convert_grads(pfor_input, op_type, *args, **kw_args):
del args
del kw_args
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 43cca1a498..c2751e529a 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -611,7 +611,7 @@ class LSTMStateTuple(_LSTMStateTuple):
# TODO(scottzhu): Stop exporting this class in TF 2.0.
@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):
- """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead.
+ """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead.
Basic LSTM recurrent network cell.
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 5d949467fd..046a48d192 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -36,10 +36,12 @@ from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
+# pylint: disable=g-bad-import-order
from tensorflow.python.ops.gen_string_ops import *
from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
+# pylint: enable=g-bad-import-order
# pylint: enable=wildcard-import
@@ -328,6 +330,17 @@ def reduce_join(inputs, axis=None,
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
+
+# This wrapper provides backwards compatibility for code that predates the
+# unit argument and that passed 'name' as a positional argument.
+@tf_export("strings.length")
+def string_length(input, name=None, unit="BYTE"):
+ return gen_string_ops.string_length(input, unit=unit, name=name)
+
+
+string_length.__doc__ = gen_string_ops.string_length.__doc__
+
+
ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index a43676cd70..5032ca79f9 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -198,7 +198,7 @@ VariableSynchronization = variables.VariableSynchronization # pylint: disable=i
VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
AUTO_REUSE = _ReuseMode.AUTO_REUSE
-tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
+tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
get_variable() should create the requested variable if it doesn't exist or, if
@@ -515,8 +515,10 @@ class _VariableStore(object):
"synchronization": synchronization,
"aggregation": aggregation,
}
- # `fn_args` can handle functions, `functools.partial`, `lambda`.
- if "constraint" in function_utils.fn_args(custom_getter):
+ # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
+ # `lambda`.
+ if ("constraint" in function_utils.fn_args(custom_getter) or
+ function_utils.has_kwargs(custom_getter)):
custom_getter_kwargs["constraint"] = constraint
return custom_getter(**custom_getter_kwargs)
else:
@@ -906,7 +908,7 @@ class _VariableStore(object):
if use_resource is None:
# Set the default value if unspecified.
use_resource = _DEFAULT_USE_RESOURCE
- v = variable(
+ v = variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
@@ -937,7 +939,8 @@ class _VariableStore(object):
if regularizer:
with ops.colocate_with(v):
with ops.name_scope(name + "/Regularizer/"):
- loss = regularizer(v)
+ with ops.init_scope():
+ loss = regularizer(v)
if loss is not None:
if context.executing_eagerly():
v_name = "v_%s" % type(v)
@@ -992,7 +995,7 @@ def no_regularizer(_):
# TODO(alive): support caching devices and partitioned variables in Eager mode.
-@tf_export("VariableScope")
+@tf_export(v1=["VariableScope"])
class VariableScope(object):
"""Variable scope object to carry defaults to provide to `get_variable`.
@@ -1340,7 +1343,7 @@ def get_variable_scope_store():
return scope_store
-@tf_export("get_variable_scope")
+@tf_export(v1=["get_variable_scope"])
def get_variable_scope():
"""Returns the current variable scope."""
return get_variable_scope_store().current_scope
@@ -1449,7 +1452,7 @@ class EagerVariableStore(object):
# The argument list for get_variable must match arguments to get_local_variable.
# So, if you are updating the arguments, also update arguments to
# get_local_variable below.
-@tf_export("get_variable")
+@tf_export(v1=["get_variable"])
def get_variable(name,
shape=None,
dtype=None,
@@ -1594,7 +1597,7 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
-@tf_export("get_local_variable")
+@tf_export(v1=["get_local_variable"])
def get_local_variable( # pylint: disable=missing-docstring
name,
shape=None,
@@ -1939,7 +1942,7 @@ def _get_unique_variable_scope(prefix):
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
-@tf_export("variable_scope") # pylint: disable=invalid-name
+@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name
class variable_scope(object):
"""A context manager for defining ops that creates variables (layers).
@@ -2320,7 +2323,7 @@ class variable_scope(object):
# pylint: disable=g-doc-return-or-yield
-@tf_export("variable_op_scope")
+@tf_export(v1=["variable_op_scope"])
@tf_contextlib.contextmanager
def variable_op_scope(values,
name_or_scope,
@@ -2441,7 +2444,33 @@ def default_variable_creator(next_creator=None, **kwargs):
expected_shape=expected_shape, import_scope=import_scope)
+def default_variable_creator_v2(next_creator=None, **kwargs):
+ """Default variable creator."""
+ assert next_creator is None
+ initial_value = kwargs.get("initial_value", None)
+ trainable = kwargs.get("trainable", None)
+ validate_shape = kwargs.get("validate_shape", True)
+ caching_device = kwargs.get("caching_device", None)
+ name = kwargs.get("name", None)
+ variable_def = kwargs.get("variable_def", None)
+ dtype = kwargs.get("dtype", None)
+ import_scope = kwargs.get("import_scope", None)
+ constraint = kwargs.get("constraint", None)
+
+ # Set trainable value based on synchronization value.
+ synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
+
+ return resource_variable_ops.ResourceVariable(
+ initial_value=initial_value, trainable=trainable,
+ validate_shape=validate_shape, caching_device=caching_device,
+ name=name, dtype=dtype, constraint=constraint, variable_def=variable_def,
+ import_scope=import_scope)
+
+
variables.default_variable_creator = default_variable_creator
+variables.default_variable_creator_v2 = default_variable_creator_v2
def _make_getter(captured_getter, captured_previous):
@@ -2450,11 +2479,12 @@ def _make_getter(captured_getter, captured_previous):
# TODO(apassos) remove forwarding symbol
-variable = variables.Variable
+variable = variables.VariableV1
+@tf_export(v1=["variable_creator_scope"])
@tf_contextlib.contextmanager
-def variable_creator_scope(variable_creator):
+def variable_creator_scope_v1(variable_creator):
"""Scope which defines a variable creation function to be used by variable().
variable_creator is expected to be a function with the following signature:
@@ -2525,3 +2555,73 @@ def variable_creator_scope(variable_creator):
"""
with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
yield
+
+
+# Note: only the docstrings differ between this and v1.
+@tf_export(v2=["variable_creator_scope"])
+@tf_contextlib.contextmanager
+def variable_creator_scope(variable_creator):
+ """Scope which defines a variable creation function to be used by variable().
+
+ variable_creator is expected to be a function with the following signature:
+
+ ```
+ def variable_creator(next_creator, **kwargs)
+ ```
+
+ The creator is supposed to eventually call the next_creator to create a
+ variable if it does want to create a variable and not call Variable or
+ ResourceVariable directly. This helps make creators composable. A creator may
+ choose to create multiple variables, return already existing variables, or
+ simply register that a variable was created and defer to the next creators in
+ line. Creators can also modify the keyword arguments seen by the next
+ creators.
+
+ Custom getters in the variable scope will eventually resolve down to these
+ custom creators when they do create variables.
+
+ The valid keyword arguments in kwds are:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, GradientTapes automatically watch
+ uses of this Variable.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ constraint: A constraint function to be applied to the variable after
+ updates by some algorithms.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ `tf.VariableSynchronization`. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ `tf.VariableAggregation`.
+
+ This set may grow over time, so it's important the signature of creators is as
+ mentioned above.
+
+ Args:
+ variable_creator: the passed creator
+
+ Yields:
+ A scope in which the creator is active
+ """
+ with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
+ yield
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 262cd61e5a..45c8618610 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -46,6 +46,11 @@ def default_variable_creator(_, **kwds):
raise NotImplementedError("variable_scope needs to be imported")
+def default_variable_creator_v2(_, **kwds):
+ del kwds
+ raise NotImplementedError("variable_scope needs to be imported")
+
+
def _make_getter(captured_getter, captured_previous):
"""To avoid capturing loop variables."""
def getter(**kwargs):
@@ -101,21 +106,21 @@ class VariableAggregation(enum.Enum):
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
- def _variable_call(cls,
- initial_value=None,
- trainable=None,
- collections=None,
- validate_shape=True,
- caching_device=None,
- name=None,
- variable_def=None,
- dtype=None,
- expected_shape=None,
- import_scope=None,
- constraint=None,
- use_resource=None,
- synchronization=VariableSynchronization.AUTO,
- aggregation=VariableAggregation.NONE):
+ def _variable_v1_call(cls,
+ initial_value=None,
+ trainable=None,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Call on Variable class. Useful to force the signature."""
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
@@ -140,14 +145,49 @@ class VariableMetaclass(type):
synchronization=synchronization,
aggregation=aggregation)
+ def _variable_v2_call(cls,
+ initial_value=None,
+ trainable=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ import_scope=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Call on Variable class. Useful to force the signature."""
+ previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
+ for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
+ previous_getter = _make_getter(getter, previous_getter)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ variable_def=variable_def,
+ dtype=dtype,
+ import_scope=import_scope,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
def __call__(cls, *args, **kwargs):
- if cls is Variable:
- return cls._variable_call(*args, **kwargs)
+ if cls is VariableV1:
+ return cls._variable_v1_call(*args, **kwargs)
+ elif cls is Variable:
+ return cls._variable_v2_call(*args, **kwargs)
else:
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
-@tf_export("Variable")
+@tf_export(v2=["Variable"])
class Variable(six.with_metaclass(VariableMetaclass,
checkpointable.CheckpointableBase)):
"""See the [Variables Guide](https://tensorflow.org/guide/variables).
@@ -267,16 +307,13 @@ class Variable(six.with_metaclass(VariableMetaclass,
def __init__(self,
initial_value=None,
trainable=True,
- collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
- expected_shape=None,
import_scope=None,
constraint=None,
- use_resource=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
"""Creates a new variable with value `initial_value`.
@@ -297,11 +334,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
- trainable: If `True`, the default, also adds the variable to the graph
- collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
- the default list of variables to use by the `Optimizer` classes.
- collections: List of graph collections keys. The new variable is added to
- these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ trainable: If `True`, the default, GradientTapes automatically watch uses
+ of this variable.
validate_shape: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
@@ -319,8 +353,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
dtype: If set, initial_value will be converted to the given type.
If `None`, either the datatype will be kept (if `initial_value` is
a Tensor), or `convert_to_tensor` will decide.
- expected_shape: A TensorShape. If set, initial_value is expected
- to have this shape.
import_scope: Optional `string`. Name scope to add to the
`Variable.` Only used when initializing from protocol buffer.
constraint: An optional projection function to be applied to the variable
@@ -330,9 +362,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
- use_resource: if True, a ResourceVariable is created; otherwise an
- old-style ref-based variable is created. When eager execution is enabled
- a resource variable is always created.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
@@ -1009,11 +1038,207 @@ class Variable(six.with_metaclass(VariableMetaclass,
raise NotImplementedError
+@tf_export(v1=["Variable"])
+class VariableV1(Variable):
+ """See the [Variables Guide](https://tensorflow.org/guide/variables).
+
+ A variable maintains state in the graph across calls to `run()`. You add a
+ variable to the graph by constructing an instance of the class `Variable`.
+
+ The `Variable()` constructor requires an initial value for the variable,
+ which can be a `Tensor` of any type and shape. The initial value defines the
+ type and shape of the variable. After construction, the type and shape of
+ the variable are fixed. The value can be changed using one of the assign
+ methods.
+
+ If you want to change the shape of a variable later you have to use an
+ `assign` Op with `validate_shape=False`.
+
+ Just like any `Tensor`, variables created with `Variable()` can be used as
+ inputs for other Ops in the graph. Additionally, all the operators
+ overloaded for the `Tensor` class are carried over to variables, so you can
+ also add nodes to the graph by just doing arithmetic on variables.
+
+ ```python
+ import tensorflow as tf
+
+ # Create a variable.
+ w = tf.Variable(<initial-value>, name=<optional-name>)
+
+ # Use the variable in the graph like any Tensor.
+ y = tf.matmul(w, ...another variable or tensor...)
+
+ # The overloaded operators are available too.
+ z = tf.sigmoid(w + y)
+
+ # Assign a new value to the variable with `assign()` or a related method.
+ w.assign(w + 1.0)
+ w.assign_add(1.0)
+ ```
+
+ When you launch the graph, variables have to be explicitly initialized before
+ you can run Ops that use their value. You can initialize a variable by
+ running its *initializer op*, restoring the variable from a save file, or
+ simply running an `assign` Op that assigns a value to the variable. In fact,
+ the variable *initializer op* is just an `assign` Op that assigns the
+ variable's initial value to the variable itself.
+
+ ```python
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the variable initializer.
+ sess.run(w.initializer)
+ # ...you now can run ops that use the value of 'w'...
+ ```
+
+ The most common initialization pattern is to use the convenience function
+ `global_variables_initializer()` to add an Op to the graph that initializes
+ all the variables. You then run that Op after launching the graph.
+
+ ```python
+ # Add an Op to initialize global variables.
+ init_op = tf.global_variables_initializer()
+
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the Op that initializes global variables.
+ sess.run(init_op)
+ # ...you can now run any Op that uses variable values...
+ ```
+
+ If you need to create a variable with an initial value dependent on another
+ variable, use the other variable's `initialized_value()`. This ensures that
+ variables are initialized in the right order.
+
+ All variables are automatically collected in the graph where they are
+ created. By default, the constructor adds the new variable to the graph
+ collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
+ `global_variables()` returns the contents of that collection.
+
+ When building a machine learning model it is often convenient to distinguish
+ between variables holding the trainable model parameters and other variables
+ such as a `global step` variable used to count training steps. To make this
+ easier, the variable constructor supports a `trainable=<bool>` parameter. If
+ `True`, the new variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
+ `trainable_variables()` returns the contents of this collection. The
+ various `Optimizer` classes use this collection as the default list of
+ variables to optimize.
+
+ WARNING: tf.Variable objects by default have a non-intuitive memory model. A
+ Variable is represented internally as a mutable Tensor which can
+ non-deterministically alias other Tensors in a graph. The set of operations
+ which consume a Variable and can lead to aliasing is undetermined and can
+ change across TensorFlow versions. Avoid writing code which relies on the
+ value of a Variable either changing or not changing as other operations
+ happen. For example, using Variable objects or simple functions thereof as
+ predicates in a `tf.cond` is dangerous and error-prone:
+
+ ```
+ v = tf.Variable(True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken.
+ ```
+
+ Here replacing adding `use_resource=True` when constructing the variable will
+ fix any nondeterminism issues:
+ ```
+ v = tf.Variable(True, use_resource=True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn)
+ ```
+
+ To use the replacement for variables which does
+ not have these issues:
+
+ * Add `use_resource=True` when constructing `tf.Variable`;
+ * Call `tf.get_variable_scope().set_use_resource(True)` inside a
+ `tf.variable_scope` before the `tf.get_variable()` call.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Creates a new variable with value `initial_value`.
+
+ The new variable is added to the graph collections listed in `collections`,
+ which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+
+ If `trainable` is `True` the variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`.
+
+ This constructor creates both a `variable` Op and an `assign` Op to set the
+ variable to its initial value.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ variable_def: `VariableDef` protocol buffer. If not `None`, recreates
+ the Variable object with its contents, referencing the variable's nodes
+ in the graph, which must already exist. The graph is not changed.
+ `variable_def` and the other arguments are mutually exclusive.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ expected_shape: A TensorShape. If set, initial_value is expected
+ to have this shape.
+ import_scope: Optional `string`. Name scope to add to the
+ `Variable.` Only used when initializing from protocol buffer.
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+ use_resource: whether to use resource variables.
+ synchronization: unused
+ aggregation: unused
+
+ Raises:
+ ValueError: If both `variable_def` and initial_value are specified.
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If eager execution is enabled.
+ """
+
+ SaveSliceInfo = Variable.SaveSliceInfo
+
+
# TODO(apassos): do not repeat all comments here
-class RefVariable(Variable):
+class RefVariable(VariableV1):
"""Ref-based implementation of variables."""
- def __init__(self,
+ def __init__(self, # pylint: disable=super-init-not-called
initial_value=None,
trainable=True,
collections=None,
@@ -1873,7 +2098,7 @@ class RefVariable(Variable):
def _OverloadAllOperators(): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
- Variable._OverloadOperator(operator)
+ Variable._OverloadOperator(operator) # pylint: disable=protected-access
# For slicing, bind getitem differently than a tensor (use SliceHelperVar
# instead)
# pylint: disable=protected-access
@@ -2395,23 +2620,22 @@ class PartitionedVariable(object):
def _get_partitions(self):
return self._partitions
- def _apply_assign_fn(self,
- assign_fn,
- value):
+ def _apply_assign_fn(self, assign_fn, value):
partition_axes = self._partition_axes()
if len(partition_axes) > 1:
raise NotImplementedError(
"Cannot do assign action along more than one dimension: %s. "
- "Multi-axis partition assign action is not supported "
- % str(partition_axes))
+ "Multi-axis partition assign action is not supported " %
+ str(partition_axes))
partition_ix = partition_axes[0]
size_splits_list = [
- var.shape[partition_ix].value for var in self._variable_list]
- value_list = array_ops.split(
- value, size_splits_list, axis=partition_ix)
+ var.shape[partition_ix].value for var in self._variable_list
+ ]
+ value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
op_list = [
assign_fn(var, value_list[idx], idx)
- for idx, var in enumerate(self._variable_list)]
+ for idx, var in enumerate(self._variable_list)
+ ]
return op_list
def assign(self, value, use_locking=False, name=None, read_value=True):
@@ -2441,7 +2665,8 @@ class PartitionedVariable(object):
return assign_list
return [assign.op for assign in assign_list]
-@tf_export("global_variables")
+
+@tf_export(v1=["global_variables"])
def global_variables(scope=None):
"""Returns global variables.
@@ -2467,7 +2692,7 @@ def global_variables(scope=None):
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
-@tf_export("all_variables")
+@tf_export(v1=["all_variables"])
@deprecated("2017-03-02", "Please use tf.global_variables instead.")
def all_variables():
"""See `tf.global_variables`."""
@@ -2492,7 +2717,7 @@ def _all_saveable_objects(scope=None):
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
-@tf_export("local_variables")
+@tf_export(v1=["local_variables"])
def local_variables(scope=None):
"""Returns local variables.
@@ -2520,7 +2745,7 @@ def local_variables(scope=None):
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
-@tf_export("model_variables")
+@tf_export(v1=["model_variables"])
def model_variables(scope=None):
"""Returns all variables in the MODEL_VARIABLES collection.
@@ -2537,7 +2762,7 @@ def model_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
-@tf_export("trainable_variables")
+@tf_export(v1=["trainable_variables"])
def trainable_variables(scope=None):
"""Returns all variables created with `trainable=True`.
@@ -2559,7 +2784,7 @@ def trainable_variables(scope=None):
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
-@tf_export("moving_average_variables")
+@tf_export(v1=["moving_average_variables"])
def moving_average_variables(scope=None):
"""Returns all variables that maintain their moving averages.
@@ -2581,7 +2806,7 @@ def moving_average_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
-@tf_export("initializers.variables", "variables_initializer")
+@tf_export(v1=["initializers.variables", "variables_initializer"])
def variables_initializer(var_list, name="init"):
"""Returns an Op that initializes a list of variables.
@@ -2607,7 +2832,7 @@ def variables_initializer(var_list, name="init"):
return control_flow_ops.no_op(name=name)
-@tf_export("initialize_variables")
+@tf_export(v1=["initialize_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
def initialize_variables(var_list, name="init"):
@@ -2615,7 +2840,7 @@ def initialize_variables(var_list, name="init"):
return variables_initializer(var_list, name=name)
-@tf_export("initializers.global_variables", "global_variables_initializer")
+@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
def global_variables_initializer():
"""Returns an Op that initializes global variables.
@@ -2629,7 +2854,7 @@ def global_variables_initializer():
return variables_initializer(global_variables())
-@tf_export("initialize_all_variables")
+@tf_export(v1=["initialize_all_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
def initialize_all_variables():
@@ -2637,7 +2862,7 @@ def initialize_all_variables():
return global_variables_initializer()
-@tf_export("initializers.local_variables", "local_variables_initializer")
+@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
def local_variables_initializer():
"""Returns an Op that initializes all local variables.
@@ -2651,7 +2876,7 @@ def local_variables_initializer():
return variables_initializer(local_variables())
-@tf_export("initialize_local_variables")
+@tf_export(v1=["initialize_local_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
def initialize_local_variables():
@@ -2659,7 +2884,7 @@ def initialize_local_variables():
return local_variables_initializer()
-@tf_export("is_variable_initialized")
+@tf_export(v1=["is_variable_initialized"])
@tf_should_use.should_use_result
def is_variable_initialized(variable):
"""Tests if a variable has been initialized.
@@ -2674,7 +2899,7 @@ def is_variable_initialized(variable):
return state_ops.is_variable_initialized(variable)
-@tf_export("assert_variables_initialized")
+@tf_export(v1=["assert_variables_initialized"])
@tf_should_use.should_use_result
def assert_variables_initialized(var_list=None):
"""Returns an Op to check if variables are initialized.
@@ -2717,7 +2942,7 @@ def assert_variables_initialized(var_list=None):
return array_ops.stack(ranks)
-@tf_export("report_uninitialized_variables")
+@tf_export(v1=["report_uninitialized_variables"])
@tf_should_use.should_use_result
def report_uninitialized_variables(var_list=None,
name="report_uninitialized_variables"):
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 875be31602..6791e1cd61 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import function
@@ -33,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
@@ -41,6 +43,8 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access
+control_flow_ops._while_v2 = sys.modules[__name__]
+
# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
# control dependencies on external nodes with at least 1 output.
# Another idea is to create const nodes outside the loop and add control edges
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index b7e217a35b..924b2e7c06 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -47,8 +47,8 @@ class SavedModelLoaderTest(test.TestCase):
def setUp(self):
"""Write test SavedModels to a temp directory."""
with session.Session(graph=ops.Graph()) as sess:
- x = variables.Variable(5, name="x")
- y = variables.Variable(11, name="y")
+ x = variables.VariableV1(5, name="x")
+ y = variables.VariableV1(11, name="y")
z = x + y
sess.run(variables.global_variables_initializer())
@@ -134,8 +134,8 @@ class SavedModelLoaderTest(test.TestCase):
def test_restore_variables(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.session(graph=ops.Graph()) as sess:
- x = variables.Variable(0, name="x")
- y = variables.Variable(0, name="y")
+ x = variables.VariableV1(0, name="x")
+ y = variables.VariableV1(0, name="y")
z = x * y
sess.run(variables.global_variables_initializer())
@@ -186,8 +186,10 @@ class SavedModelLoaderTest(test.TestCase):
"""
path = _get_export_dir("no_variable_saved_model")
with session.Session(graph=ops.Graph()) as sess:
- x = variables.Variable(5, name="x", collections=["not_global_variable"])
- y = variables.Variable(11, name="y", collections=["not_global_variable"])
+ x = variables.VariableV1(
+ 5, name="x", collections=["not_global_variable"])
+ y = variables.VariableV1(
+ 11, name="y", collections=["not_global_variable"])
self.assertFalse(variables._all_saveable_objects())
z = x + y
sess.run(variables.variables_initializer([x, y]))
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 49d52d3bee..80b75b7ee6 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -60,7 +60,7 @@ class SavedModelTest(test.TestCase):
return os.path.join(test.get_temp_dir(), label)
def _init_and_validate_variable(self, sess, variable_name, variable_value):
- v = variables.Variable(variable_value, name=variable_name)
+ v = variables.VariableV1(variable_value, name=variable_name)
sess.run(variables.global_variables_initializer())
self.assertEqual(variable_value, v.eval())
@@ -458,7 +458,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable added to a collection. SavedModel invoked to:
# - add with weights.
with self.session(graph=ops.Graph()) as sess:
- v = variables.Variable(42, name="v")
+ v = variables.VariableV1(42, name="v")
ops.add_to_collection("foo_vars", v)
sess.run(variables.global_variables_initializer())
self.assertEqual(42, v.eval())
@@ -468,7 +468,7 @@ class SavedModelTest(test.TestCase):
# SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.session(graph=ops.Graph()) as sess:
- v = variables.Variable(43, name="v")
+ v = variables.VariableV1(43, name="v")
ops.add_to_collection("bar_vars", v)
sess.run(variables.global_variables_initializer())
self.assertEqual(43, v.eval())
@@ -780,13 +780,13 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
- v3 = variables.Variable(42, name="v3")
+ v3 = variables.VariableV1(42, name="v3")
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the main_op.
@@ -815,13 +815,13 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
- v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
+ v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the legacy_init_op.
@@ -860,11 +860,11 @@ class SavedModelTest(test.TestCase):
g = ops.Graph()
with self.session(graph=g) as sess:
# Initialize variable `v1` to 1.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
# Initialize another variable `v2` to 42.
- v2 = variables.Variable(42, name="v2", trainable=False, collections=[])
+ v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
ops.add_to_collection("v", v2)
# Set up an assignment op to be run as part of the init op.
@@ -889,9 +889,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -918,9 +918,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -947,9 +947,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -1071,13 +1071,13 @@ class SavedModelTest(test.TestCase):
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
with sess.graph.device("/cpu:1"):
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
# v3 is an unsaved variable derived from v1 and v2. It is used to
# exercise the ability to run an init op when restoring a graph.
- v3 = variables.Variable(1, name="v3", trainable=False, collections=[])
+ v3 = variables.VariableV1(1, name="v3", trainable=False, collections=[])
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
init_op = control_flow_ops.group(assign_v3, name="init_op")
@@ -1140,7 +1140,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
custom_saver = training.Saver(name="my_saver")
builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver)
@@ -1162,7 +1162,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
training.Saver(name="my_saver")
builder.add_meta_graph_and_variables(sess, ["tag"])
@@ -1184,7 +1184,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["tag_0"])
@@ -1293,8 +1293,8 @@ class SavedModelTest(test.TestCase):
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled.
with session.Session(graph=ops.Graph()) as sess:
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
@@ -1303,8 +1303,8 @@ class SavedModelTest(test.TestCase):
# Add a graph with the same float32 variables and a Complex Op composing
# them with strip_default_attrs disabled.
with session.Session(graph=ops.Graph()) as sess:
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph(["bar"], strip_default_attrs=False)
@@ -1366,7 +1366,7 @@ class SavedModelTest(test.TestCase):
# Add a graph with a single variable and a test op with a defaultless
# float32 attr, "test_attr".
with session.Session(graph=ops.Graph()) as sess:
- variables.Variable(1.0, dtype=dtypes.float64, name="var")
+ variables.VariableV1(1.0, dtype=dtypes.float64, name="var")
test_ops.test_attr(T=dtypes.float32, name="test_attr")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["foo"])
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 1c1a1a54cd..384c7a82d2 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -8,6 +8,7 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
# Transitive dependencies of this target will be included in the pip package.
py_library(
@@ -21,6 +22,13 @@ py_library(
":saved_model_cli",
":saved_model_utils",
":strip_unused",
+ # The following py_library are needed because
+ # py_binary may not depend on them when --define=no_tensorflow_py_deps=true
+ # is specified. See https://github.com/tensorflow/tensorflow/issues/22390
+ ":freeze_graph_lib",
+ ":optimize_for_inference_lib",
+ ":selective_registration_header_lib",
+ ":strip_unused_lib",
],
)
@@ -44,6 +52,7 @@ py_library(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/saved_model:loader",
"@six_archive//:six",
],
diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py
index e38945fabc..5dc14a6961 100644
--- a/tensorflow/python/tools/freeze_graph_test.py
+++ b/tensorflow/python/tools/freeze_graph_test.py
@@ -60,7 +60,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
# We'll create an input graph that has a single variable containing 1.0,
# and that then multiplies it by 2.
with ops.Graph().as_default():
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
sess = session.Session()
init = variables.global_variables_initializer()
@@ -138,7 +138,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
features = parsing_ops.parse_example(examples, feature_configs)
feature = features[feature_name]
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
scores = math_ops.multiply(variable_node, feature, name="output_node")
class_feature = array_ops.fill(array_ops.shape(feature),
"class_%s" % feature_name)
@@ -174,7 +174,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
with ops.Graph().as_default():
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
sess = session.Session()
init = variables.global_variables_initializer()
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 3bd4bd75bd..1efabcd854 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -344,7 +344,7 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
raise ValueError("steps_per_run should be greater than 0")
self._num_steps = num_steps
self._last_step = last_step
- self._steps_per_run = steps_per_run
+ self._steps_per_run_initial_value = steps_per_run
def begin(self):
self._global_step_tensor = training_util.get_global_step()
@@ -353,7 +353,8 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
self._steps_per_run_variable = get_or_create_steps_per_run_variable()
def _update_steps_per_run_variable(self, global_step, session):
- steps = min(self._last_step - global_step, self._steps_per_run)
+ steps = min(self._last_step - global_step,
+ self._steps_per_run_initial_value)
self._steps_per_run_variable.load(steps, session=session)
def after_create_session(self, session, coord):
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 56c4043d9d..eff15b24ce 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -247,7 +247,7 @@ def _default_getter(name, shape, dtype, initializer=None,
def initial_value():
return initializer(
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
- return variables.Variable(
+ return variables.VariableV1(
initial_value=initial_value,
name=name,
dtype=variable_dtype,
diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py
index b36444a14c..2c4eb02d53 100644
--- a/tensorflow/python/training/evaluation.py
+++ b/tensorflow/python/training/evaluation.py
@@ -18,13 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import time
import math
+import time
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@@ -77,6 +78,59 @@ def _get_latest_eval_step_value(update_ops):
return array_ops.identity(_get_or_create_eval_step().read_value())
+class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
+ """Run hook used by the evaluation routines to run the `eval_ops` N times."""
+
+ def __init__(self, num_evals, steps_per_run=1):
+ """Constructs the run hook.
+
+ Args:
+ num_evals: The number of evaluations to run for. if set to None, will
+ iterate the dataset until all inputs are exhausted.
+ steps_per_run: Number of steps executed per run call.
+ """
+ self._num_evals = num_evals
+ self._evals_completed = None
+ self._steps_per_run_initial_value = steps_per_run
+
+ def _set_evals_completed_tensor(self, updated_eval_step):
+ self._evals_completed = updated_eval_step
+
+ def begin(self):
+ self._steps_per_run_variable = \
+ basic_session_run_hooks.get_or_create_steps_per_run_variable()
+
+ def after_create_session(self, session, coord):
+ # Update number of steps to run in the first run call
+ if self._num_evals is None:
+ steps = self._steps_per_run_initial_value
+ else:
+ steps = min(self._steps_per_run_initial_value, self._num_evals)
+ self._steps_per_run_variable.load(steps, session=session)
+
+ def before_run(self, run_context):
+ return session_run_hook.SessionRunArgs({
+ 'evals_completed': self._evals_completed
+ })
+
+ def after_run(self, run_context, run_values):
+ evals_completed = run_values.results['evals_completed']
+ # Update number of steps to run in the next iteration
+ if self._num_evals is None:
+ steps = self._steps_per_run_initial_value
+ else:
+ steps = min(self._num_evals - evals_completed,
+ self._steps_per_run_initial_value)
+ self._steps_per_run_variable.load(steps, session=run_context.session)
+
+ if self._num_evals is None:
+ logging.info('Evaluation [%d]', evals_completed)
+ else:
+ logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
+ if self._num_evals is not None and evals_completed >= self._num_evals:
+ run_context.request_stop()
+
+
class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
"""Run hook used by the evaluation routines to run the `eval_ops` N times."""
@@ -176,7 +230,15 @@ def _evaluate_once(checkpoint_path,
hooks = list(hooks or [])
if eval_ops is not None:
- update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
+ if any([isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks]):
+ steps_per_run_variable = \
+ basic_session_run_hooks.get_or_create_steps_per_run_variable()
+ update_eval_step = state_ops.assign_add(
+ eval_step,
+ math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype),
+ use_locking=True)
+ else:
+ update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
if isinstance(eval_ops, dict):
eval_ops['update_eval_step'] = update_eval_step
@@ -188,7 +250,7 @@ def _evaluate_once(checkpoint_path,
eval_step_value = _get_latest_eval_step_value(eval_ops)
for h in hooks:
- if isinstance(h, _StopAfterNEvalsHook):
+ if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)):
h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access
logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 5a9215730e..03a32f6ca0 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -63,7 +63,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
def testVariables(self):
with self.cached_session():
- step = variables.Variable(1)
+ step = variables.VariableV1(1)
assign_1 = step.assign(1)
assign_2 = step.assign(2)
assign_100 = step.assign(100)
@@ -121,7 +121,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
# Test that ref types are valid.
if not context.executing_eagerly():
- x = variables.Variable(0.0)
+ x = variables.VariableV1(0.0)
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
boundaries, values = [1.0, 2.0], [1, 2, 3]
learning_rate_decay.piecewise_constant(x_ref, boundaries, values)
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 2d7799d66a..c870d99de9 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -69,8 +69,8 @@ class ScaffoldTest(test.TestCase):
def test_defaults_empty_graph(self):
with ops.Graph().as_default():
scaffold = monitored_session.Scaffold()
- variables.Variable(1, name='my_var')
- variables.Variable(
+ variables.VariableV1(1, name='my_var')
+ variables.VariableV1(
2, name='my_local_var', collections=[ops.GraphKeys.LOCAL_VARIABLES])
scaffold.finalize()
self.assertTrue(isinstance(scaffold.init_op, ops.Operation))
@@ -105,7 +105,7 @@ class ScaffoldTest(test.TestCase):
def test_caches_values(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
scaffold1 = monitored_session.Scaffold()
scaffold1.finalize()
scaffold2 = monitored_session.Scaffold()
@@ -119,7 +119,7 @@ class ScaffoldTest(test.TestCase):
def test_raise_error_if_more_than_one_cached_item(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
with self.assertRaisesRegexp(RuntimeError, 'More than one item'):
@@ -127,7 +127,7 @@ class ScaffoldTest(test.TestCase):
def test_uses_passed_values(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold = monitored_session.Scaffold(
init_op=2,
@@ -148,7 +148,7 @@ class ScaffoldTest(test.TestCase):
def test_graph_is_finalized(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
monitored_session.Scaffold().finalize()
with self.assertRaisesRegexp(RuntimeError,
'Graph is finalized and cannot be modified'):
@@ -157,7 +157,7 @@ class ScaffoldTest(test.TestCase):
def test_new_scaffold_from_default_scaffold(self):
scaffold1 = monitored_session.Scaffold()
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold2 = monitored_session.Scaffold(
init_op=2,
@@ -180,7 +180,7 @@ class ScaffoldTest(test.TestCase):
def test_new_scaffold_from_existing_scaffold(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold1 = monitored_session.Scaffold(
init_op=2,
@@ -1374,7 +1374,7 @@ class MonitoredSessionTest(test.TestCase):
def test_defaults(self):
with ops.Graph().as_default():
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
@@ -1700,7 +1700,7 @@ class MonitoredSessionTest(test.TestCase):
def test_graph_finalized_during_run_unfinalized_after_exit(self):
with ops.Graph().as_default() as g:
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
self.assertTrue(g.finalized)
@@ -1708,7 +1708,7 @@ class MonitoredSessionTest(test.TestCase):
def test_keep_finalized_graph_as_finalized(self):
with ops.Graph().as_default() as g:
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
monitored_session.Scaffold().finalize()
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
@@ -2032,7 +2032,7 @@ class MonitoredSessionTest(test.TestCase):
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
- graph_state = variables.Variable(0.0)
+ graph_state = variables.VariableV1(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
def step_fn(step_context):
@@ -2088,7 +2088,7 @@ class MonitoredSessionTest(test.TestCase):
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
vv = constant_op.constant(3.2)
- graph_state = variables.Variable(0.0)
+ graph_state = variables.VariableV1(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
class Hook(session_run_hook.SessionRunHook):
@@ -2125,7 +2125,7 @@ class SingularMonitoredSessionTest(test.TestCase):
def test_handles_initialization(self):
with ops.Graph().as_default():
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.SingularMonitoredSession() as session:
# If it's not initialized, following statement raises an error.
self.assertEqual(0, session.run(a_var))
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 699162b30c..30b0ed20c8 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -471,7 +471,10 @@ class Optimizer(
if var_list is None:
var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
+ # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
+ # to be executed.
+ with ops.control_dependencies([loss_value]):
+ grads = tape.gradient(loss_value, var_list, grad_loss)
return list(zip(grads, var_list))
# Non-callable/Tensor loss case
@@ -585,7 +588,7 @@ class Optimizer(
var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s." %
- ([str(v) for _, _, v in converted_grads_and_vars],))
+ ([str(v) for _, v, _ in converted_grads_and_vars],))
with ops.init_scope():
self._create_slots(var_list)
update_ops = []
diff --git a/tensorflow/python/training/quantize_training_test.py b/tensorflow/python/training/quantize_training_test.py
index 9754adea85..6edbf7665f 100644
--- a/tensorflow/python/training/quantize_training_test.py
+++ b/tensorflow/python/training/quantize_training_test.py
@@ -58,7 +58,8 @@ class PywrapQuantizeTrainingTest(test.TestCase):
g = ops.Graph()
with session.Session(graph=g) as sess:
a = constant_op.constant(6.0, shape=[1, 1], name='a')
- b = variables.Variable(constant_op.constant(7.0, shape=[1, 1]), name='b')
+ b = variables.VariableV1(
+ constant_op.constant(7.0, shape=[1, 1]), name='b')
c = math_ops.matmul(a, b, name='matmul')
init_op = variables.global_variables_initializer()
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index 9b9e28af2b..15fe42bbd8 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -44,7 +44,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -64,9 +64,9 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var0 = variables.Variable(zero64)
+ var0 = variables.VariableV1(zero64)
count_up_to_3 = var0.count_up_to(3)
- var1 = variables.Variable(zero64)
+ var1 = variables.VariableV1(zero64)
count_up_to_30 = var1.count_up_to(30)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30])
@@ -131,7 +131,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -184,7 +184,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
with session.Session() as other_sess:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -199,7 +199,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -215,7 +215,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -250,7 +250,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunners(self):
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -267,7 +267,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunnersRaisesIfNotASession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -280,7 +280,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunnersIgnoresMonitoredSession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -297,7 +297,7 @@ class QueueRunnerTest(test.TestCase):
graph = ops.Graph()
with graph.as_default():
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 69b1055ebe..49e6e6546d 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -311,8 +311,8 @@ class SaverTest(test.TestCase):
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver(
@@ -350,8 +350,8 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the parameter nodes
# have not been initialized either.
with self.cached_session() as sess:
- v0 = variables.Variable(-1.0, name="v0")
- v1 = variables.Variable(-1.0, name="v1")
+ v0 = variables.VariableV1(-1.0, name="v0")
+ v1 = variables.VariableV1(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
@@ -370,7 +370,7 @@ class SaverTest(test.TestCase):
self.assertEqual(30.0, v2.values().eval())
def testFilenameTensor(self):
- v0 = variables.Variable(0, name="v0")
+ v0 = variables.VariableV1(0, name="v0")
filename = b"somerandomfilename"
save = saver_module.Saver({"v0": v0}, filename=filename)
with self.cached_session() as sess:
@@ -379,7 +379,7 @@ class SaverTest(test.TestCase):
self.assertEqual(sess.run(tensor), filename)
def testInvalidPath(self):
- v0 = variables.Variable(0, name="v0")
+ v0 = variables.VariableV1(0, name="v0")
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.cached_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
@@ -392,7 +392,7 @@ class SaverTest(test.TestCase):
with self.cached_session() as sess:
# Build a graph with 1 node, and save and restore for them.
- v = variables.Variable(np.int64(15), name="v")
+ v = variables.VariableV1(np.int64(15), name="v")
save = saver_module.Saver({"v": v}, restore_sequentially=True)
variables.global_variables_initializer().run()
@@ -402,7 +402,7 @@ class SaverTest(test.TestCase):
self.assertEqual(save_path, val)
with self.cached_session() as sess:
- v = variables.Variable(np.int64(-1), name="v")
+ v = variables.VariableV1(np.int64(-1), name="v")
save = saver_module.Saver({"v": v})
with self.assertRaisesWithPredicateMatch(
@@ -416,9 +416,9 @@ class SaverTest(test.TestCase):
def testSomeErrors(self):
with ops_lib.Graph().as_default():
- v0 = variables.Variable([10.0], name="v0")
- v1 = variables.Variable([20.0], name="v1")
- v2 = variables.Variable([20.0], name="v2")
+ v0 = variables.VariableV1([10.0], name="v0")
+ v1 = variables.VariableV1([20.0], name="v1")
+ v2 = variables.VariableV1([20.0], name="v2")
v2._set_save_slice_info(
variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
@@ -446,7 +446,7 @@ class SaverTest(test.TestCase):
def testSameName(self):
with ops_lib.Graph().as_default():
- v0 = variables.Variable([10.0], name="v0")
+ v0 = variables.VariableV1([10.0], name="v0")
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saving one variable under two names raises an error.
@@ -468,8 +468,8 @@ class SaverTest(test.TestCase):
with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver([v0, v1, v2.saveable])
@@ -490,8 +490,8 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the variables
# have not been initialized either.
with self.session(graph=ops_lib.Graph()) as sess:
- v0 = variables.Variable(-1.0, name="v0")
- v1 = variables.Variable(-1.0, name="v1")
+ v0 = variables.VariableV1(-1.0, name="v0")
+ v1 = variables.VariableV1(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver([v0, v1, v2.saveable])
@@ -515,8 +515,8 @@ class SaverTest(test.TestCase):
# Build another graph with 2 nodes, initialized
# differently, and a Restore node for them.
with self.session(graph=ops_lib.Graph()) as sess:
- v0_2 = variables.Variable(1000.0, name="v0")
- v1_2 = variables.Variable(2000.0, name="v1")
+ v0_2 = variables.VariableV1(1000.0, name="v0")
+ v1_2 = variables.VariableV1(2000.0, name="v1")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
v2_2.insert("k1000", 3000.0).run()
@@ -574,14 +574,14 @@ class SaverTest(test.TestCase):
save_path = os.path.join(self.get_temp_dir(), "gpu")
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_1 = variables.Variable(123.45)
+ v0_1 = variables.VariableV1(123.45)
save = saver_module.Saver({"v0": v0_1})
variables.global_variables_initializer().run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_2 = variables.Variable(543.21)
+ v0_2 = variables.VariableV1(543.21)
save = saver_module.Saver({"v0": v0_2})
variables.global_variables_initializer().run()
@@ -591,22 +591,22 @@ class SaverTest(test.TestCase):
save_path = os.path.join(self.get_temp_dir(), "gpu")
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_1 = variables.Variable(123.45)
+ v0_1 = variables.VariableV1(123.45)
save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
variables.global_variables_initializer().run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_2 = variables.Variable(543.21)
+ v0_2 = variables.VariableV1(543.21)
save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
variables.global_variables_initializer().run()
def testVariables(self):
save_path = os.path.join(self.get_temp_dir(), "variables")
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(1.0)
- twos = variables.Variable([2.0, 2.0, 2.0])
+ one = variables.VariableV1(1.0)
+ twos = variables.VariableV1([2.0, 2.0, 2.0])
v2 = saver_test_utils.CheckpointedOp(name="v2")
init = variables.global_variables_initializer()
save = saver_module.Saver()
@@ -615,8 +615,8 @@ class SaverTest(test.TestCase):
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(0.0)
- twos = variables.Variable([0.0, 0.0, 0.0])
+ one = variables.VariableV1(0.0)
+ twos = variables.VariableV1([0.0, 0.0, 0.0])
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
@@ -628,14 +628,14 @@ class SaverTest(test.TestCase):
def testVarListShouldBeEmptyInDeferredBuild(self):
with ops_lib.Graph().as_default():
- v = variables.Variable(1.0)
+ v = variables.VariableV1(1.0)
with self.assertRaisesRegexp(ValueError, "defer_build"):
saver_module.Saver([v], defer_build=True)
def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):
save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")
with ops_lib.Graph().as_default(), session.Session() as sess:
- variables.Variable(1.0)
+ variables.VariableV1(1.0)
saver = saver_module.Saver(defer_build=True)
with self.assertRaisesRegexp(RuntimeError, "build"):
saver.save(sess, save_path)
@@ -643,18 +643,18 @@ class SaverTest(test.TestCase):
def testDeferredBuild(self):
save_path = os.path.join(self.get_temp_dir(), "deferred_build")
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(1.0)
+ one = variables.VariableV1(1.0)
save = saver_module.Saver(defer_build=True)
# if build is not deferred, saver cannot save the `twos`.
- twos = variables.Variable([2.0, 2.0, 2.0])
+ twos = variables.VariableV1([2.0, 2.0, 2.0])
init = variables.global_variables_initializer()
save.build()
init.run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(0.0)
- twos = variables.Variable([0.0, 0.0, 0.0])
+ one = variables.VariableV1(0.0)
+ twos = variables.VariableV1([0.0, 0.0, 0.0])
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
save.restore(sess, save_path)
@@ -664,7 +664,7 @@ class SaverTest(test.TestCase):
def testReshape(self):
save_path = os.path.join(self.get_temp_dir(), "variables_reshape")
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ var = variables.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
init = variables.global_variables_initializer()
save = saver_module.Saver()
init.run()
@@ -672,7 +672,7 @@ class SaverTest(test.TestCase):
# Error when restoring with default reshape=False
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
+ var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
save = saver_module.Saver()
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
@@ -681,7 +681,7 @@ class SaverTest(test.TestCase):
# Restored to new shape with reshape=True
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
+ var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
save = saver_module.Saver(reshape=True)
save.restore(sess, save_path)
self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval())
@@ -731,8 +731,8 @@ class SaverTest(test.TestCase):
for save_path in paths:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
@@ -770,8 +770,8 @@ class SaverTest(test.TestCase):
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
@@ -859,10 +859,10 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(10, name="v0")
+ v0 = variables.VariableV1(10, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(20, name="v1")
+ v1 = variables.VariableV1(20, name="v1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -890,7 +890,7 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
save = saver_module.Saver(
{
@@ -914,7 +914,7 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v1 = variables.Variable(222)
+ v1 = variables.VariableV1(222)
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -938,10 +938,10 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(222, name="v1")
+ v1 = variables.VariableV1(222, name="v1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -984,7 +984,7 @@ class SaveRestoreShardedTest(test.TestCase):
def testSaverDef(self):
with self.cached_session():
- v0 = variables.Variable(123, name="v0")
+ v0 = variables.VariableV1(123, name="v0")
save = saver_module.Saver({"v0": v0}, sharded=True)
sd = save.as_saver_def()
self.assertTrue(sd.sharded)
@@ -1023,7 +1023,7 @@ class SaveRestoreShardedTest(test.TestCase):
if use_resource:
vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]
else:
- vs = [variables.Variable(rnd, name=var_name)]
+ vs = [variables.VariableV1(rnd, name=var_name)]
variables.global_variables_initializer().run()
if call_saver_with_dict:
@@ -1054,7 +1054,7 @@ class SaveRestoreShardedTest(test.TestCase):
]
else:
new_vs = [
- variables.Variable(
+ variables.VariableV1(
array_ops.zeros(
shape=var_full_shape), # != original contents.
name=var_name)
@@ -1210,7 +1210,7 @@ class MaxToKeepTest(test.TestCase):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
variables.global_variables_initializer().run()
self.assertEqual([], save.last_checkpoints)
@@ -1389,9 +1389,9 @@ class MaxToKeepTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(222, name="v1")
+ v1 = variables.VariableV1(222, name="v1")
save = saver_module.Saver(
{
"v0": v0,
@@ -1448,7 +1448,7 @@ class MaxToKeepTest(test.TestCase):
save_dir2 = self._get_test_dir("max_to_keep_0")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
variables.global_variables_initializer().run()
# Test max_to_keep being None.
@@ -1475,7 +1475,7 @@ class MaxToKeepTest(test.TestCase):
save_dir = self._get_test_dir("no_meta_graph")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
save = saver_module.Saver({"v": v})
variables.global_variables_initializer().run()
@@ -1632,13 +1632,13 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(1.0, name="v0")
+ v0 = variables.VariableV1(1.0, name="v0")
control_flow_ops.cond(
math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),
lambda: math_ops.subtract(v0, 1))
control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),
lambda i: math_ops.add(i, 1), [v0])
- var = variables.Variable(constant_op.constant(0, dtype=dtypes.int64))
+ var = variables.VariableV1(constant_op.constant(0, dtype=dtypes.int64))
count_up_to = var.count_up_to(3)
input_queue = data_flow_ops.FIFOQueue(
30, dtypes.float32, shared_name="collection_queue")
@@ -1687,7 +1687,7 @@ class MetaGraphTest(test.TestCase):
def testAddCollectionDefFails(self):
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(10.0, name="v0")
+ v0 = variables.VariableV1(10.0, name="v0")
# Creates a saver.
save = saver_module.Saver({"v0": v0})
# Generates MetaGraphDef.
@@ -1711,8 +1711,8 @@ class MetaGraphTest(test.TestCase):
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
- v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
- v1 = variables.Variable(11.0, name="v1")
+ v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
+ v1 = variables.VariableV1(11.0, name="v1")
# Creates 2 savers.
saver0 = saver_module.Saver({"v0": v0}, name="saver0")
saver1 = saver_module.Saver({"v1": v1}, name="saver1")
@@ -1788,8 +1788,8 @@ class MetaGraphTest(test.TestCase):
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
- v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
- v1 = variables.Variable(11.0, name="v1")
+ v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
+ v1 = variables.VariableV1(11.0, name="v1")
# Creates 2 savers.
saver0 = saver_module.Saver({"v0": v0}, name="saver0")
@@ -1840,7 +1840,7 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
with self.session(graph=ops_lib.Graph()):
# Creates a graph.
- variables.Variable(10.0, name="v0")
+ variables.VariableV1(10.0, name="v0")
# Exports the graph as binary format.
saver_module.export_meta_graph(filename, as_text=False)
with self.session(graph=ops_lib.Graph()):
@@ -1871,8 +1871,8 @@ class MetaGraphTest(test.TestCase):
test_dir = self._get_test_dir("slice_saver")
filename = os.path.join(test_dir, "metafile")
with self.cached_session():
- v1 = variables.Variable([20.0], name="v1")
- v2 = variables.Variable([20.0], name="v2")
+ v1 = variables.VariableV1([20.0], name="v1")
+ v2 = variables.VariableV1([20.0], name="v2")
v2._set_save_slice_info(
variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
@@ -1899,7 +1899,7 @@ class MetaGraphTest(test.TestCase):
# Hidden 1
images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])
with ops_lib.name_scope("hidden1"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
@@ -1907,7 +1907,7 @@ class MetaGraphTest(test.TestCase):
# the save and restore of control flow context (which doesn't make any
# sense here from a machine learning perspective). The typical biases is
# a simple Variable without the conditions.
- biases = variables.Variable(
+ biases = variables.VariableV1(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
@@ -1915,7 +1915,7 @@ class MetaGraphTest(test.TestCase):
hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -1933,15 +1933,16 @@ class MetaGraphTest(test.TestCase):
_, biases = control_flow_ops.while_loop(
loop_cond, loop_body,
- [constant_op.constant(0), variables.Variable(array_ops.zeros([32]))])
+ [constant_op.constant(0),
+ variables.VariableV1(array_ops.zeros([32]))])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights) + biases
ops_lib.add_to_collection("logits", logits)
init_all_op = variables.global_variables_initializer()
@@ -2028,7 +2029,7 @@ class MetaGraphTest(test.TestCase):
# Create while loop using `outer_body_fn`.
with ops_lib.Graph().as_default():
- var = variables.Variable(0.0)
+ var = variables.VariableV1(0.0)
var_name = var.name
output = graph_fn(var)
output_name = output.name
@@ -2122,8 +2123,8 @@ class MetaGraphTest(test.TestCase):
def testStrippedOpListDef(self):
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(0.0)
- var = variables.Variable(10.0)
+ v0 = variables.VariableV1(0.0)
+ var = variables.VariableV1(10.0)
math_ops.add(v0, var)
@function.Defun(dtypes.float32)
@@ -2161,8 +2162,8 @@ class MetaGraphTest(test.TestCase):
# With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
# (complex64) in the "Complex" op must be removed.
with self.cached_session():
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
@@ -2178,8 +2179,8 @@ class MetaGraphTest(test.TestCase):
# (complex64) in the "Complex" op must *not* be removed, even if they map
# to their defaults.
with self.session(graph=ops_lib.Graph()):
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
@@ -2198,9 +2199,9 @@ class MetaGraphTest(test.TestCase):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
with session.Session() as sess:
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2243,7 +2244,7 @@ class MetaGraphTest(test.TestCase):
self.assertIsNone(new_saver_1)
# Create a variable in graph_2 under scope "my_scope".
- variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
+ variables.VariableV1(array_ops.zeros([10]), name="my_scope/my_var")
sess.run(variables.global_variables_initializer())
# Restore the checkpoint into a different scope "subgraph_2".
new_saver_2 = saver_module.import_meta_graph(
@@ -2268,9 +2269,9 @@ class MetaGraphTest(test.TestCase):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
with session.Session() as sess:
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2299,9 +2300,9 @@ class MetaGraphTest(test.TestCase):
with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2332,9 +2333,9 @@ class MetaGraphTest(test.TestCase):
with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2385,9 +2386,9 @@ class CheckpointReaderTest(test.TestCase):
def testDebugString(self):
# Builds a graph.
- v0 = variables.Variable(
+ v0 = variables.VariableV1(
[[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
- v1 = variables.Variable(
+ v1 = variables.VariableV1(
[[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")
init_all_op = variables.global_variables_initializer()
save = saver_module.Saver(
@@ -2444,7 +2445,8 @@ class WriteGraphTest(test.TestCase):
def testWriteGraph(self):
test_dir = self._get_test_dir("write_graph_dir")
- variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
+ variables.VariableV1(
+ [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph(),
os.path.join(test_dir, "l1"), "graph.pbtxt")
truth = os.path.join(test_dir, "l1", "graph.pbtxt")
@@ -2453,7 +2455,8 @@ class WriteGraphTest(test.TestCase):
def testRecursiveCreate(self):
test_dir = self._get_test_dir("deep_dir")
- variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
+ variables.VariableV1(
+ [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
os.path.join(test_dir, "l1", "l2", "l3"),
"graph.pbtxt")
@@ -2477,7 +2480,7 @@ class ScopedGraphTest(test.TestCase):
images = constant_op.constant(
1.2, dtypes.float32, shape=[100, 28], name="images")
with ops_lib.name_scope("hidden1"):
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
@@ -2485,7 +2488,7 @@ class ScopedGraphTest(test.TestCase):
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
- biases1 = variables.Variable(
+ biases1 = variables.VariableV1(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
@@ -2494,7 +2497,7 @@ class ScopedGraphTest(test.TestCase):
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights2 = variables.Variable(
+ weights2 = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -2511,16 +2514,16 @@ class ScopedGraphTest(test.TestCase):
return it + 1, biases2
_, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [
- constant_op.constant(0), variables.Variable(array_ops.zeros([32]))
+ constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights3 = variables.Variable(
+ weights3 = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases3 = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases3 = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights3) + biases3
ops_lib.add_to_collection("logits", logits)
@@ -2566,7 +2569,7 @@ class ScopedGraphTest(test.TestCase):
with graph.as_default():
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -2583,16 +2586,16 @@ class ScopedGraphTest(test.TestCase):
return it + 1, biases
_, biases = control_flow_ops.while_loop(loop_cond, loop_body, [
- constant_op.constant(0), variables.Variable(array_ops.zeros([32]))
+ constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights) + biases
ops_lib.add_to_collection("logits", logits)
@@ -2629,9 +2632,9 @@ class ScopedGraphTest(test.TestCase):
with ops_lib.name_scope("hidden1"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
- biases1 = variables.Variable([0.1] * 3, name="biases")
+ biases1 = variables.VariableV1([0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
@@ -2685,9 +2688,9 @@ class ScopedGraphTest(test.TestCase):
with ops_lib.name_scope("hidden1"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
- biases1 = variables.Variable([0.1] * 3, name="biases")
+ biases1 = variables.VariableV1([0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
@@ -2720,12 +2723,12 @@ class ScopedGraphTest(test.TestCase):
graph = ops_lib.Graph()
with graph.as_default():
with ops_lib.name_scope("hidden1"):
- variable1 = variables.Variable([1.0], name="variable1")
+ variable1 = variables.VariableV1([1.0], name="variable1")
saver1 = saver_module.Saver(var_list=[variable1])
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)
with ops_lib.name_scope("hidden2"):
- variable2 = variables.Variable([2.0], name="variable2")
+ variable2 = variables.VariableV1([2.0], name="variable2")
saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)
@@ -2978,7 +2981,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with ops_lib.Graph().as_default() as g:
- a = variables.Variable(1., name="a")
+ a = variables.VariableV1(1., name="a")
a_saver = saver_module.Saver([a])
with self.session(graph=g) as sess:
@@ -2986,7 +2989,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with ops_lib.Graph().as_default() as g:
- a = variables.Variable([1.], name="a")
+ a = variables.VariableV1([1.], name="a")
a_saver = saver_module.Saver([a])
with self.session(graph=g) as sess:
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/training/server_lib_same_variables_no_clear_test.py b/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
index c7e84e9ba1..5aa7f45c2b 100644
--- a/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
+++ b/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
@@ -37,8 +37,8 @@ class SameVariablesNoClearTest(test.TestCase):
server = server_lib.Server.create_local_server()
with session.Session(server.target) as sess_1:
- v0 = variables.Variable([[2, 1]], name="v0")
- v1 = variables.Variable([[1], [2]], name="v1")
+ v0 = variables.VariableV1([[2, 1]], name="v0")
+ v1 = variables.VariableV1([[1], [2]], name="v1")
v2 = math_ops.matmul(v0, v1)
sess_1.run([v0.initializer, v1.initializer])
self.assertAllEqual([[4]], sess_1.run(v2))
diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py
index 063044f0d0..cf995707fc 100644
--- a/tensorflow/python/training/server_lib_test.py
+++ b/tensorflow/python/training/server_lib_test.py
@@ -76,9 +76,9 @@ class GrpcServerTest(test.TestCase):
def testResetFails(self):
# Creates variable with container name.
with ops.container("test0"):
- v0 = variables.Variable(1.0, name="v0")
+ v0 = variables.VariableV1(1.0, name="v0")
# Creates variable with default container.
- v1 = variables.Variable(2.0, name="v1")
+ v1 = variables.VariableV1(2.0, name="v1")
# Verifies resetting the non-existent target returns error.
with self.assertRaises(errors_impl.NotFoundError):
session.Session.reset("nonexistent", ["test0"])
@@ -234,8 +234,8 @@ class GrpcServerTest(test.TestCase):
[0.], dtype=dtypes.float32))
self.assertIsNotNone(input_queue)
- var = variables.Variable(1., dtype=dtypes.float32, trainable=False,
- name="var")
+ var = variables.VariableV1(1., dtype=dtypes.float32, trainable=False,
+ name="var")
sess.run(variables.global_variables_initializer())
queue_runner_impl.start_queue_runners(sess)
@@ -245,7 +245,7 @@ class GrpcServerTest(test.TestCase):
server = self._cached_server
init_value = array_ops.placeholder(dtypes.int32)
- v = variables.Variable(init_value, validate_shape=False, name="v")
+ v = variables.VariableV1(init_value, validate_shape=False, name="v")
sharing_config = config_pb2.ConfigProto(isolate_session_state=False)
sharing_sess_0 = session.Session(server.target, config=sharing_config)
@@ -302,7 +302,7 @@ class GrpcServerTest(test.TestCase):
isolate_config = config_pb2.ConfigProto(isolate_session_state=True)
with ops.Graph().as_default():
- w_vector = variables.Variable([1, 2, 3], name="w")
+ w_vector = variables.VariableV1([1, 2, 3], name="w")
with session.Session(server.target, config=sharing_config) as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(w_vector)
@@ -310,20 +310,20 @@ class GrpcServerTest(test.TestCase):
self.assertAllEqual([1, 2, 3], sess.run(w_vector))
with ops.Graph().as_default():
- w_vector = variables.Variable([4, 5, 6], name="w")
+ w_vector = variables.VariableV1([4, 5, 6], name="w")
with session.Session(server.target, config=sharing_config) as sess:
self.assertAllEqual([1, 2, 3], sess.run(w_vector))
sess.run(w_vector.initializer)
self.assertAllEqual([4, 5, 6], sess.run(w_vector))
with ops.Graph().as_default():
- w_scalar = variables.Variable(86, name="w")
+ w_scalar = variables.VariableV1(86, name="w")
with session.Session(server.target, config=sharing_config) as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(w_scalar.initializer)
with ops.Graph().as_default():
- w_scalar = variables.Variable(37, name="w")
+ w_scalar = variables.VariableV1(37, name="w")
with session.Session(server.target, config=isolate_config) as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(w_scalar)
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index a2e0645ba8..5e4749f306 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -182,6 +183,10 @@ class SessionManager(object):
"""
self._target = master
sess = session.Session(self._target, graph=self._graph, config=config)
+ # TODO(jhseu): Delete once tpu.initialize_system() goes away.
+ sess.run(
+ distribution_strategy_context.get_distribution_strategy().initialize()
+ )
if checkpoint_dir and checkpoint_filename_with_path:
raise ValueError("Can not provide both checkpoint_dir and "
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index f1d18f7704..2b5c3b01de 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -40,7 +40,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceeds(self):
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -50,7 +50,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFeedDict(self):
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -61,7 +61,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFn(self):
with ops.Graph().as_default():
- v = variables.Variable([125], name="v")
+ v = variables.VariableV1([125], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -79,7 +79,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -97,7 +97,7 @@ class SessionManagerTest(test.TestCase):
# Renames the checkpoint directory.
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
- v = variables.Variable([6.0, 7.0, 8.0], name="v")
+ v = variables.VariableV1([6.0, 7.0, 8.0], name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
@@ -134,7 +134,7 @@ class SessionManagerTest(test.TestCase):
checkpoint_filename_with_path=None):
# Create a new Graph and SessionManager and recover from a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
+ v = variables.VariableV1(2, name="v")
with session_lib.Session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm2 = session_manager.SessionManager(
@@ -162,7 +162,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -186,7 +186,7 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionReturnsNoneAfterTimeout(self):
with ops.Graph().as_default():
- variables.Variable(1, name="v")
+ variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables(),
recovery_wait_secs=1)
@@ -217,7 +217,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -230,8 +230,8 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -275,7 +275,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -288,8 +288,8 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -321,7 +321,7 @@ class SessionManagerTest(test.TestCase):
# local_init_op exactly once, regardless of whether the session was
# successfully recovered.
with ops.Graph().as_default():
- w = variables.Variable(
+ w = variables.VariableV1(
1,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -356,8 +356,8 @@ class SessionManagerTest(test.TestCase):
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
1,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -389,8 +389,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionLocalInit(self):
server = server_lib.Server.create_local_server()
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -420,8 +420,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -439,8 +439,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionInsufficientReadyForLocalInitCheck(self):
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -456,13 +456,13 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithReadyForLocalInitOp(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- x = variables.Variable(
+ x = variables.VariableV1(
3 * v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -495,25 +495,25 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithPartialInitOp(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- x = variables.Variable(
+ x = variables.VariableV1(
3 * v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x")
# TODO(b/70206927): Use ResourceVariables once they are handled properly.
- v_res = variables.Variable(1, name="v_res")
- w_res = variables.Variable(
+ v_res = variables.VariableV1(1, name="v_res")
+ w_res = variables.VariableV1(
v_res,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w_res")
- x_res = variables.Variable(
+ x_res = variables.VariableV1(
3 * v_res,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -565,7 +565,7 @@ class SessionManagerTest(test.TestCase):
# cyclic dependencies.
with ops.Graph().as_default():
i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
- v = variables.Variable(array_ops.identity(i), name="v")
+ v = variables.VariableV1(array_ops.identity(i), name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm = session_manager.SessionManager(
@@ -579,8 +579,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionDidNotInitLocalVariable(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -596,8 +596,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionDidNotInitLocalVariableList(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -613,8 +613,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithReadyNotReadyForLocal(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -634,8 +634,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -656,7 +656,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceeds(self):
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -666,7 +666,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFeedDict(self):
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -677,7 +677,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFn(self):
with ops.Graph().as_default():
- v = variables.Variable([125], name="v")
+ v = variables.VariableV1([125], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -695,7 +695,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
saver = saver_lib.Saver({"v": v})
@@ -713,7 +713,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
# Renames the checkpoint directory.
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
- v = variables.Variable([6.0, 7.0, 8.0], name="v")
+ v = variables.VariableV1([6.0, 7.0, 8.0], name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
@@ -755,7 +755,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
saver = saver_lib.Saver({"v": v})
@@ -768,7 +768,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
+ v = variables.VariableV1(2, name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm2 = session_manager.SessionManager(
@@ -785,7 +785,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testWaitForSessionReturnsNoneAfterTimeout(self):
with ops.Graph().as_default():
- variables.Variable(1, name="v")
+ variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized(),
recovery_wait_secs=1)
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 0755364bbe..a5e626d320 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -242,10 +242,9 @@ class Supervisor(object):
ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
supervisors in `prepare_or_wait_for_session()` to check if the model is
ready to run the local_init_op.
- The model is considered ready if it returns an empty array. Defaults to
- the tensor returned from
- `tf.report_uninitialized_variables(tf.global_variables())`. If `None`,
- the model is not checked for readiness before running local_init_op.
+ The model is considered ready if it returns an empty array. Defaults to
+ `None`. If `None`, the model is not checked for readiness before running
+ local_init_op.
is_chief: If True, create a chief supervisor in charge of initializing
and restoring the model. If False, create a supervisor that relies
on a chief supervisor for inits and restore.
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index caf6eba3e0..7cd99d8680 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -423,7 +423,7 @@ class SupervisorTest(test.TestCase):
def testLogdirButExplicitlyNoSummaryWriter(self):
logdir = self._test_dir("explicit_no_summary_writer")
with ops.Graph().as_default():
- variables.Variable([1.0], name="foo")
+ variables.VariableV1([1.0], name="foo")
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
@@ -491,7 +491,7 @@ class SupervisorTest(test.TestCase):
def testNoLogdirSucceeds(self):
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0])
+ variables.VariableV1([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir="", summary_op=None)
sess = sv.prepare_or_wait_for_session("")
sess.close()
@@ -499,7 +499,7 @@ class SupervisorTest(test.TestCase):
def testUseSessionManager(self):
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0])
+ variables.VariableV1([1.0, 2.0, 3.0])
sm = session_manager_lib.SessionManager()
# Pass in session_manager. The additional init_op is ignored.
sv = supervisor.Supervisor(logdir="", session_manager=sm)
@@ -508,7 +508,7 @@ class SupervisorTest(test.TestCase):
def testInitOp(self):
logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0])
+ v = variables.VariableV1([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
@@ -517,7 +517,7 @@ class SupervisorTest(test.TestCase):
def testInitFn(self):
logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0])
+ v = variables.VariableV1([1.0, 2.0, 3.0])
def _init_fn(sess):
sess.run(v.initializer)
@@ -531,7 +531,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("feed_dict_init_op")
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sv = supervisor.Supervisor(
logdir=logdir,
init_op=variables.global_variables_initializer(),
@@ -550,10 +550,10 @@ class SupervisorTest(test.TestCase):
g = ops.Graph()
with g.as_default():
with ops.device("/job:local"):
- v = variables.Variable(
+ v = variables.VariableV1(
1, name="default_ready_for_local_init_op_v_" + str(uid))
vadd = v.assign_add(1)
- w = variables.Variable(
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -590,7 +590,7 @@ class SupervisorTest(test.TestCase):
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable(
+ v = variables.VariableV1(
10.0, name="ready_for_local_init_op_restore_v_" + str(uid))
summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v)
sv = supervisor.Supervisor(logdir=logdir)
@@ -607,10 +607,10 @@ class SupervisorTest(test.TestCase):
g = ops.Graph()
with g.as_default():
with ops.device("/job:local"):
- v = variables.Variable(
+ v = variables.VariableV1(
1.0, name="ready_for_local_init_op_restore_v_" + str(uid))
vadd = v.assign_add(1)
- w = variables.Variable(
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -642,13 +642,13 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("default_local_init_op")
with ops.Graph().as_default():
# A local variable.
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
# An entity which is initialized through a TABLE_INITIALIZER.
- w = variables.Variable([4, 5, 6], trainable=False, collections=[])
+ w = variables.VariableV1([4, 5, 6], trainable=False, collections=[])
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer)
# This shouldn't add a variable to the VARIABLES collection responsible
@@ -668,7 +668,7 @@ class SupervisorTest(test.TestCase):
with ops.Graph().as_default():
with ops.device("/job:localhost"):
# A local variable.
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
@@ -687,8 +687,8 @@ class SupervisorTest(test.TestCase):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_init_op_fails")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
- variables.Variable([4.0, 5.0, 6.0], name="w")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([4.0, 5.0, 6.0], name="w")
# w will not be initialized.
sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer)
with self.assertRaisesRegexp(RuntimeError,
@@ -699,11 +699,11 @@ class SupervisorTest(test.TestCase):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_init_op_fails_for_local_variable")
with ops.Graph().as_default():
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
name="v",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
- variables.Variable(
+ variables.VariableV1(
[1.0, 2.0, 3.0],
name="w",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
@@ -716,17 +716,17 @@ class SupervisorTest(test.TestCase):
def testSetupFail(self):
logdir = self._test_dir("setup_fail")
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([1.0, 2.0, 3.0], name="v")
with self.assertRaisesRegexp(ValueError, "must have their device set"):
supervisor.Supervisor(logdir=logdir, is_chief=False)
with ops.Graph().as_default(), ops.device("/job:ps"):
- variables.Variable([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([1.0, 2.0, 3.0], name="v")
supervisor.Supervisor(logdir=logdir, is_chief=False)
def testDefaultGlobalStep(self):
logdir = self._test_dir("default_global_step")
with ops.Graph().as_default():
- variables.Variable(287, name="global_step")
+ variables.VariableV1(287, name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
self.assertEquals(287, sess.run(sv.global_step))
@@ -735,7 +735,7 @@ class SupervisorTest(test.TestCase):
def testRestoreFromMetaGraph(self):
logdir = self._test_dir("restore_from_meta_graph")
with ops.Graph().as_default():
- variables.Variable(1, name="v0")
+ variables.VariableV1(1, name="v0")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
filename = sv.saver.save(sess, sv.save_path)
@@ -757,7 +757,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("standard_services_without_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable([1.0], name="foo")
+ v = variables.VariableV1([1.0], name="foo")
summary.scalar("v", v[0])
sv = supervisor.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
@@ -796,7 +796,7 @@ class SupervisorTest(test.TestCase):
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([10.10], name="foo")
+ v = variables.VariableV1([10.10], name="foo")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(1.0, v.eval()[0])
@@ -807,7 +807,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("standard_services_with_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable([123], name="global_step")
+ v = variables.VariableV1([123], name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
saver_def=sv.saver.saver_def)
@@ -860,7 +860,7 @@ class SupervisorTest(test.TestCase):
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([-12], name="global_step")
+ v = variables.VariableV1([-12], name="global_step")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(123, v.eval()[0])
diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py
index fff17402e2..1ef8756ef6 100644
--- a/tensorflow/python/training/sync_replicas_optimizer_test.py
+++ b/tensorflow/python/training/sync_replicas_optimizer_test.py
@@ -40,11 +40,12 @@ def get_workers(num_workers, replicas_to_aggregate, workers):
is_chief = (worker_id == 0)
with graph.as_default():
with ops.device("/job:ps/task:0"):
- global_step = variables.Variable(0, name="global_step", trainable=False)
- var_0 = variables.Variable(0.0, name="v0")
+ global_step = variables.VariableV1(
+ 0, name="global_step", trainable=False)
+ var_0 = variables.VariableV1(0.0, name="v0")
with ops.device("/job:ps/task:1"):
- var_1 = variables.Variable(1.0, name="v1")
- var_sparse = variables.Variable([[3.0], [4.0]], name="v_sparse")
+ var_1 = variables.VariableV1(1.0, name="v1")
+ var_sparse = variables.VariableV1([[3.0], [4.0]], name="v_sparse")
with ops.device("/job:worker/task:" + str(worker_id)):
grads_0 = constant_op.constant(0.1 + worker_id * 0.2)
@@ -272,8 +273,8 @@ class SyncReplicasOptimizerHookTest(test.TestCase):
replicas_to_aggregate=1,
total_num_replicas=1)
hook = opt.make_session_run_hook(True)
- v = variables.Variable([0.])
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ v = variables.VariableV1([0.])
+ global_step = variables.VariableV1(0, name="global_step", trainable=False)
opt.minimize(v, global_step=global_step)
hook.begin()
@@ -282,8 +283,8 @@ class SyncReplicasOptimizerHookTest(test.TestCase):
opt=adam.AdamOptimizer(0.01),
replicas_to_aggregate=1,
total_num_replicas=1)
- v = variables.Variable([0.], name="fetch_variable_test")
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ v = variables.VariableV1([0.], name="fetch_variable_test")
+ global_step = variables.VariableV1(0, name="global_step", trainable=False)
opt.minimize(v, global_step=global_step)
opt_variables = opt.variables()
beta1_power, beta2_power = opt._opt._get_beta_accumulators()
diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py
index d131a11067..f410ceaaff 100644
--- a/tensorflow/python/training/training_ops_test.py
+++ b/tensorflow/python/training/training_ops_test.py
@@ -51,7 +51,7 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypes(self, x, alpha, delta, use_gpu=None):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
+ var = variables.VariableV1(x)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
apply_sgd = training_ops.apply_gradient_descent(var, alpha, delta)
@@ -70,8 +70,8 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForAdagrad(self, x, y, lr, grad, use_gpu=None):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
- accum = variables.Variable(y)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -94,9 +94,9 @@ class TrainingOpsTest(TensorFlowTestCase):
lr_power=-0.5):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
- accum = variables.Variable(y)
- linear = variables.Variable(z)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
+ linear = variables.VariableV1(z)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -148,8 +148,8 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForSparseAdagrad(self, x, y, lr, grad, indices):
self.setUp()
with self.test_session(use_gpu=False):
- var = variables.Variable(x)
- accum = variables.Variable(y)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -178,9 +178,9 @@ class TrainingOpsTest(TensorFlowTestCase):
lr_power=-0.5):
self.setUp()
with self.test_session(use_gpu=False):
- var = variables.Variable(x)
- accum = variables.Variable(y)
- linear = variables.Variable(z)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
+ linear = variables.VariableV1(z)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -257,9 +257,9 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForAdam(self, var, m, v, grad, use_gpu):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var_t = variables.Variable(var)
- m_t = variables.Variable(m)
- v_t = variables.Variable(v)
+ var_t = variables.VariableV1(var)
+ m_t = variables.VariableV1(m)
+ v_t = variables.VariableV1(v)
t = 1
beta1 = np.array(0.9, dtype=var.dtype)
@@ -270,8 +270,8 @@ class TrainingOpsTest(TensorFlowTestCase):
epsilon = np.array(1e-8, dtype=var.dtype)
beta1_t = constant_op.constant(beta1, self._toType(var.dtype), [])
beta2_t = constant_op.constant(beta2, self._toType(var.dtype), [])
- beta1_power_t = variables.Variable(beta1_power)
- beta2_power_t = variables.Variable(beta2_power)
+ beta1_power_t = variables.VariableV1(beta1_power)
+ beta2_power_t = variables.VariableV1(beta2_power)
lr_t = constant_op.constant(lr, self._toType(var.dtype), [])
epsilon_t = constant_op.constant(epsilon, self._toType(var.dtype), [])
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/training/training_util_test.py b/tensorflow/python/training/training_util_test.py
index 6cc177e0e8..ba64e785ac 100644
--- a/tensorflow/python/training/training_util_test.py
+++ b/tensorflow/python/training/training_util_test.py
@@ -49,7 +49,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_shape(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
- variables.Variable(
+ variables.VariableV1(
[0],
trainable=False,
dtype=dtypes.int32,
@@ -73,7 +73,7 @@ class GlobalStepTest(test.TestCase):
def test_get_global_step(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
- variables.Variable(
+ variables.VariableV1(
0,
trainable=False,
dtype=dtypes.int32,
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
index 4e9b07e20a..a56dfbff8e 100644
--- a/tensorflow/python/util/function_utils.py
+++ b/tensorflow/python/util/function_utils.py
@@ -59,6 +59,29 @@ def fn_args(fn):
return tuple(args)
+def has_kwargs(fn):
+ """Returns whether the passed callable has **kwargs in its signature.
+
+ Args:
+ fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+ Returns:
+ `bool`: if `fn` has **kwargs in its signature.
+
+ Raises:
+ `TypeError`: If fn is not a Function, or function-like object.
+ """
+ if isinstance(fn, functools.partial):
+ fn = fn.func
+ elif _is_callable_object(fn):
+ fn = fn.__call__
+ elif not callable(fn):
+ raise TypeError(
+ 'fn should be a function-like object, but is of type {}.'.format(
+ type(fn)))
+ return tf_inspect.getfullargspec(fn).varkw is not None
+
+
def get_func_name(func):
"""Returns name of passed callable."""
_, func = tf_decorator.unwrap(func)
diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py
index 1588328c26..e5b0843e4b 100644
--- a/tensorflow/python/util/function_utils_test.py
+++ b/tensorflow/python/util/function_utils_test.py
@@ -135,6 +135,101 @@ class FnArgsTest(test.TestCase):
self.assertEqual(3, double_wrapped_fn(a=3))
+class HasKwargsTest(test.TestCase):
+
+ def test_simple_function(self):
+
+ fn_has_kwargs = lambda **x: x
+ self.assertTrue(function_utils.has_kwargs(fn_has_kwargs))
+
+ fn_has_no_kwargs = lambda x: x
+ self.assertFalse(function_utils.has_kwargs(fn_has_no_kwargs))
+
+ def test_callable(self):
+
+ class FooHasKwargs(object):
+
+ def __call__(self, **x):
+ del x
+ self.assertTrue(function_utils.has_kwargs(FooHasKwargs()))
+
+ class FooHasNoKwargs(object):
+
+ def __call__(self, x):
+ del x
+ self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs()))
+
+ def test_bounded_method(self):
+
+ class FooHasKwargs(object):
+
+ def fn(self, **x):
+ del x
+ self.assertTrue(function_utils.has_kwargs(FooHasKwargs().fn))
+
+ class FooHasNoKwargs(object):
+
+ def fn(self, x):
+ del x
+ self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs().fn))
+
+ def test_partial_function(self):
+ expected_test_arg = 123
+
+ def fn_has_kwargs(test_arg, **x):
+ if test_arg != expected_test_arg:
+ return ValueError('partial fn does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_kwargs, test_arg=123)
+ self.assertTrue(function_utils.has_kwargs(wrapped_fn))
+ some_kwargs = dict(x=1, y=2, z=3)
+ self.assertEqual(wrapped_fn(**some_kwargs), some_kwargs)
+
+ def fn_has_no_kwargs(x, test_arg):
+ if test_arg != expected_test_arg:
+ return ValueError('partial fn does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg=123)
+ self.assertFalse(function_utils.has_kwargs(wrapped_fn))
+ some_arg = 1
+ self.assertEqual(wrapped_fn(some_arg), some_arg)
+
+ def test_double_partial(self):
+ expected_test_arg1 = 123
+ expected_test_arg2 = 456
+
+ def fn_has_kwargs(test_arg1, test_arg2, **x):
+ if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
+ return ValueError('partial does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_kwargs, test_arg2=456)
+ double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
+
+ self.assertTrue(function_utils.has_kwargs(double_wrapped_fn))
+ some_kwargs = dict(x=1, y=2, z=3)
+ self.assertEqual(double_wrapped_fn(**some_kwargs), some_kwargs)
+
+ def fn_has_no_kwargs(x, test_arg1, test_arg2):
+ if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
+ return ValueError('partial does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg2=456)
+ double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
+
+ self.assertFalse(function_utils.has_kwargs(double_wrapped_fn))
+ some_arg = 1
+ self.assertEqual(double_wrapped_fn(some_arg), some_arg)
+
+ def test_raises_type_error(self):
+ with self.assertRaisesRegexp(
+ TypeError, 'fn should be a function-like object'):
+ function_utils.has_kwargs('not a function')
+
+
class GetFuncNameTest(test.TestCase):
def testWithSimpleFunction(self):
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 653ca525dc..758cba7487 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -19,6 +19,9 @@ This module can perform operations on nested structures. A nested structure is a
Python sequence, tuple (including `namedtuple`), or dict that can contain
further sequences, tuples, and dicts.
+attr.s decorated classes (http://www.attrs.org) are also supported, in the
+same way as `namedtuple`.
+
The utilities here assume (and do not check) that the nested structures form a
'tree', i.e., no references in the structure of the input of these functions
should be recursive.
@@ -38,6 +41,12 @@ import six as _six
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
+def _get_attrs_values(obj):
+ """Returns the list of values from an attrs instance."""
+ attrs = getattr(obj.__class__, "__attrs_attrs__")
+ return [getattr(obj, a.name) for a in attrs]
+
+
def _sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
@@ -64,6 +73,7 @@ def _is_namedtuple(instance, strict=False):
# See the swig file (util.i) for documentation.
_is_mapping = _pywrap_tensorflow.IsMapping
+_is_attrs = _pywrap_tensorflow.IsAttrs
def _sequence_like(instance, args):
@@ -85,7 +95,7 @@ def _sequence_like(instance, args):
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
- elif _is_namedtuple(instance):
+ elif _is_namedtuple(instance) or _is_attrs(instance):
return type(instance)(*args)
else:
# Not a namedtuple
@@ -93,6 +103,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
+ """Yields the next value from the given iterable."""
if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
@@ -101,6 +112,9 @@ def _yield_value(iterable):
# corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable):
yield iterable[key]
+ elif _is_attrs(iterable):
+ for value in _get_attrs_values(iterable):
+ yield value
else:
for value in iterable:
yield value
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index bfb4c6f910..e03a8daaa1 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -33,6 +33,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
class _CustomMapping(collections.Mapping):
@@ -53,6 +58,35 @@ class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
+ if attr:
+ class BadAttr(object):
+ """Class that has a non-iterable __attrs_attrs__."""
+ __attrs_attrs__ = None
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testAttrsFlattenAndPack(self):
+ if attr is None:
+ self.skipTest("attr module is unavailable.")
+
+ field_values = [1, 2]
+ sample_attr = NestTest.SampleAttr(*field_values)
+ self.assertFalse(nest._is_attrs(field_values))
+ self.assertTrue(nest._is_attrs(sample_attr))
+ flat = nest.flatten(sample_attr)
+ self.assertEqual(field_values, flat)
+ restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
+ self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
+ self.assertEqual(restructured_from_flat, sample_attr)
+
+ # Check that flatten fails if attributes are not iterable
+ with self.assertRaisesRegexp(TypeError, "object is not iterable"):
+ flat = nest.flatten(NestTest.BadAttr())
+
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 562bbdcfeb..38b8491c66 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -15,9 +15,11 @@ limitations under the License.
#include "tensorflow/python/util/util.h"
#include <functional>
+#include <memory>
#include <unordered_map>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -190,6 +192,19 @@ int IsMappingHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}
+// Returns 1 if `o` is an instance of attrs-decorated class.
+// Returns 0 otherwise.
+int IsAttrsHelper(PyObject* o) {
+ Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
+ if (cls) {
+ return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
+ } else {
+ // PyObject_GetAttrString returns null on error
+ PyErr_Clear();
+ return 0;
+ }
+}
+
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
@@ -204,6 +219,7 @@ int IsSequenceHelper(PyObject* o) {
});
// We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
+ if (IsAttrsHelper(o)) return true;
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
LOG(WARNING) << "Sets are not currently considered sequences, "
"but this may change in the future, "
@@ -222,93 +238,168 @@ int IsSequenceHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}
-// Implements the same idea as tensorflow.util.nest._yield_value
-// During construction we check if the iterable is a dictionary.
-// If so, we construct a sequence from its sorted keys that will be used
-// for iteration.
-// If not, we construct a sequence directly from the iterable.
-// At each step, we get the next element from the sequence and use it
-// either as a key or return it directly.
-//
-// 'iterable' must not be modified while ValIterator is used.
-class ValIterator {
+// ValueIterator interface
+class ValueIterator {
public:
- explicit ValIterator(PyObject* iterable)
- : dict_(nullptr),
- mapping_(nullptr),
- last_mapping_element_(nullptr),
- seq_(nullptr),
- index_(0) {
- if (PyDict_Check(iterable)) {
- dict_ = iterable;
- // PyDict_Keys returns a list, which can be used with
- // PySequence_Fast_GET_ITEM.
- seq_ = PyDict_Keys(iterable);
- // Iterate through dictionaries in a deterministic order by sorting the
- // keys. Notice this means that we ignore the original order of
- // `OrderedDict` instances. This is intentional, to avoid potential
- // bugs caused by mixing ordered and plain dicts (e.g., flattening
- // a dict but using a corresponding `OrderedDict` to pack it back).
- PyList_Sort(seq_);
- } else if (IsMappingHelper(iterable)) {
- mapping_ = iterable;
- seq_ = MappingKeys(iterable);
- PyList_Sort(seq_);
+ virtual ~ValueIterator() {}
+ virtual Safe_PyObjectPtr next() = 0;
+
+ bool valid() const { return is_valid_; }
+
+ protected:
+ void invalidate() { is_valid_ = false; }
+
+ private:
+ bool is_valid_ = true;
+};
+
+using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
+
+// Iterate through dictionaries in a deterministic order by sorting the
+// keys. Notice this means that we ignore the original order of
+// `OrderedDict` instances. This is intentional, to avoid potential
+// bugs caused by mixing ordered and plain dicts (e.g., flattening
+// a dict but using a corresponding `OrderedDict` to pack it back).
+class DictValueIterator : public ValueIterator {
+ public:
+ explicit DictValueIterator(PyObject* dict)
+ : dict_(dict), keys_(PyDict_Keys(dict)) {
+ if (PyList_Sort(keys_.get()) == -1) {
+ invalidate();
} else {
- seq_ = PySequence_Fast(iterable, "");
+ iter_.reset(PyObject_GetIter(keys_.get()));
}
- size_ = PySequence_Fast_GET_SIZE(seq_);
}
- ~ValIterator() { Py_DECREF(seq_); }
-
- // Return a borrowed reference to the next element from iterable.
- // Return nullptr when iteration is over.
- PyObject* next() {
- if (TF_PREDICT_FALSE(seq_ == nullptr)) {
- return nullptr;
- }
- PyObject* element = nullptr;
- if (index_ < size_) {
- // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
- // references. For general mappings, ValIterator keeps a reference to the
- // last retrieved element (and decrefs it before producing the next
- // element) to abstract away the borrowed/new difference.
- element = PySequence_Fast_GET_ITEM(seq_, index_);
- ++index_;
- if (dict_ != nullptr) {
- element = PyDict_GetItem(dict_, element);
- if (element == nullptr) {
- PyErr_SetString(PyExc_RuntimeError,
- "Dictionary was modified during iteration over it");
- return nullptr;
- }
- } else if (mapping_ != nullptr) {
- element = PyObject_GetItem(mapping_, element);
- if (element == nullptr) {
- PyErr_SetString(PyExc_RuntimeError,
- "Mapping was modified during iteration over it");
- return nullptr;
- }
- last_mapping_element_.reset(element);
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
+ if (key) {
+ // PyDict_GetItem returns a borrowed reference.
+ PyObject* elem = PyDict_GetItem(dict_, key.get());
+ if (elem) {
+ Py_INCREF(elem);
+ result.reset(elem);
+ } else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Dictionary was modified during iteration over it");
}
}
- return element;
+ return result;
}
private:
- // Special casing for things that pass PyDict_Check (faster, no Python calls)
PyObject* dict_;
+ Safe_PyObjectPtr keys_;
+ Safe_PyObjectPtr iter_;
+};
- // General mappings which have custom Python logic
+// Iterate over mapping objects by sorting the keys first
+class MappingValueIterator : public ValueIterator {
+ public:
+ explicit MappingValueIterator(PyObject* mapping)
+ : mapping_(mapping), keys_(MappingKeys(mapping)) {
+ if (!keys_ || PyList_Sort(keys_.get()) == -1) {
+ invalidate();
+ } else {
+ iter_.reset(PyObject_GetIter(keys_.get()));
+ }
+ }
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
+ if (key) {
+ // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
+ PyObject* elem = PyObject_GetItem(mapping_, key.get());
+ if (elem) {
+ result.reset(elem);
+ } else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Mapping was modified during iteration over it");
+ }
+ }
+ return result;
+ }
+
+ private:
PyObject* mapping_;
- Safe_PyObjectPtr last_mapping_element_;
+ Safe_PyObjectPtr keys_;
+ Safe_PyObjectPtr iter_;
+};
+
+// Iterate over a sequence, by index.
+class SequenceValueIterator : public ValueIterator {
+ public:
+ explicit SequenceValueIterator(PyObject* iterable)
+ : seq_(PySequence_Fast(iterable, "")),
+ size_(PySequence_Fast_GET_SIZE(seq_.get())),
+ index_(0) {}
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ if (index_ < size_) {
+ // PySequence_Fast_GET_ITEM returns a borrowed reference.
+ PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
+ ++index_;
+ Py_INCREF(elem);
+ result.reset(elem);
+ }
- PyObject* seq_;
- Py_ssize_t size_;
+ return result;
+ }
+
+ private:
+ Safe_PyObjectPtr seq_;
+ const Py_ssize_t size_;
Py_ssize_t index_;
};
+// Just return itself as a single item.
+class SparseTensorValueIterator : public ValueIterator {
+ public:
+ explicit SparseTensorValueIterator(PyObject* tensor) : tensor_(tensor) {
+ Py_INCREF(tensor);
+ }
+
+ Safe_PyObjectPtr next() override { return std::move(tensor_); }
+
+ private:
+ Safe_PyObjectPtr tensor_;
+};
+
+class AttrsValueIterator : public ValueIterator {
+ public:
+ explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
+ Py_INCREF(nested);
+ cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
+ if (cls_) {
+ attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
+ if (attrs_) {
+ iter_.reset(PyObject_GetIter(attrs_.get()));
+ }
+ }
+ if (!iter_ || PyErr_Occurred()) invalidate();
+ }
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
+ if (item) {
+ Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
+ result.reset(PyObject_GetAttr(nested_.get(), name.get()));
+ }
+
+ return result;
+ }
+
+ private:
+ Safe_PyObjectPtr nested_;
+ Safe_PyObjectPtr cls_;
+ Safe_PyObjectPtr attrs_;
+ Safe_PyObjectPtr iter_;
+};
+
bool IsSparseTensorValueType(PyObject* o) {
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
return false;
@@ -322,93 +413,37 @@ int IsSequenceForDataHelper(PyObject* o) {
!IsSparseTensorValueType(o);
}
-bool GetNextValuesForDict(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- Safe_PyObjectPtr keys(PyDict_Keys(nested));
- if (PyList_Sort(keys.get()) == -1) return false;
- Py_ssize_t size = PyList_Size(keys.get());
- for (Py_ssize_t i = 0; i < size; ++i) {
- // We know that key and item will not be deleted because nested owns
- // a reference to them and callers of flatten must not modify nested
- // while the method is running.
- PyObject* key = PyList_GET_ITEM(keys.get(), i);
- PyObject* item = PyDict_GetItem(nested, key);
- Py_INCREF(item);
- next_values->emplace_back(item);
- }
- return true;
-}
-
-bool GetNextValuesForMapping(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- Safe_PyObjectPtr keys(MappingKeys(nested));
- if (keys.get() == nullptr) {
- return false;
- }
- if (PyList_Sort(keys.get()) == -1) return false;
- Py_ssize_t size = PyList_Size(keys.get());
- for (Py_ssize_t i = 0; i < size; ++i) {
- PyObject* key = PyList_GET_ITEM(keys.get(), i);
- // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
- PyObject* item = PyObject_GetItem(nested, key);
- next_values->emplace_back(item);
- }
- return true;
-}
-
-bool GetNextValuesForIterable(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- PyObject* item;
- PyObject* iterator = PyObject_GetIter(nested);
- if (iterator == nullptr || PyErr_Occurred()) {
- return false;
- }
- while ((item = PyIter_Next(iterator)) != nullptr) {
- next_values->emplace_back(item);
- }
- Py_DECREF(iterator);
- return true;
-}
-
-// GetNextValues returns the values that the FlattenHelper function will recurse
-// over next.
-bool GetNextValues(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
+ValueIteratorPtr GetValueIterator(PyObject* nested) {
if (PyDict_Check(nested)) {
- // if nested is dictionary, sort it by key and recurse on each value
- return GetNextValuesForDict(nested, next_values);
+ return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
- // same treatment as dictionaries, but for custom mapping types
- return GetNextValuesForMapping(nested, next_values);
+ return absl::make_unique<MappingValueIterator>(nested);
+ } else if (IsAttrsHelper(nested)) {
+ return absl::make_unique<AttrsValueIterator>(nested);
+ } else {
+ return absl::make_unique<SequenceValueIterator>(nested);
}
- // iterate and recurse
- return GetNextValuesForIterable(nested, next_values);
}
-// Similar to above, just specialized for the functions in the data pacakage.
-bool GetNextValuesForData(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
+// Similar to above, just specialized for the functions in the data package.
+ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
if (PyDict_Check(nested)) {
- // if nested is dictionary, sort it by key and recurse on each value
- return GetNextValuesForDict(nested, next_values);
+ return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
- // same treatment as dictionaries, but for custom mapping types
- return GetNextValuesForMapping(nested, next_values);
+ return absl::make_unique<MappingValueIterator>(nested);
+ } else if (IsAttrsHelper(nested)) {
+ return absl::make_unique<AttrsValueIterator>(nested);
} else if (IsSparseTensorValueType(nested)) {
- // if nested is a SparseTensorValue, just return itself as a single item
- Py_INCREF(nested);
- next_values->emplace_back(nested);
- return true;
+ return absl::make_unique<SparseTensorValueIterator>(nested);
+ } else {
+ return absl::make_unique<SequenceValueIterator>(nested);
}
- // iterate and recurse
- return GetNextValuesForIterable(nested, next_values);
}
bool FlattenHelper(
PyObject* nested, PyObject* list,
const std::function<int(PyObject*)>& is_sequence_helper,
- const std::function<bool(PyObject*, std::vector<Safe_PyObjectPtr>*)>&
- next_values_getter) {
+ const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
// if nested is not a sequence, append itself and exit
int is_seq = is_sequence_helper(nested);
if (is_seq == -1) return false;
@@ -416,16 +451,15 @@ bool FlattenHelper(
return PyList_Append(list, nested) != -1;
}
- std::vector<Safe_PyObjectPtr> next_values;
- // Get the next values to recurse over.
- if (!next_values_getter(nested, &next_values)) return false;
+ ValueIteratorPtr iter = value_iterator_getter(nested);
+ if (!iter->valid()) return false;
- for (const auto& item : next_values) {
+ for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
if (Py_EnterRecursiveCall(" in flatten")) {
return false;
}
- const bool success =
- FlattenHelper(item.get(), list, is_sequence_helper, next_values_getter);
+ const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
+ value_iterator_getter);
Py_LeaveRecursiveCall();
if (!success) {
return false;
@@ -579,22 +613,25 @@ bool AssertSameStructureHelper(
}
}
- ValIterator iter1(o1);
- ValIterator iter2(o2);
+ ValueIteratorPtr iter1 = GetValueIterator(o1);
+ ValueIteratorPtr iter2 = GetValueIterator(o2);
+
+ if (!iter1->valid() || !iter2->valid()) return false;
while (true) {
- PyObject* v1 = iter1.next();
- PyObject* v2 = iter2.next();
- if (v1 != nullptr && v2 != nullptr) {
+ Safe_PyObjectPtr v1 = iter1->next();
+ Safe_PyObjectPtr v2 = iter2->next();
+ if (v1 && v2) {
if (Py_EnterRecursiveCall(" in assert_same_structure")) {
return false;
}
- bool no_internal_errors = AssertSameStructureHelper(
- v1, v2, check_types, error_msg, is_type_error, is_sequence_helper);
+ bool no_internal_errors =
+ AssertSameStructureHelper(v1.get(), v2.get(), check_types, error_msg,
+ is_type_error, is_sequence_helper);
Py_LeaveRecursiveCall();
if (!no_internal_errors) return false;
if (!error_msg->empty()) return true;
- } else if (v1 == nullptr && v2 == nullptr) {
+ } else if (!v1 && !v2) {
// Done with all recursive calls. Structure matched.
return true;
} else {
@@ -652,10 +689,11 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
+bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
- if (FlattenHelper(nested, list, IsSequenceHelper, GetNextValues)) {
+ if (FlattenHelper(nested, list, IsSequenceHelper, GetValueIterator)) {
return list;
} else {
Py_DECREF(list);
@@ -668,7 +706,7 @@ bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
PyObject* FlattenForData(PyObject* nested) {
PyObject* list = PyList_New(0);
if (FlattenHelper(nested, list, IsSequenceForDataHelper,
- GetNextValuesForData)) {
+ GetValueIteratorForData)) {
return list;
} else {
Py_DECREF(list);
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 343605285e..01f85ea1dc 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -56,6 +56,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
// True if the sequence subclasses mapping.
bool IsMapping(PyObject* o);
+// Returns a true if its input is an instance of an attr.s decorated class.
+//
+// Args:
+// o: the input to be checked.
+//
+// Returns:
+// True if the object is an instance of an attr.s decorated class.
+bool IsAttrs(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 104a615636..32a6e684fa 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -65,6 +65,18 @@ Returns:
%unignore tensorflow::swig::IsMapping;
%noexception tensorflow::swig::IsMapping;
+%feature("docstring") tensorflow::swig::IsAttrs
+"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
+
+Args:
+ instance: An instance of a Python object.
+
+Returns:
+ True if `instance` is an instance of an `attr.s` decorated class.
+"""
+%unignore tensorflow::swig::IsAttrs;
+%noexception tensorflow::swig::IsAttrs;
+
%feature("docstring") tensorflow::swig::SameNamedtuples
"Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples;
diff --git a/tensorflow/requirements.txt b/tensorflow/requirements.txt
deleted file mode 100644
index 6e111edefc..0000000000
--- a/tensorflow/requirements.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-keras_applications >= 1.0.5
-keras_preprocessing >= 1.0.3
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 689679c838..cad5de1b0c 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -19,9 +19,18 @@ load(
"@local_config_cuda//cuda:build_defs.bzl",
"cuda_default_copts",
"if_cuda",
+ "if_cuda_is_configured",
+)
+load(
+ "@local_config_rocm//rocm:build_defs.bzl",
+ "if_rocm",
+ "if_rocm_is_configured",
+ "rocm_copts",
+ "rocm_default_copts",
)
load(
"//third_party/mkl:build_defs.bzl",
+ "if_enable_mkl",
"if_mkl",
"if_mkl_lnx_x64",
"if_mkl_ml",
@@ -38,6 +47,8 @@ load(
def register_extension_info(**kwargs):
pass
+# if_cuda_is_configured def placeholder
+
# Given a source file, generate a test name.
# i.e. "common_runtime/direct_session_test.cc" becomes
# "common_runtime_direct_session_test"
@@ -237,6 +248,7 @@ def tf_copts(android_optimization_level_override = "-O2", is_external = False):
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_enable_mkl(["-DENABLE_MKL"]) +
if_ngraph(["-DINTEL_NGRAPH=1"]) +
if_mkl_lnx_x64(["-fopenmp"]) +
if_android_arm(["-mfpu=neon"]) +
@@ -861,12 +873,16 @@ def tf_cuda_only_cc_test(
srcs = srcs + tf_binary_additional_srcs(),
size = size,
args = args,
- copts = _cuda_copts() + tf_copts(),
+ copts = _cuda_copts() + rocm_copts() + tf_copts(),
data = data + tf_binary_dynamic_kernel_dsos(kernels),
- deps = deps + tf_binary_dynamic_kernel_deps(kernels) + if_cuda([
- clean_dep("//tensorflow/core:cuda"),
- clean_dep("//tensorflow/core:gpu_lib"),
- ]),
+ deps = deps + tf_binary_dynamic_kernel_deps(kernels) +
+ if_cuda_is_configured([
+ clean_dep("//tensorflow/core:cuda"),
+ clean_dep("//tensorflow/core:gpu_lib"),
+ ]) +
+ if_rocm_is_configured([
+ clean_dep("//tensorflow/core:gpu_lib"),
+ ]),
linkopts = if_not_windows(["-lpthread", "-lm"]) + linkopts + _rpath_linkopts(name),
linkstatic = linkstatic or select({
# cc_tests with ".so"s in srcs incorrectly link on Darwin
@@ -1001,7 +1017,7 @@ register_extension_info(
label_regex_for_dep = "{extension_name}",
)
-def _cuda_copts():
+def _cuda_copts(opts = []):
"""Gets the appropriate set of copts for (maybe) CUDA compilation.
If we're doing CUDA compilation, returns copts for our particular CUDA
@@ -1017,13 +1033,17 @@ def _cuda_copts():
"@local_config_cuda//cuda:using_clang": ([
"-fcuda-flush-denormals-to-zero",
]),
- })
+ }) + if_cuda_is_configured(opts)
# Build defs for TensorFlow kernels
# When this target is built using --config=cuda, a cc_library is built
# that passes -DGOOGLE_CUDA=1 and '-x cuda', linking in additional
# libraries needed by GPU kernels.
+#
+# When this target is built using --config=rocm, a cc_library is built
+# that passes -DTENSORFLOW_USE_ROCM and '-x rocm', linking in additional
+# libraries needed by GPU kernels.
def tf_gpu_kernel_library(
srcs,
copts = [],
@@ -1031,16 +1051,18 @@ def tf_gpu_kernel_library(
deps = [],
hdrs = [],
**kwargs):
- copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts()
+ copts = copts + tf_copts() + _cuda_copts(opts = cuda_copts) + rocm_copts(opts = cuda_copts)
kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
native.cc_library(
srcs = srcs,
hdrs = hdrs,
copts = copts,
- deps = deps + if_cuda([
+ deps = deps + if_cuda_is_configured([
clean_dep("//tensorflow/core:cuda"),
clean_dep("//tensorflow/core:gpu_lib"),
+ ]) + if_rocm_is_configured([
+ clean_dep("//tensorflow/core:gpu_lib"),
]),
alwayslink = 1,
**kwargs
@@ -1079,9 +1101,12 @@ def tf_cuda_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs)
deps = deps + if_cuda(cuda_deps + [
clean_dep("//tensorflow/core:cuda"),
"@local_config_cuda//cuda:cuda_headers",
+ ]) + if_rocm_is_configured(cuda_deps + [
+ # rocm_header placeholder
]),
- copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
+ copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
+ if_enable_mkl(["-DENABLE_MKL"]) +
if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs
)
@@ -1462,6 +1487,9 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [
"@local_config_cuda//cuda:cuda_headers",
"@local_config_cuda//cuda:cudart_static",
]
+ rocm_deps = [
+ clean_dep("//tensorflow/core:stream_executor_headers_lib"),
+ ]
deps = deps + tf_custom_op_library_additional_deps()
if gpu_srcs:
basename = name.split(".")[0]
@@ -1470,13 +1498,14 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [
srcs = gpu_srcs,
copts = _cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
features = if_cuda(["-use_header_modules"]),
- deps = deps + if_cuda(cuda_deps),
+ deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps),
)
cuda_deps.extend([":" + basename + "_gpu"])
+ rocm_deps.extend([":" + basename + "_gpu"])
check_deps(
name = name + "_check_deps",
- deps = deps + if_cuda(cuda_deps),
+ deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps),
disallowed_deps = [
clean_dep("//tensorflow/core:framework"),
clean_dep("//tensorflow/core:lib"),
@@ -1485,7 +1514,7 @@ def tf_custom_op_library(name, srcs = [], gpu_srcs = [], deps = [], linkopts = [
tf_cc_shared_object(
name = name,
srcs = srcs,
- deps = deps + if_cuda(cuda_deps),
+ deps = deps + if_cuda_is_configured(cuda_deps) + if_rocm_is_configured(rocm_deps),
data = if_static([name + "_check_deps"]),
copts = tf_copts(is_external = True),
features = ["windows_export_all_symbols"],
@@ -1677,7 +1706,7 @@ def py_test(deps = [], data = [], kernels = [], **kwargs):
deps = select({
"//conditions:default": deps,
clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
- }) + tf_binary_dynamic_kernel_deps(kernels),
+ }),
data = data + select({
"//conditions:default": [],
clean_dep("//tensorflow:no_tensorflow_py_deps"): ["//tensorflow/tools/pip_package:win_pip_package_marker"],
@@ -1690,6 +1719,29 @@ register_extension_info(
label_regex_for_dep = "{extension_name}",
)
+# Similar to py_test above, this macro is used to exclude dependencies for some py_binary
+# targets in order to reduce the size of //tensorflow/tools/pip_package:simple_console_windows.
+# See https://github.com/tensorflow/tensorflow/issues/22390
+def py_binary(name, deps = [], **kwargs):
+ # Add an extra target for dependencies to avoid nested select statement.
+ native.py_library(
+ name = name + "_deps",
+ deps = deps,
+ )
+ native.py_binary(
+ name = name,
+ deps = select({
+ "//conditions:default": [":" + name + "_deps"],
+ clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
+ }),
+ **kwargs
+ )
+
+register_extension_info(
+ extension_name = "py_binary",
+ label_regex_for_dep = "{extension_name}",
+)
+
def tf_py_test(
name,
srcs,
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
deleted file mode 100644
index eb41deee13..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
+++ /dev/null
@@ -1,24 +0,0 @@
-path: "tensorflow.ConfigProto.Experimental"
-tf_proto {
- descriptor {
- name: "Experimental"
- field {
- name: "collective_group_leader"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "executor_type"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
deleted file mode 100644
index e565b903d2..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
+++ /dev/null
@@ -1,148 +0,0 @@
-path: "tensorflow.ConfigProto"
-tf_proto {
- descriptor {
- name: "ConfigProto"
- field {
- name: "device_count"
- number: 1
- label: LABEL_REPEATED
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ConfigProto.DeviceCountEntry"
- }
- field {
- name: "intra_op_parallelism_threads"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "inter_op_parallelism_threads"
- number: 5
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "use_per_session_threads"
- number: 9
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "session_inter_op_thread_pool"
- number: 12
- label: LABEL_REPEATED
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ThreadPoolOptionProto"
- }
- field {
- name: "placement_period"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "device_filters"
- number: 4
- label: LABEL_REPEATED
- type: TYPE_STRING
- }
- field {
- name: "gpu_options"
- number: 6
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.GPUOptions"
- }
- field {
- name: "allow_soft_placement"
- number: 7
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "log_device_placement"
- number: 8
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "graph_options"
- number: 10
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.GraphOptions"
- }
- field {
- name: "operation_timeout_in_ms"
- number: 11
- label: LABEL_OPTIONAL
- type: TYPE_INT64
- }
- field {
- name: "rpc_options"
- number: 13
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.RPCOptions"
- }
- field {
- name: "cluster_def"
- number: 14
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ClusterDef"
- }
- field {
- name: "isolate_session_state"
- number: 15
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "experimental"
- number: 16
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ConfigProto.Experimental"
- }
- nested_type {
- name: "DeviceCountEntry"
- field {
- name: "key"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "value"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- options {
- map_entry: true
- }
- }
- nested_type {
- name: "Experimental"
- field {
- name: "collective_group_leader"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "executor_type"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- }
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
deleted file mode 100644
index 4f0147a523..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.data.Iterator"
-tf_class {
- is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_classes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shapes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_types"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'iterator_resource\', \'initializer\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "from_string_handle"
- argspec: "args=[\'string_handle\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "from_structure"
- argspec: "args=[\'output_types\', \'output_shapes\', \'shared_name\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "get_next"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "make_initializer"
- argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "string_handle"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
deleted file mode 100644
index c23b04b4ef..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ /dev/null
@@ -1,58 +0,0 @@
-path: "tensorflow.estimator.BoostedTreesClassifier"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
- is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "params"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
- }
- member_method {
- name: "eval_dir"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
- }
- member_method {
- name: "get_variable_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_variable_value"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "latest_checkpoint"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
- }
- member_method {
- name: "train"
- argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
deleted file mode 100644
index 6878d28fff..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ /dev/null
@@ -1,58 +0,0 @@
-path: "tensorflow.estimator.BoostedTreesRegressor"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
- is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "params"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
- }
- member_method {
- name: "eval_dir"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
- }
- member_method {
- name: "get_variable_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_variable_value"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "latest_checkpoint"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
- }
- member_method {
- name: "train"
- argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
deleted file mode 100644
index bf1f94b6ae..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
+++ /dev/null
@@ -1,105 +0,0 @@
-path: "tensorflow.estimator.RunConfig"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.run_config.RunConfig\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "cluster_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "device_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "eval_distribute"
- mtype: "<type \'property\'>"
- }
- member {
- name: "evaluation_master"
- mtype: "<type \'property\'>"
- }
- member {
- name: "global_id_in_cluster"
- mtype: "<type \'property\'>"
- }
- member {
- name: "is_chief"
- mtype: "<type \'property\'>"
- }
- member {
- name: "keep_checkpoint_every_n_hours"
- mtype: "<type \'property\'>"
- }
- member {
- name: "keep_checkpoint_max"
- mtype: "<type \'property\'>"
- }
- member {
- name: "log_step_count_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "master"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "num_ps_replicas"
- mtype: "<type \'property\'>"
- }
- member {
- name: "num_worker_replicas"
- mtype: "<type \'property\'>"
- }
- member {
- name: "protocol"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_checkpoints_secs"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_checkpoints_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_summary_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "service"
- mtype: "<type \'property\'>"
- }
- member {
- name: "session_config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "task_id"
- mtype: "<type \'property\'>"
- }
- member {
- name: "task_type"
- mtype: "<type \'property\'>"
- }
- member {
- name: "tf_random_seed"
- mtype: "<type \'property\'>"
- }
- member {
- name: "train_distribute"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "replace"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
deleted file mode 100644
index 5c46dc5ee7..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ /dev/null
@@ -1,251 +0,0 @@
-path: "tensorflow.image"
-tf_module {
- member {
- name: "ResizeMethod"
- mtype: "<type \'type\'>"
- }
- member_method {
- name: "adjust_brightness"
- argspec: "args=[\'image\', \'delta\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "adjust_contrast"
- argspec: "args=[\'images\', \'contrast_factor\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "adjust_gamma"
- argspec: "args=[\'image\', \'gamma\', \'gain\'], varargs=None, keywords=None, defaults=[\'1\', \'1\'], "
- }
- member_method {
- name: "adjust_hue"
- argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "adjust_jpeg_quality"
- argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "adjust_saturation"
- argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "central_crop"
- argspec: "args=[\'image\', \'central_fraction\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "convert_image_dtype"
- argspec: "args=[\'image\', \'dtype\', \'saturate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "crop_and_resize"
- argspec: "args=[\'image\', \'boxes\', \'box_ind\', \'crop_size\', \'method\', \'extrapolation_value\', \'name\'], varargs=None, keywords=None, defaults=[\'bilinear\', \'0\', \'None\'], "
- }
- member_method {
- name: "crop_to_bounding_box"
- argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "decode_and_crop_jpeg"
- argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
- }
- member_method {
- name: "decode_bmp"
- argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
- }
- member_method {
- name: "decode_gif"
- argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "decode_image"
- argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'uint8\'>\", \'None\'], "
- }
- member_method {
- name: "decode_jpeg"
- argspec: "args=[\'contents\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
- }
- member_method {
- name: "decode_png"
- argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'None\'], "
- }
- member_method {
- name: "draw_bounding_boxes"
- argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "encode_jpeg"
- argspec: "args=[\'image\', \'format\', \'quality\', \'progressive\', \'optimize_size\', \'chroma_downsampling\', \'density_unit\', \'x_density\', \'y_density\', \'xmp_metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'95\', \'False\', \'False\', \'True\', \'in\', \'300\', \'300\', \'\', \'None\'], "
- }
- member_method {
- name: "encode_png"
- argspec: "args=[\'image\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
- }
- member_method {
- name: "extract_glimpse"
- argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], "
- }
- member_method {
- name: "extract_image_patches"
- argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "extract_jpeg_shape"
- argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
- }
- member_method {
- name: "flip_left_right"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "flip_up_down"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "grayscale_to_rgb"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "hsv_to_rgb"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "image_gradients"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "is_jpeg"
- argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "non_max_suppression"
- argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
- }
- member_method {
- name: "non_max_suppression_overlaps"
- argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
- }
- member_method {
- name: "non_max_suppression_padded"
- argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
- }
- member_method {
- name: "pad_to_bounding_box"
- argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "per_image_standardization"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "psnr"
- argspec: "args=[\'a\', \'b\', \'max_val\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_brightness"
- argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_contrast"
- argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_flip_left_right"
- argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_flip_up_down"
- argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_hue"
- argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_jpeg_quality"
- argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_saturation"
- argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "resize_area"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_bicubic"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_bilinear"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_image_with_crop_or_pad"
- argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "resize_image_with_pad"
- argspec: "args=[\'image\', \'target_height\', \'target_width\', \'method\'], varargs=None, keywords=None, defaults=[\'0\'], "
- }
- member_method {
- name: "resize_images"
- argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], "
- }
- member_method {
- name: "resize_nearest_neighbor"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "rgb_to_grayscale"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "rgb_to_hsv"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "rgb_to_yiq"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "rgb_to_yuv"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "rot90"
- argspec: "args=[\'image\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
- }
- member_method {
- name: "sample_distorted_bounding_box"
- argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "sobel_edges"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "ssim"
- argspec: "args=[\'img1\', \'img2\', \'max_val\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "ssim_multiscale"
- argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\'], "
- }
- member_method {
- name: "total_variation"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "transpose_image"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "yiq_to_rgb"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "yuv_to_rgb"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
deleted file mode 100644
index e579fe6a1a..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ /dev/null
@@ -1,268 +0,0 @@
-path: "tensorflow.keras.Model"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
deleted file mode 100644
index 6f05cdd093..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ /dev/null
@@ -1,289 +0,0 @@
-path: "tensorflow.keras.Sequential"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "add"
- argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "pop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_classes"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict_proba"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "symbolic_set_inputs"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt
deleted file mode 100644
index 2e9de9ebb2..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt
+++ /dev/null
@@ -1,55 +0,0 @@
-path: "tensorflow.keras.activations"
-tf_module {
- member_method {
- name: "deserialize"
- argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "elu"
- argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
- }
- member_method {
- name: "get"
- argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "hard_sigmoid"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "linear"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "relu"
- argspec: "args=[\'x\', \'alpha\', \'max_value\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'0\'], "
- }
- member_method {
- name: "selu"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "serialize"
- argspec: "args=[\'activation\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "sigmoid"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "softmax"
- argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
- }
- member_method {
- name: "softplus"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "softsign"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "tanh"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
deleted file mode 100644
index 56914e1746..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ /dev/null
@@ -1,268 +0,0 @@
-path: "tensorflow.keras.models.Model"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
deleted file mode 100644
index 4c1c54001d..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ /dev/null
@@ -1,289 +0,0 @@
-path: "tensorflow.keras.models.Sequential"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "add"
- argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "pop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_classes"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict_proba"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "symbolic_set_inputs"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
index 537e73aa89..47b5b56faf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
@@ -8,5 +8,11 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
index cec04a2bf0..c0c2e7b9f8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
@@ -55,6 +55,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
enum_type {
name: "TraceLevel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
index 05698b03ee..af7fc9d4ef 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-variable.pbtxt
@@ -1,5 +1,6 @@
path: "tensorflow.Variable"
tf_class {
+ is_instance: "<class \'tensorflow.python.ops.variables.VariableV1\'>"
is_instance: "<class \'tensorflow.python.ops.variables.Variable\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index c3ba2dba57..825afb622f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -91,6 +91,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 3541671bee..cdad5f6360 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index b113c18ee0..df41bff1b5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 7210bf5db4..028bcc2ce9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 9e429a32a5..ef3409b1b5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -33,6 +33,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
name: "experimental_predict_with_explanations"
argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 56af1d137c..775130468f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -33,6 +33,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
name: "experimental_predict_with_explanations"
argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
index a308c76ebc..72856466ec 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
@@ -233,6 +233,14 @@ tf_module {
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "xdivy"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "xlogy"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "zeta"
argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 503e145a91..509ceff9df 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -2221,6 +2221,10 @@ tf_module {
argspec: "args=[\'max_shard_bytes\', \'axis\', \'bytes_per_string_element\', \'max_shards\'], varargs=None, keywords=None, defaults=[\'0\', \'16\', \'None\'], "
}
member_method {
+ name: "variable_creator_scope"
+ argspec: "args=[\'variable_creator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "variable_op_scope"
argspec: "args=[\'values\', \'name_or_scope\', \'default_name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index c81c156518..312e94b41d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -10,7 +10,7 @@ tf_module {
}
member_method {
name: "length"
- argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "regex_full_match"
@@ -48,4 +48,8 @@ tf_module {
name: "to_number"
argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
+ member_method {
+ name: "unicode_script"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
index 537e73aa89..47b5b56faf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
@@ -8,5 +8,11 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
index cec04a2bf0..c0c2e7b9f8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
@@ -55,6 +55,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
enum_type {
name: "TraceLevel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt
deleted file mode 100644
index c13eb7b8bb..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable-scope.pbtxt
+++ /dev/null
@@ -1,105 +0,0 @@
-path: "tensorflow.VariableScope"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.variable_scope.VariableScope\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "caching_device"
- mtype: "<type \'property\'>"
- }
- member {
- name: "constraint"
- mtype: "<type \'property\'>"
- }
- member {
- name: "custom_getter"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "original_name_scope"
- mtype: "<type \'property\'>"
- }
- member {
- name: "partitioner"
- mtype: "<type \'property\'>"
- }
- member {
- name: "regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "reuse"
- mtype: "<type \'property\'>"
- }
- member {
- name: "use_resource"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'reuse\', \'name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'name_scope\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'None\', \'None\', \'None\', \'None\', \'\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
- }
- member_method {
- name: "get_collection"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_variable"
- argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
- }
- member_method {
- name: "global_variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "local_variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reuse_variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_caching_device"
- argspec: "args=[\'self\', \'caching_device\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_custom_getter"
- argspec: "args=[\'self\', \'custom_getter\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_dtype"
- argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_initializer"
- argspec: "args=[\'self\', \'initializer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_partitioner"
- argspec: "args=[\'self\', \'partitioner\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_regularizer"
- argspec: "args=[\'self\', \'regularizer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_use_resource"
- argspec: "args=[\'self\', \'use_resource\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "trainable_variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt
deleted file mode 100644
index ac3ccd468b..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable.-save-slice-info.pbtxt
+++ /dev/null
@@ -1,17 +0,0 @@
-path: "tensorflow.Variable.SaveSliceInfo"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.variables.SaveSliceInfo\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "spec"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'full_name\', \'full_shape\', \'var_offset\', \'var_shape\', \'save_slice_info_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "to_proto"
- argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
deleted file mode 100644
index 05698b03ee..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.-variable.pbtxt
+++ /dev/null
@@ -1,130 +0,0 @@
-path: "tensorflow.Variable"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.variables.Variable\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "SaveSliceInfo"
- mtype: "<type \'type\'>"
- }
- member {
- name: "constraint"
- mtype: "<type \'property\'>"
- }
- member {
- name: "device"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
- name: "initial_value"
- mtype: "<type \'property\'>"
- }
- member {
- name: "initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
- }
- member_method {
- name: "assign"
- argspec: "args=[\'self\', \'value\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
- }
- member_method {
- name: "assign_add"
- argspec: "args=[\'self\', \'delta\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
- }
- member_method {
- name: "assign_sub"
- argspec: "args=[\'self\', \'delta\', \'use_locking\', \'name\', \'read_value\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'True\'], "
- }
- member_method {
- name: "count_up_to"
- argspec: "args=[\'self\', \'limit\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "eval"
- argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "from_proto"
- argspec: "args=[\'variable_def\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_shape"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "initialized_value"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load"
- argspec: "args=[\'self\', \'value\', \'session\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "read_value"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "scatter_add"
- argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "scatter_nd_add"
- argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "scatter_nd_sub"
- argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "scatter_nd_update"
- argspec: "args=[\'self\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "scatter_sub"
- argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "scatter_update"
- argspec: "args=[\'self\', \'sparse_delta\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "set_shape"
- argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "to_proto"
- argspec: "args=[\'self\', \'export_scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "value"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index c3ba2dba57..825afb622f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -91,6 +91,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 3541671bee..cdad5f6360 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index b113c18ee0..df41bff1b5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 7210bf5db4..028bcc2ce9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -92,6 +92,10 @@ tf_class {
argspec: "args=[], varargs=args, keywords=None, defaults=None"
}
member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "repeat"
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 9e429a32a5..ef3409b1b5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -33,6 +33,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
name: "experimental_predict_with_explanations"
argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 56af1d137c..775130468f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -33,6 +33,10 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
name: "experimental_predict_with_explanations"
argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
index d499c67d89..e3c63fe737 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
@@ -49,10 +49,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member_method {
- name: "global_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "he_normal"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -68,12 +64,4 @@ tf_module {
name: "lecun_uniform"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
- member_method {
- name: "local_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "variables"
- argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], "
- }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
index a308c76ebc..72856466ec 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
@@ -233,6 +233,14 @@ tf_module {
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "xdivy"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "xlogy"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "zeta"
argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 96212f5528..d2dc8bc85f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -1,10 +1,6 @@
path: "tensorflow"
tf_module {
member {
- name: "AUTO_REUSE"
- mtype: "<enum \'_ReuseMode\'>"
- }
- member {
name: "AggregationMethod"
mtype: "<type \'type\'>"
}
@@ -233,18 +229,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "Variable"
- mtype: "<class \'tensorflow.python.ops.variables.VariableMetaclass\'>"
- }
- member {
name: "VariableAggregation"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
- name: "VariableScope"
- mtype: "<type \'type\'>"
- }
- member {
name: "VariableSynchronization"
mtype: "<class \'enum.EnumMeta\'>"
}
@@ -553,10 +541,6 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
- name: "variable_scope"
- mtype: "<type \'type\'>"
- }
- member {
name: "variance_scaling_initializer"
mtype: "<type \'type\'>"
}
@@ -617,10 +601,6 @@ tf_module {
argspec: "args=[\'names\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "all_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "angle"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -733,10 +713,6 @@ tf_module {
argspec: "args=[\'tensor\', \'tf_type\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
- name: "assert_variables_initialized"
- argspec: "args=[\'var_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "atan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1137,10 +1113,6 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
- }
- member_method {
name: "get_seed"
argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None"
}
@@ -1153,26 +1125,10 @@ tf_module {
argspec: "args=[\'handle\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "get_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
- }
- member_method {
- name: "get_variable_scope"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "global_norm"
argspec: "args=[\'t_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "global_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "global_variables_initializer"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "gradients"
argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
}
@@ -1249,18 +1205,6 @@ tf_module {
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
}
member_method {
- name: "initialize_all_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "initialize_local_variables"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "initialize_variables"
- argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], "
- }
- member_method {
name: "invert_permutation"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1289,10 +1233,6 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "is_variable_initialized"
- argspec: "args=[\'variable\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "lbeta"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1329,14 +1269,6 @@ tf_module {
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "local_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "local_variables_initializer"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "log"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1449,14 +1381,6 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "model_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "moving_average_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "multinomial"
argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
@@ -1657,10 +1581,6 @@ tf_module {
argspec: "args=[\'base_type\', \'conversion_func\', \'priority\'], varargs=None, keywords=None, defaults=[\'100\'], "
}
member_method {
- name: "report_uninitialized_variables"
- argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'report_uninitialized_variables\'], "
- }
- member_method {
name: "required_space_to_batch_paddings"
argspec: "args=[\'input_shape\', \'block_shape\', \'base_paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
@@ -2069,10 +1989,6 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "trainable_variables"
- argspec: "args=[\'scope\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "transpose"
argspec: "args=[\'a\', \'perm\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'None\', \'transpose\', \'False\'], "
}
@@ -2141,14 +2057,6 @@ tf_module {
argspec: "args=[\'max_shard_bytes\', \'axis\', \'bytes_per_string_element\', \'max_shards\'], varargs=None, keywords=None, defaults=[\'0\', \'16\', \'None\'], "
}
member_method {
- name: "variable_op_scope"
- argspec: "args=[\'values\', \'name_or_scope\', \'default_name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables_initializer"
- argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], "
- }
- member_method {
name: "verify_tensor_all_finite"
argspec: "args=[\'t\', \'msg\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index c81c156518..312e94b41d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -10,7 +10,7 @@ tf_module {
}
member_method {
name: "length"
- argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "regex_full_match"
@@ -48,4 +48,8 @@ tf_module {
name: "to_number"
argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
}
+ member_method {
+ name: "unicode_script"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt
deleted file mode 100644
index e62dec93e6..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.variable_scope.pbtxt
+++ /dev/null
@@ -1,9 +0,0 @@
-path: "tensorflow.variable_scope"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.variable_scope.variable_scope\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'name_or_scope\', \'default_name\', \'values\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\', \'constraint\', \'auxiliary_name_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\'], "
- }
-}
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 4efa4a9651..3cbea41dca 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -19,6 +19,7 @@ py_test(
"api_compatibility_test.py",
"//tensorflow:tf_python_api_gen_v2",
],
+ args = ["--only_test_core_api=true"],
data = [
"//tensorflow/tools/api/golden:api_golden_v1",
"//tensorflow/tools/api/golden:api_golden_v2",
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index d06c7f2d49..6487a6267e 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -56,6 +56,14 @@ _UPDATE_GOLDENS_HELP = """
have to be authorized by TensorFlow leads.
"""
+# DEFINE_boolean, only_test_core_api, default False:
+_ONLY_TEST_CORE_API_HELP = """
+ Some TF APIs are being moved outside of the tensorflow/ directory. There is
+ no garuntee which versions of these APIs will be present when running this
+ test. Therefore, do not error out on API changes in non-core TF code
+ if this flag is set.
+"""
+
# DEFINE_boolean, verbose_diffs, default True:
_VERBOSE_DIFFS_HELP = """
If set to true, print line by line diffs on all libraries. If set to
@@ -67,6 +75,8 @@ _API_GOLDEN_FOLDER_V2 = 'tensorflow/tools/api/golden/v2'
_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
+_NON_CORE_PACKAGES = ['estimator']
+
def _KeyToFilePath(key, api_version):
"""From a given key, construct a filepath.
@@ -111,6 +121,19 @@ def _VerifyNoSubclassOfMessageVisitor(path, parent, unused_children):
'They are not yet supported by the API tools.' % path)
+def _FilterNonCoreGoldenFiles(golden_file_list):
+ """Filter out non-core API pbtxt files."""
+ filtered_file_list = []
+ filtered_package_prefixes = [
+ 'tensorflow.%s.' % p for p in _NON_CORE_PACKAGES]
+ for f in golden_file_list:
+ if any([f.rsplit('/')[-1].startswith(pre)
+ for pre in filtered_package_prefixes]):
+ continue
+ filtered_file_list.append(f)
+ return filtered_file_list
+
+
class ApiCompatibilityTest(test.TestCase):
def __init__(self, *args, **kwargs):
@@ -233,6 +256,9 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
+ if FLAGS.only_test_core_api:
+ visitor.do_not_descend_map['tf'].extend(
+ _NON_CORE_PACKAGES)
traverse.traverse(tf_v2.compat.v1, visitor)
def testNoSubclassOfMessageV2(self):
@@ -240,6 +266,9 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
+ if FLAGS.only_test_core_api:
+ visitor.do_not_descend_map['tf'].extend(
+ _NON_CORE_PACKAGES)
traverse.traverse(tf_v2, visitor)
def _checkBackwardsCompatibility(
@@ -252,6 +281,9 @@ class ApiCompatibilityTest(test.TestCase):
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = [
'Experimental']
+ if FLAGS.only_test_core_api:
+ public_api_visitor.do_not_descend_map['tf'].extend(
+ _NON_CORE_PACKAGES)
if additional_private_map:
public_api_visitor.private_map.update(additional_private_map)
@@ -260,6 +292,8 @@ class ApiCompatibilityTest(test.TestCase):
# Read all golden files.
golden_file_list = file_io.get_matching_files(golden_file_pattern)
+ if FLAGS.only_test_core_api:
+ golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
@@ -325,6 +359,11 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--update_goldens', type=bool, default=False, help=_UPDATE_GOLDENS_HELP)
+ # TODO(mikecase): Create Estimator's own API compatibility test or
+ # a more general API compatibility test for use for TF components.
+ parser.add_argument(
+ '--only_test_core_api', type=bool, default=False,
+ help=_ONLY_TEST_CORE_API_HELP)
parser.add_argument(
'--verbose_diffs', type=bool, default=True, help=_VERBOSE_DIFFS_HELP)
FLAGS, unparsed = parser.parse_known_args()
diff --git a/tensorflow/tools/benchmark/README.md b/tensorflow/tools/benchmark/README.md
index e64af2bfe1..dee1a20f3f 100644
--- a/tensorflow/tools/benchmark/README.md
+++ b/tensorflow/tools/benchmark/README.md
@@ -32,7 +32,7 @@ adb push bazel-bin/tensorflow/tools/benchmark/benchmark_model /data/local/tmp
(4) Run the benchmark. For example:
```
-adb shell "/data/local/tmp/benchmark_model \
+adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/tensorflow_inception_graph.pb \
--input_layer="input:0" \
--input_layer_shape="1,224,224,3" \
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
index a30858db82..dd8d705331 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
@@ -26,7 +26,7 @@ ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
ENV NVIDIA_REQUIRE_CUDA "cuda>=9.0"
ENV NCCL_VERSION 2.2.13
-ENV CUDNN_VERSION 7.2.1.38
+ENV CUDNN_VERSION 7.1.4.18
# TODO(b/110903506): /usr/loca/cuda/lib64/stubs should not be needed in
# LD_LIBRARY_PATH. The stubs/libcuda.so is not meant to used at runtime. The
diff --git a/tensorflow/tools/ci_build/Dockerfile.rocm b/tensorflow/tools/ci_build/Dockerfile.rocm
new file mode 100644
index 0000000000..aadaa8bac1
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.rocm
@@ -0,0 +1,97 @@
+# This Dockerfile provides a starting point for a ROCm installation of
+# MIOpen and tensorflow.
+FROM ubuntu:xenial
+MAINTAINER Jeff Poznanovic <jeffrey.poznanovic@amd.com>
+
+ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/debian/
+ARG ROCM_PATH=/opt/rocm
+
+ENV DEBIAN_FRONTEND noninteractive
+ENV TF_NEED_ROCM 1
+ENV HOME /root/
+RUN apt update && apt install -y wget software-properties-common
+
+# Add rocm repository
+RUN apt-get clean all
+RUN wget -qO - $DEB_ROCM_REPO/rocm.gpg.key | apt-key add -
+RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO xenial main > /etc/apt/sources.list.d/rocm.list"
+
+# Install misc pkgs
+RUN apt-get update --allow-insecure-repositories && DEBIAN_FRONTEND=noninteractive apt-get install -y \
+ build-essential \
+ clang-3.8 \
+ clang-format-3.8 \
+ clang-tidy-3.8 \
+ cmake \
+ cmake-qt-gui \
+ ssh \
+ curl \
+ apt-utils \
+ pkg-config \
+ g++-multilib \
+ git \
+ libunwind-dev \
+ libfftw3-dev \
+ libelf-dev \
+ libncurses5-dev \
+ libpthread-stubs0-dev \
+ vim \
+ gfortran \
+ libboost-program-options-dev \
+ libssl-dev \
+ libboost-dev \
+ libboost-system-dev \
+ libboost-filesystem-dev \
+ rpm \
+ libnuma-dev \
+ virtualenv \
+ python-pip \
+ python3-pip \
+ wget && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+# Install rocm pkgs
+RUN apt-get update --allow-insecure-repositories && \
+ DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
+ rocm-dev rocm-libs rocm-utils \
+ rocfft miopen-hip miopengemm rocblas hipblas rocrand \
+ rocm-profiler cxlactivitylogger && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN cd ~ && git clone https://github.com/GPUOpen-ProfessionalCompute-Tools/HIP.git
+RUN cd ~/HIP && mkdir -p build && cd build && cmake .. && make package -j && dpkg -i *.deb
+
+ENV HCC_HOME=$ROCM_PATH/hcc
+ENV HIP_PATH=$ROCM_PATH/hip
+ENV OPENCL_ROOT=$ROCM_PATH/opencl
+ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}"
+ENV PATH="$ROCM_PATH/bin:${PATH}"
+ENV PATH="$OPENCL_ROOT/bin:${PATH}"
+
+# Add target file to help determine which device(s) to build for
+RUN echo -e "gfx803\ngfx900" >> /opt/rocm/bin/target.lst
+
+# Setup environment variables, and add those environment variables at the end of ~/.bashrc
+ARG HCC_HOME=/opt/rocm/hcc
+ARG HIP_PATH=/opt/rocm/hip
+ARG PATH=$HCC_HOME/bin:$HIP_PATH/bin:$PATH
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa && \
+ add-apt-repository -y ppa:george-edison55/cmake-3.x
+RUN /install/install_deb_packages.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_bazel.sh
+RUN /install/install_golang.sh
+
+# Set up the master bazelrc configuration file.
+COPY install/.bazelrc /etc/bazel.bazelrc
+
+# Configure the build for our CUDA configuration.
+ENV TF_NEED_ROCM 1
+
diff --git a/tensorflow/tools/ci_build/builds/docker_test.sh b/tensorflow/tools/ci_build/builds/docker_test.sh
index e337ea4b05..38891b60e5 100755
--- a/tensorflow/tools/ci_build/builds/docker_test.sh
+++ b/tensorflow/tools/ci_build/builds/docker_test.sh
@@ -19,7 +19,7 @@
#
# Usage: docker_test.sh <IMAGE_TYPE> <TAG> <WHL_PATH>
# Arguments:
-# IMAGE_TYPE : Type of the image: (CPU|GPU)
+# IMAGE_TYPE : Type of the image: (CPU|GPU|ROCM)
# TAG : Docker image tag
# WHL_PATH : Path to the whl file to be installed inside the docker image
#
@@ -60,6 +60,8 @@ if [[ "${IMAGE_TYPE}" == "cpu" ]]; then
DOCKERFILE="tensorflow/tools/docker/Dockerfile"
elif [[ "${IMAGE_TYPE}" == "gpu" ]]; then
DOCKERFILE="tensorflow/tools/docker/Dockerfile.gpu"
+elif [[ "${IMAGE_TYPE}" == "rocm" ]]; then
+ DOCKERFILE="tensorflow/tools/docker/Dockerfile.rocm"
else
die "Unrecognized image type: $1"
fi
@@ -106,13 +108,16 @@ if [ "${IMAGE_TYPE}" == "gpu" ]; then
devices=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
libs=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | xargs -I{} echo '-v {}:{}')
GPU_EXTRA_PARAMS="${devices} ${libs}"
+elif [ "${IMAGE_TYPE}" == "rocm" ]; then
+ ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video"
else
GPU_EXTRA_PARAMS=""
+ ROCM_EXTRA_PARAMS=""
fi
# Run docker image with source directory mapped
docker run -v ${BASE_DIR}:/tensorflow-src -w /tensorflow-src \
-${GPU_EXTRA_PARAMS} \
+${GPU_EXTRA_PARAMS} ${ROCM_EXTRA_PARAMS} \
"${DOCKER_IMG_TAG}" \
/bin/bash -c "tensorflow/tools/ci_build/builds/run_pip_tests.sh && "\
"tensorflow/tools/ci_build/builds/test_tutorials.sh && "\
diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh
index fef121ab5a..6543779022 100755
--- a/tensorflow/tools/ci_build/builds/pip.sh
+++ b/tensorflow/tools/ci_build/builds/pip.sh
@@ -132,6 +132,7 @@ echo "Using Bazel flags: ${BAZEL_FLAGS}"
PIP_BUILD_TARGET="//tensorflow/tools/pip_package:build_pip_package"
GPU_FLAG=""
if [[ ${CONTAINER_TYPE} == "cpu" ]] || \
+ [[ ${CONTAINER_TYPE} == "rocm" ]] || \
[[ ${CONTAINER_TYPE} == "debian.jessie.cpu" ]]; then
bazel build ${BAZEL_FLAGS} ${PIP_BUILD_TARGET} || \
die "Build failed."
@@ -255,7 +256,8 @@ if [[ $(uname) == "Linux" ]]; then
die "ERROR: Cannot find repaired wheel."
fi
# Copy and rename for gpu manylinux as we do not want auditwheel to package in libcudart.so
- elif [[ ${CONTAINER_TYPE} == "gpu" ]]; then
+ elif [[ ${CONTAINER_TYPE} == "gpu" ]] || \
+ [[ ${CONTAINER_TYPE} == "rocm" ]]; then
WHL_PATH=${AUDITED_WHL_NAME}
cp ${WHL_DIR}/${WHL_BASE_NAME} ${WHL_PATH}
echo "Copied manylinx1 wheel file at ${WHL_PATH}"
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index 17198a6560..7d5cf3f843 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -111,7 +111,6 @@ bazel clean
# virtualenv.
export TF_NEED_GCP=0
export TF_NEED_HDFS=0
-export TF_ENABLE_XLA=0
# Obtain the path to Python binary
if [[ ${IS_VIRTUALENV} == "1" ]]; then
diff --git a/tensorflow/tools/ci_build/builds/with_the_same_user b/tensorflow/tools/ci_build/builds/with_the_same_user
index b216e3549f..1cc5aed15d 100755
--- a/tensorflow/tools/ci_build/builds/with_the_same_user
+++ b/tensorflow/tools/ci_build/builds/with_the_same_user
@@ -48,6 +48,12 @@ getent passwd "${CI_BUILD_UID}" || adduser ${ADDUSER_OPTS} \
usermod -a -G sudo "${CI_BUILD_USER}"
echo "${CI_BUILD_USER} ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-nopasswd-sudo
+if [[ "${TF_NEED_ROCM}" -eq 1 ]]; then
+ # ROCm requires the video group in order to use the GPU for compute. If it
+ # exists on the host, add it to the container.
+ getent group video || addgroup video && adduser "${CI_BUILD_USER}" video
+fi
+
if [ -e /root/.bazelrc ]; then
cp /root/.bazelrc "${CI_BUILD_HOME}/.bazelrc"
chown "${CI_BUILD_UID}:${CI_BUILD_GID}" "${CI_BUILD_HOME}/.bazelrc"
diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh
index 77265e0f50..eab0616513 100755
--- a/tensorflow/tools/ci_build/ci_build.sh
+++ b/tensorflow/tools/ci_build/ci_build.sh
@@ -18,7 +18,7 @@
# <COMMAND>
#
# CONTAINER_TYPE: Type of the docker container used the run the build:
-# e.g., (cpu | gpu | android | tensorboard)
+# e.g., (cpu | gpu | rocm | android | tensorboard)
#
# DOCKERFILE_PATH: (Optional) Path to the Dockerfile used for docker build.
# If this optional value is not supplied (via the
@@ -103,6 +103,14 @@ if [[ "${CONTAINER_TYPE}" != gpu* ]]; then
GPU_EXTRA_PARAMS=""
fi
+# Add extra params for rocm devices and libraries for ROCm container.
+if [[ "${CONTAINER_TYPE}" == "rocm" ]]; then
+ ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video"
+else
+ ROCM_EXTRA_PARAMS=""
+fi
+
+
# Determine the docker image name
DOCKER_IMG_NAME="${BUILD_TAG}.${CONTAINER_TYPE}"
@@ -159,6 +167,7 @@ ${DOCKER_BINARY} run --rm --pid=host \
-v ${WORKSPACE}:/workspace \
-w /workspace \
${GPU_EXTRA_PARAMS} \
+ ${ROCM_EXTRA_PARAMS} \
${CI_DOCKER_EXTRA_PARAMS[@]} \
"${DOCKER_IMG_NAME}" \
${CI_COMMAND_PREFIX[@]} \
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
index 8eeddcdb82..3b5c92d148 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
# Only running cc tests, python version does not matter.
export PYTHON_BIN_PATH=`which python`
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
index 8eca1987f0..52eff6330f 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py2_core.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python2`
yes "" | $PYTHON_BIN_PATH configure.py
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
index f6fa9251d4..d12027599a 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_contrib.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python3`
yes "" | $PYTHON_BIN_PATH configure.py
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
index 51eb2cd7e6..7c531a4d68 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_py3_core.sh
@@ -26,6 +26,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=`which python3`
yes "" | $PYTHON_BIN_PATH configure.py
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow.sh b/tensorflow/tools/ci_build/linux/libtensorflow.sh
index beef8e063b..3b6e15feb9 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow.sh
@@ -27,5 +27,8 @@ SUFFIX="-cpu-linux-"
if [ "${TF_NEED_CUDA}" == "1" ]; then
SUFFIX="-gpu-linux-"
fi
+if [ "${TF_NEED_ROCM}" == "1" ]; then
+ SUFFIX="-rocm-linux-"
+fi
build_libtensorflow_tarball "${SUFFIX}$(uname -m)"
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh
index 4bf34dd299..b76262b6e9 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_cpu.sh
@@ -19,4 +19,5 @@
set -ex
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
"${SCRIPT_DIR}/libtensorflow_docker.sh"
diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
index 60c974c36b..467b8dc808 100755
--- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh
@@ -38,6 +38,11 @@ if [ "${TF_NEED_CUDA}" == "1" ]; then
DOCKER_BINARY="nvidia-docker"
DOCKER_FILE="Dockerfile.gpu"
fi
+if [ "${TF_NEED_ROCM}" == "1" ]; then
+ DOCKER_IMAGE="tf-tensorflow-rocm"
+ DOCKER_BINARY="docker"
+ DOCKER_FILE="Dockerfile.rocm"
+fi
docker build \
-t "${DOCKER_IMAGE}" \
@@ -53,6 +58,7 @@ ${DOCKER_BINARY} run \
-e "TF_NEED_HDFS=0" \
-e "TF_NEED_CUDA=${TF_NEED_CUDA}" \
-e "TF_NEED_TENSORRT=${TF_NEED_CUDA}" \
+ -e "TF_NEED_ROCM=${TF_NEED_ROCM}" \
-e "TF_NEED_OPENCL_SYCL=0" \
"${DOCKER_IMAGE}" \
"/workspace/tensorflow/tools/ci_build/linux/libtensorflow.sh"
diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/tools/ci_build/linux/libtensorflow_rocm.sh
index 8f495a9dc9..c1ebbe3630 100644..100755
--- a/tensorflow/contrib/data/python/ops/contrib_op_loader.py
+++ b/tensorflow/tools/ci_build/linux/libtensorflow_rocm.sh
@@ -1,4 +1,5 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#!/usr/bin/env bash
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,13 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Python helper for loading contrib ops and kernels."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.util import loader
-from tensorflow.python.platform import resource_loader
+#
+# Script to build a binary releases of libtensorflow with GPU support.
-_dataset_ops = loader.load_op_library(
- resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
+set -ex
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+export TF_NEED_ROCM=1
+"${SCRIPT_DIR}/libtensorflow_docker.sh"
diff --git a/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh b/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh
new file mode 100755
index 0000000000..200089f90e
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/rocm/run_cc_core.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+set -e
+set -x
+
+N_JOBS=$(grep -c ^processor /proc/cpuinfo)
+
+echo ""
+echo "Bazel will use ${N_JOBS} concurrent job(s)."
+echo ""
+
+# Run configure.
+export PYTHON_BIN_PATH=`which python3`
+export CC_OPT_FLAGS='-mavx'
+
+export TF_NEED_ROCM=1
+
+yes "" | $PYTHON_BIN_PATH configure.py
+
+# Run bazel test command. Double test timeouts to avoid flakes.
+bazel test --config=rocm --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
+ --test_lang_filters=cc --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --build_tests_only --test_output=errors --local_test_jobs=1 --config=opt \
+ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh b/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh
new file mode 100755
index 0000000000..1d0b838c1b
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+set -e
+set -x
+
+N_JOBS=$(grep -c ^processor /proc/cpuinfo)
+
+echo ""
+echo "Bazel will use ${N_JOBS} concurrent job(s)."
+echo ""
+
+# Run configure.
+export PYTHON_BIN_PATH=`which python3`
+export CC_OPT_FLAGS='-mavx'
+
+export TF_NEED_ROCM=1
+
+yes "" | $PYTHON_BIN_PATH configure.py
+
+# Run bazel test command. Double test timeouts to avoid flakes.
+bazel test --config=rocm --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test -k \
+ --test_lang_filters=py --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --build_tests_only --test_output=errors --local_test_jobs=1 --config=opt \
+ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
index c7cc16e669..adee0d3171 100755
--- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
+++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh
@@ -27,6 +27,7 @@ echo ""
# Run configure.
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export CC_OPT_FLAGS='-mavx'
export PYTHON_BIN_PATH=$(which python2)
yes "" | $PYTHON_BIN_PATH configure.py
diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh
index 9ae5fc6bea..06798adc03 100755
--- a/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/osx/libtensorflow_cpu.sh
@@ -26,6 +26,7 @@ source "${SCRIPT_DIR}/../builds/libtensorflow.sh"
export PYTHON_BIN_PATH="/usr/bin/python"
export TF_NEED_HDFS=0
export TF_NEED_CUDA=0
+export TF_NEED_ROCM=0
export TF_NEED_OPENCL_SYCL=0
export TF_NEED_MKL=0
export COMPUTECPP_PATH="/usr/local"
diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh b/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh
index d95fcdeb85..95f1992d7d 100755
--- a/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh
+++ b/tensorflow/tools/ci_build/osx/libtensorflow_gpu.sh
@@ -27,6 +27,7 @@ export TF_NEED_CUDA=1
export LD_LIBRARY_PATH="/usr/local/cuda/lib:/usr/local/cuda/extras/CUPTI/lib:${LD_LIBRARY_PATH}"
export PYTHON_BIN_PATH="/usr/bin/python"
export TF_NEED_HDFS=0
+export TF_NEED_ROCM=0
export TF_NEED_OPENCL_SYCL=0
export TF_NEED_MKL=0
export COMPUTECPP_PATH="/usr/local"
diff --git a/tensorflow/tools/ci_build/osx/libtensorflow_rocm.sh b/tensorflow/tools/ci_build/osx/libtensorflow_rocm.sh
new file mode 100755
index 0000000000..aeabc0e39e
--- /dev/null
+++ b/tensorflow/tools/ci_build/osx/libtensorflow_rocm.sh
@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Script to produce binary release of libtensorflow (C API, Java jars etc.).
+
+set -ex
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# See comments at the top of this file for details.
+source "${SCRIPT_DIR}/../builds/libtensorflow.sh"
+
+# Configure script
+export TF_NEED_ROCM=1
+export PYTHON_BIN_PATH="/usr/bin/python"
+export TF_NEED_GCP=0
+export TF_NEED_HDFS=0
+export TF_NEED_CUDA=0
+export TF_NEED_OPENCL_SYCL=0
+export TF_NEED_MKL=0
+export COMPUTECPP_PATH="/usr/local"
+
+export PATH="/usr/local/cuda/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"
+build_libtensorflow_tarball "-gpu-darwin-$(uname -m)"
diff --git a/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh b/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh
new file mode 100755
index 0000000000..a0de128020
--- /dev/null
+++ b/tensorflow/tools/ci_build/xla/linux/rocm/run_py3.sh
@@ -0,0 +1,41 @@
+#!/usr/bin/env bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+set -e
+set -x
+
+N_JOBS=$(grep -c ^processor /proc/cpuinfo)
+
+echo ""
+echo "Bazel will use ${N_JOBS} concurrent job(s)."
+echo ""
+
+# Run configure.
+export PYTHON_BIN_PATH=`which python3`
+
+export TF_NEED_ROCM=1
+
+yes "" | $PYTHON_BIN_PATH configure.py
+echo "build --distinct_host_configuration=false" >> .tf_configure.bazelrc
+
+bazel clean
+# Run bazel test command. Double test timeouts to avoid flakes.
+bazel test --config=rocm --test_tag_filters=-no_gpu,-benchmark-test,-no_oss -k \
+ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
+ --build_tests_only --test_output=errors --local_test_jobs=1 \
+ --config=xla -- \
+ //tensorflow/compiler/...
diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD
index 003a19a9ab..3aa53a5615 100644
--- a/tensorflow/tools/dist_test/server/BUILD
+++ b/tensorflow/tools/dist_test/server/BUILD
@@ -8,6 +8,7 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
py_binary(
name = "grpc_tensorflow_server",
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index b218e900bf..2a858b4fd6 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -37,6 +37,7 @@ py_library(
name = "doc_controls",
srcs = ["doc_controls.py"],
srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
)
py_test(
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index b450bc42c5..b9f4902639 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -125,6 +125,7 @@ genrule(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
+ "@icu//:icu4c/LICENSE",
"@jpeg//:LICENSE.md",
"@llvm//:LICENSE.TXT",
"@lmdb//:LICENSE",
@@ -136,16 +137,6 @@ genrule(
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -170,7 +161,14 @@ genrule(
"@grpc//third_party/nanopb:LICENSE.txt",
"@grpc//third_party/address_sorting:LICENSE",
],
- ),
+ ) + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ ],
+ }),
outs = ["include/tensorflow/c/LICENSE"],
cmd = "$(location :concat_licenses.sh) $(SRCS) >$@",
tools = [":concat_licenses.sh"],
@@ -192,6 +190,7 @@ genrule(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
+ "@icu//:icu4j/main/shared/licenses/LICENSE",
"@jpeg//:LICENSE.md",
"@llvm//:LICENSE.TXT",
"@lmdb//:LICENSE",
@@ -203,16 +202,6 @@ genrule(
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -230,7 +219,14 @@ genrule(
]) + if_mkl([
"//third_party/mkl:LICENSE",
"//third_party/mkl_dnn:LICENSE",
- ]),
+ ]) + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ ],
+ }),
outs = ["include/tensorflow/jni/LICENSE"],
cmd = "$(location :concat_licenses.sh) $(SRCS) >$@",
tools = [":concat_licenses.sh"],
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 12354a6ab2..c621812535 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -66,8 +66,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
"//tensorflow/contrib/gan:gan",
@@ -108,6 +106,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python:util_example_parser_configuration",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/eager:eager_pip",
"//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
@@ -152,6 +151,7 @@ filegroup(
"@gemmlowp//:LICENSE",
"@gif_archive//:COPYING",
"@highwayhash//:LICENSE",
+ "@icu//:icu4c/LICENSE",
"@jpeg//:LICENSE.md",
"@lmdb//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
@@ -167,17 +167,6 @@ filegroup(
"@zlib_archive//:zlib.h",
"@org_python_pypi_backports_weakref//:LICENSE",
] + select({
- "//tensorflow:with_aws_support": [
- "@aws//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_gcp_support": [
- "@com_github_googleapis_googleapis//:LICENSE",
- "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow:with_jemalloc_linux_x86_64": [
"@jemalloc//:COPYING",
],
@@ -186,11 +175,6 @@ filegroup(
],
"//conditions:default": [],
}) + select({
- "//tensorflow:with_kafka_support": [
- "@kafka//:LICENSE",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow/core/kernels:xsmm": [
"@libxsmm_archive//:LICENSE.md",
],
@@ -213,7 +197,16 @@ filegroup(
"@ngraph_tf//:LICENSE",
"@nlohmann_json_lib//:LICENSE.MIT",
"@tbb//:LICENSE",
- ]) + tf_additional_license_deps(),
+ ]) + tf_additional_license_deps() + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "@aws//:LICENSE",
+ "@com_github_googleapis_googleapis//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
+ "@kafka//:LICENSE",
+ ],
+ }),
)
sh_binary(
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index bfc007bc39..c6ef82ccdc 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -90,6 +90,7 @@ BLACKLIST = [
"//tensorflow/contrib/lite/python:interpreter.py",
"//tensorflow/contrib/lite/python:interpreter_test.py",
"//tensorflow/contrib/ffmpeg:test_data",
+ "//tensorflow/contrib/fused_conv:fused_conv2d_bias_activation_op_test_base",
"//tensorflow/contrib/hadoop:test_data",
"//tensorflow/contrib/factorization/examples:mnist",
"//tensorflow/contrib/factorization/examples:mnist.py",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 1481b53920..88c9c20d36 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -55,8 +55,7 @@ REQUIRED_PACKAGES = [
'keras_preprocessing >= 1.0.3',
'numpy >= 1.13.3',
'six >= 1.10.0',
- 'protobuf >= 3.6.0',
- 'setuptools <= 39.1.0',
+ 'protobuf >= 3.6.1',
'tensorboard >= 1.11.0, < 1.12.0',
'termcolor >= 1.1.0',
]
@@ -86,7 +85,7 @@ else:
if 'tf_nightly' in project_name:
for i, pkg in enumerate(REQUIRED_PACKAGES):
if 'tensorboard' in pkg:
- REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.11.0a0, < 1.12.0a0'
+ REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.12.0a0, < 1.13.0a0'
break
# weakref.finalize and enum were introduced in Python 3.4
diff --git a/tensorflow/tools/quantization/BUILD b/tensorflow/tools/quantization/BUILD
deleted file mode 100644
index 17443a8617..0000000000
--- a/tensorflow/tools/quantization/BUILD
+++ /dev/null
@@ -1,78 +0,0 @@
-# Description:
-# Utilities for quantizing TensorFlow graphs to lower bit depths.
-
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_library(
- name = "quantize_graph_lib",
- srcs = ["quantize_graph.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//tensorflow/python:session",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:tensor_util",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "quantize_graph",
- srcs = ["quantize_graph.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python", # TODO(b/34059704): remove when fixed
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//tensorflow/python:tensor_util",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "quantize_graph_test",
- size = "small",
- srcs = ["quantize_graph_test.py"],
- srcs_version = "PY2AND3",
- tags = ["nomsan"], # http://b/32242946
- deps = [
- ":quantize_graph",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:platform",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "graph_to_dot",
- srcs = ["graph_to_dot.py"],
- main = "graph_to_dot.py",
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:platform",
- ],
-)
diff --git a/tensorflow/tools/quantization/graph_to_dot.py b/tensorflow/tools/quantization/graph_to_dot.py
deleted file mode 100644
index 81d6aa62c8..0000000000
--- a/tensorflow/tools/quantization/graph_to_dot.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Converts a GraphDef file into a DOT format suitable for visualization.
-
-This script takes a GraphDef representing a network, and produces a DOT file
-that can then be visualized by GraphViz tools like dot and xdot.
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-
-from google.protobuf import text_format
-
-from tensorflow.core.framework import graph_pb2
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
-from tensorflow.python.platform import gfile
-
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_bool("input_binary", True,
- """Whether the input files are in binary format.""")
-flags.DEFINE_string("dot_output", "", """Where to write the DOT output.""")
-
-
-def main(unused_args):
- if not gfile.Exists(FLAGS.graph):
- print("Input graph file '" + FLAGS.graph + "' does not exist!")
- return -1
-
- graph = graph_pb2.GraphDef()
- with open(FLAGS.graph, "r") as f:
- if FLAGS.input_binary:
- graph.ParseFromString(f.read())
- else:
- text_format.Merge(f.read(), graph)
-
- with open(FLAGS.dot_output, "wb") as f:
- print("digraph graphname {", file=f)
- for node in graph.node:
- output_name = node.name
- print(" \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f)
- for input_full_name in node.input:
- parts = input_full_name.split(":")
- input_name = re.sub(r"^\^", "", parts[0])
- print(" \"" + input_name + "\" -> \"" + output_name + "\";", file=f)
- print("}", file=f)
- print("Created DOT file '" + FLAGS.dot_output + "'.")
-
-
-if __name__ == "__main__":
- app.run()
diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py
deleted file mode 100644
index 3acb532263..0000000000
--- a/tensorflow/tools/quantization/quantize_graph.py
+++ /dev/null
@@ -1,1302 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Transforms a float-trained graph into an equivalent quantized version.
-
-An example of command-line usage is:
-bazel build tensorflow/tools/quantization:quantize_graph \
-&& bazel-bin/tensorflow/tools/quantization/quantize_graph \
---input=tensorflow_inception_graph.pb
---output_node_names="softmax2" --print_nodes --output=/tmp/quantized_graph.pb \
---mode=eightbit --logtostderr
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-import re
-import numpy as np
-
-from tensorflow.core.framework import attr_value_pb2
-from tensorflow.core.framework import graph_pb2
-from tensorflow.core.framework import node_def_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import graph_util
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags as flags_lib
-from tensorflow.python.platform import gfile
-
-flags = flags_lib
-FLAGS = flags.FLAGS
-
-flags.DEFINE_boolean("print_nodes", False, """Lists all nodes in the model.""")
-flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
-flags.DEFINE_string("output_node_names", "",
- """Output node names, comma separated.""")
-flags.DEFINE_string("output", "", """File to save the output graph to.""")
-flags.DEFINE_integer("bitdepth", 8,
- """How many bits to quantize the graph to.""")
-flags.DEFINE_string("mode", "round",
- """What transformation to apply (round, quantize,"""
- """ eightbit, weights, or weights_rounded).""")
-flags.DEFINE_string("test_input_dims", "1,224,224,3",
- """The size of the input tensor to use when testing a"""
- """ graph loaded from a file.""")
-flags.DEFINE_boolean("strip_redundant_quantization", True,
- """Removes redundant dequantize/quantize pairs.""")
-flags.DEFINE_boolean("quantized_input", False,
- "If true, assume Placeholders are quantized with values "
- "covering [--quantized_input_min,--quantized_input_max]. "
- "Only supported when --mode=eightbit")
-flags.DEFINE_float("quantized_input_min", 0,
- "The minimum of the actual input range when "
- "--quantized_input")
-flags.DEFINE_float("quantized_input_max", 1,
- "The maximum of the actual input range when "
- "--quantized_input")
-flags.DEFINE_float(
- "quantized_fallback_min", None,
- "The fallback 'min' value to use for layers which lack min-max "
- "information. Note: this should be considered a coarse tool just good "
- "enough for experimentation purposes, since graphs quantized in this way "
- "would be very inaccurate.")
-flags.DEFINE_float(
- "quantized_fallback_max", None,
- "The fallback 'max' value to use for layers which lack min-max "
- "information. Note: this should be considered a coarse tool just good "
- "enough for experimentation purposes, since graphs quantized in this way "
- "would be very inaccurate.")
-
-
-def print_input_nodes(current_node, nodes_map, indent, already_visited):
- print(" " * indent + current_node.op + ":" + current_node.name)
- already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- if input_node_name in already_visited:
- continue
- input_node = nodes_map[input_node_name]
- print_input_nodes(input_node, nodes_map, indent + 1, already_visited)
-
-
-def create_node(op, name, inputs):
- new_node = node_def_pb2.NodeDef()
- new_node.op = op
- new_node.name = name
- for input_name in inputs:
- new_node.input.extend([input_name])
- return new_node
-
-
-def create_constant_node(name, value, dtype, shape=None):
- node = create_node("Const", name, [])
- set_attr_dtype(node, "dtype", dtype)
- set_attr_tensor(node, "value", value, dtype, shape)
- return node
-
-
-def copy_attr(node, key, attr_value):
- try:
- node.attr[key].CopyFrom(attr_value)
- except KeyError:
- pass
-
-
-def set_attr_dtype(node, key, value):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(type=value.as_datatype_enum))
- except KeyError:
- pass
-
-
-def set_attr_shape(node, key, value):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(shape=tensor_shape.as_shape(value).as_proto()))
- except KeyError:
- pass
-
-
-def set_attr_tensor(node, key, value, dtype, shape=None):
- try:
- node.attr[key].CopyFrom(
- attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
- value, dtype=dtype, shape=shape)))
- except KeyError:
- pass
-
-
-def set_attr_string(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(s=value))
- except KeyError:
- pass
-
-
-def set_attr_int_list(node, key, value):
- list_value = attr_value_pb2.AttrValue.ListValue(i=value)
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value))
- except KeyError:
- pass
-
-
-def set_attr_bool(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value))
- except KeyError:
- pass
-
-
-def set_attr_int(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(i=value))
- except KeyError:
- pass
-
-
-def set_attr_float(node, key, value):
- try:
- node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value))
- except KeyError:
- pass
-
-
-def node_name_from_input(node_name):
- """Strips off ports and other decorations to get the underlying node name."""
- if node_name.startswith("^"):
- node_name = node_name[1:]
- m = re.search(r"(.*):\d+$", node_name)
- if m:
- node_name = m.group(1)
- return node_name
-
-
-def ensure_tensor_name_has_port(node_name):
- """Makes sure that a tensor name has :0 if no explicit port exists."""
- m = re.search(r"(.*):\d+$", node_name)
- if m:
- name_with_port = node_name
- else:
- name_with_port = node_name + ":0"
- return name_with_port
-
-
-def unique_node_name_from_input(node_name):
- """Replaces invalid characters in input names to get a unique node name."""
- return node_name.replace(":", "__port__").replace("^", "__hat__")
-
-
-def quantize_array(arr, num_buckets):
- """Quantizes a numpy array.
-
- This function maps each scalar in arr to the center of one of num_buckets
- buckets. For instance,
- quantize_array([0, 0.3, 0.6, 1], 2) => [0.25, 0.25, 0.75, 0.75]
-
- Args:
- arr: The numpy array to quantize.
- num_buckets: The number of buckets to map "var" to.
- Returns:
- The quantized numpy array.
- Raises:
- ValueError: when num_buckets < 1.
- """
- if num_buckets < 1:
- raise ValueError("num_buckets must be >= 1")
- arr_max = arr.max()
- arr_min = arr.min()
- if arr_max == arr_min:
- return arr
- bucket_width = (arr_max - arr_min) / num_buckets
- # Map scalars to bucket indices. Take special care of max(arr).
- bucket_indices = np.floor((arr - arr_min) / bucket_width)
- bucket_indices[bucket_indices == num_buckets] = num_buckets - 1
- # Map each scalar to the center of a bucket.
- arr = arr_min + bucket_width * (bucket_indices + 0.5)
- return arr
-
-
-def quantize_weight_rounded(input_node):
- """Returns a replacement node for input_node containing bucketed floats."""
- input_tensor = input_node.attr["value"].tensor
- tensor_value = tensor_util.MakeNdarray(input_tensor)
- shape = input_tensor.tensor_shape
- # Currently, the parameter FLAGS.bitdepth is used to compute the
- # number of buckets as 1 << FLAGS.bitdepth, meaning the number of
- # buckets can only be a power of 2.
- # This could be fixed by introducing a new parameter, num_buckets,
- # which would allow for more flexibility in chosing the right model
- # size/accuracy tradeoff. But I didn't want to add more parameters
- # to this script than absolutely necessary.
- num_buckets = 1 << FLAGS.bitdepth
- tensor_value_rounded = quantize_array(tensor_value, num_buckets)
- tensor_shape_list = tensor_util.TensorShapeProtoToList(shape)
- return [
- create_constant_node(
- input_node.name,
- tensor_value_rounded,
- dtypes.float32,
- shape=tensor_shape_list)
- ]
-
-
-def quantize_weight_eightbit(input_node, quantization_mode):
- """Returns replacement nodes for input_node using the Dequantize op."""
- base_name = input_node.name + "_"
- quint8_const_name = base_name + "quint8_const"
- min_name = base_name + "min"
- max_name = base_name + "max"
- float_tensor = tensor_util.MakeNdarray(input_node.attr["value"].tensor)
- min_value = np.min(float_tensor.flatten())
- max_value = np.max(float_tensor.flatten())
- # Make sure that the range includes zero.
- if min_value > 0.0:
- min_value = 0.0
- # min_value == max_value is a tricky case. It can occur for general
- # tensors, and of course for scalars. The quantized ops cannot deal
- # with this case, so we set max_value to something else.
- # It's a tricky question what is the numerically best solution to
- # deal with this degeneracy.
- # TODO(petewarden): Better use a tolerance than a hard comparison?
- if min_value == max_value:
- if abs(min_value) < 0.000001:
- max_value = min_value + 1.0
- elif min_value > 0:
- max_value = 2 * min_value
- else:
- max_value = min_value / 2.0
-
- sess = session.Session()
- with sess.as_default():
- quantize_op = array_ops.quantize_v2(
- float_tensor,
- min_value,
- max_value,
- dtypes.quint8,
- mode=quantization_mode)
- quint8_tensor = quantize_op[0].eval()
- shape = tensor_util.TensorShapeProtoToList(input_node.attr["value"]
- .tensor.tensor_shape)
- quint8_const_node = create_constant_node(
- quint8_const_name, quint8_tensor, dtypes.quint8, shape=shape)
- min_node = create_constant_node(min_name, min_value, dtypes.float32)
- max_node = create_constant_node(max_name, max_value, dtypes.float32)
- dequantize_node = create_node("Dequantize", input_node.name,
- [quint8_const_name, min_name, max_name])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", quantization_mode)
- return [quint8_const_node, min_node, max_node, dequantize_node]
-
-
-EightbitizeRecursionState = collections.namedtuple(
- "EightbitizeRecursionState",
- ["already_visited", "output_node_stack", "merged_with_fake_quant"])
-
-
-class GraphRewriter(object):
- """Takes a float graph, and rewrites it in quantized form."""
-
- def __init__(self,
- input_graph,
- mode,
- quantized_input_range,
- fallback_quantization_range=None):
- """Sets up the class to rewrite a float graph.
-
- Args:
- input_graph: A float graph to transform.
- mode: A string controlling how quantization is performed -
- round, quantize, eightbit, or weights.
- quantized_input_range: if set, assume the input is
- quantized and represents the range
- [quantized_input_range[0], quantized_input_range[1]]
- fallback_quantization_range: if set, then for nodes where the quantization
- range can't be inferred from the graph, use the range
- [fallback_quantization_range[0], fallback_quantization_range[1]) instead
- of using a RequantizationRange node in the graph.
-
- Raises:
- ValueError: Two nodes with the same name were found in the graph.
- """
- self.input_graph = input_graph
- self.nodes_map = self.create_nodes_map(input_graph)
- self.output_graph = None
- self.mode = mode
- self.final_node_renames = {}
- if quantized_input_range:
- self.input_range = (quantized_input_range[0], quantized_input_range[1])
- if self.input_range[0] >= self.input_range[1]:
- raise ValueError("Invalid quantized_input_range: [%s,%s]" %
- self.input_range)
- if self.mode != "eightbit":
- raise ValueError(
- "quantized_input_range can only be specified in eightbit mode")
- else:
- self.input_range = None
-
- if fallback_quantization_range:
- self.fallback_quantization_range = [
- fallback_quantization_range[0], fallback_quantization_range[1]
- ]
- if (self.fallback_quantization_range[0] >=
- self.fallback_quantization_range[1]):
- raise ValueError("Invalid fallback_quantization_range: [%s,%s]" %
- self.fallback_quantization_range)
- if self.mode != "eightbit":
- raise ValueError("fallback_quantization_range can only be "
- "specified in eightbit mode")
- else:
- self.fallback_quantization_range = None
-
- # Data that is valid only during the recursive call to rewrite the graph.
- self.state = None
-
- def create_nodes_map(self, graph):
- """Builds a mapping of node names to their defs from the graph."""
- nodes_map = {}
- for node in graph.node:
- if node.name not in nodes_map.keys():
- nodes_map[node.name] = node
- else:
- raise ValueError("Duplicate node names detected.")
- return nodes_map
-
- def rewrite(self, output_node_names):
- """Triggers rewriting of the float graph.
-
- Args:
- output_node_names: A list of names of the nodes that produce the final
- results.
-
- Returns:
- A quantized version of the float graph.
- """
- self.output_graph = graph_pb2.GraphDef()
- output_nodes = [
- self.nodes_map[output_node_name]
- for output_node_name in output_node_names
- ]
- if self.mode == "round":
- self.already_visited = {}
- for output_node in output_nodes:
- self.round_nodes_recursively(output_node)
- elif self.mode == "quantize":
- self.already_visited = {}
- self.already_quantized = {}
- for output_node in output_nodes:
- self.quantize_nodes_recursively(output_node)
- elif self.mode == "eightbit":
- self.set_input_graph(graph_util.remove_training_nodes(
- self.input_graph, protected_nodes=output_node_names))
- output_nodes = [
- self.nodes_map[output_node_name]
- for output_node_name in output_node_names
- ]
-
- self.state = EightbitizeRecursionState(
- already_visited={}, output_node_stack=[], merged_with_fake_quant={})
- for output_node in output_nodes:
- self.eightbitize_nodes_recursively(output_node)
- self.state = None
- if self.input_range:
- self.add_output_graph_node(
- create_constant_node("quantized_input_min_value", self.input_range[
- 0], dtypes.float32, []))
- self.add_output_graph_node(
- create_constant_node("quantized_input_max_value", self.input_range[
- 1], dtypes.float32, []))
- if self.fallback_quantization_range:
- self.add_output_graph_node(
- create_constant_node("fallback_quantization_min_value",
- self.fallback_quantization_range[0],
- dtypes.float32, []))
- self.add_output_graph_node(
- create_constant_node("fallback_quantization_max_value",
- self.fallback_quantization_range[1],
- dtypes.float32, []))
- if FLAGS.strip_redundant_quantization:
- self.output_graph = self.remove_redundant_quantization(
- self.output_graph)
- self.remove_dead_nodes(output_node_names)
- self.apply_final_node_renames()
- elif self.mode == "weights":
- self.output_graph = self.quantize_weights(self.input_graph,
- b"MIN_COMBINED")
- self.remove_dead_nodes(output_node_names)
- elif self.mode == "weights_rounded":
- self.output_graph = self.quantize_weights(self.input_graph, self.mode)
- self.remove_dead_nodes(output_node_names)
- else:
- print("Bad mode - " + self.mode + ".")
- return self.output_graph
-
- def round_nodes_recursively(self, current_node):
- """The entry point for simple rounding quantization."""
- if (current_node.name in self.already_visited
- ) and self.already_visited[current_node.name]:
- return
- self.already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.round_nodes_recursively(input_node)
- nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
- if any(current_node.op in s for s in nodes_to_quantize):
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- new_node.name = current_node.name + "_original"
- self.add_output_graph_node(new_node)
- levels = 1 << FLAGS.bitdepth
- constant_name = current_node.name + "_round_depth"
- constant_tensor = constant_op.constant(
- levels, dtype=dtypes.int32, name=constant_name)
- constant_node = constant_tensor.op.node_def
- self.add_output_graph_node(constant_node)
- quantize_node = node_def_pb2.NodeDef()
- quantize_node.op = "RoundToSteps"
- quantize_node.name = current_node.name
- quantize_node.input.extend([current_node.name + "_original"])
- quantize_node.input.extend([constant_node.name])
- self.add_output_graph_node(quantize_node)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- def quantize_nodes_recursively(self, current_node):
- """The entry point for quantizing nodes to eight bit and back."""
- if self.already_visited[current_node.name]:
- return
- self.already_visited[current_node.name] = True
- for input_node_name in current_node.input:
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.quantize_nodes_recursively(input_node)
- nodes_to_quantize = ["Conv2D", "BiasAdd", "MatMul"]
- if any(current_node.op in s for s in nodes_to_quantize):
- for input_name in current_node.input:
- input_name = node_name_from_input(input_name)
- input_node = self.nodes_map[input_name]
- self.quantize_node(input_node)
- self.quantize_node(current_node)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- def quantize_node(self, input_node):
- """Handles quantizing a single node."""
- input_name = input_node.name
- if input_name in self.already_quantized:
- return
- self.already_quantized[input_name] = True
- original_input_name = input_name + "_original"
- reshape_name = input_name + "_reshape"
- reshape_dims_name = input_name + "_reshape_dims"
- max_name = input_name + "_max"
- min_name = input_name + "_min"
- dims_name = input_name + "_dims"
- quantize_name = input_name + "_quantize"
- dequantize_name = input_name
- original_input_node = node_def_pb2.NodeDef()
- original_input_node.CopyFrom(input_node)
- original_input_node.name = original_input_name
- self.add_output_graph_node(original_input_node)
- reshape_dims_node = create_constant_node(reshape_dims_name, -1,
- dtypes.int32, [1])
- self.add_output_graph_node(reshape_dims_node)
- reshape_node = create_node("Reshape", reshape_name,
- [original_input_name, reshape_dims_name])
- set_attr_dtype(reshape_node, "T", dtypes.float32)
- self.add_output_graph_node(reshape_node)
- dims_node = create_constant_node(dims_name, 0, dtypes.int32, [1])
- self.add_output_graph_node(dims_node)
- max_node = create_node("Max", max_name, [reshape_name, dims_name])
- set_attr_dtype(max_node, "T", dtypes.float32)
- set_attr_bool(max_node, "keep_dims", False)
- self.add_output_graph_node(max_node)
- min_node = create_node("Min", min_name, [reshape_name, dims_name])
- set_attr_dtype(min_node, "T", dtypes.float32)
- set_attr_bool(min_node, "keep_dims", False)
- self.add_output_graph_node(min_node)
- quantize_node = create_node("Quantize", quantize_name,
- [original_input_name, min_name, max_name])
- set_attr_dtype(quantize_node, "T", dtypes.quint8)
- set_attr_string(quantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(quantize_node)
- dequantize_node = create_node("Dequantize", dequantize_name,
- [quantize_name, min_name, max_name])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- def should_merge_with_fake_quant_node(self):
- """Should the current node merge with self.state.output_node_stack[-1]?"""
- if not self.state.output_node_stack:
- return False
- top = self.state.output_node_stack[-1]
- return top[1] == 0 and top[0].op in ["FakeQuantWithMinMaxVars"]
-
- def should_quantize_const(self, node):
- if not self.state.output_node_stack:
- return False
- top = self.state.output_node_stack[-1]
- if not top[2]:
- return False
- dtype = dtypes.as_dtype(node.attr["dtype"].type)
- assert dtype == dtypes.float32, (
- "Failed to quantized constant %s of type %s" % (node.name, dtype))
- return True
-
- def eightbitize_nodes_recursively(self, current_node):
- """The entry point for transforming a graph into full eight bit."""
- if current_node.name in self.state.already_visited:
- if (self.should_merge_with_fake_quant_node() or
- current_node.name in self.state.merged_with_fake_quant):
- raise ValueError("Unsupported graph structure: output of node %s "
- "is processed by a FakeQuant* node and should have "
- "no other outputs.", current_node.name)
- return
- self.state.already_visited[current_node.name] = True
-
- for i, input_node_name in enumerate(current_node.input):
- quantize_input = False
- if current_node.op in ("MatMul", "Conv2D", "BiasAdd", "MaxPool",
- "AvgPool", "Relu", "Relu6",
- "BatchNormWithGlobalNormalization"):
- quantize_input = True
- elif current_node.op == "Concat" and i > 0:
- quantize_input = (
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32)
- elif current_node.op == "Reshape" and i == 0:
- quantize_input = (
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32)
-
- self.state.output_node_stack.append((current_node, i, quantize_input))
-
- input_node_name = node_name_from_input(input_node_name)
- input_node = self.nodes_map[input_node_name]
- self.eightbitize_nodes_recursively(input_node)
-
- self.state.output_node_stack.pop()
-
- if current_node.op == "MatMul":
- self.eightbitize_mat_mul_node(current_node)
- elif current_node.op == "Conv2D":
- self.eightbitize_conv_node(current_node)
- elif current_node.op == "BiasAdd":
- self.eightbitize_bias_add_node(current_node)
- elif current_node.op == "MaxPool" or current_node.op == "AvgPool":
- self.eightbitize_single_input_tensor_node(current_node,
- self.add_pool_function)
- elif current_node.op == "Relu" or current_node.op == "Relu6":
- self.eightbitize_single_input_tensor_node(current_node,
- self.add_relu_function)
- elif (current_node.op == "Concat" and
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32):
- self.eightbitize_concat_node(current_node)
- elif current_node.op == "BatchNormWithGlobalNormalization":
- self.eightbitize_batch_norm_node(current_node)
- elif (current_node.op == "Reshape" and
- dtypes.as_dtype(current_node.attr["T"].type) == dtypes.float32):
- self.eightbitize_reshape_node(current_node)
- elif (self.input_range and
- current_node.op in ("Placeholder", "PlaceholderV2")):
- self.eightbitize_placeholder_node(current_node)
- elif current_node.op == "FakeQuantWithMinMaxVars":
- # It will have been merged into the underlying node.
- pass
- elif current_node.op == "Const":
- if self.should_quantize_const(current_node):
- for n in quantize_weight_eightbit(current_node, b"MIN_FIRST"):
- self.add_output_graph_node(n)
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- ###################################################################
- # Note: if more cases are added here, you may need to update the op
- # name lists in the loop over children at the start of the function.
- ###################################################################
- else:
- new_node = node_def_pb2.NodeDef()
- new_node.CopyFrom(current_node)
- self.add_output_graph_node(new_node)
-
- if (self.should_merge_with_fake_quant_node() and
- current_node.name not in self.state.merged_with_fake_quant):
- raise ValueError(
- "FakeQuant* node %s failed to merge with node %s of type %s" %
- (self.state.output_node_stack[-1][0], current_node.name,
- current_node.op))
-
- def add_eightbit_prologue_nodes(self, original_node):
- """Adds input conversion nodes to handle quantizing the underlying node."""
- namespace_prefix = original_node.name + "_eightbit"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- input_names = []
- min_max_names = []
- for original_input_name in original_node.input:
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name,
- reduction_dims_name))
- input_names.append(quantize_input_name)
- min_max_names.append(min_input_name)
- min_max_names.append(max_input_name)
- all_input_names = []
- all_input_names.extend(input_names)
- all_input_names.extend(min_max_names)
- return all_input_names
-
- def add_common_quantization_nodes(self, namespace_prefix):
- """Builds constant nodes needed for quantization of inputs."""
- reshape_dims_name = namespace_prefix + "_reshape_dims"
- reduction_dims_name = namespace_prefix + "_reduction_dims"
-
- reshape_dims_node = create_constant_node(reshape_dims_name, -1,
- dtypes.int32, [1])
- self.add_output_graph_node(reshape_dims_node)
- reduction_dims_node = create_constant_node(reduction_dims_name, 0,
- dtypes.int32, [1])
- self.add_output_graph_node(reduction_dims_node)
- return reshape_dims_name, reduction_dims_name
-
- def eightbitize_input_to_node(self, namespace_prefix, original_input_name,
- reshape_dims_name, reduction_dims_name):
- """Takes one float input to an op, and converts it to quantized form."""
- unique_input_name = unique_node_name_from_input(original_input_name)
- reshape_input_name = namespace_prefix + "_reshape_" + unique_input_name
- min_input_name = namespace_prefix + "_min_" + unique_input_name
- max_input_name = namespace_prefix + "_max_" + unique_input_name
- quantize_input_name = namespace_prefix + "_quantize_" + unique_input_name
- reshape_input_node = create_node("Reshape", reshape_input_name,
- [original_input_name, reshape_dims_name])
- set_attr_dtype(reshape_input_node, "T", dtypes.float32)
- self.add_output_graph_node(reshape_input_node)
- min_input_node = create_node("Min", min_input_name,
- [reshape_input_name, reduction_dims_name])
- set_attr_dtype(min_input_node, "T", dtypes.float32)
- set_attr_bool(min_input_node, "keep_dims", False)
- self.add_output_graph_node(min_input_node)
- max_input_node = create_node("Max", max_input_name,
- [reshape_input_name, reduction_dims_name])
- set_attr_dtype(max_input_node, "T", dtypes.float32)
- set_attr_bool(max_input_node, "keep_dims", False)
- self.add_output_graph_node(max_input_node)
- quantize_input_node = create_node(
- "QuantizeV2", quantize_input_name,
- [original_input_name, min_input_name, max_input_name])
- set_attr_dtype(quantize_input_node, "T", dtypes.quint8)
- set_attr_string(quantize_input_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(quantize_input_node)
- min_output_name = quantize_input_name + ":1"
- max_output_name = quantize_input_name + ":2"
- return quantize_input_name, min_output_name, max_output_name
-
- def add_quantize_down_nodes(self, original_node, quantized_output_name):
- quantized_outputs = [
- quantized_output_name, quantized_output_name + ":1",
- quantized_output_name + ":2"
- ]
- min_max_inputs = None
- if self.should_merge_with_fake_quant_node():
- # Use the inputs to the FakeQuantWithMinMaxVars node as the inputs to
- # Requantize.
- fake_quant_node = self.state.output_node_stack[-1][0]
- min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
- assert original_node.name not in self.state.merged_with_fake_quant
- self.state.merged_with_fake_quant[original_node.name] = True
- elif self.fallback_quantization_range:
- min_max_inputs = [
- "fallback_quantization_min_value:0",
- "fallback_quantization_max_value:0"
- ]
- else:
- # Add a RequantizationRange node for finding the min and max values.
- requant_range_node = create_node(
- "RequantizationRange", original_node.name + "_eightbit_requant_range",
- quantized_outputs)
- set_attr_dtype(requant_range_node, "Tinput", dtypes.qint32)
- self.add_output_graph_node(requant_range_node)
- min_max_inputs = [
- requant_range_node.name + ":0", requant_range_node.name + ":1"
- ]
- requantize_node = create_node("Requantize",
- original_node.name + "_eightbit_requantize",
- quantized_outputs + min_max_inputs)
- set_attr_dtype(requantize_node, "Tinput", dtypes.qint32)
- set_attr_dtype(requantize_node, "out_type", dtypes.quint8)
- self.add_output_graph_node(requantize_node)
- return requantize_node.name
-
- def add_dequantize_result_node(self,
- quantized_output_name,
- original_node_name,
- min_tensor_index=1):
- min_max_inputs = [
- "%s:%s" % (quantized_output_name, min_tensor_index),
- "%s:%s" % (quantized_output_name, (min_tensor_index + 1))
- ]
- dequantize_name = original_node_name
- if self.should_merge_with_fake_quant_node():
- fake_quant_node = self.state.output_node_stack[-1][0]
- if original_node_name not in self.state.merged_with_fake_quant:
- min_max_inputs = [fake_quant_node.input[1], fake_quant_node.input[2]]
- self.state.merged_with_fake_quant[original_node_name] = True
- dequantize_name = fake_quant_node.name
-
- dequantize_node = create_node(
- "Dequantize", dequantize_name,
- [quantized_output_name, min_max_inputs[0], min_max_inputs[1]])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- def eightbitize_mat_mul_node(self, original_node):
- """Replaces a MatMul node with the eight bit equivalent sub-graph."""
- quantized_mat_mul_name = original_node.name + "_eightbit_quantized_mat_mul"
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_mat_mul_node = create_node("QuantizedMatMul",
- quantized_mat_mul_name,
- all_input_names)
- set_attr_dtype(quantized_mat_mul_node, "T1", dtypes.quint8)
- set_attr_dtype(quantized_mat_mul_node, "T2", dtypes.quint8)
- set_attr_dtype(quantized_mat_mul_node, "Toutput", dtypes.qint32)
- copy_attr(quantized_mat_mul_node, "transpose_a",
- original_node.attr["transpose_a"])
- copy_attr(quantized_mat_mul_node, "transpose_b",
- original_node.attr["transpose_b"])
- self.add_output_graph_node(quantized_mat_mul_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_mat_mul_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_conv_node(self, original_node):
- """Replaces a Conv2D node with the eight bit equivalent sub-graph."""
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_conv_name = original_node.name + "_eightbit_quantized_conv"
- quantized_conv_node = create_node("QuantizedConv2D", quantized_conv_name,
- all_input_names)
- copy_attr(quantized_conv_node, "strides", original_node.attr["strides"])
- copy_attr(quantized_conv_node, "padding", original_node.attr["padding"])
- set_attr_dtype(quantized_conv_node, "Tinput", dtypes.quint8)
- set_attr_dtype(quantized_conv_node, "Tfilter", dtypes.quint8)
- set_attr_dtype(quantized_conv_node, "out_type", dtypes.qint32)
- self.add_output_graph_node(quantized_conv_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_conv_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_bias_add_node(self, original_node):
- """Replaces a BiasAdd node with the eight bit equivalent sub-graph."""
- quantized_bias_add_name = (
- original_node.name + "_eightbit_quantized_bias_add")
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_bias_add_node = create_node("QuantizedBiasAdd",
- quantized_bias_add_name,
- all_input_names)
- set_attr_dtype(quantized_bias_add_node, "T1", dtypes.quint8)
- set_attr_dtype(quantized_bias_add_node, "T2", dtypes.quint8)
- set_attr_dtype(quantized_bias_add_node, "out_type", dtypes.qint32)
- self.add_output_graph_node(quantized_bias_add_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_bias_add_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def eightbitize_single_input_tensor_node(self, original_node,
- add_op_function):
- """Replaces a single-tensor node with the eight bit equivalent sub-graph.
-
- Converts a node like this:
-
- Shape(f) Input(f)
- | |
- +--------v v
- Operation
- |
- v
- (f)
-
- Into a quantized equivalent:
-
- Input(f) ReshapeDims
- +------v v-------------+
- | Reshape
- | |
- | | ReductionDims
- | +-----+ |
- | | +---c---------+
- | v v v v-------+
- | Min Max
- | +----+ |
- v v v--------+
- Quantize
- |
- v
- QuantizedOperation
- | | |
- v v v
- Dequantize
- |
- v
- (f)
-
-
- Args:
- original_node: Float node to be converted.
- add_op_function: Function to create the actual node.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- quantized_op_name = original_node.name + "_eightbit_quantized"
- quantized_op_type = "Quantized" + original_node.op
- all_input_names = self.add_eightbit_prologue_nodes(original_node)
- quantized_op_node = create_node(quantized_op_type, quantized_op_name,
- all_input_names)
- add_op_function(original_node, quantized_op_node)
- self.add_output_graph_node(quantized_op_node)
- self.add_dequantize_result_node(quantized_op_name, original_node.name)
-
- def add_pool_function(self, original_node, quantized_op_node):
- set_attr_dtype(quantized_op_node, "T", dtypes.quint8)
- copy_attr(quantized_op_node, "ksize", original_node.attr["ksize"])
- copy_attr(quantized_op_node, "strides", original_node.attr["strides"])
- copy_attr(quantized_op_node, "padding", original_node.attr["padding"])
-
- def add_relu_function(self, unused_arg_node, quantized_op_node):
- set_attr_dtype(quantized_op_node, "Tinput", dtypes.quint8)
-
- def eightbitize_concat_node(self, original_node):
- """Replaces a Concat node with the eight bit equivalent sub-graph.
-
- Converts a node like this:
-
- Shape(f) Input0(f) Input1(f)
- | | |
- +--------v v v----------+
- Concat
- |
- v
- (f)
-
- Into a quantized equivalent:
-
- Shape(f) Input0(f) ReshapeDims Input1(f)
- | +------v v--------------+------------------v v------+
- | | Reshape Reshape |
- | | | | |
- | | | ReductionDims | |
- | | +------+ | +--------+ |
- | | | +---c---------+-----------c-----+ | |
- | | +v v v v-------+---------v v v v+ |
- | | Min Max Min Max |
- | | +----+ | | +-----+ |
- | v v v--------+ +----------v v v
- | Quantize Quantize
- | +------------------+ +----------------------+
- +-------------------------------+ | |
- v v v
- QuantizedConcat
- | | |
- v v v
- Dequantize
- |
- v
- (f)
- Args:
- original_node: Float node to be converted.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- namespace_prefix = original_node.name + "_eightbit"
- quantized_concat_name = namespace_prefix + "_quantized_concat"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- shape_input_name = original_node.input[0]
- original_inputs = original_node.input[1:]
- input_names = []
- min_names = []
- max_names = []
- for original_input_name in original_inputs:
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name,
- reduction_dims_name))
- input_names.append(quantize_input_name)
- min_names.append(min_input_name)
- max_names.append(max_input_name)
- all_input_names = [shape_input_name]
- all_input_names.extend(input_names)
- all_input_names.extend(min_names)
- all_input_names.extend(max_names)
- quantized_concat_node = create_node("QuantizedConcat",
- quantized_concat_name, all_input_names)
- set_attr_int(quantized_concat_node, "N", len(original_inputs))
- set_attr_dtype(quantized_concat_node, "T", dtypes.quint8)
- self.add_output_graph_node(quantized_concat_node)
- self.add_dequantize_result_node(quantized_concat_name, original_node.name)
-
- def eightbitize_placeholder_node(self, current_node):
- """Replaces a placeholder node with a quint8 placeholder node+dequantize."""
- name = current_node.name
-
- # Convert the placeholder into a quantized type.
- output_node = node_def_pb2.NodeDef()
- output_node.CopyFrom(current_node)
- set_attr_dtype(output_node, "dtype", dtypes.quint8)
- output_node.name += "_original_input"
- self.add_output_graph_node(output_node)
-
- # Add a dequantize to convert back to float.
- dequantize_node = create_node("Dequantize", name, [
- output_node.name, "quantized_input_min_value",
- "quantized_input_max_value"
- ])
- set_attr_dtype(dequantize_node, "T", dtypes.quint8)
- set_attr_string(dequantize_node, "mode", b"MIN_FIRST")
- self.add_output_graph_node(dequantize_node)
-
- # For the descent over the graph to work, the dequantize node must be named
- # current_node.name. However, for the feeding of the graph to work, the
- # placeholder must have the name current_node.name; so record a final set
- # of renames to apply after all processing has been done.
- self.final_node_renames[output_node.name] = name
- self.final_node_renames[dequantize_node.name] = name + "_dequantize"
-
- def eightbitize_reshape_node(self, original_node):
- """Replaces a Reshape node with the eight bit equivalent sub-graph.
-
- Args:
- original_node: Float node to be converted.
-
- Returns:
- Subgraph representing the quantized version of the original node.
-
- """
- namespace_prefix = original_node.name + "_eightbit"
- quantized_reshape_name = namespace_prefix + "_quantized_reshape"
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- shape_input_name = original_node.input[1]
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_node.input[0],
- reshape_dims_name, reduction_dims_name))
- quantized_reshape_node = create_node(
- "QuantizedReshape", quantized_reshape_name,
- [quantize_input_name, shape_input_name, min_input_name, max_input_name])
- set_attr_dtype(quantized_reshape_node, "T", dtypes.quint8)
- self.add_output_graph_node(quantized_reshape_node)
- self.add_dequantize_result_node(quantized_reshape_name, original_node.name)
-
- def eightbitize_batch_norm_node(self, original_node):
- """Replaces a MatMul node with the eight bit equivalent sub-graph."""
- namespace_prefix = original_node.name + "_eightbit"
- original_input_name = original_node.input[0]
- original_mean_name = original_node.input[1]
- original_variance_name = original_node.input[2]
- original_beta_name = original_node.input[3]
- original_gamma_name = original_node.input[4]
- quantized_batch_norm_name = namespace_prefix + "_quantized_batch_norm"
-
- reshape_dims_name, reduction_dims_name = self.add_common_quantization_nodes(
- namespace_prefix)
- quantize_input_name, min_input_name, max_input_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_input_name,
- reshape_dims_name, reduction_dims_name))
- quantize_mean_name, min_mean_name, max_mean_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_mean_name,
- reshape_dims_name, reduction_dims_name))
- quantize_variance_name, min_variance_name, max_variance_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_variance_name,
- reshape_dims_name, reduction_dims_name))
- quantize_beta_name, min_beta_name, max_beta_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_beta_name,
- reshape_dims_name, reduction_dims_name))
- quantize_gamma_name, min_gamma_name, max_gamma_name = (
- self.eightbitize_input_to_node(namespace_prefix, original_gamma_name,
- reshape_dims_name, reduction_dims_name))
- quantized_batch_norm_node = create_node(
- "QuantizedBatchNormWithGlobalNormalization", quantized_batch_norm_name,
- [
- quantize_input_name, min_input_name, max_input_name,
- quantize_mean_name, min_mean_name, max_mean_name,
- quantize_variance_name, min_variance_name, max_variance_name,
- quantize_beta_name, min_beta_name, max_beta_name,
- quantize_gamma_name, min_gamma_name, max_gamma_name
- ])
- set_attr_dtype(quantized_batch_norm_node, "Tinput", dtypes.quint8)
- set_attr_dtype(quantized_batch_norm_node, "out_type", dtypes.qint32)
- copy_attr(quantized_batch_norm_node, "scale_after_normalization",
- original_node.attr["scale_after_normalization"])
- copy_attr(quantized_batch_norm_node, "variance_epsilon",
- original_node.attr["variance_epsilon"])
- self.add_output_graph_node(quantized_batch_norm_node)
- quantize_down_name = self.add_quantize_down_nodes(original_node,
- quantized_batch_norm_name)
- self.add_dequantize_result_node(quantize_down_name, original_node.name)
-
- def add_output_graph_node(self, output_node):
- """Inserts one node into the new graph."""
- self.output_graph.node.extend([output_node])
-
- def remove_redundant_quantization(self, old_graph):
- """Removes unneeded pairs of quantize/dequantize ops from the graph.
-
- This is a bit of a tricky function, because it's attempting to spot the
- pattern of dequantizing from eight-bit up to float, and then immediately
- quantizing back down to eight bits again, that's introduced by previous
- passes that do 'key-hole' conversions of individual nodes but have to
- convert back to float to match the previous output interface, since they
- don't know that the next op can handle quantized tensors.
- It works by:
- - Looking for Quantize nodes.
- - Checking to see if their first input is a Dequantize node.
- - Seeing if their min/max inputs come from Min/Max nodes.
- - Making sure those Min/Max nodes are being fed from the same Dequantize.
- - Or that the Min is indirectly being fed from the same Dequantize as Max.
- - Making sure the Dequantize is going through a Reshape (which we add
- during the previous pass when we create the quantize sub-graph).
- - Looking for the dims Const op for the Min/Max dims.
- If all of these conditions are met, then it's a sub-graph pattern that
- we know how to optimize out (and is likely the common one we've introduced).
- We then rewire the graph to skip it entirely, and then rely on the dead node
- removal pass to get rid of any nodes that are no longer needed.
-
- Args:
- old_graph: The model we'll be stripping redundant nodes from.
-
- Returns:
- A graph with the unnecessary nodes removed.
-
- Raises:
- ValueError: Two nodes with the same name were found in the graph.
- """
- old_nodes_map = self.create_nodes_map(old_graph)
- self.output_graph = graph_pb2.GraphDef()
- inputs_to_rename = {}
- # We go through all the nodes, looking for any that match the patterns we
- # know how to optimize away.
- for node in old_graph.node:
- # We always start with a Quantize node, and examine its inputs to see if
- # they are in a form that can be removed.
- if node.op not in ["Quantize", "QuantizeV2"]:
- continue
- dequantize_node_name = node_name_from_input(node.input[0])
- if dequantize_node_name not in old_nodes_map:
- raise ValueError("Input node name '" + dequantize_node_name +
- "' not found in node '" + node.name + "'")
- dequantize_node = old_nodes_map[dequantize_node_name]
- # Do we have a Dequantize feeding in, with the same type as the Quantize?
- if dequantize_node.op != "Dequantize":
- continue
- if node.attr["T"] != dequantize_node.attr["T"]:
- continue
- # Now look at the other inputs, and ensure they're Min/Max nodes.
- min_node_name = node_name_from_input(node.input[1])
- max_node_name = node_name_from_input(node.input[2])
- min_node = old_nodes_map[min_node_name]
- max_node = old_nodes_map[max_node_name]
- is_min_right_type = (min_node.op in ["Min", "Dequantize"])
- is_max_right_type = (max_node.op in ["Max", "Dequantize"])
- if not is_min_right_type or not is_max_right_type:
- print("Didn't find expected types on inputs : %s, %s." % (min_node.op,
- max_node.op))
- continue
- min_node_input_name = node_name_from_input(min_node.input[0])
- max_node_input_name = node_name_from_input(max_node.input[0])
- # There are two different patterns for Min nodes we can recognize, one
- # where the input comes directly from the same one as the Max, and
- # another where we run it through another Min first, so check for both.
- is_same_input = False
- if min_node_input_name == max_node_input_name:
- is_same_input = True
- else:
- first_min_node_input = old_nodes_map[min_node_input_name]
- if first_min_node_input.op == "Concat":
- second_min_node_name = node_name_from_input(
- first_min_node_input.input[1])
- second_min_node = old_nodes_map[second_min_node_name]
- if second_min_node.op == "Min":
- second_min_node_input_name = node_name_from_input(
- second_min_node.input[0])
- is_same_input = (second_min_node_input_name == max_node_input_name)
- if not is_same_input:
- print("Different min/max inputs: " + min_node_input_name)
- continue
- # We recognize this pattern, so mark the graph edges to be rewired to
- # route around it entirely, since we know it's a no-op.
- dequantize_source_name = node_name_from_input(dequantize_node.input[0])
- node_tensor_name = ensure_tensor_name_has_port(node.name)
- min_tensor_name = node.name + ":1"
- max_tensor_name = node.name + ":2"
- inputs_to_rename[node_tensor_name] = dequantize_source_name
- inputs_to_rename[min_tensor_name] = dequantize_node.input[1]
- inputs_to_rename[max_tensor_name] = dequantize_node.input[2]
- # Finally we apply all the rewiring we've marked to the graph.
- for node in old_graph.node:
- for index, input_full_name in enumerate(node.input):
- input_name = ensure_tensor_name_has_port(input_full_name)
- if input_name in inputs_to_rename:
- node.input[index] = inputs_to_rename[input_name]
- self.add_output_graph_node(node)
- return self.output_graph
-
- def apply_final_node_renames(self):
- """Applies node renames in self.final_node_renames to self.output_graph."""
- old_graph = self.output_graph
- self.output_graph = graph_pb2.GraphDef()
- for node in old_graph.node:
- node.name = self.final_node_renames.get(node.name, node.name)
- for index, input_name in enumerate(node.input):
- node_name = node_name_from_input(input_name)
- input_full_name = ensure_tensor_name_has_port(input_name)
- if node_name in self.final_node_renames:
- node.input[index] = "%s%s" % (self.final_node_renames[node_name],
- input_full_name[len(node_name):])
- self.add_output_graph_node(node)
- return self.output_graph
-
- def remove_dead_nodes(self, output_names):
- """Removes nodes that are no longer needed for inference from the graph."""
- old_output_graph = self.output_graph
- self.output_graph = graph_util.extract_sub_graph(old_output_graph,
- output_names)
-
- def quantize_weights(self, input_graph, quantization_mode):
- """Quantize float Const ops.
-
- There are two modes of operations, both replace float Const ops with
- quantized values.
- 1. If quantization_mode is "weights_rounded", this function replaces float
- Const ops with quantized float Const ops - same as the original op, but
- float values being mapped to the center of one of 1<<FLAGS.bitdepth buckets.
- This does not change the raw model size, but compression algorithms such as
- zip (as used for compressing apks) or bzip2 will achieve a very good
- compression ratio.
- 2. For other quantization modes ("MIN_COMBINED" or "MIN_FIRST"), float
- Const ops are quantized and replaced by a tuple of four ops to perform
- the dequantization at runtime:
- * eight-bit Const (bucket indices, same shape as original float Const op
- * two float Const ops (min and max value of original float Const op)
- * Dequantize op to convert the eight-bit consts to float tensors.
- The quantization mode is important because we see accuracy problems when
- quantizing weights for different situations depending on the algorithm
- used. We haven't figured out exactly what the underlying cause is yet,
- unfortunately.
-
- Args:
- input_graph: A GraphDef of the model containing float Const ops.
- quantization_mode: How to quantize and dequantize the values.
-
- Returns:
- A GraphDef of the converted graph.
-
- Raises:
- ValueError: If quantization_mode is unsupported.
- """
- output_graph = graph_pb2.GraphDef()
- for input_node in input_graph.node:
- should_quantize = False
- if input_node.op == "Const":
- dtype = dtypes.as_dtype(input_node.attr["dtype"].type)
- if dtype == dtypes.float32:
- should_quantize = True
- if should_quantize:
- if quantization_mode == "weights_rounded":
- output_graph.node.extend(quantize_weight_rounded(input_node))
- elif quantization_mode in (b"MIN_COMBINED", b"MIN_FIRST"):
- output_graph.node.extend(
- quantize_weight_eightbit(input_node, quantization_mode))
- else:
- raise ValueError("Unsupported quantization mode %s." %
- quantization_mode)
- else:
- output_node = node_def_pb2.NodeDef()
- output_node.CopyFrom(input_node)
- output_graph.node.extend([output_node])
- return output_graph
-
- def set_input_graph(self, new_input_graph):
- self.input_graph = new_input_graph
- self.nodes_map = self.create_nodes_map(self.input_graph)
-
-
-def main(unused_args):
- if not gfile.Exists(FLAGS.input):
- print("Input graph file '" + FLAGS.input + "' does not exist!")
- return -1
-
- known_modes = [
- "round", "quantize", "eightbit", "weights", "test", "weights_rounded"
- ]
- if not any(FLAGS.mode in s for s in known_modes):
- print("mode is '" + FLAGS.mode + "', not in " + ", ".join(known_modes) +
- ".")
- return -1
-
- tf_graph = graph_pb2.GraphDef()
- with gfile.Open(FLAGS.input, "rb") as f:
- data = f.read()
- tf_graph.ParseFromString(data)
-
- graph = ops.Graph()
- with graph.as_default():
- importer.import_graph_def(tf_graph, input_map={}, name="")
-
- quantized_input_range = None
- if FLAGS.quantized_input:
- quantized_input_range = [
- FLAGS.quantized_input_min, FLAGS.quantized_input_max
- ]
-
- fallback_quantization_range = None
- if (FLAGS.quantized_fallback_min is not None or
- FLAGS.quantized_fallback_max is not None):
- assert FLAGS.quantized_fallback_min is not None
- assert FLAGS.quantized_fallback_max is not None
- fallback_quantization_range = [
- FLAGS.quantized_fallback_min, FLAGS.quantized_fallback_max
- ]
-
- rewriter = GraphRewriter(tf_graph, FLAGS.mode, quantized_input_range,
- fallback_quantization_range)
-
- output_graph = rewriter.rewrite(FLAGS.output_node_names.split(","))
-
- f = gfile.FastGFile(FLAGS.output, "wb")
- f.write(output_graph.SerializeToString())
-
- return 0
-
-
-if __name__ == "__main__":
- app.run()
diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py
deleted file mode 100644
index 92bb5127da..0000000000
--- a/tensorflow/tools/quantization/quantize_graph_test.py
+++ /dev/null
@@ -1,966 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests the graph quantization script.
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-import numpy as np
-
-from tensorflow.core.framework import graph_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import graph_util
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops as ops_lib
-from tensorflow.python.platform import flags as flags_lib
-from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
-from tensorflow.tools.quantization import quantize_graph
-
-flags = flags_lib
-FLAGS = flags.FLAGS
-
-
-def run_graph_def(graph_def, input_map, outputs):
- graph = ops_lib.Graph()
- with graph.as_default():
- importer.import_graph_def(graph_def, input_map={}, name="")
- with session.Session(graph=graph) as sess:
- results = sess.run(outputs, feed_dict=input_map)
- return results
-
-
-def test_mat_mul(m, n, k, a, b):
- """Tests a MatMul replacement."""
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- mat_mul_name = "mat_mul"
-
- float_graph_def = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=a, dtype=dtypes.float32, shape=[m, k])
- float_graph_def.node.extend([a_constant])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=b, dtype=dtypes.float32, shape=[k, n])
- float_graph_def.node.extend([b_constant])
- mat_mul_node = quantize_graph.create_node("MatMul", mat_mul_name,
- [a_constant_name, b_constant_name])
- quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32)
- quantize_graph.set_attr_bool(mat_mul_node, "transpose_a", False)
- quantize_graph.set_attr_bool(mat_mul_node, "transpose_b", False)
- float_graph_def.node.extend([mat_mul_node])
-
- test_graph(float_graph_def, {}, [mat_mul_name])
-
-
-def test_conv(depth, image_width, image_height, image_batch_count, filter_size,
- filter_count, stride, padding, input_values, filter_values):
- """Tests a Conv replacement."""
- input_constant_name = "input_constant"
- filter_constant_name = "filter_constant"
- conv_name = "conv"
-
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=input_values,
- dtype=dtypes.float32,
- shape=[image_batch_count, image_height, image_width, depth])
- float_graph_def.node.extend([input_constant])
- filter_constant = quantize_graph.create_constant_node(
- filter_constant_name,
- value=filter_values,
- dtype=dtypes.float32,
- shape=[filter_size, filter_size, depth, filter_count])
- float_graph_def.node.extend([filter_constant])
- conv_node = quantize_graph.create_node(
- "Conv2D", conv_name, [input_constant_name, filter_constant_name])
- quantize_graph.set_attr_dtype(conv_node, "T", dtypes.float32)
- quantize_graph.set_attr_int_list(conv_node, "strides", [1, stride, stride, 1])
- quantize_graph.set_attr_string(conv_node, "padding", padding)
- float_graph_def.node.extend([conv_node])
-
- test_graph(float_graph_def, {}, [conv_name])
-
-
-def are_tensors_near(a, b, tolerance):
- """Tests whether two tensors are nearly identical.
-
- This is a specialized comparison function designed to help debug problems with
- quantization. It prints out information about the differences between tensors
- on failure, paying special attention to possible biases by looking at the mean
- and absolute average errors.
-
- Args:
- a: First comparison tensor.
- b: Second comparison tensor.
- tolerance: Float value indicating how large an error between values is ok.
-
- Returns:
- Boolean indicating whether the two inputs were close enough.
- """
- flat_a = a.flatten()
- flat_b = b.flatten()
- if len(flat_a) != len(flat_b):
- tf_logging.info("Tensors are different sizes: " + str(len(flat_a)) + " vs "
- + str(len(flat_b)))
- return False
- value_count = len(flat_a)
- how_many_different = 0
- total_difference = 0
- total_abs_difference = 0
- for index in range(value_count):
- a_value = flat_a[index]
- b_value = flat_b[index]
- difference = a_value - b_value
- total_difference += difference
- total_abs_difference += abs(difference)
- if abs(difference) > tolerance:
- how_many_different += 1
- mean_difference = total_difference / value_count
- mean_abs_difference = total_abs_difference / value_count
- proportion_different = (how_many_different * 1.0) / value_count
- if how_many_different == 0:
- return True
- else:
- tf_logging.info("Tensors have {0} different values ({1}%), with mean"
- " difference {2} and mean absolute difference {3}".format(
- how_many_different, proportion_different * 100,
- mean_difference, mean_abs_difference))
- return False
-
-
-def get_top_value(input_values):
- max_value = None
- max_index = None
- for index, value in enumerate(input_values.flatten()):
- if max_value is None or value > max:
- max_value = value
- max_index = index
- return max_index, max_value
-
-
-def test_graph(float_graph_def, input_map, output_names, log_graph=False):
- """Runs the float graph through the rewriter and tests the results."""
- float_results = run_graph_def(
- float_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- # TODO(petewarden): round test is currently failing because there is no
- # RoundToSteps op available.
- # round_rewriter = quantize_graph.GraphRewriter(float_graph_def, "round")
- # round_graph_def = round_rewriter.rewrite(output_name)
- # round_results = run_graph_def(round_graph_def, input_map,
- # [output_name + ":0"])
- # assert are_tensors_near(expected, round_results[0], 1.0)
- #
- # TODO(petewarden): Add test for "quantize" mode.
-
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite(output_names)
- eightbit_results = run_graph_def(
- eightbit_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- for expected, result in zip(float_results, eightbit_results):
- assert are_tensors_near(expected, result, 1.0)
-
- if log_graph:
- tf_logging.info("8bit:\n%s", str(eightbit_graph_def))
-
- # Test the weights_rounded mode. This uses the default bit_depth.
- weights_rounded_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "weights_rounded", quantized_input_range=None)
- weights_rounded_graph_def = weights_rounded_rewriter.rewrite(output_names)
- weights_rounded_results = run_graph_def(
- weights_rounded_graph_def, input_map,
- [output_name + ":0" for output_name in output_names])
- for expected, result in zip(float_results, weights_rounded_results):
- assert are_tensors_near(expected, result, 1.0)
-
-
-class QuantizeGraphTest(test.TestCase):
-
- def test_negative_const_problem(self):
- shape_constant_name = "shape_constant"
- shape_constant = quantize_graph.create_constant_node(
- shape_constant_name, value=-0.8, dtype=dtypes.float32, shape=[1])
- quantization_result = quantize_graph.quantize_weight_eightbit(
- shape_constant, b"MIN_COMBINED")
- self.assertEqual(4, len(quantization_result))
-
- def test_odd_padding_problem(self):
- """Tests one error case we ran into in a real graph."""
- test_conv(1, 4, 4, 1, 3, 1, 2, b"SAME",
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
- [1, 2, 3, 4, 5, 6, 7, 8, 9])
-
- def test_mat_mul_tiny(self):
- # These tests are added to test the generate case where
- # min(matrix) == max(matrix), which used to cause problems.
- test_mat_mul(1, 1, 1, [2], [3])
- test_mat_mul(1, 2, 1, [1], [2, 3])
- test_mat_mul(1, 1, 2, [1, 1], [1, 1])
- test_mat_mul(1, 1, 2, [0, 0], [1, 1])
- # The general case.
- test_mat_mul(1, 1, 2, [1, 2], [1, 2])
-
- def test_mat_mul_small(self):
- test_mat_mul(2, 4, 3, [1, 2, 3, 4, 5, 6],
- [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])
-
- def test_conv(self):
- test_conv(1, 4, 3, 1, 3, 1, 1, b"SAME",
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- [1, 4, 7, 2, 5, 8, 3, 6, 9])
-
- def test_reshape(self):
- """Tests that MatMul->Reshape->MatMul avoids extra quantize/dequantize."""
-
- def make_matmul(name, a, b):
- n = quantize_graph.create_node("MatMul", name, [a.name, b.name])
- quantize_graph.set_attr_dtype(n, "T", dtypes.float32)
- quantize_graph.set_attr_bool(n, "transpose_a", False)
- quantize_graph.set_attr_bool(n, "transpose_b", False)
- return n
-
- # matmul_1 = input*weight_1
- input_node = quantize_graph.create_constant_node(
- "input", value=[0, 1, 2, 3], dtype=dtypes.float32, shape=[4, 1])
- weight_1_node = quantize_graph.create_constant_node(
- "weight_1",
- value=[.5, .6, .7, .8, .9],
- dtype=dtypes.float32,
- shape=[1, 5])
- matmul_1_node = make_matmul("matmul_1", input_node, weight_1_node)
-
- # Reshape 4x5 to 10x2.
- new_shape_node = quantize_graph.create_constant_node(
- "new_shape_node", value=[10, 2], dtype=dtypes.int32, shape=[2])
- reshape_node = quantize_graph.create_node(
- "Reshape", "reshape", [matmul_1_node.name, new_shape_node.name])
- quantize_graph.set_attr_dtype(reshape_node, "T", dtypes.float32)
-
- # matmul_2_node = reshape*weight_2
- weight_2_node = quantize_graph.create_constant_node(
- "weight_2", value=[1.5, 2.5], dtype=dtypes.float32, shape=[2, 1])
- matmul_2_node = make_matmul("matmul_2", reshape_node, weight_2_node)
-
- g = graph_pb2.GraphDef()
- g.node.extend([
- input_node, weight_1_node, matmul_1_node, new_shape_node, reshape_node,
- weight_2_node, matmul_2_node
- ])
-
- # Test the graph
- test_graph(g, {}, ["matmul_2"])
-
- # Verify there is only one Quantize and one Requantize op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- g, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite(["matmul_2"])
-
- ops = [node.op for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(1, ops.count("QuantizedReshape"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- def test_quantize_array(self):
- # Test invalid parameters (empty array, or 0 buckets.
- self.assertRaises(ValueError, quantize_graph.quantize_array, np.array([]),
- 2)
- self.assertRaises(ValueError, quantize_graph.quantize_array,
- np.array([1, 2]), 0)
- # Test input array of length 1.
- arr = np.array([1])
- qarr = quantize_graph.quantize_array(arr, 1)
- self.assertEqual(arr, qarr)
- qarr = quantize_graph.quantize_array(arr, 2)
- self.assertEqual(arr, qarr)
- # Test input array with all elements equal.
- arr = np.array([1, 1, 1])
- qarr = quantize_graph.quantize_array(arr, 10)
- self.assertTrue((np.array([1, 1, 1]) == qarr).all())
- # Test "normal" input arrays.
- arr = np.array([0, 0.3, 0.6, 1])
- qarr = quantize_graph.quantize_array(arr, 1)
- self.assertTrue((np.array([0.5, 0.5, 0.5, 0.5]) == qarr).all())
- qarr = quantize_graph.quantize_array(arr, 2)
- self.assertTrue((np.array([0.25, 0.25, 0.75, 0.75]) == qarr).all())
- qarr = quantize_graph.quantize_array(arr.reshape((2, 2)), 2)
- self.assertTrue((np.array([[0.25, 0.25], [0.75, 0.75]]) == qarr).all())
-
- def test_non_float_concat(self):
- concat_dim = quantize_graph.create_constant_node(
- "concat_dim", value=0, dtype=dtypes.int32, shape=[])
- a = quantize_graph.create_constant_node(
- "a",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- b = quantize_graph.create_constant_node(
- "b",
- value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- concat = quantize_graph.create_node("Concat", "concat",
- [concat_dim.name, a.name, b.name])
- quantize_graph.set_attr_int(concat, "N", 2)
- quantize_graph.set_attr_dtype(concat, "T", dtypes.int32)
-
- g = graph_pb2.GraphDef()
- g.node.extend([concat_dim, a, b, concat])
- test_graph(g, {}, [concat.name])
-
- def test_non_float_reshape(self):
- a = quantize_graph.create_constant_node(
- "a",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.int32,
- shape=[2, 2, 3])
- shape = quantize_graph.create_constant_node(
- "shape", value=[12], dtype=dtypes.int32, shape=[1])
- reshape = quantize_graph.create_node("Reshape", "reshape",
- [a.name, shape.name])
- quantize_graph.set_attr_dtype(reshape, "T", dtypes.int32)
-
- g = graph_pb2.GraphDef()
- g.node.extend([a, shape, reshape])
- test_graph(g, {}, [reshape.name])
-
- def test_concat(self):
- shape_constant_name = "shape_constant"
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- concat_name = "concat"
-
- float_graph_def = graph_pb2.GraphDef()
- shape_constant = quantize_graph.create_constant_node(
- shape_constant_name, value=0, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([shape_constant])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 2, 3])
- float_graph_def.node.extend([a_constant])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name,
- value=[13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
- dtype=dtypes.float32,
- shape=[2, 2, 3])
- float_graph_def.node.extend([b_constant])
- concat_node = quantize_graph.create_node(
- "Concat", concat_name,
- [shape_constant_name, a_constant_name, b_constant_name])
- quantize_graph.set_attr_int(concat_node, "N", 2)
- quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
- float_graph_def.node.extend([concat_node])
-
- test_graph(float_graph_def, {}, [concat_name])
-
- # Verify the concat is quantized.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite([concat_name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- self.assertEqual(1, ops.count("QuantizedConcat"))
-
- def test_multiple_outputs(self):
- input_constant_name = "input_constant"
- split_constant_name = "split_constant"
- split_name = "split"
- concat_constant_name = "concat_constant"
- concat_name = "concat"
-
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 6])
- float_graph_def.node.extend([input_constant])
- split_constant = quantize_graph.create_constant_node(
- split_constant_name, value=1, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([split_constant])
- split_node = quantize_graph.create_node(
- "Split", split_name, [split_constant_name, input_constant_name])
- quantize_graph.set_attr_int(split_node, "num_split", 2)
- quantize_graph.set_attr_dtype(split_node, "T", dtypes.float32)
- float_graph_def.node.extend([split_node])
- concat_constant = quantize_graph.create_constant_node(
- concat_constant_name, value=1, dtype=dtypes.int32, shape=[])
- float_graph_def.node.extend([concat_constant])
- concat_node = quantize_graph.create_node(
- "Concat", concat_name,
- [concat_constant_name, split_name + ":0", split_name + ":1"])
- quantize_graph.set_attr_int(concat_node, "N", 2)
- quantize_graph.set_attr_dtype(concat_node, "T", dtypes.float32)
- float_graph_def.node.extend([concat_node])
-
- test_graph(float_graph_def, {}, [concat_name])
-
- def test_node_name_from_input(self):
- self.assertEqual("SomeName",
- quantize_graph.node_name_from_input("^SomeName:2"))
-
- def test_unique_node_name_from_input(self):
- self.assertEqual("__hat__SomeName__port__2",
- quantize_graph.unique_node_name_from_input("^SomeName:2"))
-
- def test_identity(self):
- input_constant_name = "input_constant"
- identity_name = "identity"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[2, 6])
- float_graph_def.node.extend([input_constant])
- identity_node = quantize_graph.create_node("Identity", identity_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(identity_node, "T", dtypes.float32)
- float_graph_def.node.extend([identity_node])
-
- mul_name = "mul"
- mul_node = quantize_graph.create_node("Mul", mul_name,
- [identity_name, identity_name])
- quantize_graph.set_attr_dtype(mul_node, "T", dtypes.float32)
- float_graph_def.node.extend([mul_node])
-
- test_graph(float_graph_def, {}, [mul_name])
-
- def test_keep_control_edges(self):
- no_op_name = "no_op"
- a_constant_name = "a_constant"
- b_constant_name = "b_constant"
- a_check_name = "a_check"
- b_check_name = "b_check"
- a_identity_name = "a_identity"
- b_identity_name = "b_identity"
- add_name = "add"
- graph_def = graph_pb2.GraphDef()
- no_op = quantize_graph.create_node("NoOp", no_op_name, [])
- graph_def.node.extend([no_op])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=1, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant])
- a_check_node = quantize_graph.create_node("CheckNumerics", a_check_name,
- [a_constant_name])
- graph_def.node.extend([a_check_node])
- a_identity_node = quantize_graph.create_node(
- "Identity", a_identity_name,
- [a_constant_name, "^" + a_check_name, "^" + no_op_name])
- graph_def.node.extend([a_identity_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=1, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant])
- b_check_node = quantize_graph.create_node("CheckNumerics", b_check_name,
- [b_constant_name])
- graph_def.node.extend([b_check_node])
- b_identity_node = quantize_graph.create_node(
- "Identity", b_identity_name, [b_constant_name, "^" + b_check_name])
- graph_def.node.extend([b_identity_node])
- add_node = quantize_graph.create_node("Add", add_name,
- [a_identity_name, b_identity_name])
- quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
- graph_def.node.extend([add_node])
-
- expected_output = graph_pb2.GraphDef()
- no_op = quantize_graph.create_node("NoOp", no_op_name, [])
- expected_output.node.extend([no_op])
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=1, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant])
- a_identity_node = quantize_graph.create_node(
- "Identity", a_identity_name, [a_constant_name, "^" + no_op_name])
- expected_output.node.extend([a_identity_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=1, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant])
- add_node = quantize_graph.create_node("Add", add_name,
- [a_identity_name, b_constant_name])
- quantize_graph.set_attr_dtype(add_node, "T", dtypes.float32)
- expected_output.node.extend([add_node])
- expected_output.versions.CopyFrom(graph_def.versions)
- expected_output.library.CopyFrom(graph_def.library)
-
- output = graph_util.remove_training_nodes(graph_def)
- stripped_output = graph_util.extract_sub_graph(output, [add_name])
- self.assertProtoEquals(expected_output, stripped_output)
-
- def test_batch_norm(self):
- input_constant_name = "input_constant"
- mean_constant_name = "mean_constant"
- variance_constant_name = "variance_constant"
- beta_constant_name = "beta_constant"
- gamma_constant_name = "gamma_constant"
- batch_norm_name = "batch_norm"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6],
- dtype=dtypes.float32,
- shape=[1, 1, 6, 2])
- float_graph_def.node.extend([input_constant])
- mean_constant = quantize_graph.create_constant_node(
- mean_constant_name, value=[10, 20], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([mean_constant])
- variance_constant = quantize_graph.create_constant_node(
- variance_constant_name,
- value=[0.25, 0.5],
- dtype=dtypes.float32,
- shape=[2])
- float_graph_def.node.extend([variance_constant])
- beta_constant = quantize_graph.create_constant_node(
- beta_constant_name, value=[0.1, 0.6], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([beta_constant])
- gamma_constant = quantize_graph.create_constant_node(
- gamma_constant_name, value=[0, 0], dtype=dtypes.float32, shape=[2])
- float_graph_def.node.extend([gamma_constant])
- batch_norm_node = quantize_graph.create_node(
- "BatchNormWithGlobalNormalization", batch_norm_name, [
- input_constant_name, mean_constant_name, variance_constant_name,
- beta_constant_name, gamma_constant_name
- ])
- quantize_graph.set_attr_dtype(batch_norm_node, "T", dtypes.float32)
- quantize_graph.set_attr_bool(batch_norm_node, "scale_after_normalization",
- False)
- quantize_graph.set_attr_float(batch_norm_node, "variance_epsilon", 0.001)
- float_graph_def.node.extend([batch_norm_node])
- test_graph(float_graph_def, {}, [batch_norm_name])
-
- def test_max_pool(self):
- input_constant_name = "input_constant"
- max_pool_name = "max_pool"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- max_pool_node = quantize_graph.create_node("MaxPool", max_pool_name,
- [input_constant_name])
- quantize_graph.set_attr_int_list(max_pool_node, "ksize", [1, 2, 2, 1])
- quantize_graph.set_attr_int_list(max_pool_node, "strides", [1, 1, 1, 1])
- quantize_graph.set_attr_string(max_pool_node, "padding", b"SAME")
- float_graph_def.node.extend([max_pool_node])
- test_graph(float_graph_def, {}, [max_pool_name])
-
- def test_avg_pool(self):
- input_constant_name = "input_constant"
- avg_pool_name = "avg_pool"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- avg_pool_node = quantize_graph.create_node("AvgPool", avg_pool_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(avg_pool_node, "T", dtypes.float32)
- quantize_graph.set_attr_int_list(avg_pool_node, "ksize", [1, 2, 2, 1])
- quantize_graph.set_attr_int_list(avg_pool_node, "strides", [1, 1, 1, 1])
- quantize_graph.set_attr_string(avg_pool_node, "padding", b"SAME")
- float_graph_def.node.extend([avg_pool_node])
- test_graph(float_graph_def, {}, [avg_pool_name])
-
- def test_relu(self):
- input_constant_name = "input_constant"
- relu_name = "relu"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- relu_node = quantize_graph.create_node("Relu", relu_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32)
- float_graph_def.node.extend([relu_node])
- test_graph(float_graph_def, {}, [relu_name])
-
- def test_relu_w_fake_quant_w_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- relu_node = quantize_graph.create_node("Relu", "relu", [input_node.name])
- quantize_graph.set_attr_dtype(relu_node, "T", dtypes.float32)
-
- min_node = quantize_graph.create_constant_node(
- "min_bias_add", value=0, dtype=dtypes.float32, shape=[])
- max_node = quantize_graph.create_constant_node(
- "max_bias_add", value=12, dtype=dtypes.float32, shape=[])
- fake_quant_node = quantize_graph.create_node(
- "FakeQuantWithMinMaxVars", "fake_quant",
- [relu_node.name, min_node.name, max_node.name])
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend(
- [input_node, relu_node, min_node, max_node, fake_quant_node])
- test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
-
- # Verify there is only one Quantize and one Requantize op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- def test_relu6(self):
- input_constant_name = "input_constant"
- relu6_name = "relu6"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 2, 6, 1])
- float_graph_def.node.extend([input_constant])
- relu6_node = quantize_graph.create_node("Relu6", relu6_name,
- [input_constant_name])
- quantize_graph.set_attr_dtype(relu6_node, "T", dtypes.float32)
- float_graph_def.node.extend([relu6_node])
- test_graph(float_graph_def, {}, [relu6_name])
-
- def test_bias_add(self):
- input_constant_name = "input_constant"
- offset_constant_name = "offset_constant"
- bias_add_name = "bias_add"
- float_graph_def = graph_pb2.GraphDef()
- input_constant = quantize_graph.create_constant_node(
- input_constant_name,
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 6])
- float_graph_def.node.extend([input_constant])
- offset_constant = quantize_graph.create_constant_node(
- offset_constant_name,
- value=[1, 2, 3, 4, 5, 6],
- dtype=dtypes.float32,
- shape=[6])
- float_graph_def.node.extend([offset_constant])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", bias_add_name, [input_constant_name, offset_constant_name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
- float_graph_def.node.extend([bias_add_node])
- test_graph(float_graph_def, {}, [bias_add_name])
-
- def test_quantized_input_range_errors(self):
- with self.assertRaises(ValueError):
- # Invalid mode.
- quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "weights_rounded",
- [0, 1])
- with self.assertRaises(ValueError):
- # Invalid range.
- quantize_graph.GraphRewriter(graph_pb2.GraphDef(), "eightbit", [0, -1])
-
- def test_quantized_input_range_bias_add(self):
- input_shape = [1, 1, 2, 6]
- input_n = quantize_graph.create_node("Placeholder", "input", [])
- quantize_graph.set_attr_dtype(input_n, "dtype", dtypes.float32)
- quantize_graph.set_attr_shape(input_n, "shape", input_shape)
- offset_n = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5, 6], dtype=dtypes.float32, shape=[6])
- bias_add_n = quantize_graph.create_node("BiasAdd", "bias_add",
- [input_n.name, offset_n.name])
- quantize_graph.set_attr_dtype(bias_add_n, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([input_n, offset_n, bias_add_n])
-
- input_map = {
- input_n.name + ":0":
- np.reshape([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], input_shape)
- }
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [bias_add_n.name], [-1, 20.])
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [bias_add_n.name], [0, 12.])
-
- def test_quantized_input_range_mat_mul(self):
- shapes = [[3, 2], [2, 4]]
- inputs = []
- for i, shape in enumerate(shapes):
- node = quantize_graph.create_node("Placeholder", "input_%s" % i, [])
- quantize_graph.set_attr_dtype(node, "dtype", dtypes.float32)
- quantize_graph.set_attr_shape(node, "shape", shape)
- inputs.append(node)
- mat_mul_node = quantize_graph.create_node("MatMul", "mat_mul",
- [n.name for n in inputs])
- quantize_graph.set_attr_dtype(mat_mul_node, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend(inputs + [mat_mul_node])
-
- input_map = {
- inputs[0].name + ":0":
- np.reshape([1, 2, 3, 4, 5, 6], shapes[0]),
- inputs[1].name + ":0":
- np.reshape([.8, .7, .6, .5, .4, .3, .2, .1], shapes[1])
- }
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [mat_mul_node.name], [-1, 20.])
- self._RunTestsForQuantizedInputRange(float_graph_def, input_map,
- [mat_mul_node.name], [0, 6.])
-
- def _RunTestsForQuantizedInputRange(self, float_graph_def, input_map,
- output_names, input_range):
- if sys.version_info[0] == 3:
- # uint8->quint8 conversion for numpy is not working currently.
- return
-
- quantized_input_map = {}
- for k, v in input_map.items():
- arr = [
- int(
- round((n - input_range[0]) * 255 / (input_range[1] - input_range[
- 0]))) for n in v.flat
- ]
- arr = np.array(arr, np.uint8)
- arr = arr.reshape(v.shape)
- arr = arr.astype(dtypes.quint8.as_numpy_dtype)
- quantized_input_map[k] = arr
- output_tensors = [output_name + ":0" for output_name in output_names]
- float_results = run_graph_def(float_graph_def, input_map, output_tensors)
-
- # Quantize treating the input as quantized in range <input_range>.
- rewriter = quantize_graph.GraphRewriter(float_graph_def, "eightbit",
- input_range)
- graph_def = rewriter.rewrite(output_names)
- results = run_graph_def(graph_def, quantized_input_map, output_tensors)
- for expected, result in zip(float_results, results):
- assert are_tensors_near(expected, result, .5)
- ops = [node.op for node in graph_def.node]
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(len(output_names), ops.count("Dequantize"))
-
- # Quantize without treating input as quantized.
- rewriter = quantize_graph.GraphRewriter(
- float_graph_def, "eightbit", quantized_input_range=None)
- graph_def = rewriter.rewrite(output_names)
- results = run_graph_def(graph_def, input_map, output_tensors)
- for expected, result in zip(float_results, results):
- assert are_tensors_near(expected, result, .5)
- ops = [node.op for node in graph_def.node]
- self.assertEqual(
- len(input_map), ops.count("QuantizeV2") + ops.count("Quantize"))
- self.assertEqual(len(output_names), ops.count("Dequantize"))
-
- def test_bias_add_w_fake_quant_w_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 5])
- offset_node = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", "bias_add", [input_node.name, offset_node.name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
-
- min_node = quantize_graph.create_constant_node(
- "min_bias_add", value=-.5, dtype=dtypes.float32, shape=[])
- max_node = quantize_graph.create_constant_node(
- "max_bias_add", value=15.5, dtype=dtypes.float32, shape=[])
- fake_quant_node = quantize_graph.create_node(
- "FakeQuantWithMinMaxVars", "fake_quant",
- [bias_add_node.name, min_node.name, max_node.name])
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([
- input_node, offset_node, bias_add_node, min_node, max_node,
- fake_quant_node
- ])
- test_graph(float_graph_def, {}, [fake_quant_node.name], log_graph=True)
-
- # Verify there is only one Quantize and one Requantize op.
- # Pass in fallback_quantization_range, although it will have no effect
- # because the FakeQuantWithMinMaxVars are used instead.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def,
- "eightbit",
- quantized_input_range=None,
- fallback_quantization_range=[-100, 100])
- eightbit_graph_def = eightbit_rewriter.rewrite([fake_quant_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- node_names = [node.name for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- # The fallback constants are not in the graph.
- self.assertEqual(0, node_names.count("fallback_quantization_min_value"))
- self.assertEqual(0, node_names.count("fallback_quantization_max_value"))
-
- def test_bias_add_w_fallback_min_max_vars(self):
- input_node = quantize_graph.create_constant_node(
- "input",
- value=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
- dtype=dtypes.float32,
- shape=[1, 1, 2, 5])
- offset_node = quantize_graph.create_constant_node(
- "offset", value=[1, 2, 3, 4, 5], dtype=dtypes.float32, shape=[5])
- bias_add_node = quantize_graph.create_node(
- "BiasAdd", "bias_add", [input_node.name, offset_node.name])
- quantize_graph.set_attr_dtype(bias_add_node, "T", dtypes.float32)
-
- float_graph_def = graph_pb2.GraphDef()
- float_graph_def.node.extend([input_node, offset_node, bias_add_node])
- test_graph(float_graph_def, {}, [bias_add_node.name], log_graph=True)
-
- # Verify there is only one Quantize, one Requantize op, and no
- # RequantizationRange op.
- eightbit_rewriter = quantize_graph.GraphRewriter(
- float_graph_def,
- "eightbit",
- quantized_input_range=None,
- fallback_quantization_range=[-.5, 15.5])
- eightbit_graph_def = eightbit_rewriter.rewrite([bias_add_node.name])
-
- ops = [node.op for node in eightbit_graph_def.node]
- node_names = [node.name for node in eightbit_graph_def.node]
- # No quantize since all inputs are const and can be quantized up-front.
- self.assertEqual(0, ops.count("QuantizeV2") + ops.count("Quantize"))
-
- # One dequantize at the end.
- self.assertEqual(1, ops.count("Dequantize"))
-
- # No RequantizationRange
- self.assertEqual(0, ops.count("RequantizationRange"))
-
- # The fallback constants are in the graph.
- self.assertEqual(1, node_names.count("fallback_quantization_min_value"))
- self.assertEqual(1, node_names.count("fallback_quantization_max_value"))
-
- def test_remove_redundant_quantization(self):
- a_constant_name = "a_constant"
- a_constant_min_name = "a_constant_min"
- a_constant_max_name = "a_constant_max"
- a_dequantize_name = "a_dequantize"
- a_quantize_name = "a_quantize"
- b_constant_name = "b_constant"
- b_constant_min_name = "b_constant_min"
- b_constant_max_name = "b_constant_max"
- b_dequantize_name = "b_dequantize"
- b_quantize_name = "b_quantize"
- mat_mul_name = "mat_mul"
- graph_def = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- graph_def.node.extend([a_constant])
- a_constant_min = quantize_graph.create_constant_node(
- a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant_min])
- a_constant_max = quantize_graph.create_constant_node(
- a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([a_constant_max])
- a_dequantize_node = quantize_graph.create_node(
- "Dequantize", a_dequantize_name,
- [a_constant_name, a_constant_min_name, a_constant_max_name])
- quantize_graph.set_attr_dtype(a_dequantize_node, "T", dtypes.uint8)
- graph_def.node.extend([a_dequantize_node])
- a_quantize_node = quantize_graph.create_node(
- "QuantizeV2", a_quantize_name,
- [a_dequantize_name, a_dequantize_name + ":1", a_dequantize_name + ":2"])
- quantize_graph.set_attr_dtype(a_quantize_node, "T", dtypes.uint8)
- graph_def.node.extend([a_quantize_node])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- graph_def.node.extend([b_constant])
- b_constant_min = quantize_graph.create_constant_node(
- b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant_min])
- b_constant_max = quantize_graph.create_constant_node(
- b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
- graph_def.node.extend([b_constant_max])
- b_dequantize_node = quantize_graph.create_node(
- "Dequantize", b_dequantize_name,
- [b_constant_name, b_constant_min_name, b_constant_max_name])
- quantize_graph.set_attr_dtype(b_dequantize_node, "T", dtypes.uint8)
- graph_def.node.extend([b_dequantize_node])
- b_quantize_node = quantize_graph.create_node(
- "QuantizeV2", b_quantize_name,
- [b_dequantize_name, b_dequantize_name + ":1", b_dequantize_name + ":2"])
- quantize_graph.set_attr_dtype(b_quantize_node, "T", dtypes.uint8)
- graph_def.node.extend([b_quantize_node])
- mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
- a_quantize_name, b_quantize_name, a_quantize_name + ":1",
- a_quantize_name + ":2", b_quantize_name + ":1", b_quantize_name + ":2"
- ])
- quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
- quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
- graph_def.node.extend([mat_mul_node])
-
- expected_output = graph_pb2.GraphDef()
- a_constant = quantize_graph.create_constant_node(
- a_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- expected_output.node.extend([a_constant])
- a_constant_min = quantize_graph.create_constant_node(
- a_constant_min_name, value=2, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant_min])
- a_constant_max = quantize_graph.create_constant_node(
- a_constant_max_name, value=2, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([a_constant_max])
- b_constant = quantize_graph.create_constant_node(
- b_constant_name, value=(0,), dtype=dtypes.quint8, shape=[])
- expected_output.node.extend([b_constant])
- b_constant_min = quantize_graph.create_constant_node(
- b_constant_min_name, value=3, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant_min])
- b_constant_max = quantize_graph.create_constant_node(
- b_constant_max_name, value=3, dtype=dtypes.float32, shape=[])
- expected_output.node.extend([b_constant_max])
- mat_mul_node = quantize_graph.create_node("QuantizedMatMul", mat_mul_name, [
- a_constant_name, b_constant_name, a_constant_min_name,
- a_constant_max_name, b_constant_min_name, b_constant_max_name
- ])
- quantize_graph.set_attr_dtype(mat_mul_node, "T1", dtypes.uint8)
- quantize_graph.set_attr_dtype(mat_mul_node, "T2", dtypes.int32)
- expected_output.node.extend([mat_mul_node])
- expected_output.versions.CopyFrom(graph_def.versions)
- expected_output.library.CopyFrom(graph_def.library)
-
- rewriter = quantize_graph.GraphRewriter(
- graph_def, [mat_mul_name], quantized_input_range=None)
- output = rewriter.remove_redundant_quantization(graph_def)
- stripped_output = graph_util.extract_sub_graph(output, [mat_mul_name])
- self.assertProtoEquals(expected_output, stripped_output)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index b850c5a17f..9b4b698874 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,6 +1,7 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("//third_party:nccl/nccl_configure.bzl", "nccl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
@@ -20,9 +21,11 @@ load(
"def_file_filter_configure",
)
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
+load("//third_party/icu:workspace.bzl", icu = "repo")
def initialize_third_party():
flatbuffers()
+ icu()
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
@@ -43,6 +46,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
sycl_configure(name = "local_config_sycl")
syslibs_configure(name = "local_config_syslibs")
python_configure(name = "local_config_python")
+ rocm_configure(name = "local_config_rocm")
initialize_third_party()
@@ -53,39 +57,39 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# Point //external/local_config_arm_compiler to //external/arm_compiler
arm_compiler_configure(
name = "local_config_arm_compiler",
- remote_config_repo = "../arm_compiler",
build_file = clean_dep("//third_party/toolchains/cpus/arm:BUILD"),
+ remote_config_repo = "../arm_compiler",
)
mkl_repository(
name = "mkl_linux",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
+ strip_prefix = "mklml_lnx_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
],
- sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
- strip_prefix = "mklml_lnx_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_windows",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
+ strip_prefix = "mklml_win_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
],
- sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
- strip_prefix = "mklml_win_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
mkl_repository(
name = "mkl_darwin",
+ build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
+ sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
+ strip_prefix = "mklml_mac_2019.0.20180710",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
"https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
],
- sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
- strip_prefix = "mklml_mac_2019.0.20180710",
- build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
)
if path_prefix:
@@ -94,39 +98,40 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "mkl_dnn",
+ build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
+ sha256 = "363cc9239eacf8e7917753c6d8c94f767e4cd049160d0654a61ef32d5e1b3049",
+ strip_prefix = "mkl-dnn-4e333787e0d66a1dca1218e99a891d493dbc8ef1",
urls = [
"https://mirror.bazel.build/github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
"https://github.com/intel/mkl-dnn/archive/4e333787e0d66a1dca1218e99a891d493dbc8ef1.tar.gz",
],
- sha256 = "363cc9239eacf8e7917753c6d8c94f767e4cd049160d0654a61ef32d5e1b3049",
- strip_prefix = "mkl-dnn-4e333787e0d66a1dca1218e99a891d493dbc8ef1",
- build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"),
)
tf_http_archive(
name = "com_google_absl",
+ build_file = clean_dep("//third_party:com_google_absl.BUILD"),
+ sha256 = "7dd09690ae7ca4551de3111d4a86b75b23ec17445f273d3c42bdcdc1c7b02e4e",
+ strip_prefix = "abseil-cpp-48cd2c3f351ff188bc85684b84a91b6e6d17d896",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/48cd2c3f351ff188bc85684b84a91b6e6d17d896.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/48cd2c3f351ff188bc85684b84a91b6e6d17d896.tar.gz",
],
- sha256 = "84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e",
- strip_prefix = "abseil-cpp-e01d95528ea2137a4a27a88d1f57c6cb260aafed",
- build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)
tf_http_archive(
name = "eigen_archive",
+ build_file = clean_dep("//third_party:eigen.BUILD"),
+ sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
+ strip_prefix = "eigen-eigen-fd6845384b86",
urls = [
"https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
"https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
],
- sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
- strip_prefix = "eigen-eigen-fd6845384b86",
- build_file = clean_dep("//third_party:eigen.BUILD"),
)
tf_http_archive(
name = "arm_compiler",
+ build_file = clean_dep("//:arm_compiler.BUILD"),
sha256 = "970285762565c7890c6c087d262b0a18286e7d0384f13a37786d8521773bc969",
strip_prefix = "tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf",
urls = [
@@ -135,216 +140,211 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# remove the whitelist entry in third_party/repo.bzl.
# "https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz",
],
- build_file = clean_dep("//:arm_compiler.BUILD"),
)
tf_http_archive(
name = "libxsmm_archive",
+ build_file = clean_dep("//third_party:libxsmm.BUILD"),
+ sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
+ strip_prefix = "libxsmm-1.9",
urls = [
"https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz",
"https://github.com/hfp/libxsmm/archive/1.9.tar.gz",
],
- sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
- strip_prefix = "libxsmm-1.9",
- build_file = clean_dep("//third_party:libxsmm.BUILD"),
)
tf_http_archive(
name = "ortools_archive",
+ build_file = clean_dep("//third_party:ortools.BUILD"),
+ sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
+ strip_prefix = "or-tools-6.7.2/src",
urls = [
"https://mirror.bazel.build/github.com/google/or-tools/archive/v6.7.2.tar.gz",
"https://github.com/google/or-tools/archive/v6.7.2.tar.gz",
],
- sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
- strip_prefix = "or-tools-6.7.2/src",
- build_file = clean_dep("//third_party:ortools.BUILD"),
)
tf_http_archive(
name = "com_googlesource_code_re2",
+ sha256 = "803c7811146edeef8f91064de37c6f19136ff01a2a8cdb3230e940b2fd9f07fe",
+ strip_prefix = "re2-2018-07-01",
+ system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/re2/archive/2018-07-01.tar.gz",
"https://github.com/google/re2/archive/2018-07-01.tar.gz",
],
- sha256 = "803c7811146edeef8f91064de37c6f19136ff01a2a8cdb3230e940b2fd9f07fe",
- strip_prefix = "re2-2018-07-01",
- system_build_file = clean_dep("//third_party/systemlibs:re2.BUILD"),
)
tf_http_archive(
name = "com_github_googlecloudplatform_google_cloud_cpp",
- urls = [
- "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
- "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
- ],
sha256 = "fdd3b3aecce60987e5525e55bf3a21d68a8695320bd5b980775af6507eec3944",
strip_prefix = "google-cloud-cpp-14760a86c4ffab9943b476305c4fe927ad95db1c",
system_build_file = clean_dep("//third_party/systemlibs:google_cloud_cpp.BUILD"),
system_link_files = {
"//third_party/systemlibs:google_cloud_cpp.google.cloud.bigtable.BUILD": "google/cloud/bigtable/BUILD",
},
+ urls = [
+ "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
+ "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
+ ],
)
tf_http_archive(
name = "com_github_googleapis_googleapis",
+ build_file = clean_dep("//third_party:googleapis.BUILD"),
+ sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
+ strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
+ system_build_file = clean_dep("//third_party/systemlibs:googleapis.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
"https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
],
- sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
- strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
- build_file = clean_dep("//third_party:googleapis.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:googleapis.BUILD"),
)
tf_http_archive(
name = "gemmlowp",
+ sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
+ strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
urls = [
"https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
"https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
],
- sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658",
- strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98",
)
tf_http_archive(
name = "farmhash_archive",
+ build_file = clean_dep("//third_party:farmhash.BUILD"),
+ sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0",
+ strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45",
urls = [
"https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
"https://github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz",
],
- sha256 = "6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0",
- strip_prefix = "farmhash-816a4ae622e964763ca0862d9dbd19324a1eaf45",
- build_file = clean_dep("//third_party:farmhash.BUILD"),
)
tf_http_archive(
name = "highwayhash",
+ build_file = clean_dep("//third_party:highwayhash.BUILD"),
+ sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37",
+ strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968",
urls = [
"http://mirror.bazel.build/github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
"https://github.com/google/highwayhash/archive/fd3d9af80465e4383162e4a7c5e2f406e82dd968.tar.gz",
],
- sha256 = "9c3e0e87d581feeb0c18d814d98f170ff23e62967a2bd6855847f0b2fe598a37",
- strip_prefix = "highwayhash-fd3d9af80465e4383162e4a7c5e2f406e82dd968",
- build_file = clean_dep("//third_party:highwayhash.BUILD"),
)
tf_http_archive(
name = "nasm",
+ build_file = clean_dep("//third_party:nasm.BUILD"),
+ sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
+ strip_prefix = "nasm-2.13.03",
+ system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
urls = [
"https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
"http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
"http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
],
- sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
- strip_prefix = "nasm-2.13.03",
- build_file = clean_dep("//third_party:nasm.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
)
tf_http_archive(
name = "jpeg",
+ build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
+ sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
+ strip_prefix = "libjpeg-turbo-2.0.0",
+ system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
"https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
],
- sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
- strip_prefix = "libjpeg-turbo-2.0.0",
- build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
)
tf_http_archive(
name = "png_archive",
+ build_file = clean_dep("//third_party:png.BUILD"),
+ patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
+ sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
+ strip_prefix = "libpng-1.6.34",
+ system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
"https://github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
],
- sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
- strip_prefix = "libpng-1.6.34",
- build_file = clean_dep("//third_party:png.BUILD"),
- patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
- system_build_file = clean_dep("//third_party/systemlibs:png.BUILD"),
)
tf_http_archive(
name = "org_sqlite",
+ build_file = clean_dep("//third_party:sqlite.BUILD"),
+ sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
+ strip_prefix = "sqlite-amalgamation-3240000",
+ system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
urls = [
"https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
"https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
],
- sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
- strip_prefix = "sqlite-amalgamation-3240000",
- build_file = clean_dep("//third_party:sqlite.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:sqlite.BUILD"),
)
tf_http_archive(
name = "gif_archive",
+ build_file = clean_dep("//third_party:gif.BUILD"),
+ sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
+ strip_prefix = "giflib-5.1.4",
+ system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"),
urls = [
"https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
"http://pilotfiber.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz",
],
- sha256 = "34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1",
- strip_prefix = "giflib-5.1.4",
- build_file = clean_dep("//third_party:gif.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:gif.BUILD"),
)
tf_http_archive(
name = "six_archive",
+ build_file = clean_dep("//third_party:six.BUILD"),
+ sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
+ strip_prefix = "six-1.10.0",
+ system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
"https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz",
],
- sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a",
- strip_prefix = "six-1.10.0",
- build_file = clean_dep("//third_party:six.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"),
)
tf_http_archive(
name = "astor_archive",
+ build_file = clean_dep("//third_party:astor.BUILD"),
+ sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d",
+ strip_prefix = "astor-0.6.2",
+ system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
"https://pypi.python.org/packages/d8/be/c4276b3199ec3feee2a88bc64810fbea8f26d961e0a4cd9c68387a9f35de/astor-0.6.2.tar.gz",
],
- sha256 = "ff6d2e2962d834acb125cc4dcc80c54a8c17c253f4cc9d9c43b5102a560bb75d",
- strip_prefix = "astor-0.6.2",
- build_file = clean_dep("//third_party:astor.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:astor.BUILD"),
)
tf_http_archive(
name = "gast_archive",
+ build_file = clean_dep("//third_party:gast.BUILD"),
+ sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
+ strip_prefix = "gast-0.2.0",
+ system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
"https://pypi.python.org/packages/5c/78/ff794fcae2ce8aa6323e789d1f8b3b7765f601e7702726f430e814822b96/gast-0.2.0.tar.gz",
],
- sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
- strip_prefix = "gast-0.2.0",
- build_file = clean_dep("//third_party:gast.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
)
tf_http_archive(
name = "termcolor_archive",
+ build_file = clean_dep("//third_party:termcolor.BUILD"),
+ sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b",
+ strip_prefix = "termcolor-1.1.0",
+ system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"),
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
"https://pypi.python.org/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz",
],
- sha256 = "1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b",
- strip_prefix = "termcolor-1.1.0",
- build_file = clean_dep("//third_party:termcolor.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:termcolor.BUILD"),
)
tf_http_archive(
name = "absl_py",
- urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
- "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
- ],
sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
strip_prefix = "abseil-py-pypi-v0.2.2",
system_build_file = clean_dep("//third_party/systemlibs:absl_py.BUILD"),
@@ -352,17 +352,21 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
"//third_party/systemlibs:absl_py.absl.flags.BUILD": "absl/flags/BUILD",
"//third_party/systemlibs:absl_py.absl.testing.BUILD": "absl/testing/BUILD",
},
+ urls = [
+ "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
+ "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
+ ],
)
tf_http_archive(
name = "org_python_pypi_backports_weakref",
+ build_file = clean_dep("//third_party:backports_weakref.BUILD"),
+ sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892",
+ strip_prefix = "backports.weakref-1.0rc1/src",
urls = [
"https://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
"https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz",
],
- sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892",
- strip_prefix = "backports.weakref-1.0rc1/src",
- build_file = clean_dep("//third_party:backports_weakref.BUILD"),
)
filegroup_external(
@@ -385,9 +389,9 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "protobuf_archive",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
# We need to import the protobuf library under the names com_google_protobuf
@@ -395,222 +399,222 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
# Unfortunately there is no way to alias http_archives at the moment.
tf_http_archive(
name = "com_google_protobuf",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
tf_http_archive(
name = "com_google_protobuf_cc",
- urls = PROTOBUF_URLS,
sha256 = PROTOBUF_SHA256,
strip_prefix = PROTOBUF_STRIP_PREFIX,
+ urls = PROTOBUF_URLS,
)
tf_http_archive(
name = "nsync",
+ sha256 = "692f9b30e219f71a6371b98edd39cef3cbda35ac3abc4cd99ce19db430a5591a",
+ strip_prefix = "nsync-1.20.1",
+ system_build_file = clean_dep("//third_party/systemlibs:nsync.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/nsync/archive/1.20.1.tar.gz",
"https://github.com/google/nsync/archive/1.20.1.tar.gz",
],
- sha256 = "692f9b30e219f71a6371b98edd39cef3cbda35ac3abc4cd99ce19db430a5591a",
- strip_prefix = "nsync-1.20.1",
- system_build_file = clean_dep("//third_party/systemlibs:nsync.BUILD"),
)
tf_http_archive(
name = "com_google_googletest",
+ sha256 = "353ab86e35cea1cd386115279cf4b16695bbf21b897bfbf2721cf4cb5f64ade8",
+ strip_prefix = "googletest-997d343dd680e541ef96ce71ee54a91daf2577a0",
urls = [
"https://mirror.bazel.build/github.com/google/googletest/archive/997d343dd680e541ef96ce71ee54a91daf2577a0.zip",
"https://github.com/google/googletest/archive/997d343dd680e541ef96ce71ee54a91daf2577a0.zip",
],
- sha256 = "353ab86e35cea1cd386115279cf4b16695bbf21b897bfbf2721cf4cb5f64ade8",
- strip_prefix = "googletest-997d343dd680e541ef96ce71ee54a91daf2577a0",
)
tf_http_archive(
name = "com_github_gflags_gflags",
+ sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
+ strip_prefix = "gflags-2.2.1",
urls = [
"https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz",
"https://github.com/gflags/gflags/archive/v2.2.1.tar.gz",
],
- sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
- strip_prefix = "gflags-2.2.1",
)
tf_http_archive(
name = "pcre",
+ build_file = clean_dep("//third_party:pcre.BUILD"),
sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5",
+ strip_prefix = "pcre-8.42",
+ system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"),
urls = [
"https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
"http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
],
- strip_prefix = "pcre-8.42",
- build_file = clean_dep("//third_party:pcre.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:pcre.BUILD"),
)
tf_http_archive(
name = "swig",
+ build_file = clean_dep("//third_party:swig.BUILD"),
sha256 = "58a475dbbd4a4d7075e5fe86d4e54c9edde39847cdb96a3053d87cb64a23a453",
+ strip_prefix = "swig-3.0.8",
+ system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"),
urls = [
"https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
"http://ufpr.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
"http://pilotfiber.dl.sourceforge.net/project/swig/swig/swig-3.0.8/swig-3.0.8.tar.gz",
],
- strip_prefix = "swig-3.0.8",
- build_file = clean_dep("//third_party:swig.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:swig.BUILD"),
)
tf_http_archive(
name = "curl",
+ build_file = clean_dep("//third_party:curl.BUILD"),
sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5",
+ strip_prefix = "curl-7.60.0",
+ system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
urls = [
"https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz",
"https://curl.haxx.se/download/curl-7.60.0.tar.gz",
],
- strip_prefix = "curl-7.60.0",
- build_file = clean_dep("//third_party:curl.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:curl.BUILD"),
)
tf_http_archive(
name = "grpc",
+ sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
+ strip_prefix = "grpc-1.13.0",
+ system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.13.0.tar.gz",
"https://github.com/grpc/grpc/archive/v1.13.0.tar.gz",
],
- sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
- strip_prefix = "grpc-1.13.0",
- system_build_file = clean_dep("//third_party/systemlibs:grpc.BUILD"),
)
tf_http_archive(
name = "linenoise",
+ build_file = clean_dep("//third_party:linenoise.BUILD"),
sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7",
+ strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3",
urls = [
"https://mirror.bazel.build/github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
"https://github.com/antirez/linenoise/archive/c894b9e59f02203dbe4e2be657572cf88c4230c3.tar.gz",
],
- strip_prefix = "linenoise-c894b9e59f02203dbe4e2be657572cf88c4230c3",
- build_file = clean_dep("//third_party:linenoise.BUILD"),
)
# TODO(phawkins): currently, this rule uses an unofficial LLVM mirror.
# Switch to an official source of snapshots if/when possible.
tf_http_archive(
name = "llvm",
+ build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
+ sha256 = "a4f8bfe7e3e69069934a87e612a1d4d3b8b6af13e0f1213a42a6046e1bcd50d8",
+ strip_prefix = "llvm-d3429e96fe1e45b1dc0106463832523f37faf271",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/db98902adc6431c9cc4ddec50fe174cfc9e626d6.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/db98902adc6431c9cc4ddec50fe174cfc9e626d6.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d3429e96fe1e45b1dc0106463832523f37faf271.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/d3429e96fe1e45b1dc0106463832523f37faf271.tar.gz",
],
- sha256 = "8c02d312b3d417cf9bc7e58ff53c2528bf77a5d839ce4a23b95bd04b9e5da023",
- strip_prefix = "llvm-db98902adc6431c9cc4ddec50fe174cfc9e626d6",
- build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
tf_http_archive(
name = "lmdb",
+ build_file = clean_dep("//third_party:lmdb.BUILD"),
+ sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
+ strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
+ system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
"https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
],
- sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
- strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
- build_file = clean_dep("//third_party:lmdb.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:lmdb.BUILD"),
)
tf_http_archive(
name = "jsoncpp_git",
+ build_file = clean_dep("//third_party:jsoncpp.BUILD"),
+ sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
+ strip_prefix = "jsoncpp-1.8.4",
+ system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
"https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
],
- sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
- strip_prefix = "jsoncpp-1.8.4",
- build_file = clean_dep("//third_party:jsoncpp.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jsoncpp.BUILD"),
)
tf_http_archive(
name = "boringssl",
+ sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
+ strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
+ system_build_file = clean_dep("//third_party/systemlibs:boringssl.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
"https://github.com/google/boringssl/archive/7f634429a04abc48e2eb041c81c5235816c96514.tar.gz",
],
- sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
- strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
- system_build_file = clean_dep("//third_party/systemlibs:boringssl.BUILD"),
)
tf_http_archive(
name = "zlib_archive",
+ build_file = clean_dep("//third_party:zlib.BUILD"),
+ sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
+ strip_prefix = "zlib-1.2.11",
+ system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"),
urls = [
"https://mirror.bazel.build/zlib.net/zlib-1.2.11.tar.gz",
"https://zlib.net/zlib-1.2.11.tar.gz",
],
- sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1",
- strip_prefix = "zlib-1.2.11",
- build_file = clean_dep("//third_party:zlib.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:zlib.BUILD"),
)
tf_http_archive(
name = "fft2d",
+ build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
+ sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
urls = [
"https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz",
],
- sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296",
- build_file = clean_dep("//third_party/fft2d:fft2d.BUILD"),
)
tf_http_archive(
name = "snappy",
+ build_file = clean_dep("//third_party:snappy.BUILD"),
+ sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
+ strip_prefix = "snappy-1.1.7",
+ system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/snappy/archive/1.1.7.tar.gz",
"https://github.com/google/snappy/archive/1.1.7.tar.gz",
],
- sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4",
- strip_prefix = "snappy-1.1.7",
- build_file = clean_dep("//third_party:snappy.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"),
)
tf_http_archive(
name = "nccl_archive",
+ build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
+ sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
+ strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
urls = [
"https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
"https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
],
- sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
- strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
- build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
)
tf_http_archive(
name = "kafka",
+ build_file = clean_dep("//third_party:kafka/BUILD"),
+ patch_file = clean_dep("//third_party/kafka:config.patch"),
+ sha256 = "cc6ebbcd0a826eec1b8ce1f625ffe71b53ef3290f8192b6cae38412a958f4fd3",
+ strip_prefix = "librdkafka-0.11.5",
urls = [
"https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
"https://github.com/edenhill/librdkafka/archive/v0.11.5.tar.gz",
],
- sha256 = "cc6ebbcd0a826eec1b8ce1f625ffe71b53ef3290f8192b6cae38412a958f4fd3",
- strip_prefix = "librdkafka-0.11.5",
- build_file = clean_dep("//third_party:kafka/BUILD"),
- patch_file = clean_dep("//third_party/kafka:config.patch"),
)
tf_http_archive(
name = "aws",
+ build_file = clean_dep("//third_party:aws.BUILD"),
+ sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c",
+ strip_prefix = "aws-sdk-cpp-1.3.15",
urls = [
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
"https://github.com/aws/aws-sdk-cpp/archive/1.3.15.tar.gz",
],
- sha256 = "b888d8ce5fc10254c3dd6c9020c7764dd53cf39cf011249d0b4deda895de1b7c",
- strip_prefix = "aws-sdk-cpp-1.3.15",
- build_file = clean_dep("//third_party:aws.BUILD"),
)
java_import_external(
@@ -640,14 +644,14 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "jemalloc",
+ build_file = clean_dep("//third_party:jemalloc.BUILD"),
+ sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
+ strip_prefix = "jemalloc-4.4.0",
+ system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
"https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
],
- sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
- strip_prefix = "jemalloc-4.4.0",
- build_file = clean_dep("//third_party:jemalloc.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
)
java_import_external(
@@ -696,196 +700,196 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_pprof",
+ build_file = clean_dep("//third_party:pprof.BUILD"),
+ sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4",
+ strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
urls = [
"https://mirror.bazel.build/github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
"https://github.com/google/pprof/archive/c0fb62ec88c411cc91194465e54db2632845b650.tar.gz",
],
- sha256 = "e0928ca4aa10ea1e0551e2d7ce4d1d7ea2d84b2abbdef082b0da84268791d0c4",
- strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650",
- build_file = clean_dep("//third_party:pprof.BUILD"),
)
tf_http_archive(
name = "cub_archive",
+ build_file = clean_dep("//third_party:cub.BUILD"),
+ sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
+ strip_prefix = "cub-1.8.0",
urls = [
"https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip",
"https://github.com/NVlabs/cub/archive/1.8.0.zip",
],
- sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
- strip_prefix = "cub-1.8.0",
- build_file = clean_dep("//third_party:cub.BUILD"),
)
tf_http_archive(
name = "cython",
+ build_file = clean_dep("//third_party:cython.BUILD"),
+ delete = ["BUILD.bazel"],
sha256 = "bccc9aa050ea02595b2440188813b936eaf345e85fb9692790cecfe095cf91aa",
+ strip_prefix = "cython-0.28.4",
+ system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/cython/cython/archive/0.28.4.tar.gz",
"https://github.com/cython/cython/archive/0.28.4.tar.gz",
],
- strip_prefix = "cython-0.28.4",
- build_file = clean_dep("//third_party:cython.BUILD"),
- delete = ["BUILD.bazel"],
- system_build_file = clean_dep("//third_party/systemlibs:cython.BUILD"),
)
tf_http_archive(
name = "bazel_toolchains",
+ sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
+ strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
"https://github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
],
- strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
- sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
)
tf_http_archive(
name = "arm_neon_2_x86_sse",
+ build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5",
strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d",
urls = [
"https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
"https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
],
- build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
)
tf_http_archive(
name = "double_conversion",
+ build_file = clean_dep("//third_party:double_conversion.BUILD"),
+ sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
+ strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
+ system_build_file = clean_dep("//third_party/systemlibs:double_conversion.BUILD"),
urls = [
"https://mirror.bazel.build/github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
"https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
],
- sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
- strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
- build_file = clean_dep("//third_party:double_conversion.BUILD"),
- system_build_file = clean_dep("//third_party/systemlibs:double_conversion.BUILD"),
)
tf_http_archive(
name = "tflite_mobilenet",
+ build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
sha256 = "23f814d1c076bdf03715dfb6cab3713aa4fbdf040fd5448c43196bd2e97a4c1b",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip",
],
- build_file = clean_dep("//third_party:tflite_mobilenet.BUILD"),
)
tf_http_archive(
name = "tflite_mobilenet_ssd",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "767057f2837a46d97882734b03428e8dd640b93236052b312b2f0e45613c1cf0",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_mobilenet_ssd_quant",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_mobilenet_ssd_quant_protobuf",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "09280972c5777f1aa775ef67cb4ac5d5ed21970acd8535aeca62450ef14f0d79",
+ strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
"http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
],
- strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_conv_actions_frozen",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/conv_actions_tflite.zip",
],
- build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
tf_http_archive(
name = "tflite_smartreply",
+ build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
sha256 = "8980151b85a87a9c1a3bb1ed4748119e4a85abd3cb5744d83da4d4bd0fbeef7c",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip",
],
- build_file = clean_dep("//third_party:tflite_smartreply.BUILD"),
)
tf_http_archive(
name = "tflite_ovic_testdata",
+ build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
sha256 = "a9a705d8d519220178e2e65d383fdb21da37fdb31d1e909b0a1acdac46479e9c",
+ strip_prefix = "ovic",
urls = [
"https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
"https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip",
],
- build_file = clean_dep("//third_party:tflite_ovic_testdata.BUILD"),
- strip_prefix = "ovic",
)
tf_http_archive(
name = "build_bazel_rules_android",
sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
+ strip_prefix = "rules_android-0.1.1",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
"https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
],
- strip_prefix = "rules_android-0.1.1",
)
tf_http_archive(
name = "tbb",
+ build_file = clean_dep("//third_party/ngraph:tbb.BUILD"),
+ sha256 = "724686f90bcda78f13b76f297d964008737ccd6399328143c1c0093e73ae6a13",
+ strip_prefix = "tbb-tbb_2018",
urls = [
"https://mirror.bazel.build/github.com/01org/tbb/archive/tbb_2018.zip",
"https://github.com/01org/tbb/archive/tbb_2018.zip",
],
- sha256 = "724686f90bcda78f13b76f297d964008737ccd6399328143c1c0093e73ae6a13",
- strip_prefix = "tbb-tbb_2018",
- build_file = clean_dep("//third_party/ngraph:tbb.BUILD"),
)
tf_http_archive(
name = "ngraph",
+ build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
+ sha256 = "bf9dcc88e5c66021e3aac80491a231711211540d613bf9b6bd28db3f5bb86b62",
+ strip_prefix = "ngraph-0.8.1",
urls = [
- "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.7.0.tar.gz",
- "https://github.com/NervanaSystems/ngraph/archive/v0.7.0.tar.gz",
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.8.1.tar.gz",
+ "https://github.com/NervanaSystems/ngraph/archive/v0.8.1.tar.gz",
],
- sha256 = "34434b6d5993ac5233538c84f498840db7ac91df82e225c379ee7c8f6de644a5",
- strip_prefix = "ngraph-0.7.0",
- build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
)
tf_http_archive(
name = "nlohmann_json_lib",
+ build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
+ sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
+ strip_prefix = "json-3.1.1",
urls = [
"https://mirror.bazel.build/github.com/nlohmann/json/archive/v3.1.1.tar.gz",
"https://github.com/nlohmann/json/archive/v3.1.1.tar.gz",
],
- sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
- strip_prefix = "json-3.1.1",
- build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
)
tf_http_archive(
name = "ngraph_tf",
+ build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
+ sha256 = "402f84c748c113780a60f35f39aab118435285543aee4900d712b76fbf8a21ee",
+ strip_prefix = "ngraph-tf-0.6.1",
urls = [
- "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.5.0.tar.gz",
- "https://github.com/NervanaSystems/ngraph-tf/archive/v0.5.0.tar.gz",
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.6.1.tar.gz",
+ "https://github.com/NervanaSystems/ngraph-tf/archive/v0.6.1.tar.gz",
],
- sha256 = "23b4566d8e40d6f1f236b0ffe3905dd964ae42ca54bacff67f24abcefd443afb",
- strip_prefix = "ngraph-tf-0.5.0",
- build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
)
##############################################################################
diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel
index 934c0d9650..d0be482fda 100644
--- a/third_party/flatbuffers/BUILD.bazel
+++ b/third_party/flatbuffers/BUILD.bazel
@@ -108,11 +108,14 @@ cc_binary(
"grpc/src/compiler/schema_interface.h",
"src/flatc_main.cpp",
"src/idl_gen_cpp.cpp",
+ "src/idl_gen_dart.cpp",
"src/idl_gen_general.cpp",
"src/idl_gen_go.cpp",
"src/idl_gen_grpc.cpp",
"src/idl_gen_js.cpp",
"src/idl_gen_json_schema.cpp",
+ "src/idl_gen_lobster.cpp",
+ "src/idl_gen_lua.cpp",
"src/idl_gen_php.cpp",
"src/idl_gen_python.cpp",
"src/idl_gen_text.cpp",
diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl
index 3aeef96a72..7613767fc4 100644
--- a/third_party/flatbuffers/workspace.bzl
+++ b/third_party/flatbuffers/workspace.bzl
@@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
def repo():
third_party_http_archive(
name = "flatbuffers",
- strip_prefix = "flatbuffers-1.9.0",
- sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3",
+ strip_prefix = "flatbuffers-1f5eae5d6a135ff6811724f6c57f911d1f46bb15",
+ sha256 = "b2bb0311ca40b12ebe36671bdda350b10c7728caf0cfe2d432ea3b6e409016f3",
urls = [
- "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
- "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
+ "https://mirror.bazel.build/github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz",
+ "https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz",
],
build_file = "//third_party/flatbuffers:BUILD.bazel",
system_build_file = "//third_party/flatbuffers:BUILD.system",
diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl
index f638756d23..c8812fab33 100644
--- a/third_party/gpus/crosstool/BUILD.tpl
+++ b/third_party/gpus/crosstool/BUILD.tpl
@@ -2,6 +2,20 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+toolchain(
+ name = "toolchain-linux-x86_64",
+ exec_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ target_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ toolchain = ":cc-compiler-local",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
cc_toolchain_suite(
name = "toolchain",
toolchains = {
diff --git a/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl b/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
new file mode 100644
index 0000000000..0e175b3ef6
--- /dev/null
+++ b/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
@@ -0,0 +1,158 @@
+major_version: "local"
+minor_version: ""
+default_target_cpu: "same_as_host"
+
+default_toolchain {
+ cpu: "k8"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "piii"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "arm"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "ppc"
+ toolchain_identifier: "local_linux"
+}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ builtin_sysroot: ""
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ supports_gold_linker: false
+ supports_incremental_linker: false
+ supports_fission: false
+ supports_interface_shared_objects: false
+ supports_normalizing_ar: false
+ supports_start_end_lib: false
+ supports_thin_archives: false
+ target_libc: "local"
+ target_cpu: "local"
+ target_system_name: "local"
+ toolchain_identifier: "local_linux"
+
+ tool_path { name: "ar" path: "/usr/bin/ar" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ # As part of the TensorFlow release, we place some ROCm-related compilation
+ # files in @local_config_rocm//crosstool/clang/bin, and this relative
+ # path, combined with the rest of our Bazel configuration causes our
+ # compilation to use those files.
+ tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_rocm" }
+ # Use "-std=c++11" for hipcc. For consistency, force both the host compiler
+ # and the device compiler to use "-std=c++11".
+ cxx_flag: "-std=c++11"
+ linker_flag: "-Wl,-no-as-needed"
+ linker_flag: "-lstdc++"
+ #linker_flag: "-B/usr/bin/"
+ linker_flag: "-B/opt/rocm/hcc/compiler/bin"
+
+%{host_compiler_includes}
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+
+ # C(++) compiles invoke the compiler (as that is the one knowing where
+ # to find libraries), but we provide LD so other rules can invoke the linker.
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ objcopy_embed_flag: "-I"
+ objcopy_embed_flag: "binary"
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Anticipated future default.
+ unfiltered_cxx_flag: "-no-canonical-prefixes"
+
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ unfiltered_cxx_flag: "-Wno-builtin-macro-redefined"
+ unfiltered_cxx_flag: "-D__DATE__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIME__=\"redacted\""
+ unfiltered_cxx_flag: "-D__HIP_PLATFORM_HCC__"
+ # The macro EIGEN_USE_HIP is used to tell Eigen to use the HIP platform headers
+ # It needs to be always set when compiling Eigen headers
+ # (irrespective of whether the source file is being compiled via HIPCC)
+ # so adding -DEIGEN_USE_HIP as a default CXX flag here
+ unfiltered_cxx_flag: "-DEIGEN_USE_HIP"
+
+
+ # Security hardening on by default.
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now have
+ # it enabled by default.
+ #compiler_flag: "-U_FORTIFY_SOURCE"
+ #compiler_flag: "-D_FORTIFY_SOURCE=1"
+ #compiler_flag: "-fstack-protector"
+ #compiler_flag: "-fPIE"
+ #linker_flag: "-pie"
+ #linker_flag: "-Wl,-z,relro,-z,now"
+
+ # Enable coloring even if there's no attached terminal. Bazel removes the
+ # escape sequences if --nocolor is specified. This isn't supported by gcc
+ # on Ubuntu 14.04.
+ # compiler_flag: "-fcolor-diagnostics"
+
+ # All warnings are enabled. Maybe enable -Werror as well?
+ compiler_flag: "-Wall"
+ # Enable a few more warnings that aren't part of -Wall.
+ compiler_flag: "-Wunused-but-set-parameter"
+ # But disable some that are problematic.
+ compiler_flag: "-Wno-free-nonheap-object" # has false positives
+
+ # Keep stack frames for debugging, even in opt mode.
+ compiler_flag: "-fno-omit-frame-pointer"
+
+ # Anticipated future default.
+ linker_flag: "-no-canonical-prefixes"
+ unfiltered_cxx_flag: "-fno-canonical-system-headers"
+ # Have gcc return the exit code from ld.
+ linker_flag: "-pass-exit-codes"
+ # Stamp the binary with a unique identifier.
+ linker_flag: "-Wl,--build-id=md5"
+ linker_flag: "-Wl,--hash-style=gnu"
+ # Gold linker only? Can we enable this by default?
+ # linker_flag: "-Wl,--warn-execstack"
+ # linker_flag: "-Wl,--detect-odr-violations"
+
+ # Include directory for ROCm headers.
+%{rocm_include_path}
+
+ compilation_mode_flags {
+ mode: DBG
+ # Enable debug symbols.
+ compiler_flag: "-g"
+ }
+ compilation_mode_flags {
+ mode: OPT
+
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt or
+ # even generally? However, that can't happen here, as it requires special
+ # handling in Bazel.
+ compiler_flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ compiler_flag: "-O2"
+
+ # Disable assertions
+ compiler_flag: "-DNDEBUG"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ compiler_flag: "-ffunction-sections"
+ compiler_flag: "-fdata-sections"
+ linker_flag: "-Wl,--gc-sections"
+ }
+ linking_mode_flags { mode: DYNAMIC }
+}
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
new file mode 100755
index 0000000000..824238022b
--- /dev/null
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
@@ -0,0 +1,241 @@
+#!/usr/bin/env python
+"""Crosstool wrapper for compiling ROCm programs.
+
+SYNOPSIS:
+ crosstool_wrapper_driver_rocm [options passed in by cc_library()
+ or cc_binary() rule]
+
+DESCRIPTION:
+ This script is expected to be called by the cc_library() or cc_binary() bazel
+ rules. When the option "-x rocm" is present in the list of arguments passed
+ to this script, it invokes the hipcc compiler. Most arguments are passed
+ as is as a string to --compiler-options of hipcc. When "-x rocm" is not
+ present, this wrapper invokes gcc with the input arguments as is.
+"""
+
+from __future__ import print_function
+
+__author__ = 'whchung@gmail.com (Wen-Heng (Jack) Chung)'
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by rocm_configure.bzl.
+CPU_COMPILER = ('%{cpu_compiler}')
+GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
+
+HIPCC_PATH = '%{hipcc_path}'
+PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from the argv list.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ option: The option whose value to extract, without the leading '-'.
+
+ Returns:
+ A list of values, either directly following the option,
+ (eg., -opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., -opt val1 -opt val2).
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-' + option, nargs='*', action='append')
+ args, _ = parser.parse_known_args(argv)
+ if not args or not vars(args)[option]:
+ return []
+ else:
+ return sum(vars(args)[option], [])
+
+
+def GetHostCompilerOptions(argv):
+ """Collect the -isystem, -iquote, and --sysroot option values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be used as the --compiler-options to hipcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-isystem', nargs='*', action='append')
+ parser.add_argument('-iquote', nargs='*', action='append')
+ parser.add_argument('--sysroot', nargs=1)
+ parser.add_argument('-g', nargs='*', action='append')
+ parser.add_argument('-fno-canonical-system-headers', action='store_true')
+
+ args, _ = parser.parse_known_args(argv)
+
+ opts = ''
+
+ if args.isystem:
+ opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, []))
+ if args.iquote:
+ opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
+ if args.g:
+ opts += ' -g' + ' -g'.join(sum(args.g, []))
+ #if args.fno_canonical_system_headers:
+ # opts += ' -fno-canonical-system-headers'
+ if args.sysroot:
+ opts += ' --sysroot ' + args.sysroot[0]
+
+ return opts
+
+def GetHipccOptions(argv):
+ """Collect the -hipcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be passed directly to hipcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-hipcc_options', nargs='*', action='append')
+
+ args, _ = parser.parse_known_args(argv)
+
+ if args.hipcc_options:
+ options = _update_options(sum(args.hipcc_options, []))
+ return ' '.join(['--'+a for a in options])
+ return ''
+
+
+def InvokeHipcc(argv, log=False):
+ """Call hipcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('hipcc ' + args)
+ """
+
+ host_compiler_options = GetHostCompilerOptions(argv)
+ hipcc_compiler_options = GetHipccOptions(argv)
+ opt_option = GetOptionValue(argv, 'O')
+ m_options = GetOptionValue(argv, 'm')
+ m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
+ include_options = GetOptionValue(argv, 'I')
+ out_file = GetOptionValue(argv, 'o')
+ depfiles = GetOptionValue(argv, 'MF')
+ defines = GetOptionValue(argv, 'D')
+ defines = ''.join([' -D' + define for define in defines])
+ undefines = GetOptionValue(argv, 'U')
+ undefines = ''.join([' -U' + define for define in undefines])
+ std_options = GetOptionValue(argv, 'std')
+ hipcc_allowed_std_options = ["c++11"]
+ std_options = ''.join([' -std=' + define
+ for define in std_options if define in hipcc_allowed_std_options])
+
+ # The list of source files get passed after the -c option. I don't know of
+ # any other reliable way to just get the list of source files to be compiled.
+ src_files = GetOptionValue(argv, 'c')
+
+ if len(src_files) == 0:
+ return 1
+ if len(out_file) != 1:
+ return 1
+
+ opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0)
+ else ' -g')
+
+ includes = (' -I ' + ' -I '.join(include_options)
+ if len(include_options) > 0
+ else '')
+
+ # Unfortunately, there are other options that have -c prefix too.
+ # So allowing only those look like C/C++ files.
+ src_files = [f for f in src_files if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ srcs = ' '.join(src_files)
+ out = ' -o ' + out_file[0]
+
+ hipccopts = ' '
+ hipccopts += ' ' + hipcc_compiler_options
+ hipccopts += undefines
+ hipccopts += defines
+ hipccopts += std_options
+ hipccopts += m_options
+
+ if depfiles:
+ # Generate the dependency file
+ depfile = depfiles[0]
+ cmd = (HIPCC_PATH + ' ' + hipccopts +
+ host_compiler_options +
+ ' ' + GCC_HOST_COMPILER_PATH +
+ ' -I .' + includes + ' ' + srcs + ' -M -o ' + depfile)
+ if log: Log(cmd)
+ exit_status = os.system(cmd)
+ if exit_status != 0:
+ return exit_status
+
+ cmd = (HIPCC_PATH + ' ' + hipccopts +
+ host_compiler_options + ' -fPIC' +
+ ' ' + GCC_HOST_COMPILER_PATH +
+ ' -I .' + opt + includes + ' -c ' + srcs + out)
+
+ # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
+ # Need to investigate and fix.
+ cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd
+ if log: Log(cmd)
+ return os.system(cmd)
+
+
+def main():
+ # ignore PWD env var
+ os.environ['PWD']=''
+
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--rocm_log', action='store_true')
+ parser.add_argument('-pass-exit-codes', action='store_true')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'rocm':
+ if args.rocm_log: Log('-x rocm')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.rocm_log: Log('using hipcc')
+ return InvokeHipcc(leftover, log=args.rocm_log)
+
+ # XXX use hipcc to link
+ if args.pass_exit_codes:
+ gpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('-pass-exit-codes'))]
+
+ # special handling for $ORIGIN
+ # - guard every argument with ''
+ modified_gpu_compiler_flags = []
+ for flag in gpu_compiler_flags:
+ modified_gpu_compiler_flags.append("'" + flag + "'")
+
+ if args.rocm_log: Log('Link with hipcc: %s' % (' '.join([HIPCC_PATH] + modified_gpu_compiler_flags)))
+ return subprocess.call([HIPCC_PATH] + modified_gpu_compiler_flags)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x rocm. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--rocm_log'))]
+
+ # XXX: SE codes need to be built with gcc, but need this macro defined
+ cpu_compiler_flags.append("-D__HIP_PLATFORM_HCC__")
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 5648b1525a..69f4599c16 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -48,6 +48,7 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"]
CUDA_LIB_PATHS = [
"lib64/",
"lib64/stubs/",
+ "lib/powerpc64le-linux-gnu/",
"lib/x86_64-linux-gnu/",
"lib/x64/",
"lib/",
@@ -70,6 +71,7 @@ CUPTI_HEADER_PATHS = [
# the other CUDA libraries but rather in a special extras/CUPTI directory.
CUPTI_LIB_PATHS = [
"extras/CUPTI/lib64/",
+ "lib/powerpc64le-linux-gnu/",
"lib/x86_64-linux-gnu/",
"lib64/",
"extras/CUPTI/libx64/",
@@ -1105,8 +1107,8 @@ def symlink_genrule_for_dir(
# $(@D) will include the full path to the file.
dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
- # On Windows, symlink is not supported, so we just copy all the files.
- cmd = "cp -f" if _is_windows(repository_ctx) else "ln -s"
+ # Copy the headers to create a sandboxable setup.
+ cmd = "cp -f"
command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(
@@ -1332,27 +1334,14 @@ def _create_local_cuda_repository(repository_ctx):
cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
cuda_defines["%{host_compiler_warnings}"] = ""
- # TODO(klimek): We currently need to inject "/" as builtin directory path
- # to disable bazel's dependency checks.
- # The problem is that:
- # - the python rules symlink the python headers into the bazel root
- # - the rules use 'includes' in the BUILD file to redirect includes of the
- # python headers through those paths
- # - bazel currently uses -isystem for include paths specified via 'includes'
- # - gcc follows symlinks when resolving files via -isystem paths, and puts
- # the resolved paths into the .d file, which makes the dependency check
- # fail for bazel
- # There are multiple possible ways to solve this:
- # 1. make bazel not use -isystem for paths specified via 'includes'
- # 2. cp the headers instead of symlinking them
- #
- # Once this is fixed, the right builtin directory path is:
- # (host_compiler_includes +
- # "\n cxx_builtin_include_directory: \"%s\"" % cuda_include_path)
- # The cuda directory needs to be passed, as there is currently no rule
- # providing the cuda headers in the same way the python headers are
- # provided.
- cuda_defines["%{host_compiler_includes}"] = "\n cxx_builtin_include_directory: \"/\""
+ # nvcc has the system include paths built in and will automatically
+ # search them; we cannot work around that, so we add the relevant cuda
+ # system paths to the allowed compiler specific include paths.
+ cuda_defines["%{host_compiler_includes}"] = (
+ host_compiler_includes + "\n" +
+ _cuda_include_path(repository_ctx, cuda_config) +
+ "\n cxx_builtin_include_directory: \"%s\"" % cupti_header_dir +
+ "\n cxx_builtin_include_directory: \"%s\"" % cudnn_header_dir)
nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" %
(
cuda_config.cuda_toolkit_path,
diff --git a/third_party/gpus/rocm/BUILD b/third_party/gpus/rocm/BUILD
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/gpus/rocm/BUILD
diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl
new file mode 100644
index 0000000000..8258bb3589
--- /dev/null
+++ b/third_party/gpus/rocm/BUILD.tpl
@@ -0,0 +1,99 @@
+licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "using_hipcc",
+ values = {
+ "define": "using_rocm_hipcc=true",
+ },
+)
+
+cc_library(
+ name = "rocm_headers",
+ hdrs = [
+ "rocm/rocm_config.h",
+ %{rocm_headers}
+ ],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "hip",
+ srcs = ["rocm/lib/%{hip_lib}"],
+ data = ["rocm/lib/%{hip_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "rocblas",
+ srcs = ["rocm/lib/%{rocblas_lib}"],
+ data = ["rocm/lib/%{rocblas_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "rocfft",
+ srcs = ["rocm/lib/%{rocfft_lib}"],
+ data = ["rocm/lib/%{rocfft_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "hiprand",
+ srcs = ["rocm/lib/%{hiprand_lib}"],
+ data = ["rocm/lib/%{hiprand_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ "rocm/include/rocrand",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "miopen",
+ srcs = ["rocm/lib/%{miopen_lib}"],
+ data = ["rocm/lib/%{miopen_lib}"],
+ includes = [
+ ".",
+ "rocm/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "rocm",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":rocm_headers",
+ ":hip",
+ ":rocblas",
+ ":rocfft",
+ ":hiprand",
+ ":miopen",
+ ],
+)
+
+%{rocm_include_genrules}
diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl
new file mode 100644
index 0000000000..08c59f95a0
--- /dev/null
+++ b/third_party/gpus/rocm/build_defs.bzl.tpl
@@ -0,0 +1,45 @@
+# Macros for building ROCm code.
+def if_rocm(if_true, if_false = []):
+ """Shorthand for select()'ing on whether we're building with ROCm.
+
+ Returns a select statement which evaluates to if_true if we're building
+ with ROCm enabled. Otherwise, the select statement evaluates to if_false.
+
+ """
+ return select({
+ "@local_config_rocm//rocm:using_hipcc": if_true,
+ "//conditions:default": if_false
+ })
+
+
+def rocm_default_copts():
+ """Default options for all ROCm compilations."""
+ return if_rocm(["-x", "rocm"] + %{rocm_extra_copts})
+
+def rocm_copts(opts = []):
+ """Gets the appropriate set of copts for (maybe) ROCm compilation.
+
+ If we're doing ROCm compilation, returns copts for our particular ROCm
+ compiler. If we're not doing ROCm compilation, returns an empty list.
+
+ """
+ return rocm_default_copts() + select({
+ "//conditions:default": [],
+ "@local_config_rocm//rocm:using_hipcc": ([
+ "",
+ ]),
+ }) + if_rocm_is_configured(opts)
+
+def rocm_is_configured():
+ """Returns true if ROCm was enabled during the configure process."""
+ return %{rocm_is_configured}
+
+def if_rocm_is_configured(x):
+ """Tests if the ROCm was enabled during the configure process.
+
+ Unlike if_rocm(), this does not require that we are building with
+ --config=rocm. Used to allow non-ROCm code to depend on ROCm libraries.
+ """
+ if rocm_is_configured():
+ return x
+ return []
diff --git a/third_party/gpus/rocm/rocm_config.h.tpl b/third_party/gpus/rocm/rocm_config.h.tpl
new file mode 100644
index 0000000000..c5f25a845c
--- /dev/null
+++ b/third_party/gpus/rocm/rocm_config.h.tpl
@@ -0,0 +1,21 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef ROCM_ROCM_CONFIG_H_
+#define ROCM_ROCM_CONFIG_H_
+
+#define TF_ROCM_TOOLKIT_PATH "/opt/rocm"
+
+#endif // ROCM_ROCM_CONFIG_H_
diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl
new file mode 100644
index 0000000000..9108639b0b
--- /dev/null
+++ b/third_party/gpus/rocm_configure.bzl
@@ -0,0 +1,784 @@
+# -*- Python -*-
+"""Repository rule for ROCm autoconfiguration.
+
+`rocm_configure` depends on the following environment variables:
+
+ * `TF_NEED_ROCM`: Whether to enable building with ROCm.
+ * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path
+ * `ROCM_TOOLKIT_PATH`: The path to the ROCm toolkit. Default is
+ `/opt/rocm`.
+ * `TF_ROCM_VERSION`: The version of the ROCm toolkit. If this is blank, then
+ use the system default.
+ * `TF_MIOPEN_VERSION`: The version of the MIOpen library.
+ * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. Default is
+ `gfx803,gfx900`.
+"""
+
+_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
+_ROCM_TOOLKIT_PATH = "ROCM_TOOLKIT_PATH"
+_TF_ROCM_VERSION = "TF_ROCM_VERSION"
+_TF_MIOPEN_VERSION = "TF_MIOPEN_VERSION"
+_TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS"
+_TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO"
+
+_DEFAULT_ROCM_VERSION = ""
+_DEFAULT_MIOPEN_VERSION = ""
+_DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm"
+_DEFAULT_ROCM_AMDGPU_TARGETS = ["gfx803", "gfx900"]
+
+def find_cc(repository_ctx):
+ """Find the C++ compiler."""
+
+ # Return a dummy value for GCC detection here to avoid error
+ target_cc_name = "gcc"
+ cc_path_envvar = _GCC_HOST_COMPILER_PATH
+ cc_name = target_cc_name
+
+ if cc_path_envvar in repository_ctx.os.environ:
+ cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
+ if cc_name_from_env:
+ cc_name = cc_name_from_env
+ if cc_name.startswith("/"):
+ # Absolute path, maybe we should make this supported by our which function.
+ return cc_name
+ cc = repository_ctx.which(cc_name)
+ if cc == None:
+ fail(("Cannot find {}, either correct your path or set the {}" +
+ " environment variable").format(target_cc_name, cc_path_envvar))
+ return cc
+
+_INC_DIR_MARKER_BEGIN = "#include <...>"
+
+def _cxx_inc_convert(path):
+ """Convert path returned by cc -E xc++ in a complete path."""
+ path = path.strip()
+ return path
+
+def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
+ """Compute the list of default C or C++ include directories."""
+ if lang_is_cpp:
+ lang = "c++"
+ else:
+ lang = "c"
+
+ # TODO: We pass -no-canonical-prefixes here to match the compiler flags,
+ # but in rocm_clang CROSSTOOL file that is a `feature` and we should
+ # handle the case when it's disabled and no flag is passed
+ result = repository_ctx.execute([
+ cc,
+ "-no-canonical-prefixes",
+ "-E",
+ "-x" + lang,
+ "-",
+ "-v",
+ ])
+ index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
+ if index1 == -1:
+ return []
+ index1 = result.stderr.find("\n", index1)
+ if index1 == -1:
+ return []
+ index2 = result.stderr.rfind("\n ")
+ if index2 == -1 or index2 < index1:
+ return []
+ index2 = result.stderr.find("\n", index2 + 1)
+ if index2 == -1:
+ inc_dirs = result.stderr[index1 + 1:]
+ else:
+ inc_dirs = result.stderr[index1 + 1:index2].strip()
+
+ return [
+ str(repository_ctx.path(_cxx_inc_convert(p)))
+ for p in inc_dirs.split("\n")
+ ]
+
+def get_cxx_inc_directories(repository_ctx, cc):
+ """Compute the list of default C and C++ include directories."""
+
+ # For some reason `clang -xc` sometimes returns include paths that are
+ # different from the ones from `clang -xc++`. (Symlink and a dir)
+ # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
+ includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
+ includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
+
+ includes_cpp_set = depset(includes_cpp)
+ return includes_cpp + [
+ inc
+ for inc in includes_c
+ if inc not in includes_cpp_set
+ ]
+
+def auto_configure_fail(msg):
+ """Output failure message when rocm configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("\n%sROCm Configuration Error:%s %s\n" % (red, no_color, msg))
+
+# END cc_configure common functions (see TODO above).
+
+def _host_compiler_includes(repository_ctx, cc):
+ """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
+
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
+
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
+
+ # Add numpy headers
+ inc_dirs.append("/usr/lib/python2.7/dist-packages/numpy/core/include")
+
+ entries = []
+ for inc_dir in inc_dirs:
+ entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
+
+ # define TENSORFLOW_USE_ROCM
+ entries.append(" unfiltered_cxx_flag: \"-DTENSORFLOW_USE_ROCM\"")
+
+ return "\n".join(entries)
+
+def _rocm_include_path(repository_ctx, rocm_config):
+ """Generates the cxx_builtin_include_directory entries for rocm inc dirs.
+
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
+
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ inc_dirs = []
+
+ # general ROCm include path
+ inc_dirs.append(rocm_config.rocm_toolkit_path + "/include")
+
+ # Add HSA headers
+ inc_dirs.append("/opt/rocm/hsa/include")
+
+ # Add HIP headers
+ inc_dirs.append("/opt/rocm/include/hip")
+ inc_dirs.append("/opt/rocm/include/hip/hcc_detail")
+
+ # Add rocrand and hiprand headers
+ inc_dirs.append("/opt/rocm/rocrand/include")
+ inc_dirs.append("/opt/rocm/hiprand/include")
+
+ # Add rocfft headers
+ inc_dirs.append("/opt/rocm/rocfft/include")
+
+ # Add rocBLAS headers
+ inc_dirs.append("/opt/rocm/rocblas/include")
+
+ # Add MIOpen headers
+ inc_dirs.append("/opt/rocm/miopen/include")
+
+ # Add hcc headers
+ inc_dirs.append("/opt/rocm/hcc/include")
+ inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/7.0.0/include/")
+ inc_dirs.append("/opt/rocm/hcc/lib/clang/7.0.0/include")
+
+ # Newer hcc builds use/are based off of clang 8.0.0.
+ inc_dirs.append("/opt/rocm/hcc/compiler/lib/clang/8.0.0/include/")
+ inc_dirs.append("/opt/rocm/hcc/lib/clang/8.0.0/include")
+
+ inc_entries = []
+ for inc_dir in inc_dirs:
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
+ return "\n".join(inc_entries)
+
+def _enable_rocm(repository_ctx):
+ if "TF_NEED_ROCM" in repository_ctx.os.environ:
+ enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip()
+ return enable_rocm == "1"
+ return False
+
+def _rocm_toolkit_path(repository_ctx):
+ """Finds the rocm toolkit directory.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A speculative real path of the rocm toolkit install directory.
+ """
+ rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH
+ if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ:
+ rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip()
+ if not repository_ctx.path(rocm_toolkit_path).exists:
+ auto_configure_fail("Cannot find rocm toolkit path.")
+ return str(repository_ctx.path(rocm_toolkit_path).realpath)
+
+def _amdgpu_targets(repository_ctx):
+ """Returns a list of strings representing AMDGPU targets."""
+ if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ:
+ return _DEFAULT_ROCM_AMDGPU_TARGETS
+ amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS]
+ amdgpu_targets = amdgpu_targets_str.split(",")
+ for amdgpu_target in amdgpu_targets:
+ if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
+ auto_configure_fail("Invalid AMDGPU target: %s" % amdgpu_target)
+ return amdgpu_targets
+
+def _cpu_value(repository_ctx):
+ """Returns the name of the host operating system.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A string containing the name of the host operating system.
+ """
+ os_name = repository_ctx.os.name.lower()
+ if os_name.startswith("mac os"):
+ return "Darwin"
+ if os_name.find("windows") != -1:
+ return "Windows"
+ result = repository_ctx.execute(["uname", "-s"])
+ return result.stdout.strip()
+
+def _lib_name(lib, cpu_value, version = "", static = False):
+ """Constructs the platform-specific name of a library.
+
+ Args:
+ lib: The name of the library, such as "hip"
+ cpu_value: The name of the host operating system.
+ version: The version of the library.
+ static: True the library is static or False if it is a shared object.
+
+ Returns:
+ The platform-specific name of the library.
+ """
+ if cpu_value in ("Linux"):
+ if static:
+ return "lib%s.a" % lib
+ else:
+ if version:
+ version = ".%s" % version
+ return "lib%s.so%s" % (lib, version)
+ elif cpu_value == "Windows":
+ return "%s.lib" % lib
+ elif cpu_value == "Darwin":
+ if static:
+ return "lib%s.a" % lib
+ elif version:
+ version = ".%s" % version
+ return "lib%s%s.dylib" % (lib, version)
+ else:
+ auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
+
+def _find_rocm_lib(
+ lib,
+ repository_ctx,
+ cpu_value,
+ basedir,
+ version = "",
+ static = False):
+ """Finds the given ROCm libraries on the system.
+
+ Args:
+ lib: The name of the library, such as "hip"
+ repository_ctx: The repository context.
+ cpu_value: The name of the host operating system.
+ basedir: The install directory of ROCm.
+ version: The version of the library.
+ static: True if static library, False if shared object.
+
+ Returns:
+ Returns a struct with the following fields:
+ file_name: The basename of the library found on the system.
+ path: The full path to the library.
+ """
+ file_name = _lib_name(lib, cpu_value, version, static)
+ if cpu_value == "Linux":
+ path = repository_ctx.path("%s/lib64/%s" % (basedir, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+ path = repository_ctx.path("%s/lib64/stubs/%s" % (basedir, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+ path = repository_ctx.path(
+ "%s/lib/x86_64-linux-gnu/%s" % (basedir, file_name),
+ )
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+
+ path = repository_ctx.path("%s/lib/%s" % (basedir, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+ path = repository_ctx.path("%s/%s" % (basedir, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+
+ auto_configure_fail("Cannot find rocm library %s" % file_name)
+
+def _find_libs(repository_ctx, rocm_config):
+ """Returns the ROCm libraries on the system.
+
+ Args:
+ repository_ctx: The repository context.
+ rocm_config: The ROCm config as returned by _get_rocm_config
+
+ Returns:
+ Map of library names to structs of filename and path as returned by
+ _find_rocm_lib.
+ """
+ cpu_value = rocm_config.cpu_value
+ return {
+ "hip": _find_rocm_lib(
+ "hip_hcc",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path,
+ ),
+ "rocblas": _find_rocm_lib(
+ "rocblas",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path + "/rocblas",
+ ),
+ "rocfft": _find_rocm_lib(
+ "rocfft",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path + "/rocfft",
+ ),
+ "hiprand": _find_rocm_lib(
+ "hiprand",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path + "/hiprand",
+ ),
+ "miopen": _find_rocm_lib(
+ "MIOpen",
+ repository_ctx,
+ cpu_value,
+ rocm_config.rocm_toolkit_path + "/miopen",
+ ),
+ }
+
+def _get_rocm_config(repository_ctx):
+ """Detects and returns information about the ROCm installation on the system.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A struct containing the following fields:
+ rocm_toolkit_path: The ROCm toolkit installation directory.
+ amdgpu_targets: A list of the system's AMDGPU targets.
+ cpu_value: The name of the host operating system.
+ """
+ cpu_value = _cpu_value(repository_ctx)
+ rocm_toolkit_path = _rocm_toolkit_path(repository_ctx)
+ return struct(
+ rocm_toolkit_path = rocm_toolkit_path,
+ amdgpu_targets = _amdgpu_targets(repository_ctx),
+ cpu_value = cpu_value,
+ )
+
+def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
+ if not out:
+ out = tpl.replace(":", "/")
+ repository_ctx.template(
+ out,
+ Label("//third_party/gpus/%s.tpl" % tpl),
+ substitutions,
+ )
+
+def _file(repository_ctx, label):
+ repository_ctx.template(
+ label.replace(":", "/"),
+ Label("//third_party/gpus/%s.tpl" % label),
+ {},
+ )
+
+_DUMMY_CROSSTOOL_BZL_FILE = """
+def error_gpu_disabled():
+ fail("ERROR: Building with --config=rocm but TensorFlow is not configured " +
+ "to build with GPU support. Please re-run ./configure and enter 'Y' " +
+ "at the prompt to build with GPU support.")
+
+ native.genrule(
+ name = "error_gen_crosstool",
+ outs = ["CROSSTOOL"],
+ cmd = "echo 'Should not be run.' && exit 1",
+ )
+
+ native.filegroup(
+ name = "crosstool",
+ srcs = [":CROSSTOOL"],
+ output_licenses = ["unencumbered"],
+ )
+"""
+
+_DUMMY_CROSSTOOL_BUILD_FILE = """
+load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
+
+error_gpu_disabled()
+"""
+
+def _create_dummy_repository(repository_ctx):
+ cpu_value = _cpu_value(repository_ctx)
+
+ # Set up BUILD file for rocm/.
+ _tpl(
+ repository_ctx,
+ "rocm:build_defs.bzl",
+ {
+ "%{rocm_is_configured}": "False",
+ "%{rocm_extra_copts}": "[]",
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "rocm:BUILD",
+ {
+ "%{hip_lib}": _lib_name("hip", cpu_value),
+ "%{rocblas_lib}": _lib_name("rocblas", cpu_value),
+ "%{miopen_lib}": _lib_name("miopen", cpu_value),
+ "%{rocfft_lib}": _lib_name("rocfft", cpu_value),
+ "%{hiprand_lib}": _lib_name("hiprand", cpu_value),
+ "%{rocm_include_genrules}": "",
+ "%{rocm_headers}": "",
+ },
+ )
+
+ # Create dummy files for the ROCm toolkit since they are still required by
+ # tensorflow/core/platform/default/build_config:rocm.
+ repository_ctx.file("rocm/hip/include/hip/hip_runtime.h", "")
+
+ # Set up rocm_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "rocm:rocm_config.h",
+ {
+ "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH,
+ },
+ "rocm/rocm/rocm_config.h",
+ )
+
+ # If rocm_configure is not configured to build with GPU support, and the user
+ # attempts to build with --config=rocm, add a dummy build rule to intercept
+ # this and fail with an actionable error message.
+ repository_ctx.file(
+ "crosstool/error_gpu_disabled.bzl",
+ _DUMMY_CROSSTOOL_BZL_FILE,
+ )
+ repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
+
+def _execute(
+ repository_ctx,
+ cmdline,
+ error_msg = None,
+ error_details = None,
+ empty_stdout_fine = False):
+ """Executes an arbitrary shell command.
+
+ Args:
+ repository_ctx: the repository_ctx object
+ cmdline: list of strings, the command to execute
+ error_msg: string, a summary of the error if the command fails
+ error_details: string, details about the error or steps to fix it
+ empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
+ it's an error
+ Return:
+ the result of repository_ctx.execute(cmdline)
+ """
+ result = repository_ctx.execute(cmdline)
+ if result.stderr or not (empty_stdout_fine or result.stdout):
+ auto_configure_fail(
+ "\n".join([
+ error_msg.strip() if error_msg else "Repository command failed",
+ result.stderr.strip(),
+ error_details if error_details else "",
+ ]),
+ )
+ return result
+
+def _norm_path(path):
+ """Returns a path with '/' and remove the trailing slash."""
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ return path
+
+def _symlink_genrule_for_dir(
+ repository_ctx,
+ src_dir,
+ dest_dir,
+ genrule_name,
+ src_files = [],
+ dest_files = []):
+ """Returns a genrule to symlink(or copy if on Windows) a set of files.
+
+ If src_dir is passed, files will be read from the given directory; otherwise
+ we assume files are in src_files and dest_files
+ """
+ if src_dir != None:
+ src_dir = _norm_path(src_dir)
+ dest_dir = _norm_path(dest_dir)
+ files = _read_dir(repository_ctx, src_dir)
+
+ # Create a list with the src_dir stripped to use for outputs.
+ dest_files = files.replace(src_dir, "").splitlines()
+ src_files = files.splitlines()
+ command = []
+
+ # We clear folders that might have been generated previously to avoid
+ # undesired inclusions
+ command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi')
+ command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi')
+ outs = []
+ for i in range(len(dest_files)):
+ if dest_files[i] != "":
+ # If we have only one file to link we do not want to use the dest_dir, as
+ # $(@D) will include the full path to the file.
+ dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
+
+ # On Windows, symlink is not supported, so we just copy all the files.
+ cmd = "ln -s"
+ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
+ outs.append(' "' + dest_dir + dest_files[i] + '",')
+ genrule = _genrule(
+ src_dir,
+ genrule_name,
+ " && ".join(command),
+ "\n".join(outs),
+ )
+ return genrule
+
+def _genrule(src_dir, genrule_name, command, outs):
+ """Returns a string with a genrule.
+
+ Genrule executes the given command and produces the given outputs.
+ """
+ return (
+ "genrule(\n" +
+ ' name = "' +
+ genrule_name + '",\n' +
+ " outs = [\n" +
+ outs +
+ "\n ],\n" +
+ ' cmd = """\n' +
+ command +
+ '\n """,\n' +
+ ")\n"
+ )
+
+def _read_dir(repository_ctx, src_dir):
+ """Returns a string with all files in a directory.
+
+ Finds all files inside a directory, traversing subfolders and following
+ symlinks. The returned string contains the full path of all files
+ separated by line breaks.
+ """
+ find_result = _execute(
+ repository_ctx,
+ ["find", src_dir, "-follow", "-type", "f"],
+ empty_stdout_fine = True,
+ )
+ result = find_result.stdout
+ return result
+
+def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets):
+ if False:
+ amdgpu_target_flags = ["--amdgpu-target=" +
+ amdgpu_target for amdgpu_target in amdgpu_targets]
+ else:
+ # AMDGPU targets are handled in the "crosstool_wrapper_driver_is_not_gcc"
+ amdgpu_target_flags = []
+ return str(amdgpu_target_flags)
+
+def _create_local_rocm_repository(repository_ctx):
+ """Creates the repository containing files set up to build with ROCm."""
+ rocm_config = _get_rocm_config(repository_ctx)
+
+ # Set up symbolic links for the rocm toolkit by creating genrules to do
+ # symlinking. We create one genrule for each directory we want to track under
+ # rocm_toolkit_path
+ rocm_toolkit_path = rocm_config.rocm_toolkit_path
+ rocm_include_path = rocm_toolkit_path + "/include"
+ genrules = [_symlink_genrule_for_dir(
+ repository_ctx,
+ rocm_include_path,
+ "rocm/include",
+ "rocm-include",
+ )]
+ genrules.append(_symlink_genrule_for_dir(
+ repository_ctx,
+ rocm_toolkit_path + "/rocfft/include",
+ "rocm/include/rocfft",
+ "rocfft-include",
+ ))
+ genrules.append(_symlink_genrule_for_dir(
+ repository_ctx,
+ rocm_toolkit_path + "/rocblas/include",
+ "rocm/include/rocblas",
+ "rocblas-include",
+ ))
+ genrules.append(_symlink_genrule_for_dir(
+ repository_ctx,
+ rocm_toolkit_path + "/miopen/include",
+ "rocm/include/miopen",
+ "miopen-include",
+ ))
+
+ rocm_libs = _find_libs(repository_ctx, rocm_config)
+ rocm_lib_src = []
+ rocm_lib_dest = []
+ for lib in rocm_libs.values():
+ rocm_lib_src.append(lib.path)
+ rocm_lib_dest.append("rocm/lib/" + lib.file_name)
+ genrules.append(_symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "",
+ "rocm-lib",
+ rocm_lib_src,
+ rocm_lib_dest,
+ ))
+
+ included_files = _read_dir(repository_ctx, rocm_include_path).replace(
+ rocm_include_path,
+ "",
+ ).splitlines()
+
+ # Set up BUILD file for rocm/
+ _tpl(
+ repository_ctx,
+ "rocm:build_defs.bzl",
+ {
+ "%{rocm_is_configured}": "True",
+ "%{rocm_extra_copts}": _compute_rocm_extra_copts(
+ repository_ctx,
+ rocm_config.amdgpu_targets,
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "rocm:BUILD",
+ {
+ "%{hip_lib}": rocm_libs["hip"].file_name,
+ "%{rocblas_lib}": rocm_libs["rocblas"].file_name,
+ "%{rocfft_lib}": rocm_libs["rocfft"].file_name,
+ "%{hiprand_lib}": rocm_libs["hiprand"].file_name,
+ "%{miopen_lib}": rocm_libs["miopen"].file_name,
+ "%{rocm_include_genrules}": "\n".join(genrules),
+ "%{rocm_headers}": ('":rocm-include",\n' +
+ '":rocfft-include",\n' +
+ '":rocblas-include",\n' +
+ '":miopen-include",'),
+ },
+ )
+
+ # Set up crosstool/
+ _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty", "%{win_linker_files}": ":empty"})
+ cc = find_cc(repository_ctx)
+ host_compiler_includes = _host_compiler_includes(repository_ctx, cc)
+ rocm_defines = {
+ "%{rocm_include_path}": _rocm_include_path(
+ repository_ctx,
+ rocm_config,
+ ),
+ "%{host_compiler_includes}": host_compiler_includes,
+ "%{clang_path}": str(cc),
+ }
+
+ _tpl(repository_ctx, "crosstool:CROSSTOOL_hipcc", rocm_defines, out = "crosstool/CROSSTOOL")
+
+ _tpl(
+ repository_ctx,
+ "crosstool:clang/bin/crosstool_wrapper_driver_rocm",
+ {
+ "%{cpu_compiler}": str(cc),
+ "%{hipcc_path}": "/opt/rocm/bin/hipcc",
+ "%{gcc_host_compiler_path}": str(cc),
+ "%{rocm_amdgpu_targets}": ",".join(
+ ["\"%s\"" % c for c in rocm_config.amdgpu_targets],
+ ),
+ },
+ )
+
+ # Set up rocm_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "rocm:rocm_config.h",
+ {
+ "%{rocm_amdgpu_targets}": ",".join(
+ ["\"%s\"" % c for c in rocm_config.amdgpu_targets],
+ ),
+ "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path,
+ },
+ "rocm/rocm/rocm_config.h",
+ )
+
+def _create_remote_rocm_repository(repository_ctx, remote_config_repo):
+ """Creates pointers to a remotely configured repo set up to build with ROCm."""
+ _tpl(
+ repository_ctx,
+ "rocm:build_defs.bzl",
+ {
+ "%{rocm_is_configured}": "True",
+ "%{rocm_extra_copts}": _compute_rocm_extra_copts(
+ repository_ctx, #_compute_capabilities(repository_ctx)
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "rocm:remote.BUILD",
+ {
+ "%{remote_rocm_repo}": remote_config_repo,
+ },
+ "rocm/BUILD",
+ )
+ _tpl(repository_ctx, "crosstool:remote.BUILD", {
+ "%{remote_rocm_repo}": remote_config_repo,
+ }, "crosstool/BUILD")
+
+def _rocm_autoconf_impl(repository_ctx):
+ """Implementation of the rocm_autoconf repository rule."""
+ if not _enable_rocm(repository_ctx):
+ _create_dummy_repository(repository_ctx)
+ elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ:
+ _create_remote_rocm_repository(
+ repository_ctx,
+ repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO],
+ )
+ else:
+ _create_local_rocm_repository(repository_ctx)
+
+rocm_configure = repository_rule(
+ implementation = _rocm_autoconf_impl,
+ environ = [
+ _GCC_HOST_COMPILER_PATH,
+ "TF_NEED_ROCM",
+ _ROCM_TOOLKIT_PATH,
+ _TF_ROCM_VERSION,
+ _TF_MIOPEN_VERSION,
+ _TF_ROCM_AMDGPU_TARGETS,
+ _TF_ROCM_CONFIG_REPO,
+ ],
+)
+
+"""Detects and configures the local ROCm toolchain.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+rocm_configure(name = "local_config_rocm")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""
diff --git a/third_party/icu/BUILD b/third_party/icu/BUILD
new file mode 100644
index 0000000000..82bab3ffd9
--- /dev/null
+++ b/third_party/icu/BUILD
@@ -0,0 +1 @@
+# This empty BUILD file is required to make Bazel treat this directory as a package.
diff --git a/third_party/icu/BUILD.bazel b/third_party/icu/BUILD.bazel
new file mode 100644
index 0000000000..36d6b9006b
--- /dev/null
+++ b/third_party/icu/BUILD.bazel
@@ -0,0 +1,88 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files([
+ "icu4c/LICENSE",
+ "icu4j/main/shared/licenses/LICENSE",
+])
+
+cc_library(
+ name = "headers",
+ hdrs = glob(["icu4c/source/common/unicode/*.h"]),
+ includes = [
+ "icu4c/source/common",
+ ],
+ deps = [
+ ],
+)
+
+cc_library(
+ name = "common",
+ hdrs = glob(["icu4c/source/common/unicode/*.h"]),
+ includes = [
+ "icu4c/source/common",
+ ],
+ deps = [
+ ":icuuc",
+ ],
+)
+
+cc_library(
+ name = "icuuc",
+ srcs = glob(
+ [
+ "icu4c/source/common/*.c",
+ "icu4c/source/common/*.cpp",
+ "icu4c/source/stubdata/*.cpp",
+ ],
+ ),
+ hdrs = glob([
+ "icu4c/source/common/*.h",
+ ]),
+ copts = [
+ "-DU_COMMON_IMPLEMENTATION",
+ "-DU_HAVE_STD_ATOMICS",
+ ] + select({
+ ":android": [
+ "-fdata-sections",
+ "-DGOOGLE_VENDOR_SRC_BRANCH",
+ "-DU_HAVE_NL_LANGINFO_CODESET=0",
+ "-Wno-deprecated-declarations",
+ ],
+ ":apple": [
+ "-DGOOGLE_VENDOR_SRC_BRANCH",
+ "-Wno-shorten-64-to-32",
+ "-Wno-unused-variable",
+ ],
+ ":windows": [
+ "/utf-8",
+ "/DLOCALE_ALLOW_NEUTRAL_NAMES=0",
+ ],
+ "//conditions:default": [],
+ }),
+ tags = ["requires-rtti"],
+ visibility = [
+ "//visibility:private",
+ ],
+ deps = [
+ ":headers",
+ ],
+)
+
+config_setting(
+ name = "android",
+ values = {"crosstool_top": "//external:android/crosstool"},
+)
+
+config_setting(
+ name = "apple",
+ values = {"cpu": "darwin"},
+)
+
+config_setting(
+ name = "windows",
+ values = {"cpu": "x64_windows"},
+)
diff --git a/third_party/icu/workspace.bzl b/third_party/icu/workspace.bzl
new file mode 100644
index 0000000000..bfebf4219b
--- /dev/null
+++ b/third_party/icu/workspace.bzl
@@ -0,0 +1,15 @@
+"""Loads a lightweight subset of the ICU library for Unicode processing."""
+
+load("//third_party:repo.bzl", "third_party_http_archive")
+
+def repo():
+ third_party_http_archive(
+ name = "icu",
+ strip_prefix = "icu-release-62-1",
+ sha256 = "e15ffd84606323cbad5515bf9ecdf8061cc3bf80fb883b9e6aa162e485aa9761",
+ urls = [
+ "https://mirror.bazel.build/github.com/unicode-org/icu/archive/release-62-1.tar.gz",
+ "https://github.com/unicode-org/icu/archive/release-62-1.tar.gz",
+ ],
+ build_file = "//third_party/icu:BUILD.bazel",
+ )
diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD
index efff7fd51b..15a3e5cfa7 100644
--- a/third_party/mkl/BUILD
+++ b/third_party/mkl/BUILD
@@ -1,26 +1,26 @@
licenses(["notice"]) # 3-Clause BSD
config_setting(
- name = "using_mkl",
+ name = "build_with_mkl",
define_values = {
- "using_mkl": "true",
+ "build_with_mkl": "true",
},
visibility = ["//visibility:public"],
)
config_setting(
- name = "using_mkl_ml_only",
+ name = "build_with_mkl_ml_only",
define_values = {
- "using_mkl": "true",
- "using_mkl_ml_only": "true",
+ "build_with_mkl": "true",
+ "build_with_mkl_ml_only": "true",
},
visibility = ["//visibility:public"],
)
config_setting(
- name = "using_mkl_lnx_x64",
+ name = "build_with_mkl_lnx_x64",
define_values = {
- "using_mkl": "true",
+ "build_with_mkl": "true",
},
values = {
"cpu": "k8",
@@ -28,6 +28,15 @@ config_setting(
visibility = ["//visibility:public"],
)
+config_setting(
+ name = "enable_mkl",
+ define_values = {
+ "enable_mkl": "true",
+ "build_with_mkl": "true",
+ },
+ visibility = ["//visibility:public"],
+)
+
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl
index b645c0fc5c..10c2d90c84 100644
--- a/third_party/mkl/build_defs.bzl
+++ b/third_party/mkl/build_defs.bzl
@@ -1,9 +1,11 @@
# -*- Python -*-
"""Skylark macros for MKL.
-if_mkl is a conditional to check if MKL is enabled or not.
-if_mkl_ml is a conditional to check if MKL-ML is enabled.
+
+if_mkl is a conditional to check if we are building with MKL.
+if_mkl_ml is a conditional to check if we are building with MKL-ML.
if_mkl_ml_only is a conditional to check for MKL-ML-only (no MKL-DNN) mode.
if_mkl_lnx_x64 is a conditional to check for MKL
+if_enable_mkl is a conditional to check if building with MKL and MKL is enabled.
mkl_repository is a repository rule for creating MKL repository rule that can
be pointed to either a local folder, or download it from the internet.
@@ -24,7 +26,7 @@ def if_mkl(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl")): if_true,
"//conditions:default": if_false,
})
@@ -40,8 +42,8 @@ def if_mkl_ml(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_false,
- str(Label("//third_party/mkl:using_mkl")): if_true,
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_false,
+ str(Label("//third_party/mkl:build_with_mkl")): if_true,
"//conditions:default": if_false,
})
@@ -56,12 +58,12 @@ def if_mkl_ml_only(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl_ml_only")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl_ml_only")): if_true,
"//conditions:default": if_false,
})
def if_mkl_lnx_x64(if_true, if_false = []):
- """Shorthand to select() on if MKL is on and the target is Linux x86-64.
+ """Shorthand to select() if building with MKL and the target is Linux x86-64.
Args:
if_true: expression to evaluate if building with MKL is enabled and the
@@ -73,7 +75,24 @@ def if_mkl_lnx_x64(if_true, if_false = []):
a select evaluating to either if_true or if_false as appropriate.
"""
return select({
- str(Label("//third_party/mkl:using_mkl_lnx_x64")): if_true,
+ str(Label("//third_party/mkl:build_with_mkl_lnx_x64")): if_true,
+ "//conditions:default": if_false,
+ })
+
+def if_enable_mkl(if_true, if_false = []):
+ """Shorthand to select() if we are building with MKL and MKL is enabled.
+
+ This is only effective when built with MKL.
+
+ Args:
+ if_true: expression to evaluate if building with MKL and MKL is enabled
+ if_false: expression to evaluate if building without MKL or MKL is not enabled.
+
+ Returns:
+ A select evaluating to either if_true or if_false as appropriate.
+ """
+ return select({
+ str(Label("//third_party/mkl:enable_mkl")): if_true,
"//conditions:default": if_false,
})
@@ -87,9 +106,9 @@ def mkl_deps():
inclusion in the deps attribute of rules.
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): ["@mkl_dnn"],
- str(Label("//third_party/mkl:using_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
- str(Label("//third_party/mkl:using_mkl")): [
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): ["@mkl_dnn"],
+ str(Label("//third_party/mkl:build_with_mkl_ml_only")): ["//third_party/mkl:intel_binary_blob"],
+ str(Label("//third_party/mkl:build_with_mkl")): [
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn",
],
diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD
index 3e567fa9fc..58ecda55e6 100644
--- a/third_party/mkl_dnn/BUILD
+++ b/third_party/mkl_dnn/BUILD
@@ -3,10 +3,10 @@ licenses(["notice"])
exports_files(["LICENSE"])
config_setting(
- name = "using_mkl_dnn_only",
+ name = "build_with_mkl_dnn_only",
define_values = {
- "using_mkl": "true",
- "using_mkl_dnn_only": "true",
+ "build_with_mkl": "true",
+ "build_with_mkl_dnn_only": "true",
},
visibility = ["//visibility:public"],
)
diff --git a/third_party/mkl_dnn/build_defs.bzl b/third_party/mkl_dnn/build_defs.bzl
index 7ce2a7d9b0..6388f31971 100644
--- a/third_party/mkl_dnn/build_defs.bzl
+++ b/third_party/mkl_dnn/build_defs.bzl
@@ -8,6 +8,6 @@ def if_mkl_open_source_only(if_true, if_false = []):
"""
return select({
- str(Label("//third_party/mkl_dnn:using_mkl_dnn_only")): if_true,
+ str(Label("//third_party/mkl_dnn:build_with_mkl_dnn_only")): if_true,
"//conditions:default": if_false,
})
diff --git a/third_party/ngraph/ngraph.BUILD b/third_party/ngraph/ngraph.BUILD
index 1fd1b8e8e0..6602a480af 100644
--- a/third_party/ngraph/ngraph.BUILD
+++ b/third_party/ngraph/ngraph.BUILD
@@ -11,41 +11,35 @@ cc_library(
cc_library(
name = "ngraph_cpu_backend",
srcs = [
- "src/ngraph/runtime/cpu/cpu_backend.cpp",
- "src/ngraph/runtime/cpu/cpu_builder.cpp",
- "src/ngraph/runtime/cpu/cpu_call_frame.cpp",
- "src/ngraph/runtime/cpu/cpu_external_function.cpp",
- "src/ngraph/runtime/cpu/cpu_kernels.cpp",
- "src/ngraph/runtime/cpu/cpu_layout_descriptor.cpp",
- "src/ngraph/runtime/cpu/cpu_tensor_view_wrapper.cpp",
- "src/ngraph/runtime/cpu/cpu_tensor_view.cpp",
- "src/ngraph/runtime/cpu/cpu_tracing.cpp",
"src/ngraph/runtime/cpu/builder/add.cpp",
"src/ngraph/runtime/cpu/builder/allreduce.cpp",
- "src/ngraph/runtime/cpu/builder/avg_pool.cpp",
- "src/ngraph/runtime/cpu/builder/argmin.cpp",
"src/ngraph/runtime/cpu/builder/argmax.cpp",
+ "src/ngraph/runtime/cpu/builder/argmin.cpp",
+ "src/ngraph/runtime/cpu/builder/avg_pool.cpp",
"src/ngraph/runtime/cpu/builder/batch_norm.cpp",
- "src/ngraph/runtime/cpu/builder/broadcast.cpp",
"src/ngraph/runtime/cpu/builder/bounded_relu.cpp",
+ "src/ngraph/runtime/cpu/builder/broadcast.cpp",
"src/ngraph/runtime/cpu/builder/concat.cpp",
"src/ngraph/runtime/cpu/builder/convert.cpp",
"src/ngraph/runtime/cpu/builder/convert_layout.cpp",
"src/ngraph/runtime/cpu/builder/convolution.cpp",
"src/ngraph/runtime/cpu/builder/dot.cpp",
"src/ngraph/runtime/cpu/builder/function_call.cpp",
- "src/ngraph/runtime/cpu/builder/lstm.cpp",
"src/ngraph/runtime/cpu/builder/lrn.cpp",
+ "src/ngraph/runtime/cpu/builder/lstm.cpp",
"src/ngraph/runtime/cpu/builder/matmul_bias.cpp",
"src/ngraph/runtime/cpu/builder/max.cpp",
"src/ngraph/runtime/cpu/builder/max_pool.cpp",
"src/ngraph/runtime/cpu/builder/min.cpp",
"src/ngraph/runtime/cpu/builder/one_hot.cpp",
- "src/ngraph/runtime/cpu/builder/relu.cpp",
"src/ngraph/runtime/cpu/builder/pad.cpp",
"src/ngraph/runtime/cpu/builder/product.cpp",
+ "src/ngraph/runtime/cpu/builder/quantize.cpp",
+ "src/ngraph/runtime/cpu/builder/quantized_avg_pool.cpp",
+ "src/ngraph/runtime/cpu/builder/quantized_max_pool.cpp",
"src/ngraph/runtime/cpu/builder/reduce_function.cpp",
"src/ngraph/runtime/cpu/builder/reduce_function_window.cpp",
+ "src/ngraph/runtime/cpu/builder/relu.cpp",
"src/ngraph/runtime/cpu/builder/replace_slice.cpp",
"src/ngraph/runtime/cpu/builder/reshape.cpp",
"src/ngraph/runtime/cpu/builder/reverse.cpp",
@@ -57,6 +51,16 @@ cc_library(
"src/ngraph/runtime/cpu/builder/slice.cpp",
"src/ngraph/runtime/cpu/builder/softmax.cpp",
"src/ngraph/runtime/cpu/builder/sum.cpp",
+ "src/ngraph/runtime/cpu/builder/topk.cpp",
+ "src/ngraph/runtime/cpu/cpu_backend.cpp",
+ "src/ngraph/runtime/cpu/cpu_builder.cpp",
+ "src/ngraph/runtime/cpu/cpu_call_frame.cpp",
+ "src/ngraph/runtime/cpu/cpu_external_function.cpp",
+ "src/ngraph/runtime/cpu/cpu_kernels.cpp",
+ "src/ngraph/runtime/cpu/cpu_layout_descriptor.cpp",
+ "src/ngraph/runtime/cpu/cpu_tensor_view.cpp",
+ "src/ngraph/runtime/cpu/cpu_tensor_view_wrapper.cpp",
+ "src/ngraph/runtime/cpu/cpu_tracing.cpp",
"src/ngraph/runtime/cpu/kernel/eigen_thread_pool.cpp",
"src/ngraph/runtime/cpu/kernel/pad.cpp",
"src/ngraph/runtime/cpu/kernel/reduce_max.cpp",
@@ -68,14 +72,19 @@ cc_library(
"src/ngraph/runtime/cpu/op/batch_dot.cpp",
"src/ngraph/runtime/cpu/op/batch_norm_relu.cpp",
"src/ngraph/runtime/cpu/op/bounded_relu.cpp",
- "src/ngraph/runtime/cpu/op/group_conv.cpp",
+ "src/ngraph/runtime/cpu/op/conv_add.cpp",
"src/ngraph/runtime/cpu/op/conv_bias.cpp",
"src/ngraph/runtime/cpu/op/conv_relu.cpp",
"src/ngraph/runtime/cpu/op/convert_layout.cpp",
+ "src/ngraph/runtime/cpu/op/dequantize.cpp",
+ "src/ngraph/runtime/cpu/op/group_conv.cpp",
"src/ngraph/runtime/cpu/op/loop_kernel.cpp",
"src/ngraph/runtime/cpu/op/lstm.cpp",
"src/ngraph/runtime/cpu/op/matmul_bias.cpp",
"src/ngraph/runtime/cpu/op/max_pool_with_indices.cpp",
+ "src/ngraph/runtime/cpu/op/quantize.cpp",
+ "src/ngraph/runtime/cpu/op/quantized_avg_pool.cpp",
+ "src/ngraph/runtime/cpu/op/quantized_max_pool.cpp",
"src/ngraph/runtime/cpu/op/rnn.cpp",
"src/ngraph/runtime/cpu/op/sigmoid_mul.cpp",
"src/ngraph/runtime/cpu/pass/cpu_assignment.cpp",
@@ -101,7 +110,7 @@ cc_library(
"-I external/ngraph/src",
"-I external/nlohmann_json_lib/include/",
'-D SHARED_LIB_EXT=\\".so\\"',
- '-D NGRAPH_VERSION=\\"0.7.0\\"',
+ '-D NGRAPH_VERSION=\\"0.8.1\\"',
"-D NGRAPH_DEX_ONLY",
],
visibility = ["//visibility:public"],
@@ -135,7 +144,7 @@ cc_library(
"-I external/ngraph/src",
"-I external/nlohmann_json_lib/include/",
'-D SHARED_LIB_EXT=\\".so\\"',
- '-D NGRAPH_VERSION=\\"0.7.0\\"',
+ '-D NGRAPH_VERSION=\\"0.8.1\\"',
],
visibility = ["//visibility:public"],
alwayslink = 1,
diff --git a/third_party/ngraph/ngraph_tf.BUILD b/third_party/ngraph/ngraph_tf.BUILD
index 979318d7c2..dbedca0a03 100644
--- a/third_party/ngraph/ngraph_tf.BUILD
+++ b/third_party/ngraph/ngraph_tf.BUILD
@@ -10,41 +10,42 @@ load(
cc_library(
name = "ngraph_tf",
srcs = [
- "src/ngraph_api.h",
"src/ngraph_api.cc",
- "src/ngraph_assign_clusters.h",
+ "src/ngraph_api.h",
"src/ngraph_assign_clusters.cc",
- "src/ngraph_builder.h",
+ "src/ngraph_assign_clusters.h",
"src/ngraph_builder.cc",
- "src/ngraph_capture_variables.h",
+ "src/ngraph_builder.h",
"src/ngraph_capture_variables.cc",
- "src/ngraph_conversions.h",
- "src/ngraph_cluster_manager.h",
+ "src/ngraph_capture_variables.h",
"src/ngraph_cluster_manager.cc",
- "src/ngraph_deassign_clusters.h",
+ "src/ngraph_cluster_manager.h",
+ "src/ngraph_conversions.h",
"src/ngraph_deassign_clusters.cc",
- "src/ngraph_encapsulate_op.cc",
- "src/ngraph_encapsulate_clusters.h",
+ "src/ngraph_deassign_clusters.h",
"src/ngraph_encapsulate_clusters.cc",
- "src/ngraph_freshness_tracker.h",
+ "src/ngraph_encapsulate_clusters.h",
+ "src/ngraph_encapsulate_op.cc",
"src/ngraph_freshness_tracker.cc",
- "src/ngraph_mark_for_clustering.h",
+ "src/ngraph_freshness_tracker.h",
"src/ngraph_mark_for_clustering.cc",
- "src/ngraph_rewrite_pass.cc",
- "src/ngraph_rewrite_for_tracking.h",
+ "src/ngraph_mark_for_clustering.h",
"src/ngraph_rewrite_for_tracking.cc",
+ "src/ngraph_rewrite_for_tracking.h",
+ "src/ngraph_rewrite_pass.cc",
"src/ngraph_tracked_variable.cc",
- "src/ngraph_utils.h",
"src/ngraph_utils.cc",
+ "src/ngraph_utils.h",
+ "src/ngraph_version_utils.h",
+ "src/tf_deadness_analysis.cc",
+ "src/tf_deadness_analysis.h",
"src/tf_graphcycles.cc",
+ "src/tf_graphcycles.h",
"logging/ngraph_log.h",
"logging/ngraph_log.cc",
"logging/tf_graph_writer.h",
"logging/tf_graph_writer.cc",
],
- hdrs = [
- "src/tf_graphcycles.h",
- ],
deps = [
"@org_tensorflow//tensorflow/core:protos_all_proto_text",
"@org_tensorflow//tensorflow/core:framework_headers_lib",
@@ -64,17 +65,19 @@ tf_cc_test(
name = "ngraph_tf_tests",
size = "small",
srcs = [
- "test/tf_exec.cpp",
"test/conversions.cpp",
- "test/padding.cpp",
"test/graph_rewrites/assign_clusters.cc",
- "test/test_utilities.h",
- "test/test_utilities.cpp",
+ "test/graph_rewrites/deadness_test.cc",
+ "test/main.cpp",
+ "test/opexecuter.cpp",
+ "test/opexecuter.h",
+ "test/padding.cpp",
+ "test/test_array_ops.cpp",
"test/test_math_ops.cpp",
"test/test_nn_ops.cpp",
- "test/opexecuter.h",
- "test/opexecuter.cpp",
- "test/main.cpp",
+ "test/test_utilities.cpp",
+ "test/test_utilities.h",
+ "test/tf_exec.cpp",
],
deps = [
":ngraph_tf",
diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl
index 3c7e5c8469..53264630a1 100644
--- a/third_party/py/python_configure.bzl
+++ b/third_party/py/python_configure.bzl
@@ -130,8 +130,8 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
# If we have only one file to link we do not want to use the dest_dir, as
# $(@D) will include the full path to the file.
dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i]
- # On Windows, symlink is not supported, so we just copy all the files.
- cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s'
+ # Copy the headers to create a sandboxable setup.
+ cmd = 'cp -f'
command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(src_dir, genrule_name, " && ".join(command),
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
index 7256a7d96e..bcbc4dda11 100644
--- a/third_party/toolchains/BUILD
+++ b/third_party/toolchains/BUILD
@@ -26,12 +26,10 @@ platform(
constraint_values = [
"@bazel_tools//platforms:x86_64",
"@bazel_tools//platforms:linux",
- "@bazel_tools//tools/cpp:clang",
- "@bazel_toolchains//constraints:xenial",
],
remote_execution_properties = """
properties: {
name: "container-image"
- value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:06b585f42eed3b2030e9566b8f88f48d7472fa0f47e59765bc115376c8801bdf"
+ value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:e5099ff15650986e268a43ee99e2d2b7ffe2459b8b6935385078d1d3b2ed4d02"
}""",
)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
index 2d3e41127d..05abcb56d8 100755
--- a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
@@ -1253,7 +1253,7 @@ genrule(
"cuda/lib/libcupti.so.9.0",
],
cmd = """
-if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.2.1" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0"
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0"
""",
)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
index a56b4513fb..6442e7628a 100755
--- a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
@@ -2,6 +2,20 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+toolchain(
+ name = "toolchain-linux-x86_64",
+ exec_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ target_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ toolchain = ":cc-compiler-local",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
cc_toolchain_suite(
name = "toolchain",
toolchains = {
diff --git a/tools/bazel.rc b/tools/bazel.rc
index ccf62629d1..0cd148ed87 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -24,12 +24,13 @@ build --define framework_shared_object=true
# Please note that MKL on MacOS or windows is still not supported.
# If you would like to use a local MKL instead of downloading, please set the
# environment variable "TF_MKL_ROOT" every time before build.
-build:mkl --define=using_mkl=true
+build:mkl --define=build_with_mkl=true --define=enable_mkl=true
build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
-build:mkl_open_source_only --define=using_mkl_dnn_only=true
+build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
+build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
build:download_clang --define=using_clang=true
@@ -42,6 +43,9 @@ build:download_clang_use_lld --linkopt='-fuse-ld=lld'
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
+build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
+build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
+
build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true
@@ -57,6 +61,11 @@ build:sycl_asan --define=using_sycl=true --define=using_trisycl=false --copt -fn
build:sycl_trisycl --crosstool_top=@local_config_sycl//crosstool:toolchain
build:sycl_trisycl --define=using_sycl=true --define=using_trisycl=true
+# Options extracted from configure script
+build:gdr --define=with_gdr_support=true
+build:ngraph --define=with_ngraph_support=true
+build:verbs --define=with_verbs_support=true
+
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
build --define=grpc_no_ares=true
@@ -64,6 +73,10 @@ build --define=grpc_no_ares=true
build --spawn_strategy=standalone
build --genrule_strategy=standalone
build -c opt
+build --define=with_jemalloc=false
+
+# Other build flags.
+build --define=grpc_no_ares=true
# Modular TF build options
build:dynamic_kernels --define=dynamic_loaded_kernels=true