aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-09-07 18:54:26 -0700
committerGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-09-07 18:54:26 -0700
commitd9a738d5fff96ecb6db62d67e049ab12202dcb42 (patch)
tree292539c9ca4036ea55ae4763d3029f32829c9722
parent18b80bbd4b8db8bd35afad7264258c1c5c269226 (diff)
parent3e1b06ee93d7a638db1fdd5f733d66064c1acf59 (diff)
Merge branch 'master' into avijit/add-cpu-backend
-rw-r--r--tensorflow/BUILD51
-rw-r--r--tensorflow/api_template.__init__.py22
-rw-r--r--tensorflow/c/BUILD1
-rw-r--r--tensorflow/c/c_api_experimental.cc210
-rw-r--r--tensorflow/c/c_api_experimental.h39
-rwxr-xr-xtensorflow/c/eager/c_api.cc13
-rwxr-xr-xtensorflow/c/eager/c_api.h6
-rw-r--r--tensorflow/c/eager/c_api_test.cc25
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.h1
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc6
-rw-r--r--tensorflow/compiler/jit/legacy_flags/BUILD12
-rw-r--r--tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc68
-rw-r--r--tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h52
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc2
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h1
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc6
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h8
-rw-r--r--tensorflow/compiler/tf2xla/BUILD13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc30
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc31
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.h2
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.cc18
-rw-r--r--tensorflow/compiler/tf2xla/side_effect_util.cc67
-rw-r--r--tensorflow/compiler/tf2xla/side_effect_util.h47
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h1
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc113
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h23
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc68
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h3
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h1
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc5
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i6
-rw-r--r--tensorflow/compiler/xla/service/BUILD32
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc241
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc7
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc55
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc16
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc111
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h16
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto26
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc91
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h47
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc188
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h61
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc147
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h92
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h20
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc59
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc112
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h13
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc10
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h5
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc14
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h8
-rw-r--r--tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc4
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions.py4
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions_test.py9
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py3
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf.py10
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py40
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/live_values.py7
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py43
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py8
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py9
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py104
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py22
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py56
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py13
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py144
-rw-r--r--tensorflow/contrib/data/python/ops/map_defun.py2
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py3
-rw-r--r--tensorflow/contrib/estimator/BUILD1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py14
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py41
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc2
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/lite/BUILD46
-rw-r--r--tensorflow/contrib/lite/allocation.cc4
-rw-r--r--tensorflow/contrib/lite/allocation.h4
-rw-r--r--tensorflow/contrib/lite/arena_planner.h6
-rw-r--r--tensorflow/contrib/lite/build_def.bzl21
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h292
-rw-r--r--tensorflow/contrib/lite/c/BUILD39
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h298
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data_test.cc83
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.c (renamed from tensorflow/contrib/lite/context.c)6
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.h491
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal_test.cc (renamed from tensorflow/contrib/lite/context_test.cc)10
-rw-r--r--tensorflow/contrib/lite/context.h478
-rw-r--r--tensorflow/contrib/lite/context_util.h2
-rw-r--r--tensorflow/contrib/lite/core/api/BUILD57
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter.cc38
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter.h45
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter_test.cc49
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc622
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.h48
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc104
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver.cc60
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver.h47
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver_test.cc197
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD5
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/BUILD2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h2
-rw-r--r--tensorflow/contrib/lite/error_reporter.h38
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD1
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc12
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.h13
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_internal.h6
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_test.cc4
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/BUILD3
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc2
-rw-r--r--tensorflow/contrib/lite/graph_info.h2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc4
-rw-r--r--tensorflow/contrib/lite/interpreter.h5
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc2
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD3
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h8
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h2
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD44
-rw-r--r--tensorflow/contrib/lite/kernels/activation_functor.h2
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h2
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/exp.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h2
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD46
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h129
-rw-r--r--tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h92
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h111
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h135
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h5
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/neg.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/padding.h2
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/register.h3
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/shape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/tile.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/tile_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc4
-rw-r--r--tensorflow/contrib/lite/memory_planner.h2
-rw-r--r--tensorflow/contrib/lite/mmap_allocation.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc636
-rw-r--r--tensorflow/contrib/lite/model.h5
-rw-r--r--tensorflow/contrib/lite/model_test.cc2
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.cc (renamed from tensorflow/contrib/lite/op_resolver.cc)3
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.h79
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver_test.cc (renamed from tensorflow/contrib/lite/op_resolver_test.cc)2
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc4
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h4
-rw-r--r--tensorflow/contrib/lite/op_resolver.h78
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.h2
-rw-r--r--tensorflow/contrib/lite/stderr_reporter.cc (renamed from tensorflow/contrib/lite/error_reporter.cc)22
-rw-r--r--tensorflow/contrib/lite/stderr_reporter.h34
-rw-r--r--tensorflow/contrib/lite/string_util.cc2
-rw-r--r--tensorflow/contrib/lite/string_util.h2
-rw-r--r--tensorflow/contrib/lite/string_util_test.cc2
-rw-r--r--tensorflow/contrib/lite/testing/BUILD2
-rw-r--r--tensorflow/contrib/lite/testing/util.h2
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc19
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h5
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD1
-rw-r--r--tensorflow/contrib/lite/tools/make/Makefile108
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc1
-rw-r--r--tensorflow/contrib/lite/tutorials/BUILD20
-rw-r--r--tensorflow/contrib/lite/tutorials/dataset.py122
-rw-r--r--tensorflow/contrib/lite/tutorials/mnist_tflite.py87
-rw-r--r--tensorflow/contrib/lite/util.h2
-rw-r--r--tensorflow/contrib/lite/util_test.cc2
-rw-r--r--tensorflow/contrib/makefile/proto_text_cc_files.txt1
-rw-r--r--tensorflow/contrib/opt/BUILD2
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py91
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py240
-rw-r--r--tensorflow/contrib/quantize/BUILD3
-rw-r--r--tensorflow/contrib/quantize/python/common.py26
-rw-r--r--tensorflow/contrib/quantize/python/common_test.py25
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py25
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py5
-rw-r--r--tensorflow/contrib/rnn/BUILD8
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py20
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py6
-rw-r--r--tensorflow/contrib/tpu/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py10
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py15
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc2
-rw-r--r--tensorflow/core/BUILD17
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt16
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt15
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt15
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt2
-rw-r--r--tensorflow/core/common_runtime/function.cc40
-rw-r--r--tensorflow/core/common_runtime/function_test.cc22
-rw-r--r--tensorflow/core/common_runtime/tracing_device.h5
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc2
-rw-r--r--tensorflow/core/framework/dataset_stateful_op_whitelist.h22
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc50
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc55
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc7
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc9
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h3
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc46
-rw-r--r--tensorflow/core/kernels/BUILD43
-rw-r--r--tensorflow/core/kernels/conditional_accumulator.h6
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.cc13
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.h3
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base_op.h3
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_op.cc3
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc98
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc594
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc43
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc4
-rw-r--r--tensorflow/core/kernels/eigen_benchmark_cpu_test.cc31
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc26
-rw-r--r--tensorflow/core/kernels/map_stage_op.cc10
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc8
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc28
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc1
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc1
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc130
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc49
-rw-r--r--tensorflow/core/kernels/regex_full_match_op.cc33
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator.h4
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator_op.cc4
-rw-r--r--tensorflow/core/kernels/typed_conditional_accumulator_base.h5
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc54
-rw-r--r--tensorflow/core/lib/core/stringpiece.h117
-rw-r--r--tensorflow/core/lib/io/record_writer.h2
-rw-r--r--tensorflow/core/lib/strings/strcat.h3
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt240
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc2
-rw-r--r--tensorflow/core/ops/dataset_ops.cc20
-rw-r--r--tensorflow/core/ops/image_ops.cc15
-rw-r--r--tensorflow/core/ops/ops.pbtxt96
-rw-r--r--tensorflow/core/ops/string_ops.cc6
-rw-r--r--tensorflow/core/platform/default/device_tracer.cc5
-rw-r--r--tensorflow/core/platform/tracing.h4
-rw-r--r--tensorflow/go/op/wrappers.go1078
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py167
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py32
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py56
-rw-r--r--tensorflow/python/data/util/nest.py33
-rw-r--r--tensorflow/python/eager/backprop.py77
-rw-r--r--tensorflow/python/eager/backprop_test.py66
-rw-r--r--tensorflow/python/eager/benchmarks_test.py5
-rw-r--r--tensorflow/python/eager/function.py2
-rw-r--r--tensorflow/python/eager/function_test.py9
-rw-r--r--tensorflow/python/eager/imperative_grad.py10
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc45
-rwxr-xr-xtensorflow/python/eager/pywrap_tfe.h25
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc69
-rw-r--r--tensorflow/python/eager/tape.py18
-rw-r--r--tensorflow/python/eager/tensor_test.py14
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py12
-rw-r--r--tensorflow/python/estimator/keras_test.py102
-rw-r--r--tensorflow/python/feature_column/BUILD1
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py16
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py15
-rw-r--r--tensorflow/python/framework/constant_op.py3
-rw-r--r--tensorflow/python/framework/tensor_shape.py4
-rw-r--r--tensorflow/python/framework/test_util.py38
-rwxr-xr-xtensorflow/python/keras/BUILD14
-rw-r--r--tensorflow/python/keras/engine/feature_columns_integration_test.py237
-rw-r--r--tensorflow/python/keras/engine/training.py148
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py5
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py57
-rw-r--r--tensorflow/python/keras/engine/training_eager.py23
-rw-r--r--tensorflow/python/keras/engine/training_test.py74
-rw-r--r--tensorflow/python/keras/engine/training_utils.py141
-rw-r--r--tensorflow/python/keras/engine/training_utils_test.py89
-rw-r--r--tensorflow/python/keras/metrics.py47
-rw-r--r--tensorflow/python/keras/metrics_test.py4
-rw-r--r--tensorflow/python/keras/models.py10
-rw-r--r--tensorflow/python/keras/models_test.py54
-rw-r--r--tensorflow/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py9
-rw-r--r--tensorflow/python/kernel_tests/conditional_accumulator_test.py88
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py21
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py35
-rw-r--r--tensorflow/python/kernel_tests/regex_full_match_op_test.py60
-rw-r--r--tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py83
-rw-r--r--tensorflow/python/lib/io/py_record_reader.cc2
-rw-r--r--tensorflow/python/lib/io/py_record_writer.cc6
-rw-r--r--tensorflow/python/lib/io/py_record_writer.h5
-rw-r--r--tensorflow/python/lib/io/py_record_writer.i22
-rw-r--r--tensorflow/python/lib/io/tf_record.py108
-rw-r--r--tensorflow/python/lib/io/tf_record_test.py107
-rw-r--r--tensorflow/python/ops/array_ops.py2
-rw-r--r--tensorflow/python/ops/clip_ops.py6
-rw-r--r--tensorflow/python/ops/data_flow_ops.py20
-rw-r--r--tensorflow/python/ops/image_ops_impl.py2
-rw-r--r--tensorflow/python/ops/image_ops_test.py41
-rw-r--r--tensorflow/python/ops/math_ops.py48
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py11
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py14
-rw-r--r--tensorflow/python/ops/string_ops.py34
-rwxr-xr-xtensorflow/python/pywrap_tfe.i4
-rw-r--r--tensorflow/python/saved_model/BUILD1
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py79
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py38
-rw-r--r--tensorflow/python/tools/BUILD1
-rw-r--r--tensorflow/python/tools/api/generator/api_gen.bzl34
-rw-r--r--tensorflow/python/tools/saved_model_cli.py2
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py6
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py37
-rw-r--r--tensorflow/python/training/checkpointable/base.py66
-rw-r--r--tensorflow/python/training/checkpointable/util.py192
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py40
-rw-r--r--tensorflow/python/util/util.i27
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc20
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt2
-rw-r--r--tensorflow/tools/api/tests/BUILD5
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py14
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu43
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh14
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh1
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh1
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/common_env.sh2
-rw-r--r--tensorflow/tools/dockerfiles/README.md4
-rw-r--r--tensorflow/tools/docs/generate_lib.py29
-rw-r--r--tensorflow/tools/docs/parser.py90
-rw-r--r--tensorflow/tools/docs/parser_test.py86
-rw-r--r--tensorflow/tools/docs/pretty_docs.py3
-rw-r--r--tensorflow/tools/pip_package/setup.py20
-rwxr-xr-xtensorflow/workspace.bzl16
-rw-r--r--third_party/gpus/cuda/remote.BUILD.tpl5
-rw-r--r--third_party/llvm/llvm.autogenerated.BUILD11
-rw-r--r--third_party/nccl/BUILD0
-rw-r--r--third_party/nccl/nccl_configure.bzl35
-rw-r--r--third_party/nccl/remote.BUILD.tpl6
-rw-r--r--third_party/nccl/system.BUILD.tpl26
474 files changed, 10937 insertions, 4827 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 661cba5ff0..386e0096ff 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -12,6 +12,7 @@ exports_files([
# The leakr files are used by //third_party/cloud_tpu.
"leakr_badwords.dic",
"leakr_badfiles.dic",
+ "leakr_file_type_recipe.ftrcp",
])
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
@@ -23,6 +24,11 @@ load(
"//tensorflow/python/tools/api/generator:api_gen.bzl",
"gen_api_init_files", # @unused
)
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files")
+load(
+ "//tensorflow/python/tools/api/generator:api_init_files.bzl",
+ "TENSORFLOW_API_INIT_FILES", # @unused
+)
load(
"//tensorflow/python/tools/api/generator:api_init_files_v1.bzl",
"TENSORFLOW_API_INIT_FILES_V1", # @unused
@@ -32,6 +38,11 @@ load(
"if_ngraph",
)
+# @unused
+TENSORFLOW_API_INIT_FILES_V2 = (
+ TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1)
+)
+
# Config setting used when building for products
# which requires restricted licenses to be avoided.
config_setting(
@@ -427,6 +438,13 @@ config_setting(
visibility = ["//visibility:public"],
)
+# This flag specifies whether TensorFlow 2.0 API should be built instead
+# of 1.* API. Note that TensorFlow 2.0 API is currently under development.
+config_setting(
+ name = "api_version_2",
+ define_values = {"tf_api_version": "2"},
+)
+
package_group(
name = "internal",
packages = [
@@ -591,13 +609,39 @@ exports_files(
)
gen_api_init_files(
- name = "tensorflow_python_api_gen",
+ name = "tf_python_api_gen_v1",
srcs = ["api_template.__init__.py"],
api_version = 1,
+ output_dir = "_api/v1/",
output_files = TENSORFLOW_API_INIT_FILES_V1,
+ output_package = "tensorflow._api.v1",
+ root_init_template = "api_template.__init__.py",
+)
+
+gen_api_init_files(
+ name = "tf_python_api_gen_v2",
+ srcs = ["api_template.__init__.py"],
+ api_version = 2,
+ compat_api_versions = [1],
+ output_dir = "_api/v2/",
+ output_files = TENSORFLOW_API_INIT_FILES_V2,
+ output_package = "tensorflow._api.v2",
root_init_template = "api_template.__init__.py",
)
+genrule(
+ name = "root_init_gen",
+ srcs = select({
+ "api_version_2": [":tf_python_api_gen_v2"],
+ "//conditions:default": [":tf_python_api_gen_v1"],
+ }),
+ outs = ["__init__.py"],
+ cmd = select({
+ "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
+ "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+ }),
+)
+
py_library(
name = "tensorflow_py",
srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
@@ -612,7 +656,10 @@ py_library(
py_library(
name = "tensorflow_py_no_contrib",
- srcs = [":tensorflow_python_api_gen"],
+ srcs = select({
+ "api_version_2": [":tf_python_api_gen_v2"],
+ "//conditions:default": [":tf_python_api_gen_v1"],
+ }) + [":root_init_gen"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python:no_contrib"],
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 779f65d5b1..53a72b8443 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -18,11 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os as _os
+
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
try:
- import os # pylint: disable=g-import-not-at-top
# Add `estimator` attribute to allow access to estimator APIs via
# "tf.estimator..."
from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top
@@ -30,9 +31,8 @@ try:
# Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
# style imports.
from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top
- __path__ += [os.path.dirname(estimator_api.__file__)]
+ __path__ += [_os.path.dirname(estimator_api.__file__)]
del estimator_api
- del os
except (ImportError, AttributeError):
print('tf.estimator package not installed.')
@@ -45,6 +45,12 @@ del LazyLoader
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
+# Make sure directory containing top level submodules is in
+# the __path__ so that "from tensorflow.foo import bar" works.
+_tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disable=undefined-variable
+if _tf_api_dir not in __path__:
+ __path__.append(_tf_api_dir)
+
del absolute_import
del division
del print_function
@@ -54,6 +60,12 @@ del print_function
# must come from this module. So python adds these symbols for the
# resolution to succeed.
# pylint: disable=undefined-variable
-del python
-del core
+try:
+ del python
+ del core
+except NameError:
+ # Don't fail if these modules are not available.
+ # For e.g. we are using this file for compat.v1 module as well and
+ # 'python', 'core' directories are not under compat/v1.
+ pass
# pylint: enable=undefined-variable
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 109b3b37aa..43c279bd80 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -204,6 +204,7 @@ tf_cuda_cc_test(
"//tensorflow:darwin": ["-headerpad_max_install_names"],
"//conditions:default": [],
}),
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 69b3ffe2a1..c046bd66cd 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -79,6 +79,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
auto* gpu_options = config.mutable_gpu_options();
gpu_options->set_allow_growth(gpu_memory_allow_growth);
+ // TODO(b/113217601): This is needed for EagerContext::runner_ to use a
+ // threadpool, so that we avoid the possibility of running the runner_ in the
+ // threadpool of GPU event mgr, as that can trigger more callbacks to be
+ // scheduled on that same threadpool, causing a deadlock in cases where the
+ // caller of event_mgr->ThenExecute() blocks on the completion of the callback
+ // (as in the case of ConstOp kernel creation on GPU, which involves copying a
+ // CPU tensor to GPU).
+ // Setting a larger thread pool does not help with the Swift caller, as we use
+ // a different TFE context for each thread of execution (for running graph
+ // functions, and their send/recvs corountines).
+ config.set_inter_op_parallelism_threads(1);
+
TF_Buffer* ret = TF_NewBuffer();
TF_CHECK_OK(MessageToBuffer(config, ret));
return ret;
@@ -8494,3 +8506,201 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
/*run_metadata*/ nullptr, status);
VLOG(1) << "Enqueuing is done.";
}
+
+TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
+ TF_Status* status) {
+ auto* opts = TFE_NewContextOptions();
+
+ // Reduce GPU memory allocation, and set appropriate config options for TFE
+ // context.
+ auto* config =
+ TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true);
+ TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
+ if (!status->status.ok()) {
+ CHECK(!config);
+ TFE_DeleteContextOptions(opts);
+ return nullptr;
+ }
+
+ auto* ctx = TFE_NewContextFromSession(opts, session, status);
+ TF_DeleteBuffer(config);
+ TFE_DeleteContextOptions(opts);
+ return ctx;
+}
+
+// TODO: retrieve the device string via TFE_ContextListDevices()
+static const char DEFAULT_CPU_DEVICE[] =
+ "/job:localhost/replica:0/task:0/device:CPU:0";
+
+static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
+ int tensor_id, TF_Status* status) {
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
+ TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
+ TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return nullptr;
+ // TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
+ TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
+ TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
+ auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
+ TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
+ shared_name.size());
+ TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
+
+ // TODO: consider making this an unknown shape.
+ const int64_t* dims_ptr = nullptr;
+ int num_dims = 0;
+ TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
+ /*num_values*/ 0, status);
+ if (!status->status.ok()) return nullptr;
+
+ int num_retvals = 1;
+ TFE_TensorHandle* queue = nullptr;
+ TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
+ if (!status->status.ok()) return nullptr;
+ CHECK_EQ(num_retvals, 1);
+
+ return queue;
+}
+
+static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
+ TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
+ TF_Status* status) {
+ TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
+ TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return;
+ TFE_OpAddInput(op, queue, status);
+ if (!status->status.ok()) return;
+ TFE_OpAddInput(op, tensor, status);
+ if (!status->status.ok()) return;
+ TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
+ TFE_OpSetAttrInt(op, "timeout_ms", -1);
+
+ int num_retvals = 0;
+ TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
+ if (!status->status.ok()) return;
+ CHECK_EQ(num_retvals, 0);
+}
+
+static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
+ TF_DataType inputType,
+ TFE_TensorHandle* queue,
+ TF_Status* status) {
+ TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
+ TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
+ if (!status->status.ok()) return nullptr;
+
+ TFE_OpAddInput(op, queue, status);
+ if (!status->status.ok()) return nullptr;
+ TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
+ TFE_OpSetAttrInt(op, "timeout_ms", -1);
+ TFE_TensorHandle* ret;
+ int num_retvals = 1;
+ TFE_Execute(op, &ret, &num_retvals, status);
+ if (!status->status.ok()) return nullptr;
+ CHECK_EQ(num_retvals, 1);
+ return ret;
+}
+
+TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
+ TF_DataType inputType,
+ TF_Status* status) {
+ assert(session);
+ VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ auto* ret = createTFEDequeue(ctx, inputType, queue, status);
+ return ret;
+}
+
+TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
+ TF_DataType inputType,
+ TF_Status* status) {
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ auto* ret = createTFEDequeue(ctx, inputType, queue, status);
+
+ return ret;
+}
+
+void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
+ TFE_TensorHandle* tensor, TF_Status* status) {
+ assert(session);
+ VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TF_DataType inputType = TFE_TensorHandleDataType(tensor);
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, inputType, queue, tensor, status);
+}
+
+void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
+ TFE_TensorHandle* tensor,
+ TF_Status* status) {
+ VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
+
+ TF_DataType inputType = TFE_TensorHandleDataType(tensor);
+ TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, inputType, queue, tensor, status);
+}
+
+void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
+ TFE_TensorHandle* tensor, TF_Status* status) {
+ VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
+ if (!status->status.ok()) return;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
+}
+
+TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
+ TF_Status* status) {
+ VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
+
+ auto ctx = TFE_CreateContextFromSession(session, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
+ ctx, TFE_DeleteContext);
+
+ TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
+ if (!status->status.ok()) return nullptr;
+ std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
+ queue_deleter(queue, TFE_DeleteTensorHandle);
+
+ return createTFEDequeue(ctx, TF_VARIANT, queue, status);
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 09d482d6df..522c91f67e 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -132,9 +132,48 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
TF_Tensor* tensor,
TF_Status* status);
+// TODO: remove this API in favor of the next one.
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
+// Creates from `session` a new eager context to run a graph function or
+// sends/recvs, so that these concurrent TFE executions can share (via
+// `session` and its associated device mgr) the same set of fifo queue resource
+// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
+// graph function execution can access the same fifo queue resource handles
+// (associated with devices managed by the device manager, which can be obtained
+// from `session`).
+//
+// TODO: Remove this function once we migrate away from using session.
+TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
+ TF_Session* session, TF_Status* status);
+
+// TODO: Retire this API in favor of the next one.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
+ TF_Session* session, int tensor_id, TF_DataType inputType,
+ TF_Status* status);
+
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
+ TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
+
+TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
+ int tensor_id,
+ TFE_TensorHandle* tensor,
+ TF_Status* status);
+
+TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
+ TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
+ TF_Status* status);
+
+// TODO: consider folding the 2 APIs below into the ones above.
+TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
+ int tensor_id,
+ TFE_TensorHandle* tensor,
+ TF_Status* status);
+
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
+ TF_Session* session, int tensor_id, TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 77e3878a94..349d9bcd7c 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -399,6 +399,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
: d->name().c_str();
}
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
+ TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
+
+ h->handle->Ref();
+
+ return new TFE_TensorHandle(h->handle);
+}
+
TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index eec2750d6e..337447eec9 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -171,6 +171,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
+// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
+// with `h`. On success, `status` is set to OK. On failure, `status` reflects
+// the error and a nullptr is returned.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
+ TFE_TensorHandle* h, TF_Status* status);
+
// This function will block till the operation that produces `h` has
// completed. The memory returned might alias the internal memory used by
// TensorFlow. Hence, callers should not mutate this memory (for example by
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 7126227cf5..55331022b9 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -1528,4 +1528,29 @@ TEST(CAPI, StringAttributes) {
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
}
+
+TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) {
+ TFE_TensorHandle* h = TestMatrixTensorHandle();
+ EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
+
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+
+ TFE_TensorHandle* h_shares_tensor =
+ TFE_TensorHandleCopySharingTensor(h, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get());
+ ASSERT_EQ(16, TF_TensorByteSize(t));
+ float data[4] = {0};
+ memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t));
+ EXPECT_EQ(1.0, data[0]);
+ EXPECT_EQ(2.0, data[1]);
+ EXPECT_EQ(3.0, data[2]);
+ EXPECT_EQ(4.0, data[3]);
+ TF_DeleteTensor(t);
+
+ TFE_DeleteTensorHandle(h);
+ TFE_DeleteTensorHandle(h_shares_tensor);
+}
} // namespace
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h
index cf5c04ac4b..bd270045e3 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.h
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h
@@ -20,6 +20,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
+#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/platform/protobuf.h"
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index b95b063348..1c9d30d7b0 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
@@ -92,9 +93,8 @@ Status Main(const MainFlags& flags) {
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
- TF_RETURN_IF_ERROR(
- WriteStringToFile(env, flags.out_function_object,
- absl::string_view(obj.data(), obj.size())));
+ TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object,
+ StringPiece(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD
index 5b6692f523..07c5b23188 100644
--- a/tensorflow/compiler/jit/legacy_flags/BUILD
+++ b/tensorflow/compiler/jit/legacy_flags/BUILD
@@ -29,18 +29,6 @@ cc_library(
)
cc_library(
- name = "parallel_check_op_flags",
- srcs = ["parallel_check_op_flags.cc"],
- hdrs = ["parallel_check_op_flags.h"],
- deps =
- [
- "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "xla_device_flags",
srcs = ["xla_device_flags.cc"],
hdrs = ["xla_device_flags.h"],
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
deleted file mode 100644
index a61694b494..0000000000
--- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
+++ /dev/null
@@ -1,68 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for the XLA bridge's parallel_check_op module.
-
-#include <mutex>
-#include <vector>
-
-#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static ParallelCheckOpFlags* flags;
-static std::vector<Flag>* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new ParallelCheckOpFlags;
- flags->parallel_check_failfast = true;
- flags->parallel_check_atol = "1e-5";
- flags->parallel_check_rtol = "1e-5";
- flag_list = new std::vector<Flag>({
- Flag("parallel_check_failfast", &flags->parallel_check_failfast,
- "Fail immediately on first parallel-check comparison error."),
- Flag("parallel_check_atol", &flags->parallel_check_atol,
- "Absolute error tolerance for parallel-check comparison."),
- Flag("parallel_check_rtol", &flags->parallel_check_rtol,
- "Relative error tolerance for parallel-check comparison."),
- });
- xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with the XLA bridge's
-// parallel_check_op module.
-void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the ParallelCheckOpFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-ParallelCheckOpFlags* GetParallelCheckOpFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
deleted file mode 100644
index 156a2a2a71..0000000000
--- a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
-
-// Legacy flags for the XLA bridge's parallel_check_op module.
-
-#include <vector>
-
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace tensorflow {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with the XLA bridge's
-// parallel_check_op module.
-void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list);
-
-// The values of flags associated with the XLA bridge's
-// parallel_check_op module.
-typedef struct {
- bool parallel_check_failfast; // Fail immediately on first parallel-check
- // comparison error.
- string parallel_check_atol; // Absolute error tolerance for parallel-check
- // comparison.
- string parallel_check_rtol; // Relative error tolerance for parallel-check
- // comparison.
-} ParallelCheckOpFlags;
-
-// Return a pointer to the ParallelCheckOpFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-ParallelCheckOpFlags* GetParallelCheckOpFlags();
-
-} // namespace legacy_flags
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 9473ac0a4c..807ab51fd3 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -633,7 +633,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
Scope root = Scope::NewRootScope().ExitOnError();
{
- auto BuildNoopNode = [](absl::string_view name, Graph* graph) {
+ auto BuildNoopNode = [](StringPiece name, Graph* graph) {
NodeDefBuilder builder(name, "NoOp");
NodeDef def;
TF_CHECK_OK(builder.Finalize(&def));
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index 17ae510a0e..debd9038c7 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
+#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/graph/algorithm.h"
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index af83c792e5..6d4160a968 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -339,11 +339,11 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
}
void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
- absl::string_view tensor_name,
+ StringPiece tensor_name,
Device* device, Tensor* cpu_tensor,
StatusCallback done) {
- manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
- done);
+ manager_.CopyDeviceTensorToCPU(device_tensor, absl::string_view(tensor_name),
+ device, cpu_tensor, done);
}
void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index df82421294..1effd6628f 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
namespace tensorflow {
@@ -110,9 +111,12 @@ class XlaDeviceContext : public DeviceContext {
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
StatusCallback done) const override;
+ // TODO(rlahaye): Replace StringPiece with absl::string_view when the
+ // StringPiece->absl::string_view change is rolled forward.
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
- absl::string_view tensor_name, Device* device,
- Tensor* cpu_tensor, StatusCallback done) override;
+ StringPiece tensor_name, // non-ABSL OK
+ Device* device, Tensor* cpu_tensor,
+ StatusCallback done) override;
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 22be7f048f..3821dced63 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -191,6 +191,7 @@ cc_library(
":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
+ ":side_effect_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal",
@@ -214,6 +215,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,
@@ -359,6 +361,7 @@ tf_cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
deps = [
+ ":side_effect_util",
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
@@ -370,6 +373,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:core_cpu_internal",
@@ -631,3 +635,12 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
+
+cc_library(
+ name = "side_effect_util",
+ srcs = ["side_effect_util.cc"],
+ hdrs = ["side_effect_util.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ ],
+)
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 4c776fb178..46794f7b50 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -115,9 +115,6 @@ tf_kernel_library(
deps = [
":if_op",
":while_op",
- "@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
@@ -168,14 +165,11 @@ tf_kernel_library(
"//tensorflow/core/kernels:sparse_to_dense_op",
"//tensorflow/core/kernels:stack_ops",
"//tensorflow/core/kernels:training_ops",
- ] + if_mkl(
- [
- "//tensorflow/core/kernels:mkl_transpose_op",
- ],
- [
- "//tensorflow/core/kernels:transpose_op",
- ],
- ),
+ "//tensorflow/core/kernels:transpose_op",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
)
tf_kernel_library(
@@ -184,6 +178,7 @@ tf_kernel_library(
hdrs = ["while_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
@@ -201,6 +196,7 @@ tf_kernel_library(
hdrs = ["if_op.h"],
deps = [
"//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:side_effect_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:literal",
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 6e1dbf5472..56da50f140 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/if_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -33,6 +34,11 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcond", &cond_type_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
// TODO(b/35949885): There is duplication here with the handling of the
@@ -90,6 +96,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
options.resolve_compile_time_constants = false;
options.return_updated_values_for_all_resources = true;
options.is_entry_computation = false;
+ options.add_token_input_output = has_token_input_output_;
XlaCompiler* compiler = ctx->compiler();
XlaCompiler::CompilationResult then_result;
@@ -191,7 +198,16 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = then_result.input_mapping[i] + 1;
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "if" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(b, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
@@ -219,6 +235,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
}
ctx->SetOutput(i, output_handle);
}
+ if (has_token_input_output_) {
+ // Set token output for this "if" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(outputs, output_types_.size());
+ auto shape_or = b->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the conditional
// bodies.
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.h b/tensorflow/compiler/tf2xla/kernels/if_op.h
index f9bc98a198..7783e13a8a 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.h
@@ -52,6 +52,8 @@ class XlaIfOp : public XlaOpKernel {
DataType cond_type_;
DataTypeVector input_types_;
DataTypeVector output_types_;
+ bool has_token_input_output_;
+ std::vector<string> token_input_nodes_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 296518229e..559414eeaa 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/while_op.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
@@ -90,6 +91,11 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
cond_name_attr_ = *name_attr;
OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr));
body_name_attr_ = *name_attr;
+ if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
+ has_token_input_output_ = false;
+ } else {
+ has_token_input_output_ = !token_input_nodes_.empty();
+ }
}
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
@@ -120,6 +126,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
body_options.return_updated_values_for_all_resources = true;
body_options.resolve_compile_time_constants = false;
body_options.is_entry_computation = false;
+ body_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult body;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
@@ -192,6 +199,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
cond_options.use_tuple_arg = true;
cond_options.resolve_compile_time_constants = false;
cond_options.is_entry_computation = false;
+ cond_options.add_token_input_output = has_token_input_output_;
XlaCompiler::CompilationResult cond;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
@@ -238,7 +246,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
std::vector<xla::XlaOp> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
- if (ctx->input_type(input_num) == DT_RESOURCE) {
+ if (has_token_input_output_ && i == num_inputs - 1) {
+ // Set token input for this "while" op.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const string& node_name : token_input_nodes_) {
+ auto token_or = compiler->GetNodeToken(node_name);
+ OP_REQUIRES_OK(ctx, token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ inputs[i] = xla::AfterAll(builder, token_inputs);
+ } else if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
@@ -273,6 +290,18 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla::GetTupleElement(while_result, i));
}
}
+ if (has_token_input_output_) {
+ // Set token output for this "while" op.
+ xla::XlaOp token_output =
+ xla::GetTupleElement(while_result, ctx->num_outputs());
+ auto shape_or = builder->GetShape(token_output);
+ OP_REQUIRES_OK(ctx, shape_or.status());
+ OP_REQUIRES(ctx, xla::ShapeUtil::IsToken(shape_or.ValueOrDie()),
+ errors::FailedPrecondition(
+ "Token output is not token type: ",
+ xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
+ OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
+ }
// Updates the values of any resource variables modified by the loop.
for (int i = 0; i < body.resource_updates.size(); ++i) {
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h
index 67edebabf9..aeeff40e68 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.h
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.h
@@ -56,6 +56,8 @@ class XlaWhileOp : public XlaOpKernel {
private:
NameAttrList cond_name_attr_;
NameAttrList body_name_attr_;
+ bool has_token_input_output_;
+ std::vector<string> token_input_nodes_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp);
};
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
index 20f2ce2919..92577b5bc8 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "absl/algorithm/container.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
@@ -30,11 +31,10 @@ namespace tensorflow {
}
}
-static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
-CreateResourceOpInfoMap() {
- auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
+static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() {
+ auto* result = new gtl::FlatMap<StringPiece, XlaResourceOpInfo>;
- auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
+ auto add = [&](StringPiece op, XlaResourceOpKind op_kind,
XlaResourceKind resource_kind) {
auto insert_result =
result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)});
@@ -103,17 +103,17 @@ CreateResourceOpInfoMap() {
return result;
}
-static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
+static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap() {
- static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
+ static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map =
CreateResourceOpInfoMap();
return *op_info_map;
}
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
- const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
+ const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos =
GetStaticResourceOpInfoMap();
- auto it = op_infos.find(op);
+ auto it = op_infos.find(StringPiece(op.data(), op.length()));
return it == op_infos.end() ? nullptr : &it->second;
}
@@ -121,7 +121,7 @@ namespace resource_op_table_internal {
std::vector<absl::string_view> GetKnownResourceOps() {
std::vector<absl::string_view> result;
for (const auto& p : GetStaticResourceOpInfoMap()) {
- result.push_back(p.first);
+ result.push_back(absl::string_view(p.first));
}
absl::c_sort(result);
return result;
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc
new file mode 100644
index 0000000000..6cd7b24592
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.cc
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
+
+#include "tensorflow/core/graph/algorithm.h"
+
+namespace tensorflow {
+
+const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
+
+const char kXlaTokenArgNodeName[] = "_xla_token_arg_node";
+
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g) {
+ std::set<std::string> results;
+ Node* first_side_effecting_node_on_path = nullptr;
+ ReverseDFS(g,
+ [&](Node* n) {
+ std::vector<string> token_input_nodes;
+ if (!GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName,
+ &token_input_nodes)
+ .ok() ||
+ token_input_nodes.empty()) {
+ return;
+ }
+
+ if (first_side_effecting_node_on_path != nullptr) {
+ return;
+ }
+
+ first_side_effecting_node_on_path = n;
+ results.insert(n->name());
+ },
+ [&](Node* n) {
+ if (first_side_effecting_node_on_path == n) {
+ first_side_effecting_node_on_path = nullptr;
+ }
+ },
+ NodeComparatorName());
+ return results;
+}
+
+bool HasSideEffectingNodes(const Graph& g) {
+ for (Node* n : g.nodes()) {
+ std::vector<string> token_input_nodes;
+ if (GetNodeAttr(n->attrs(), kXlaTokenInputNodesAttrName, &token_input_nodes)
+ .ok() &&
+ !token_input_nodes.empty()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h
new file mode 100644
index 0000000000..ad07624729
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/side_effect_util.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+
+// Side-effecting nodes will have this attribute set. Its value is the list of
+// node names which this node has side-effect dependencies on.
+//
+// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute,
+// because they always have side-effect.
+// If and While nodes may or may not have this attribute, depending on whether
+// their bodies have side-effecting nodes.
+extern const char kXlaTokenInputNodesAttrName[];
+
+// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a
+// node has side-effect dependency on current graph's token input.
+extern const char kXlaTokenArgNodeName[];
+
+// Calculates side-effect dependencies for the graph's token output.
+// Returns a set of node names representing these dependencies.
+std::set<std::string> CalculateTokenInputsForOutputToken(const Graph& g);
+
+// Returns whether a graph contains side-effecting nodes.
+bool HasSideEffectingNodes(const Graph& g);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index a29e764466..dcddef8418 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 41d305d461..dcb455779d 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
@@ -291,6 +292,10 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
"Invalid resource type in XLAShapeForArgument()");
}
}
+ case XlaCompiler::Argument::kToken: {
+ *xla_shape = xla::ShapeUtil::MakeTokenShape();
+ return Status::OK();
+ }
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
}
@@ -489,7 +494,8 @@ Status XlaCompiler::BuildArguments(
}
break;
- case XlaCompiler::Argument::kParameter: {
+ case XlaCompiler::Argument::kParameter:
+ case XlaCompiler::Argument::kToken: {
input_mapping->push_back(i);
break;
}
@@ -616,6 +622,10 @@ Status XlaCompiler::BuildArguments(
arg_expression.set_handle(arg_handles[i]);
}
break;
+ case XlaCompiler::Argument::kToken: {
+ arg_expression.set_handle(arg_handles[i]);
+ break;
+ }
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
return errors::Internal(
@@ -757,23 +767,71 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
&options_.shape_representation_fn);
core::ScopedUnref context_unref(context);
+ std::vector<XlaCompiler::Argument> real_args(args);
+ int token_input_index = -1;
+ if (options.add_token_input_output) {
+ // Add extra token input.
+ token_input_index = real_args.size();
+
+ XlaCompiler::Argument token_arg;
+ token_arg.kind = XlaCompiler::Argument::kToken;
+ real_args.push_back(token_arg);
+ }
+
std::vector<XlaExpression> arg_expressions;
std::vector<int> arg_cores;
- TF_RETURN_IF_ERROR(
- BuildArguments(*graph, args, options.use_tuple_arg, &builder, context,
- &arg_cores, &arg_expressions, &result->input_mapping,
- &result->xla_input_shapes, options.is_entry_computation));
+ TF_RETURN_IF_ERROR(BuildArguments(
+ *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
+ &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
+ options.is_entry_computation));
context->set_args(std::move(arg_expressions));
+ PushNodeTokenMapping();
+ // Use std::set instead of std::unordered_set to ensure determinism.
+ std::set<std::string> output_node_token_inputs;
+ if (token_input_index != -1) {
+ // Original token comes from input.
+ auto arg_expression = context->args()[token_input_index];
+ TF_RETURN_IF_ERROR(
+ SetNodeToken(kXlaTokenArgNodeName, arg_expression.handle()));
+
+ // Calculate token inputs for output token.
+ output_node_token_inputs = CalculateTokenInputsForOutputToken(*graph);
+
+ // If there's no side-effecting op in the graph, use token input as token
+ // output.
+ if (output_node_token_inputs.empty()) {
+ output_node_token_inputs.insert(kXlaTokenArgNodeName);
+ }
+ } else if (options.is_entry_computation) {
+ // Original token is manually created.
+ if (HasSideEffectingNodes(*graph)) {
+ TF_RETURN_IF_ERROR(
+ SetNodeToken(kXlaTokenArgNodeName, xla::CreateToken(&builder)));
+ }
+ }
+
TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_,
flib_runtime_, NextStepId()));
+ if (token_input_index != -1) {
+ // Add extra token output.
+ std::vector<xla::XlaOp> token_inputs;
+ for (const auto& node_name : output_node_token_inputs) {
+ auto token_or = GetNodeToken(node_name);
+ TF_RETURN_IF_ERROR(token_or.status());
+ token_inputs.push_back(token_or.ValueOrDie());
+ }
+ TF_RETURN_IF_ERROR(
+ context->AppendTokenRetval(xla::AfterAll(&builder, token_inputs)));
+ }
+ TF_RETURN_IF_ERROR(PopNodeTokenMapping());
int num_nonconst_outputs;
int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation(
- args, arg_cores, context->retvals(), context->resources(),
+ real_args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources,
options.always_return_tuple, &builder, result->computation.get(),
&num_computation_outputs, &num_nonconst_outputs, &result->outputs,
@@ -912,4 +970,47 @@ Status XlaCompiler::SetHostComputeControlDependency(
return Status::OK();
}
+void XlaCompiler::PushNodeTokenMapping() {
+ node_token_mapping_stack_.emplace(std::map<string, xla::XlaOp>{});
+}
+
+Status XlaCompiler::PopNodeTokenMapping() {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ node_token_mapping_stack_.pop();
+ return Status::OK();
+}
+
+Status XlaCompiler::SetNodeToken(const string& node_name,
+ const xla::XlaOp& op) {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling SetNodeToken() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ auto insert_result = node_token_mapping_stack_.top().insert({node_name, op});
+ if (!insert_result.second) {
+ return errors::FailedPrecondition("Token mapping already exists for node ",
+ node_name);
+ }
+ return Status::OK();
+}
+
+xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
+ if (node_token_mapping_stack_.empty()) {
+ return errors::FailedPrecondition(
+ "Calling GetNodeToken() when node_token_mapping_stack_ is "
+ "empty.");
+ }
+ auto iter = node_token_mapping_stack_.top().find(node_name);
+ if (iter == node_token_mapping_stack_.top().end()) {
+ return errors::FailedPrecondition("Cannot find token mapping for node ",
+ node_name);
+ }
+ return iter->second;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 8f4a9858ed..2cc603a580 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
+#include <stack>
+
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
@@ -106,6 +109,9 @@ class XlaCompiler {
// Argument is a run-time parameter.
kParameter,
+
+ // Argument is an XLA token.
+ kToken,
};
Kind kind = kInvalid;
@@ -179,6 +185,9 @@ class XlaCompiler {
// True when compiling the entry computation, false for subcomputations
// (while, call, etc.)
bool is_entry_computation = true;
+
+ // True when we should add XLA input & output to the graph/function.
+ bool add_token_input_output = false;
};
struct OutputDescription {
@@ -384,6 +393,11 @@ class XlaCompiler {
xla::Client* client() const { return options_.client; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
+ void PushNodeTokenMapping();
+ Status PopNodeTokenMapping();
+ Status SetNodeToken(const string& node_name, const xla::XlaOp& op);
+ xla::StatusOr<xla::XlaOp> GetNodeToken(const string& node_name);
+
private:
// Sets the function body `fbody` to the one registered as `function`.
Status FindFunctionBody(const NameAttrList& function,
@@ -448,6 +462,15 @@ class XlaCompiler {
std::unordered_map<string, xla::XlaOp> host_compute_control_output_;
+ // This is used to store <node name, token output> mapping. Side-effecting
+ // ops call SetNodeToken() to record its token output, so later side-effecting
+ // ops can use GetNodeToken() to get it and use it as token input.
+ //
+ // It's a stack because we need a mapping like this for each level of nested
+ // CompileGraph() call. In CompileGraph(), we will push a new mapping to the
+ // stack, and pop the mapping before returning.
+ std::stack<std::map<string, xla::XlaOp>> node_token_mapping_stack_;
+
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index be3c93ae47..40ce9fb41c 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -20,10 +20,12 @@ limitations under the License.
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -32,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -1274,5 +1277,70 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
}
}
+class DummySideEffectingOp : public XlaOpKernel {
+ public:
+ explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
+ name(), xla::CreateToken(ctx->builder())));
+ }
+};
+
+REGISTER_OP("DummySideEffectingOp");
+
+REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
+
+TEST_F(XlaCompilerTest, TokenInputAndOutput) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ NodeDef side_effecting_op;
+ side_effecting_op.set_name("DummySideEffectingOp");
+ side_effecting_op.set_op("DummySideEffectingOp");
+ AddNodeAttr(kXlaTokenInputNodesAttrName,
+ std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
+ Status status;
+ graph->AddNode(side_effecting_op, &status);
+ TF_ASSERT_OK(status);
+ EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
+
+ const std::vector<XlaCompiler::Argument> empty_args;
+ {
+ // The case for entry computation: we don't add token input/output. Instead,
+ // we use CreateToken HLO to create the entry token.
+ XlaCompiler::CompileOptions options;
+ options.is_entry_computation = true;
+ options.add_token_input_output = false;
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+ empty_args, &result));
+ EXPECT_EQ(result.xla_input_shapes.size(), 0);
+ EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+ EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 0);
+ }
+ {
+ // The case for non-entry computation (e.g. while loop body). We add token
+ // input/output.
+ XlaCompiler::CompileOptions options;
+ options.is_entry_computation = false;
+ options.add_token_input_output = true;
+ XlaCompiler compiler(DefaultOptions());
+
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
+ empty_args, &result));
+ EXPECT_EQ(result.xla_input_shapes.size(), 1);
+ EXPECT_TRUE(xla::ShapeUtil::IsToken(result.xla_input_shapes[0]));
+ EXPECT_TRUE(xla::ShapeUtil::IsTuple(result.xla_output_shape));
+ EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
+ EXPECT_TRUE(xla::ShapeUtil::IsToken(
+ xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 0)));
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index e8b4b0eb36..f247570d72 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -119,6 +119,17 @@ Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
return Status::OK();
}
+Status XlaContext::AppendTokenRetval(const xla::XlaOp& token) {
+ VLOG(1) << "Adding retval index " << retvals_.size()
+ << " with token to XLA computation";
+ XlaExpression e;
+ e.set_handle(token);
+ // We use DT_INVALID because there is no TF DataType which corresponds to XLA
+ // token. XlaCompiler handles this case separately, so putting it here is OK.
+ retvals_.push_back(Retval{DT_INVALID, TensorShape(), e});
+ return Status::OK();
+}
+
xla::XlaBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 4da891634e..d7dbdc957f 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -89,6 +89,9 @@ class XlaContext : public ResourceBase {
// As for Retval, but for return values that are resource handles.
Status AddResourceRetval(int retval_index, XlaResource* resource);
+ // As for Retval, but for return values that are XLA tokens.
+ Status AppendTokenRetval(const xla::XlaOp& token);
+
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
// constructor for a description of the remaining arguments.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index d67e50375b..636cb71e21 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -102,7 +102,8 @@ Status XlaOpKernelContext::ConstantInput(int index,
static xla::StatusOr<int> InputIndex(XlaOpKernelContext* context,
absl::string_view name) {
int start, stop;
- TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(context->op_kernel().InputRange(
+ StringPiece(name.data(), name.length()), &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
@@ -365,7 +366,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes) {
OpInputList inputs;
- TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
+ TF_RETURN_IF_ERROR(
+ context_->input_list(StringPiece(name.data(), name.size()), &inputs));
handles->clear();
shapes->clear();
for (const Tensor& input : inputs) {
@@ -378,7 +380,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
Status XlaOpKernelContext::ConstantInputList(
absl::string_view name, std::vector<xla::Literal>* outputs) {
int start, stop;
- TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
+ TF_RETURN_IF_ERROR(op_kernel().InputRange(
+ StringPiece(name.data(), name.size()), &start, &stop));
outputs->resize(stop - start);
for (int i = start; i < stop; ++i) {
TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
@@ -612,7 +615,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
const Tensor* tensor;
- CHECK(context_->input(name, &tensor).ok());
+ CHECK(context_->input(StringPiece(name.data(), name.length()), &tensor).ok());
return *tensor;
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 74a4885f1f..5d53169f68 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/string_view.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index f9473d372b..bddb664149 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -64,7 +65,7 @@ StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
absl::Span<const float> field = result->data<float>();
char* data = absl::bit_cast<char*>(field.data());
uint64 bytes = elements * sizeof(float);
- absl::string_view sp;
+ tensorflow::StringPiece sp;
auto s = file_->Read(offset_, bytes, &sp, data);
offset_ += sp.size();
if (!s.ok()) {
@@ -85,7 +86,7 @@ bool PackedLiteralReader::IsExhausted() const {
// Try to read a single byte from offset_. If we can't, we've
// exhausted the data.
char single_byte[1];
- absl::string_view sp;
+ tensorflow::StringPiece sp;
auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte);
return !s.ok();
}
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 76c09512d8..450d3fe5af 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -109,12 +109,12 @@ limitations under the License.
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
-#include "third_party/absl/strings/str_cat.h"
-#include "third_party/absl/strings/str_format.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "third_party/absl/types/span.h"
+#include "absl/types/span.h"
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ab86dce510..e784663ff6 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -159,6 +159,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
],
@@ -291,6 +292,7 @@ cc_library(
"hlo_instructions.cc",
"hlo_module.cc",
"hlo_opcode.cc",
+ "hlo_schedule.cc",
"hlo_sharding.cc",
],
hdrs = [
@@ -303,6 +305,7 @@ cc_library(
"hlo_instructions.h",
"hlo_module.h",
"hlo_opcode.h",
+ "hlo_schedule.h",
"hlo_sharding.h",
],
deps = [
@@ -331,6 +334,8 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
@@ -1037,7 +1042,6 @@ tf_cc_test(
":flatten_call_graph",
":hlo",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -1065,7 +1069,6 @@ cc_library(
":hlo",
":hlo_dataflow_analysis",
":hlo_proto",
- ":hlo_schedule",
":hlo_value",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1086,7 +1089,6 @@ tf_cc_test(
":hlo",
":hlo_dataflow_analysis",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1108,7 +1110,6 @@ cc_library(
":hlo",
":hlo_ordering",
":hlo_proto",
- ":hlo_schedule",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1177,22 +1178,6 @@ cc_library(
],
)
-cc_library(
- name = "hlo_schedule",
- srcs = ["hlo_schedule.cc"],
- hdrs = ["hlo_schedule.h"],
- deps = [
- ":hlo",
- "//tensorflow/compiler/xla:status",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib_internal",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/types:span",
- ],
-)
-
tf_cc_test(
name = "hlo_schedule_test",
srcs = ["hlo_schedule_test.cc"],
@@ -1202,7 +1187,6 @@ tf_cc_test(
":hlo_dce",
":hlo_ordering",
":hlo_parser",
- ":hlo_schedule",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1222,7 +1206,6 @@ cc_library(
":heap_simulator",
":hlo",
":hlo_ordering",
- ":hlo_schedule",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -1969,6 +1952,8 @@ tf_cc_test(
srcs = ["hlo_module_test.cc"],
deps = [
":hlo",
+ ":hlo_matchers",
+ ":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1977,6 +1962,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
+ "//tensorflow/core:test",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -2413,7 +2399,6 @@ cc_library(
":hlo",
":hlo_dce",
":hlo_ordering",
- ":hlo_schedule",
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
@@ -2587,6 +2572,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index aa40fba9bb..a0db4563fb 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2369,20 +2369,20 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
rhs_pad->shape().dimensions(3),
testcase.orig_conv_window))
.ValueOrDie();
- auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
- /*feature_group_count=*/1, window,
- dnums)
- .ValueOrDie(),
- input, rhs_pad, /*feature_group_count=*/1, window, dnums,
- DefaultPrecisionConfig(2)));
// Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place
// after the transformation.
PrecisionConfig precision_config;
precision_config.add_operand_precision(PrecisionConfig::HIGH);
precision_config.add_operand_precision(PrecisionConfig::HIGHEST);
- orig_conv->set_precision_config(precision_config);
+
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(),
+ /*feature_group_count=*/1, window,
+ dnums)
+ .ValueOrDie(),
+ input, rhs_pad, /*feature_group_count=*/1, window, dnums,
+ precision_config));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -2401,7 +2401,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) {
conv->operand(1)->shape().dimensions(2),
conv->operand(1)->shape().dimensions(3),
testcase.expected_conv_window));
- EXPECT_THAT(conv->precision_config().operand_precision(),
+ EXPECT_THAT(Cast<HloConvolutionInstruction>(conv)
+ ->precision_config()
+ .operand_precision(),
ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST));
}
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 69b654d30e..388fd5df99 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -55,8 +55,12 @@ class TestBFloat16Support : public BFloat16Support {
}
};
-class BFloat16PropagationTest : public HloTestBase {
+class BFloat16PropagationTest : public HloVerifiedTestBase {
protected:
+ BFloat16PropagationTest()
+ : HloVerifiedTestBase(/*layout_sensitive=*/false,
+ /*allow_mixed_precision=*/true) {}
+
// Runs the propagation pass on the given module, and returns whether the
// module is changed after this pass.
bool PropagatePrecision(HloModule* module) {
@@ -77,6 +81,16 @@ class BFloat16PropagationTest : public HloTestBase {
inst->users()[0]->opcode() == HloOpcode::kConvert &&
inst->users()[0]->shape().element_type() == BF16;
}
+
+ std::unique_ptr<HloInstruction> CreateDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums,
+ DefaultPrecisionConfig(2));
+ }
};
// Tests that BF16 can propagate through select over non-tuple buffers, but not
@@ -95,22 +109,22 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
- HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b));
+ HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {2, 4}), HloOpcode::kEq, a, b));
HloInstruction* sel = builder.AddInstruction(
HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
HloInstruction* xpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a));
- HloInstruction* root = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, a));
+ HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
EXPECT_TRUE(OutputsBF16(xpose));
@@ -136,13 +150,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a, b));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(dot->operand(0)));
@@ -189,8 +202,8 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple0->shape(), tuple1, 0)),
0));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
HloInstruction* output_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
@@ -198,7 +211,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), output_tuple);
EXPECT_TRUE(OutputsBF16(xpose));
@@ -231,13 +244,13 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
// lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), lhs, rhs));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(add1));
@@ -249,7 +262,7 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
// Tests that a non-fusion computation's root should not be changed.
TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* a =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
@@ -258,8 +271,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, add, add));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
@@ -267,7 +279,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), tuple);
EXPECT_FALSE(OutputsBF16(add));
@@ -277,7 +289,7 @@ TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
@@ -303,15 +315,14 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
HloInstruction* b_f1 = builder_f1.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
- HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1));
+ HloInstruction* dot = builder_f1.AddInstruction(CreateDot(shape, a_f1, b_f1));
auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), fusion1);
EXPECT_TRUE(OutputsBF16(add));
@@ -326,7 +337,7 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
@@ -340,15 +351,15 @@ TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) {
builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
HloInstruction* add_f = builder_f.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f));
- HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f));
+ HloInstruction* dot_f =
+ builder_f.AddInstruction(CreateDot(shape, add_f, add_f));
auto comp_f = module->AddEmbeddedComputation(builder_f.Build());
auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), fusion);
}
@@ -390,12 +401,11 @@ TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) {
HloInstruction::CreateGetTupleElement(shape, fusion, 0));
HloInstruction* gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, fusion, 1));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, gte0, gte1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(gte0));
@@ -440,12 +450,12 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) {
HloInstruction* xpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1));
+ HloInstruction* dot = builder.AddInstruction(
+ CreateDot(ShapeUtil::MakeShape(F32, {4, 4}), xpose, gte1));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_FALSE(OutputsBF16(add0));
@@ -472,31 +482,36 @@ TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) {
auto builder_cond = HloComputation::Builder("cond");
auto cond_param = builder_cond.AddInstruction(
HloInstruction::CreateParameter(0, shape, "cond_param"));
- auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond_param, cond_param));
+ auto cond_dot =
+ builder_cond.AddInstruction(CreateDot(shape, cond_param, cond_param));
auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
auto body_param = builder_body.AddInstruction(
HloInstruction::CreateParameter(0, shape, "body_param"));
- auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, body_param, body_param));
+ auto body_dot =
+ builder_body.AddInstruction(CreateDot(shape, body_param, body_param));
auto body = module->AddEmbeddedComputation(builder_body.Build());
auto while_hlo = builder.AddInstruction(
HloInstruction::CreateWhile(shape, cond, body, add));
- auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, while_hlo, while_hlo));
+ auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(
@@ -528,10 +543,16 @@ TEST_F(BFloat16PropagationTest,
HloInstruction::CreateParameter(0, shape, "cond_param"));
builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_param, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_param, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {0, 0}, {1, 1},
+ {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_param, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@@ -552,11 +573,10 @@ TEST_F(BFloat16PropagationTest,
auto while_hlo = builder.AddInstruction(
HloInstruction::CreateWhile(shape, cond, body, add));
- auto dot = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, while_hlo, while_hlo));
+ auto dot = builder.AddInstruction(CreateDot(shape, while_hlo, while_hlo));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(PropagatePrecision(module.get()));
+ EXPECT_FALSE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_FALSE(OutputsBF16(add));
EXPECT_FALSE(OutputsBF16(body_fusion));
@@ -593,14 +613,20 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
// This add should prevent RHS from using BF16
auto cond_add_rhs = builder_cond.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, cond_rhs, cond_rhs));
- auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond_lhs, cond_add_rhs));
+ auto cond_dot =
+ builder_cond.AddInstruction(CreateDot(shape, cond_lhs, cond_add_rhs));
builder_cond.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond = module->AddEmbeddedComputation(builder_cond.Build());
auto builder_body = HloComputation::Builder("body");
@@ -610,10 +636,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
auto body_rhs = builder_body.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 1));
- auto body_dot1 = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
- auto body_dot2 = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_rhs, body_lhs));
+ auto body_dot1 =
+ builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
+ auto body_dot2 =
+ builder_body.AddInstruction(CreateDot(shape, body_rhs, body_lhs));
auto body_transpose = builder_body.AddInstruction(
HloInstruction::CreateTranspose(shape, body_dot2, {0, 1}));
builder_body.AddInstruction(
@@ -627,11 +653,10 @@ TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) {
HloInstruction::CreateGetTupleElement(shape, while_hlo, 0));
auto rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, while_hlo, 1));
- auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+ auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), dot);
EXPECT_TRUE(OutputsBF16(lhs));
@@ -683,14 +708,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto cond0_add_rhs =
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, cond0_rhs, cond0_rhs));
- auto cond0_dot = builder_cond0.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond0_lhs, cond0_add_rhs));
+ auto cond0_dot =
+ builder_cond0.AddInstruction(CreateDot(shape, cond0_lhs, cond0_add_rhs));
builder_cond0.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond0.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond0_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond0.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond0_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond0.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond0.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond0_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond0.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond0.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond0_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond0 = module->AddEmbeddedComputation(builder_cond0.Build());
// Condition computation for the second while.
@@ -705,14 +736,20 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto cond1_add_lhs =
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, cond1_lhs, cond1_lhs));
- auto cond1_dot = builder_cond1.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, cond1_add_lhs, cond1_rhs));
+ auto cond1_dot =
+ builder_cond1.AddInstruction(CreateDot(shape, cond1_add_lhs, cond1_rhs));
builder_cond1.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt,
- builder_cond1.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond1_dot, {0, 0}, {1, 1}, {1, 1})),
- builder_cond1.AddInstruction(HloInstruction::CreateSlice(
- ShapeUtil::MakeShape(F32, {}), cond1_dot, {1, 1}, {2, 2}, {1, 1}))));
+ builder_cond1.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond1.AddInstruction(
+ HloInstruction::CreateSlice(ShapeUtil::MakeShape(F32, {1, 1}),
+ cond1_dot, {0, 0}, {1, 1}, {1, 1})))),
+ builder_cond1.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {}),
+ builder_cond1.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {1, 1}), cond1_dot, {1, 1}, {2, 2},
+ {1, 1}))))));
auto cond1 = module->AddEmbeddedComputation(builder_cond1.Build());
// Body computation shared by both whiles.
@@ -723,8 +760,8 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
HloInstruction::CreateGetTupleElement(shape, body_param, 0));
auto body_rhs = builder_body.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, body_param, 1));
- auto body_dot = builder_body.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, body_lhs, body_rhs));
+ auto body_dot =
+ builder_body.AddInstruction(CreateDot(shape, body_lhs, body_rhs));
builder_body.AddInstruction(
HloInstruction::CreateTuple({body_dot, body_rhs}));
auto body = module->AddEmbeddedComputation(builder_body.Build());
@@ -734,23 +771,22 @@ TEST_F(BFloat16PropagationTest, DoNotPropagateWhilesCallingSameComputation) {
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(tuple1->shape(), cond1, body, tuple1));
- auto lhs = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot,
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while0, 0)),
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while0, 1))));
- auto rhs = builder.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot,
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while1, 0)),
- builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, while1, 1))));
- auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, lhs, rhs));
+ auto lhs = builder.AddInstruction(
+ CreateDot(shape,
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while0, 0)),
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while0, 1))));
+ auto rhs = builder.AddInstruction(
+ CreateDot(shape,
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while1, 0)),
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, while1, 1))));
+ auto dot = builder.AddInstruction(CreateDot(shape, lhs, rhs));
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_FALSE(OutputsBF16(body_dot));
EXPECT_FALSE(OutputsBF16(body_rhs));
EXPECT_FALSE(OutputsBF16(body_lhs));
@@ -792,7 +828,7 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), add2);
EXPECT_EQ(add2->operand(0), add0);
@@ -821,15 +857,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) {
HloInstruction::CreateGetTupleElement(shape, domain, 0));
HloInstruction* b_gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, domain, 1));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte));
+ HloInstruction* dot = builder.AddInstruction(CreateDot(shape, a_gte, b_gte));
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
// test BF16 propagated through domain
@@ -867,15 +902,15 @@ TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
HloInstruction* b_trans = builder.AddInstruction(
HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
- HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans));
+ HloInstruction* dot =
+ builder.AddInstruction(CreateDot(shape, a_trans, b_trans));
HloInstruction* root = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_TRUE(PropagatePrecision(module));
EXPECT_EQ(computation->root_instruction(), root);
EXPECT_TRUE(OutputsBF16(a_trans));
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index d412578619..2368ac8c6a 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -670,6 +670,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 0fea462c85..7d99b914d4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
namespace op = xla::testing::opcode_matchers;
@@ -696,8 +697,8 @@ void CreateComputationForDotAddOutputFusionTest(const string& test_name,
auto* addend = builder.AddInstruction(
HloInstruction::CreateParameter(2, dot_shape, "param2"));
- auto* dot = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+ auto* dot =
+ builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
builder.AddInstruction(
HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 9363af3b89..4668f3872d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -70,7 +70,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -107,9 +107,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto dot_a_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs));
auto dot_b_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs));
builder.AddInstruction(HloInstruction::CreateBinary(
result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result));
@@ -151,9 +151,9 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape)));
auto dot_a_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
+ CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs));
auto dot_b_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
+ CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs));
auto tuple_result = builder.AddInstruction(
HloInstruction::CreateTuple({dot_a_result, dot_b_result}));
@@ -189,7 +189,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateParameter(0, rhs_shape, "param0"));
auto dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -229,7 +229,7 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) {
auto dot_rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1));
auto dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
+ CreateCanonicalDot(result_shape, dot_lhs, dot_rhs));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
@@ -276,8 +276,8 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion(
HloInstruction::CreateParameter(1, dot_shape, "param1"));
HloInstruction* dot_rhs = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape)));
- HloInstruction* dot_result = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
+ HloInstruction* dot_result =
+ builder.AddInstruction(CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs));
HloInstruction* add_result;
if (dot_operand_idx_in_add == 0) {
add_result = builder.AddInstruction(HloInstruction::CreateBinary(
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index a84ee78b19..fad76338a5 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -35,9 +35,7 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase {
cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_;
ParallelTaskAssignmentTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false),
- target_machine_features_([](int64 shape_size) {
+ : HloVerifiedTestBase(), target_machine_features_([](int64 shape_size) {
return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment;
}) {}
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 2384166fd2..f11aff0573 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -121,6 +121,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
index fcd87b36b3..18ee25ba91 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -69,8 +70,7 @@ TEST_P(CpuEigenDotOperationTest, SimpleDotOp) {
HloInstruction* rhs = builder.AddInstruction(
HloInstruction::CreateParameter(1, param_shape, "input"));
- builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs));
+ builder.AddInstruction(CreateCanonicalDot(param_shape, lhs, rhs));
CompileAndCheck(builder.Build(), spec.filecheck_lines);
}
@@ -87,8 +87,7 @@ TEST_P(CpuEigenDotOperationTest, DotTransposeOp) {
HloInstruction* lhs_transposed = builder.AddInstruction(
HloInstruction::CreateTranspose(param_shape, lhs, {1, 0}));
- builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs));
+ builder.AddInstruction(CreateCanonicalDot(param_shape, lhs_transposed, rhs));
CompileAndCheck(builder.Build(), spec.filecheck_lines);
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 13ccff35f8..6791e15ee0 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -108,6 +108,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@@ -480,6 +481,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
@@ -813,7 +815,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_ordering",
"//tensorflow/compiler/xla/service:hlo_reachability",
- "//tensorflow/compiler/xla/service:hlo_schedule",
"//tensorflow/compiler/xla/service:hlo_scheduling",
"@com_google_absl//absl/memory",
],
@@ -831,6 +832,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
index 0922e44a12..59ade96f7d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -73,10 +74,10 @@ TEST_F(GpuHloScheduleTest, SequentialMatMul) {
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -201,12 +202,12 @@ TEST_F(GpuHloScheduleTest, ConcurrentMatMul) {
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
- HloInstruction* add = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
+ HloInstruction* add =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, dot2));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(add));
@@ -269,23 +270,23 @@ TEST_F(GpuHloScheduleTest, LatticeMatMul) {
i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
- HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
- HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
- HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
- HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
- HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
- HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
- HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
- HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+ CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+ HloInstruction* d10 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+ HloInstruction* d11 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+ HloInstruction* d20 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+ HloInstruction* d21 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+ HloInstruction* d22 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+ HloInstruction* d30 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+ HloInstruction* d31 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+ HloInstruction* d40 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index bca775c475..96bfe0c12e 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
namespace op = xla::testing::opcode_matchers;
@@ -111,8 +112,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
- ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
@@ -128,8 +129,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
- ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
+ auto dot1 = builder.AddInstruction(
+ CreateCanonicalDot(ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index ffca5d6549..b7c37bcf3c 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -764,5 +764,20 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
return Load(return_buffer);
}
+std::vector<llvm_ir::IrArray> IrEmitter::ConstructIrArrayForOutputs(
+ const HloInstruction& hlo) {
+ std::vector<llvm_ir::IrArray> output_arrays;
+ if (ShapeUtil::IsTuple(hlo.shape())) {
+ int64 num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_arrays.reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
+ }
+ } else {
+ output_arrays.push_back(GetIrArray(hlo, hlo));
+ }
+ return output_arrays;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 579268f071..8805201480 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -124,6 +124,12 @@ class IrEmitter : public DfsHloVisitorWithDefault,
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
return bindings_.GetBasePointer(inst);
}
+
+ // Generates the IrArray for each output of an hlo instruction and returns
+ // a vector containing such IrArrays.
+ std::vector<llvm_ir::IrArray> ConstructIrArrayForOutputs(
+ const HloInstruction& hlo);
+
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
BufferAllocation::Slice GetAllocationSlice(
const HloInstruction& hlo, const ShapeIndex& index = {}) const {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
index 5c827e5f9c..66c65f6975 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
@@ -119,21 +119,11 @@ Status IrEmitterNested::EmitTargetElementLoop(
// For MOF we give the loop emitter an array for every output it should
// generate.
if (hlo.IsMultiOutputFusion()) {
- const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape());
- std::vector<llvm_ir::IrArray> target_arrays;
- target_arrays.reserve(num_elems);
- for (int64 i = 0; i != num_elems; ++i) {
- target_arrays.push_back(GetIrArray(hlo, hlo, {i}));
- }
+ std::vector<llvm_ir::IrArray> target_arrays =
+ ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
-
- std::vector<llvm::Value*> tuple_operand_ptrs;
- tuple_operand_ptrs.reserve(num_elems);
- for (const llvm_ir::IrArray& array : target_arrays) {
- tuple_operand_ptrs.push_back(array.GetBasePointer());
- }
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_, module_);
return Status::OK();
}
return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 389a98facb..f91cc00d71 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2521,15 +2521,15 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk(
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
- const HloInstruction* hlo, const ShapeIndex& index) {
+ HloInstruction* hlo, const ShapeIndex& index) {
bool fused = HloOpcode::kFusion == hlo->opcode();
- const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
- const HloInstruction* init_value_operand = [&] {
+ HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
+ HloInstruction* init_value_operand = [&] {
switch (inst->opcode()) {
case HloOpcode::kSelectAndScatter:
- return inst->operand(2);
+ return inst->mutable_operand(2);
case HloOpcode::kReduce:
- return inst->operand(1);
+ return inst->mutable_operand(1);
case HloOpcode::kTuple:
CHECK(hlo->IsMultiOutputFusion())
<< ": " << hlo->ToString() << " is not a multi-output fusion.";
@@ -2537,7 +2537,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
<< ": Found '" << inst->operand(index.back())->opcode() << "' in "
<< inst->ToString() << " but expected 'reduce'.";
// For multi-output fusion look through the tuple.
- return inst->operand(index.back())->operand(1);
+ return inst->mutable_operand(index.back())->mutable_operand(1);
default:
LOG(FATAL) << "Opcode " << inst->opcode()
<< " should not need an initializer.";
@@ -2609,28 +2609,35 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ir_emitter_context_->device_description());
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
ir_emitter_context_->llvm_module());
- // If the init_value was fused into this reduce we have to generate it first.
- if (fused && init_value_operand->opcode() != HloOpcode::kParameter) {
- CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode());
- const Literal& literal = init_value_operand->literal();
- llvm::Constant* initializer =
- llvm_ir::ConvertLiteralToIrConstant(literal, module_);
+ if (fused) {
+ // If init_value was fused into this reduce we have to generate it first.
+ std::vector<IrArray> parameter_arrays;
+ for (HloInstruction* operand : hlo->operands()) {
+ parameter_arrays.push_back(GetIrArray(*operand, *hlo));
+ }
+ GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
+ ir_emitter_context_->llvm_module(),
+ &b_, GetNestedComputer());
- llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
- *module_, initializer->getType(),
- /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
- /*Name=*/"");
- global_for_const->setAlignment(kConstantBufferAlignBytes);
- bindings_.BindHloToIrValue(*init_value_operand, global_for_const);
+ FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
+ TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
+ TF_RETURN_IF_ERROR(
+ ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &b_)
+ .EmitLoop(IrName(hlo)));
+ } else {
+ // In the unfused case the element is already there, just read from it.
+ TF_RETURN_IF_ERROR(ParallelLoopEmitter(
+ [=](const IrArray::Index& index) {
+ return GetIrArray(*init_value, *hlo)
+ .EmitReadArrayElement(index, &b_);
+ },
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &b_)
+ .EmitLoop(IrName(hlo)));
}
- TF_RETURN_IF_ERROR(ParallelLoopEmitter(
- [=](const IrArray::Index& index) {
- return GetIrArray(*init_value, *hlo)
- .EmitReadArrayElement(index, &b_);
- },
- GetIrArray(*hlo, *hlo, index), launch_dimensions, &b_)
- .EmitLoop(IrName(hlo)));
// Clean up state left behind by emitting the loop above. (This is normally
// done in IrEmitterUnnested::Postprocess().)
@@ -2819,10 +2826,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
}
// For multioutput fusion, we need to emit each operand and the root.
- std::vector<IrArray> output_arrays;
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
- output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
- }
+ std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
&b_, unroll_factor)
@@ -2830,12 +2834,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
GetIndexTypeForKernel(
&hlo, launch_dimensions.launch_bound(), &b_)));
- std::vector<llvm::Value*> tuple_operand_ptrs;
- for (int64 i = 0; i < output_arrays.size(); ++i) {
- tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
- }
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_);
+
return Status::OK();
}
@@ -2847,29 +2848,14 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
static_cast<KernelThunk*>(LastThunk()));
}
-int IrEmitterUnnested::ConstructIrArrayForOutputs(
- const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
- int64 num_outputs = 1;
- if (hlo.IsMultiOutputFusion()) {
- num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
- output_arrays->reserve(num_outputs);
- for (int64 i = 0; i < num_outputs; ++i) {
- output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
- }
- } else {
- output_arrays->push_back(GetIrArray(hlo, hlo));
- }
- return num_outputs;
-}
-
-int IrEmitterUnnested::ConstructIrArrayForInputs(
- const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
- int64 num_params = hlo.operands().size();
- param_arrays->reserve(num_params);
+std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs(
+ const HloInstruction& hlo) {
+ std::vector<IrArray> param_arrays;
+ param_arrays.reserve(hlo.operands().size());
for (const HloInstruction* param : hlo.operands()) {
- param_arrays->push_back(GetIrArray(*param, hlo));
+ param_arrays.push_back(GetIrArray(*param, hlo));
}
- return num_params;
+ return param_arrays;
}
int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
@@ -3050,10 +3036,10 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
// Construct IrArrays for the inputs and outputs.
- std::vector<IrArray> output_arrays;
- int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
- std::vector<IrArray> param_arrays;
- int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+ std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
+ int64 num_outputs = output_arrays.size();
+ std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*hlo);
+ int64 num_params = param_arrays.size();
// Allocate shared memory buffers to store the tiled inputs.
std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
@@ -3251,12 +3237,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// For multioutput fusion, emit a tuple with all the individual outputs.
if (hlo->IsMultiOutputFusion()) {
- std::vector<llvm::Value*> tuple_operand_ptrs;
- for (int64 i = 0; i < output_arrays.size(); ++i) {
- tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
- }
- llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_,
- module_);
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_);
}
return launch_dimensions;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 084462330e..bd5db72051 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -193,14 +193,12 @@ class IrEmitterUnnested : public IrEmitter {
LaunchDimensions EmitHlo021Tile(HloInstruction* hlo,
absl::Span<const int64> reduced_output_dims,
absl::Span<const int64> tiled_param_ids);
- // Generates the IrArray for each output of hlo and returns the number of
- // outputs.
- int ConstructIrArrayForOutputs(const HloInstruction& hlo,
- std::vector<llvm_ir::IrArray>* output_arrays);
- // Generates the IrArray for each input of hlo and returns the number of
- // inputs.
- int ConstructIrArrayForInputs(const HloInstruction& hlo,
- std::vector<llvm_ir::IrArray>* param_arrays);
+
+ // Generates the IrArray for each input of an hlo and returns a vector that
+ // constains such IrArrays.
+ std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs(
+ const HloInstruction& hlo);
+
// For each output of the `hlo` instruction, constructs the reduced shape for
// the output with the given `reduced_output_dims` and cast the original
// output IrArray element in `output_arrays` to the reduced shape. Returns
@@ -244,7 +242,7 @@ class IrEmitterUnnested : public IrEmitter {
// Returns a thunk that, given a reduce or select-and-scatter op, initializes
// its memory to the appropriate initial value.
StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
- const HloInstruction* hlo, const ShapeIndex& index = {});
+ HloInstruction* hlo, const ShapeIndex& index = {});
// Returns a thunk that calls host-to-device cuMemcpy to implement `inst`.
std::unique_ptr<Thunk> BuildHostToDeviceCopyThunk(const HloInstruction* inst);
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 091aca23e5..8f0dedfa40 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -49,10 +50,10 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) {
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, dot1, z));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(dot2));
@@ -68,10 +69,10 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) {
/*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
- HloInstruction* dot1 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, x, y));
- HloInstruction* dot2 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, y, x));
+ HloInstruction* dot1 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, x, y));
+ HloInstruction* dot2 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, y, x));
HloInstruction* add = builder.AddInstruction(
HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
@@ -101,23 +102,23 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) {
i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i))));
}
HloInstruction* d00 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3]));
- HloInstruction* d10 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00));
- HloInstruction* d11 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4]));
- HloInstruction* d20 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10));
- HloInstruction* d21 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11));
- HloInstruction* d22 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5]));
- HloInstruction* d30 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21));
- HloInstruction* d31 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22));
- HloInstruction* d40 = builder.AddInstruction(
- HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31));
+ CreateCanonicalDot(f32_2x2_, params[2], params[3]));
+ HloInstruction* d10 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[1], d00));
+ HloInstruction* d11 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d00, params[4]));
+ HloInstruction* d20 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, params[0], d10));
+ HloInstruction* d21 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d10, d11));
+ HloInstruction* d22 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d11, params[5]));
+ HloInstruction* d30 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d20, d21));
+ HloInstruction* d31 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d21, d22));
+ HloInstruction* d40 =
+ builder.AddInstruction(CreateCanonicalDot(f32_2x2_, d30, d31));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build(d40));
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 99d0cf50ca..93ec2c9438 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -199,6 +199,17 @@ message HloComputationProto {
int64 root_id = 6;
}
+// Serialization of an HLO schedule. An HLO schedule contains a total order of
+// instructions for each non-fusion computation in the module.
+message HloScheduleProto {
+ message InstructionSequence {
+ repeated int64 instruction_ids = 1;
+ }
+
+ // Map from computation id to sequence.
+ map<int64, InstructionSequence> sequences = 1;
+}
+
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
@@ -214,16 +225,9 @@ message HloModuleProto {
// The id of this module.
int64 id = 5;
-}
-// Serialization of HloOrdering.
-message HloOrderingProto {
- // NOTE: currently only sequential orderings are serialized.
- message SequentialComputation {
- string computation_name = 1;
- repeated string instruction_names = 2;
- }
- repeated SequentialComputation sequential_computations = 1;
+ // The schedule for this module.
+ HloScheduleProto schedule = 7;
}
// Serialization of LogicalBuffer.
@@ -322,8 +326,10 @@ message BufferAssignmentProto {
// Grouping message that contains all of the information above.
message HloProto {
+ reserved 2;
+ reserved "hlo_ordering";
+
HloModuleProto hlo_module = 1;
- HloOrderingProto hlo_ordering = 2;
BufferAssignmentProto buffer_assignment = 3;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index fe7f2be888..233d2199d1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -464,6 +464,14 @@ std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
}
string HloComputation::ToString(const HloPrintOptions& options) const {
+ return ToString(options, MakeInstructionPostOrder());
+}
+
+string HloComputation::ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const {
+ CHECK_EQ(instruction_order.size(), instruction_count());
+
std::ostringstream s;
for (int i = 0; i < options.indent_amount(); i++) {
s << " ";
@@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const {
new_options.set_indent_amount(options.indent_amount() + 1)
.set_is_in_nested_computation(true);
CanonicalNameMap name_map;
- for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+ for (const HloInstruction* instruction : instruction_order) {
+ CHECK_EQ(this, instruction->parent());
+
for (int i = 0; i < new_options.indent_amount(); i++) {
s << " ";
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index fe2d3bbbe5..91c5234a6f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -170,6 +170,11 @@ class HloComputation {
string ToString() const { return ToString(HloPrintOptions()); }
string ToString(const HloPrintOptions& options) const;
+ // Overload which accepts an order to emit the instructions in.
+ string ToString(
+ const HloPrintOptions& options,
+ absl::Span<const HloInstruction* const> instruction_order) const;
+
// Returns a serialized representation of this computation.
HloComputationProto ToProto() const;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 939b5114c3..a502fff9a0 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -227,6 +227,14 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleDomain(const HloInstruction* domain) {
+ // Domain does not have any computation or data transfer.
+ current_should_compute_bottleneck_time_ = false;
+ current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& rhs_shape = dot->operand(1)->shape();
@@ -507,8 +515,9 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
valid_position_counts.push_back(valid_position_count);
}
- const int64 fma_count =
- input_feature * output_feature * batch * Product(valid_position_counts);
+ const int64 fma_count = (input_feature / convolution->feature_group_count()) *
+ output_feature * batch *
+ Product(valid_position_counts);
current_properties_[kFlopsKey] = fma_count * kFmaFlops;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 9bb3f12ee2..46b4bbeef2 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleRecvDone(const HloInstruction* recv_done) override;
Status HandleConvert(const HloInstruction* convert) override;
Status HandleCopy(const HloInstruction* copy) override;
+ Status HandleDomain(const HloInstruction* domain) override;
Status HandleDot(const HloInstruction* dot) override;
Status HandleConvolution(const HloInstruction* convolution) override;
Status HandleFft(const HloInstruction* fft) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 2c854eea18..d76ce9ecbc 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -203,6 +203,35 @@ TEST_F(HloCostAnalysisTest, Convolution) {
sizeof(float) * (10 * 20 + 3 * 3 + 8 * 18));
}
+TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) {
+ XlaBuilder builder("convolution");
+ auto input = Parameter(
+ &builder, 0,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/120, /*y_dim=*/10,
+ /*x_dim=*/20}),
+ "input");
+ auto kernel = Parameter(
+ &builder, 1,
+ ShapeUtil::MakeShape(F32, {/*p_dim=*/120, /*z_dim=*/1, /*y_dim=*/3,
+ /*x_dim=*/3}),
+ "kernel");
+ Conv(input, kernel, {1, 1}, Padding::kValid, /*feature_group_count=*/120);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Output shape is [1x120x8x18] and each output element requires (3x3)
+ // FMAs and one FMA is 2 flops.
+ EXPECT_EQ(analysis.flop_count(), 120 * 8 * 18 * 2 * 3 * 3);
+
+ // Bytes accessed is sum of inputs and output.
+ EXPECT_EQ(analysis.bytes_accessed(),
+ sizeof(float) * (120 * 10 * 20 + 120 * 3 * 3 + 120 * 8 * 18));
+}
+
TEST_F(HloCostAnalysisTest, Reduce) {
XlaBuilder builder("reduce");
auto input =
@@ -415,7 +444,7 @@ TEST_F(FusionCostAnalysis, NoLayout) {
TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{
- XlaBuilder builder("matmul");
+ XlaBuilder builder("tuple");
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
Tuple(&builder, {x, y});
@@ -430,6 +459,30 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2);
}
+using DomainCostAnalysis = HloTestBase;
+TEST_F(DomainCostAnalysis, DomainCost) {
+ HloCostAnalysis analysis(ShapeSize);
+
+ HloComputation::Builder builder("domain");
+ auto x = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {123}), "x"));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {42}), "y"));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({x, y}));
+ auto domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
+
+ auto hlo_module = CreateNewModule();
+ hlo_module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(hlo_module->entry_computation()->root_instruction(), domain);
+ ASSERT_IS_OK(domain->Accept(&analysis));
+
+ EXPECT_EQ(analysis.flop_count(*domain), 0);
+ EXPECT_EQ(analysis.transcendental_count(*domain), 0);
+ EXPECT_EQ(analysis.bytes_accessed(*domain), 0);
+}
+
TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
XlaBuilder builder("BaseDilatedConvolution");
auto input = Parameter(
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 406d712ec6..e09d5868f2 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/util.h"
@@ -44,7 +44,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-class HloCseTest : public HloTestBase {
+class HloCseTest : public HloVerifiedTestBase {
protected:
HloCseTest() {}
};
@@ -65,13 +65,13 @@ TEST_F(HloCseTest, CombineTwoConstants) {
EXPECT_EQ(3, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
HloInstruction* constant = *computation->instructions().begin();
EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR0<float>(84.0);
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -96,14 +96,14 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
EXPECT_THAT(add, op::Add(constant1, constant2));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
auto first_operand = add->operand(0);
EXPECT_THAT(first_operand, ::testing::AnyOf(constant1, constant2));
EXPECT_THAT(add, op::Add(first_operand, first_operand));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -128,12 +128,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
EXPECT_THAT(add, op::Add(constant1, constant2));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
EXPECT_THAT(add, op::Add(constant1, constant2));
- auto result = ExecuteAndTransfer(std::move(module), {});
+ auto result = ExecuteAndTransfer(module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -177,7 +177,7 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
EXPECT_EQ(20, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
// CSE will remove both the second float(42.0f) and the corresponding
// convert/cast.
@@ -209,7 +209,7 @@ TEST_F(HloCseTest, NonscalarConstants) {
op::Tuple(common_constant1, common_constant2, uncommon_constant));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -240,7 +240,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2, exp3));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -250,7 +250,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
// Test two identical while loops with same inputs
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesSameInput) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsAndBodiesSameInput
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -278,21 +278,20 @@ f32[]) while((f32[], f32[]) %tuple.1), condition=%condition, body=%body ROOT
%while.1 = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
condition=%condition.1, body=%body
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(4, computation->instruction_count());
}
// Test two while loops with same conditions, same inputs, but different
// bodies
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsSameInputAndDifferentBodies) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsSameInputAndDifferentBodies
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -329,20 +328,19 @@ index=1 %sub = f32[] subtract(f32[] %get-tuple-element.2, f32[]
condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
f32[]) %tuple.1), condition=%condition.1, body=%body2
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
}
// Test two identical while loops with different inputs
TEST_F(HloCseTest, WhileLoopsIdenticalConditionsAndBodiesDifferentInput) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalConditionsAndBodiesDifferentInput
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -373,21 +371,20 @@ f32[] constant(2) %tuple.2 = (f32[], f32[]) tuple(f32[] %constant.4, f32[]
condition=%condition.1, body=%body
}
- )")
- .ValueOrDie();
+ )");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(8, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(8, computation->instruction_count());
}
// Test two while loops with identical bodies and same inputs, but different
// conditions
TEST_F(HloCseTest, WhileLoopsIdenticalBodiesAndInputDifferntConditions) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule WhileLoopsIdenticalBodiesAndInputDifferntConditions
%body (param: (f32[], f32[])) -> (f32[], f32[]) {
@@ -414,14 +411,13 @@ f32[]) { %constant.2 = f32[] constant(1) %constant.3 = f32[] constant(2)
%while = (f32[], f32[]) while((f32[], f32[]) %tuple.1),
condition=%condition, body=%body ROOT %while.1 = (f32[], f32[]) while((f32[],
f32[]) %tuple.1), condition=%condition.1, body=%body
- })")
- .ValueOrDie();
+ })");
- auto computation = module->entry_computation();
+ auto computation = module().entry_computation();
EXPECT_EQ(5, computation->instruction_count());
HloCSE cse(true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(&module()).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
}
@@ -450,7 +446,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
HloCSE cse(/*is_layout_sensitive=*/true);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
@@ -481,7 +477,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
EXPECT_THAT(tuple, op::Tuple(exp1, exp2));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(3, computation->instruction_count());
auto first_operand = tuple->operand(0);
@@ -516,7 +512,7 @@ TEST_F(HloCseTest, FusionInternalCSE) {
EXPECT_EQ(5, fused_computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(4, fused_computation->instruction_count());
auto root = fused_computation->root_instruction();
@@ -565,7 +561,7 @@ TEST_F(HloCseTest, IdenticalExpressions) {
EXPECT_THAT(tuple, op::Tuple(op::Add(negate1, exp1), op::Add(negate2, exp2)));
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_TRUE(cse.Run(module).ValueOrDie());
EXPECT_EQ(5, computation->instruction_count());
auto operand = tuple->operand(0);
@@ -599,7 +595,7 @@ TEST_F(HloCseTest, DoNotCombineRng) {
uint32 count_before = computation->instruction_count();
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
uint32 count_after = computation->instruction_count();
EXPECT_EQ(count_before, count_after);
@@ -653,7 +649,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
VLOG(3) << "before: " << module->ToString();
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
VLOG(3) << "after: " << module->ToString();
@@ -663,7 +659,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
}
TEST_F(HloCseTest, CompareComputations) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule m
add_computation {
@@ -684,12 +680,11 @@ TEST_F(HloCseTest, CompareComputations) {
r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation
r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2
ROOT f2 = (f32[],f32[]) tuple(r1, r2)
- })")
- .ValueOrDie();
+ })");
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
- HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
+ HloInstruction* root = module().entry_computation()->root_instruction();
EXPECT_EQ(root->operand(0), root->operand(1));
}
@@ -708,13 +703,13 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
EXPECT_EQ(2, computation->instruction_count());
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(cse.Run(module).ValueOrDie());
EXPECT_EQ(2, computation->instruction_count());
}
TEST_F(HloCseTest, Domain) {
- auto module = ParseHloString(R"(
+ ParseAndVerifyModule(R"(
HloModule module
ENTRY %entry {
%param = f32[] parameter(0), sharding={maximal device=0}
@@ -735,13 +730,11 @@ ENTRY %entry {
domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
%add = f32[] add(%domain.3, %domain.4)
ROOT %sub = f32[] subtract(%add, %domain.5)
-})")
- .ValueOrDie();
+})");
HloCSE cse(/*is_layout_sensitive=*/false);
- EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
- LOG(INFO) << "AAAAA " << module->ToString();
- const HloInstruction* sub = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(cse.Run(&module()).ValueOrDie());
+ const HloInstruction* sub = module().entry_computation()->root_instruction();
const HloInstruction* add = sub->operand(0);
EXPECT_EQ(add->operand(0), add->operand(1));
EXPECT_NE(add->operand(0), sub->operand(1));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index abd4bb1f73..102ebb24ab 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -52,10 +52,7 @@ static std::array<bool, 2> use_bf16_params{true, false};
class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
public HloVerifiedTestBase {
protected:
- HloEvaluatorTest()
- : HloVerifiedTestBase(/*layout_sensitive=*/false,
- /*allow_mixed_precision=*/false),
- use_bfloat16_(GetParam()) {
+ HloEvaluatorTest() : HloVerifiedTestBase(), use_bfloat16_(GetParam()) {
evaluator_ = absl::make_unique<HloEvaluator>();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 6a09bb08f4..63303aef1e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1052,7 +1052,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
&lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
rhs_literal_data,
- feature_group_count](absl::Span<const int64> out_index) {
+ feature_group_count](const absl::Span<const int64> out_index) {
// Dimension number applicable for input (lhs).
const int64 input_batch_dim = dnums.input_batch_dimension();
const int64 input_z_dim = dnums.input_feature_dimension();
@@ -1063,9 +1063,22 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
const int64 output_batch_dim = dnums.output_batch_dimension();
const int64 output_z_dim = dnums.output_feature_dimension();
- const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ const int64 input_z_size =
+ ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+ // The size of an input feature group.
+ const int64 input_feature_group_size = input_z_size / feature_group_count;
+
const int64 output_z_size =
ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim);
+ // The output feature dimension is a concatenation of convolution results
+ // from the different groups.
+ const int64 output_feature_group_size =
+ output_z_size / feature_group_count;
+
+ // Calculate the group index to which the current output index
+ // belongs.
+ const int64 feature_group_index =
+ out_index[output_z_dim] / output_feature_group_size;
ElementwiseT result_val = static_cast<ElementwiseT>(0);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
@@ -1073,33 +1086,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Convolve input feature with kernel.
do {
- for (int64 iz = 0; iz < z_size; ++iz) {
- int64 rhs_iz = iz;
- // Handle grouped convolutions.
- if (feature_group_count > 1) {
- // The size of a feature group.
- int64 feature_group_size = z_size / feature_group_count;
- rhs_iz = iz % feature_group_size;
-
- // The output feature dimension is a concatenation of convolution
- // results from the different groups.
- int64 output_feature_group_size =
- output_z_size / feature_group_count;
-
- // Calculate the group index to which the current input feature
- // index belongs.
- int64 input_group_index = iz / feature_group_size;
-
- // Calculate the group index to which the current output index
- // belongs.
- int64 output_group_index =
- out_index[output_z_dim] / output_feature_group_size;
- if (input_group_index != output_group_index) {
- // If the current output index does not belong to the current
- // feature group, skip it.
- continue;
- }
- }
+ for (int64 rhs_iz = 0; rhs_iz < input_feature_group_size; ++rhs_iz) {
+ const int64 iz =
+ feature_group_index * input_feature_group_size + rhs_iz;
int64 lhs_linear_index = 0;
lhs_linear_index += out_index[output_batch_dim] *
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 471a12d6aa..25ae344ea5 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -451,6 +451,28 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;
+ case HloOpcode::kDot: {
+ TF_RET_CHECK(proto.has_dot_dimension_numbers())
+ << "Dot instruction should have dot_dimension_numbers.";
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Dot instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ PrecisionConfig precision_config = proto.precision_config();
+ precision_config.mutable_operand_precision()->Resize(
+ proto.operand_ids_size(), PrecisionConfig::DEFAULT);
+ instruction = absl::make_unique<HloDotInstruction>(
+ proto.shape(), operands(0), operands(1),
+ proto.dot_dimension_numbers(), precision_config);
+ break;
+ }
+ case HloOpcode::kDomain:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Domain instruction should have 1 operands but sees "
+ << proto.operand_ids_size();
+ instruction = absl::make_unique<HloDomainInstruction>(
+ proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
+ /*user_side_metadata=*/nullptr);
+ break;
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -472,20 +494,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
computation_map.at(computation_id));
}
}
- if (instruction->opcode() == HloOpcode::kDot) {
- instruction->precision_config_ = proto.precision_config();
- instruction->precision_config_.mutable_operand_precision()->Resize(
- instruction->operand_count(), PrecisionConfig::DEFAULT);
- TF_RET_CHECK(proto.has_dot_dimension_numbers());
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(
- proto.dot_dimension_numbers());
- } else {
- TF_RET_CHECK(!proto.has_precision_config())
- << instruction->opcode() << proto.DebugString();
- TF_RET_CHECK(!proto.has_dot_dimension_numbers())
- << instruction->opcode();
- }
+ TF_RET_CHECK(!proto.has_precision_config())
+ << instruction->opcode() << proto.DebugString();
+ TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
break;
}
}
@@ -564,7 +575,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kClz:
- case HloOpcode::kDomain:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kFloor:
@@ -596,7 +606,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kAtan2:
case HloOpcode::kDivide:
case HloOpcode::kComplex:
- case HloOpcode::kDot:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
@@ -674,30 +683,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config) {
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>(dimension_numbers);
- instruction->set_precision_config(precision_config);
- return instruction;
-}
-
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCanonicalDot(
- const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) {
- CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
- CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
-
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->dot_dimension_numbers_ =
- absl::make_unique<DotDimensionNumbers>();
- instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1);
- instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0);
- return instruction;
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dimension_numbers, precision_config);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -1157,12 +1144,9 @@ bool HloInstruction::HasSideEffect() const {
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata) {
- auto instruction =
- absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape));
- instruction->operand_side_metadata_ = std::move(operand_side_metadata);
- instruction->user_side_metadata_ = std::move(user_side_metadata);
- instruction->AppendOperand(operand);
- return instruction;
+ return absl::make_unique<HloDomainInstruction>(
+ shape, operand, std::move(operand_side_metadata),
+ std::move(user_side_metadata));
}
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
@@ -1218,6 +1202,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kGather:
case HloOpcode::kScatter:
case HloOpcode::kIota:
+ case HloOpcode::kDot:
+ case HloOpcode::kDomain:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@@ -1290,11 +1276,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
- case HloOpcode::kDot:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateDot(shape, new_operands[0], new_operands[1],
- *dot_dimension_numbers_, precision_config());
- break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
clone = CreateReshape(shape, new_operands[0]);
@@ -1319,12 +1300,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
true_computation(), new_operands[2],
false_computation());
break;
- case HloOpcode::kDomain:
- CHECK_EQ(new_operands.size(), 1);
- clone =
- CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
- user_side_metadata_->Clone());
- break;
case HloOpcode::kAfterAll:
if (new_operands.empty()) {
clone = CreateToken();
@@ -1620,11 +1595,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kAfterAll:
return false;
- // Check dot dimension numbers.
- case HloOpcode::kDot:
- return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
- other.dot_dimension_numbers());
-
// Remaining instructions with special values.
case HloOpcode::kCall:
return eq_computations(to_apply(), other.to_apply());
@@ -1640,10 +1610,6 @@ bool HloInstruction::IdenticalSlowPath(
return false;
}
- case HloOpcode::kDomain:
- return operand_side_metadata().Matches(other.operand_side_metadata()) &&
- user_side_metadata().Matches(other.user_side_metadata());
-
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
@@ -1683,6 +1649,8 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kScatter:
+ case HloOpcode::kDot:
+ case HloOpcode::kDomain:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -2052,15 +2020,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
std::vector<string> extra = ExtraAttributesToStringImpl(options);
- if (dot_dimension_numbers_ != nullptr) {
- extra.push_back(DotDimensionNumbersToString());
- }
-
- string precision_config_string = PrecisionConfigToString();
- if (!precision_config_string.empty()) {
- extra.push_back(precision_config_string);
- }
-
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
if (opcode() == HloOpcode::kWhile) {
@@ -2146,11 +2105,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}),
"}"));
}
- if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
- extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
- "\", entry=", user_side_metadata_->ToString(),
- ", exit=", operand_side_metadata_->ToString(), "}"));
- }
return extra;
}
@@ -2182,19 +2136,12 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
- if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) {
- *proto.mutable_precision_config() = precision_config_;
- }
if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
}
}
- if (dot_dimension_numbers_ != nullptr) {
- *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
- }
-
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
}
@@ -2921,31 +2868,6 @@ string ConvolutionDimensionNumbersToString(
StrJoin(output_dims, ""));
}
-string HloInstruction::DotDimensionNumbersToString() const {
- std::vector<string> result;
- if (dot_dimension_numbers_ == nullptr) {
- return "";
- }
- const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
- if (!dnums.lhs_batch_dimensions().empty()) {
- result.push_back(StrCat("lhs_batch_dims={",
- StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
- }
- result.push_back(StrCat("lhs_contracting_dims={",
- StrJoin(dnums.lhs_contracting_dimensions(), ","),
- "}"));
-
- if (!dnums.rhs_batch_dimensions().empty()) {
- result.push_back(StrCat("rhs_batch_dims={",
- StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
- }
- result.push_back(StrCat("rhs_contracting_dims={",
- StrJoin(dnums.rhs_contracting_dimensions(), ","),
- "}"));
-
- return StrJoin(result, ", ");
-}
-
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
static std::unordered_map<string, RandomDistribution>* map = [] {
static auto* map = new std::unordered_map<string, RandomDistribution>;
@@ -2964,27 +2886,6 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
return found->second;
}
-string HloInstruction::PrecisionConfigToString() const {
- if (absl::c_all_of(
- precision_config_.operand_precision(), [](int32 precision) {
- return static_cast<PrecisionConfig::Precision>(precision) ==
- PrecisionConfig::DEFAULT;
- })) {
- return "";
- }
- return StrCat(
- "operand_precision={",
- StrJoin(
- precision_config_.operand_precision(), ",",
- [](string* out, int32 precision) {
- CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
- StrAppend(out,
- PrecisionToString(
- static_cast<PrecisionConfig::Precision>(precision)));
- }),
- "}");
-}
-
StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
static auto* map =
@@ -3044,6 +2945,16 @@ Status HloInstruction::set_backend_config(
return ret;
}
+const PrecisionConfig& HloInstruction::precision_config() const {
+ if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
+ return convolution->precision_config();
+ }
+ if (auto* dot = DynCast<HloDotInstruction>(this)) {
+ return dot->precision_config();
+ }
+ LOG(FATAL) << "Unimplemented method.";
+}
+
HloModule* HloInstruction::GetModule() const {
if (parent_) {
return parent_->parent();
@@ -3348,4 +3259,15 @@ const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
}
+const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
+ return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
+}
+
+const DomainMetadata& HloInstruction::operand_side_metadata() const {
+ return Cast<HloDomainInstruction>(this)->operand_side_metadata();
+}
+
+const DomainMetadata& HloInstruction::user_side_metadata() const {
+ return Cast<HloDomainInstruction>(this)->user_side_metadata();
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 691f8155f9..5581c17c2d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -421,12 +421,6 @@ class HloInstruction {
const DotDimensionNumbers& dimension_numbers,
const PrecisionConfig& precision_config);
- // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
- // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
- // and the RHS must be of rank 2.
- static std::unique_ptr<HloInstruction> CreateCanonicalDot(
- const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
-
// Creates a reduce-precision op, where operand is the data to reduce in
// precision, and exponent_bits and mantissa_bits describe the precision to
// reduce it to.
@@ -866,11 +860,6 @@ class HloInstruction {
return false;
}
- if (!absl::c_equal(precision_config_.operand_precision(),
- other.precision_config_.operand_precision())) {
- return false;
- }
-
return IdenticalSlowPath(other, eq_computations);
}
@@ -1085,15 +1074,6 @@ class HloInstruction {
return other->has_sharding() ? sharding() == other->sharding() : false;
}
- // Retrieves the operand side metadata of a kDomain instruction.
- const DomainMetadata& operand_side_metadata() const {
- return *operand_side_metadata_;
- }
- // Retrieves the user side metadata of a kDomain instruction.
- const DomainMetadata& user_side_metadata() const {
- return *user_side_metadata_;
- }
-
// When creating a new instruction which either replaces, or shifts up (kCopy
// insertion case), another instruction, we need to make sure the certain
// properties of the new instruction are copied into the derived one. As of
@@ -1101,18 +1081,6 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
- // Returns data on the dimension numbers used for a dot operation.
- const DotDimensionNumbers& dot_dimension_numbers() const {
- CHECK(dot_dimension_numbers_ != nullptr);
- return *dot_dimension_numbers_;
- }
-
- // Returns the dump string of the dot dimension numbers.
- string DotDimensionNumbersToString() const;
-
- // Returns the dump string of the precision configuration.
- string PrecisionConfigToString() const;
-
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1262,10 +1230,8 @@ class HloInstruction {
// information. Transformations to other HLOs will not preserve this
// information but it is presumed that the alternate lowering is strictly
// superior.
- const PrecisionConfig& precision_config() const { return precision_config_; }
- void set_precision_config(const PrecisionConfig& precision_config) {
- precision_config_ = precision_config;
- }
+ // Precondition: opcode must be kConvolution or kDot.
+ const PrecisionConfig& precision_config() const;
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
@@ -1508,6 +1474,15 @@ class HloInstruction {
// Delegates to HloScatterInstruction::scatter_dimension_numbers().
const ScatterDimensionNumbers& scatter_dimension_numbers() const;
+ // Delegates to HloDotInstruction::dot_dimension_numbers().
+ const DotDimensionNumbers& dot_dimension_numbers() const;
+
+ // Delegates to HloDomainInstruction::operand_side_metadata().
+ const DomainMetadata& operand_side_metadata() const;
+
+ // Delegates to HloDomainInstruction::user_side_metadata().
+ const DomainMetadata& user_side_metadata() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
@@ -1647,22 +1622,12 @@ class HloInstruction {
// Result shape of this instruction.
Shape shape_;
- // Describes the dimension numbers used for a dot.
- std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
-
- // Used to tag kCopy instructions that are eligible for copy elision.
- bool copy_elision_allowed_ = true;
-
// The sharding, if one exists.
// Uses std::shared_ptr to allow reuse of the same sharding object between
// HloInstructions and other components as HloSharding can be very large for
// many element tuples.
std::shared_ptr<const HloSharding> sharding_;
- // Fields used by the kDomain instruction.
- std::unique_ptr<DomainMetadata> operand_side_metadata_;
- std::unique_ptr<DomainMetadata> user_side_metadata_;
-
// Computations called by this instruction.
std::vector<HloComputation*> called_computations_;
@@ -1676,10 +1641,6 @@ class HloInstruction {
// HLO. See the documentation on backend_config().
string backend_config_;
- // Information used to communicate to the implementation about the algorithm
- // used to produce results. See the documentation on precision_config().
- PrecisionConfig precision_config_;
-
// String identifier for instruction.
string name_;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index ad87aa1123..fb7345a2ad 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -47,6 +47,27 @@ bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
return instruction->IsElementwiseOnOperand(operand_index);
});
}
+
+string PrecisionConfigToString(const PrecisionConfig& precision_config) {
+ if (absl::c_all_of(precision_config.operand_precision(), [](int32 precision) {
+ return static_cast<PrecisionConfig::Precision>(precision) ==
+ PrecisionConfig::DEFAULT;
+ })) {
+ return "";
+ }
+
+ return StrCat(
+ "operand_precision={",
+ StrJoin(
+ precision_config.operand_precision(), ",",
+ [](string* out, int32 precision) {
+ CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision;
+ StrAppend(out,
+ PrecisionToString(
+ static_cast<PrecisionConfig::Precision>(precision)));
+ }),
+ "}");
+}
} // namespace
HloBatchNormInstruction::HloBatchNormInstruction(
@@ -1634,7 +1655,8 @@ HloConvolutionInstruction::HloConvolutionInstruction(
: HloInstruction(HloOpcode::kConvolution, shape),
feature_group_count_(feature_group_count),
window_(window),
- convolution_dimension_numbers_(dimension_numbers) {
+ convolution_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
if (window_util::HasBaseDilation(window)) {
SetAndSanitizeName(StrCat(name(), "-base-dilated"));
}
@@ -1643,7 +1665,6 @@ HloConvolutionInstruction::HloConvolutionInstruction(
}
AppendOperand(lhs);
AppendOperand(rhs);
- set_precision_config(precision_config);
}
string HloConvolutionInstruction::ToCategory() const {
@@ -1663,6 +1684,7 @@ HloInstructionProto HloConvolutionInstruction::ToProto() const {
*proto.mutable_convolution_dimension_numbers() =
convolution_dimension_numbers_;
proto.set_feature_group_count(feature_group_count_);
+ *proto.mutable_precision_config() = precision_config_;
return proto;
}
@@ -1677,6 +1699,12 @@ std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
if (feature_group_count_ != 1) {
extra.push_back(StrCat("feature_group_count=", feature_group_count_));
}
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+
return extra;
}
@@ -1692,7 +1720,9 @@ bool HloConvolutionInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
- casted_other.convolution_dimension_numbers());
+ casted_other.convolution_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
}
std::unique_ptr<HloInstruction>
@@ -1702,7 +1732,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
CHECK_EQ(new_operands.size(), 2);
return absl::make_unique<HloConvolutionInstruction>(
shape, new_operands[0], new_operands[1], feature_group_count_, window(),
- convolution_dimension_numbers_, precision_config());
+ convolution_dimension_numbers_, precision_config_);
}
HloReduceWindowInstruction::HloReduceWindowInstruction(
@@ -2161,4 +2191,113 @@ std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl(
return absl::make_unique<HloIotaInstruction>(shape, iota_dimension());
}
+HloDotInstruction::HloDotInstruction(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config)
+ : HloInstruction(HloOpcode::kDot, shape),
+ dot_dimension_numbers_(dimension_numbers),
+ precision_config_(precision_config) {
+ AppendOperand(lhs);
+ AppendOperand(rhs);
+}
+
+HloInstructionProto HloDotInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_dot_dimension_numbers() = dot_dimension_numbers_;
+ *proto.mutable_precision_config() = precision_config_;
+ return proto;
+}
+
+std::vector<string> HloDotInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> extra = {DotDimensionNumbersToString()};
+
+ string precision_config_string = PrecisionConfigToString(precision_config_);
+ if (!precision_config_string.empty()) {
+ extra.push_back(precision_config_string);
+ }
+ return extra;
+}
+
+bool HloDotInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDotInstruction&>(other);
+ return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
+ casted_other.dot_dimension_numbers()) &&
+ protobuf_util::ProtobufEquals(precision_config(),
+ casted_other.precision_config());
+}
+
+std::unique_ptr<HloInstruction> HloDotInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return absl::make_unique<HloDotInstruction>(
+ shape, new_operands[0], new_operands[1], dot_dimension_numbers_,
+ precision_config_);
+}
+
+string HloDotInstruction::DotDimensionNumbersToString() const {
+ std::vector<string> result;
+ const DotDimensionNumbers& dnums = dot_dimension_numbers_;
+ if (!dnums.lhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("lhs_batch_dims={",
+ StrJoin(dnums.lhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("lhs_contracting_dims={",
+ StrJoin(dnums.lhs_contracting_dimensions(), ","),
+ "}"));
+
+ if (!dnums.rhs_batch_dimensions().empty()) {
+ result.push_back(StrCat("rhs_batch_dims={",
+ StrJoin(dnums.rhs_batch_dimensions(), ","), "}"));
+ }
+ result.push_back(StrCat("rhs_contracting_dims={",
+ StrJoin(dnums.rhs_contracting_dimensions(), ","),
+ "}"));
+
+ return StrJoin(result, ", ");
+}
+
+HloDomainInstruction::HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata)
+ : HloInstruction(HloOpcode::kDomain, shape),
+ operand_side_metadata_(std::move(operand_side_metadata)),
+ user_side_metadata_(std::move(user_side_metadata)) {
+ AppendOperand(operand);
+}
+
+std::vector<string> HloDomainInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
+ return {StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
+ "\", entry=", user_side_metadata_->ToString(),
+ ", exit=", operand_side_metadata_->ToString(), "}")};
+ }
+ return {};
+}
+
+bool HloDomainInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloDomainInstruction&>(other);
+ return operand_side_metadata().Matches(
+ casted_other.operand_side_metadata()) &&
+ user_side_metadata().Matches(casted_other.user_side_metadata());
+}
+
+std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return absl::make_unique<HloDomainInstruction>(
+ shape, new_operands[0], operand_side_metadata_->Clone(),
+ user_side_metadata_->Clone());
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index e1215a7566..c3a7801164 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -957,6 +957,16 @@ class HloConvolutionInstruction : public HloInstruction {
// The number of feature groups. Must be a divisor of the input feature
// dimension and output feature dimension.
int64 feature_group_count() const { return feature_group_count_; }
+
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+
string ToCategory() const override;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
@@ -979,6 +989,9 @@ class HloConvolutionInstruction : public HloInstruction {
Window window_;
// Describes the dimension numbers used for a convolution.
ConvolutionDimensionNumbers convolution_dimension_numbers_;
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfig precision_config_;
};
class HloReduceWindowInstruction : public HloInstruction {
@@ -1271,6 +1284,85 @@ class HloIotaInstruction : public HloInstruction {
const int64 iota_dimension_;
};
+class HloDotInstruction : public HloInstruction {
+ public:
+ // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
+ // dimensions specified in 'dimension_numbers'.
+ explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
+ HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers,
+ const PrecisionConfig& precision_config);
+
+ // Returns data on the dimension numbers used for a dot operation.
+ const DotDimensionNumbers& dot_dimension_numbers() const {
+ return dot_dimension_numbers_;
+ }
+
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfig& precision_config() const { return precision_config_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+ // Returns the dump string of the dot dimension numbers.
+ string DotDimensionNumbersToString() const;
+
+ // Describes the dimension numbers used for a dot.
+ DotDimensionNumbers dot_dimension_numbers_;
+
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfig precision_config_;
+};
+
+class HloDomainInstruction : public HloInstruction {
+ public:
+ explicit HloDomainInstruction(
+ const Shape& shape, HloInstruction* operand,
+ std::unique_ptr<DomainMetadata> operand_side_metadata,
+ std::unique_ptr<DomainMetadata> user_side_metadata);
+
+ // Retrieves the operand side metadata of a kDomain instruction.
+ const DomainMetadata& operand_side_metadata() const {
+ return *operand_side_metadata_;
+ }
+ // Retrieves the user side metadata of a kDomain instruction.
+ const DomainMetadata& user_side_metadata() const {
+ return *user_side_metadata_;
+ }
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape, absl::Span<HloInstruction* const> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<DomainMetadata> operand_side_metadata_;
+ std::unique_ptr<DomainMetadata> user_side_metadata_;
+};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 3a1bc4e328..cfe906d9c5 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -50,6 +51,13 @@ StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule(
return const_cast<HloInstruction*>(hlo);
}
+Status HloModule::set_schedule(HloSchedule schedule) {
+ TF_RET_CHECK(schedule.module() == this);
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ schedule_ = std::move(schedule);
+ return Status::OK();
+}
+
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names) {
@@ -198,12 +206,23 @@ void HloModule::ReplaceComputations(
string HloModule::ToString(const HloPrintOptions& options) const {
std::ostringstream s;
- s << "HloModule " << name() << "\n\n";
+ s << "HloModule " << name();
+ if (has_schedule()) {
+ TF_CHECK_OK(schedule().Verify());
+ s << ", is_scheduled=true";
+ }
+ s << "\n\n";
for (const HloComputation* computation : MakeComputationPostOrder()) {
if (computation == entry_computation()) {
s << "ENTRY ";
}
- s << computation->ToString(options) << "\n\n";
+ if (has_schedule() && schedule().is_computation_scheduled(computation)) {
+ s << computation->ToString(
+ options, schedule().sequence(computation).instructions())
+ << "\n\n";
+ } else {
+ s << computation->ToString(options) << "\n\n";
+ }
}
return s.str();
}
@@ -221,6 +240,9 @@ HloModuleProto HloModule::ToProto() const {
}
proto.add_computations()->Swap(&computation_proto);
}
+ if (has_schedule()) {
+ *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
+ }
return proto;
}
@@ -309,6 +331,13 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
}
+ if (proto.has_schedule()) {
+ TF_ASSIGN_OR_RETURN(
+ HloSchedule schedule,
+ HloSchedule::CreateFromProto(module.get(), proto.schedule()));
+ TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
+ }
+
return std::move(module);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 3c3371426b..26fd1b2438 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <vector>
#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -32,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
@@ -235,6 +237,19 @@ class HloModule {
StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
const HloInstruction* hlo);
+ // Sets the schedule of the module to the given schedule.
+ Status set_schedule(HloSchedule schedule);
+
+ // Clears the schedule of the module.
+ void clear_schedule() { schedule_.reset(); }
+
+ // Returns true if the module has a schedule set.
+ bool has_schedule() const { return schedule_.has_value(); }
+
+ // Returns the schedue of the module. CHECK fails if no schedule is set.
+ const HloSchedule& schedule() const { return *schedule_; }
+ HloSchedule& schedule() { return *schedule_; }
+
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
@@ -262,6 +277,11 @@ class HloModule {
static std::atomic<int> next_unique_module_id_;
// A unique id to label modules with.
int unique_id_;
+
+ // The HloSchedule of the module. The schedule if it exists contains a
+ // sequential order of instructions for each non-fusion computation in the
+ // module.
+ absl::optional<HloSchedule> schedule_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 3f1e1cc73e..68c18836eb 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -106,9 +106,6 @@ class HloModuleConfig {
absl::optional<ComputationLayout> entry_computation_layout_;
- // Whether this is a 'host module'.
- bool is_host_module_ = false;
-
// Module/graph-level seed handle.
uint64 seed_ = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 4bc1bacd7d..400bd4d947 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -19,9 +19,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
@@ -30,6 +33,8 @@ namespace xla {
namespace {
+namespace op = ::xla::testing::opcode_matchers;
+
class HloModuleTest : public HloTestBase {
protected:
HloModuleTest() {}
@@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) {
EXPECT_NE(module_a->unique_id(), module_b->unique_id());
}
+TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_FALSE(module_copy->has_schedule());
+}
+
+TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+ ASSERT_TRUE(module_copy->has_schedule());
+ TF_ASSERT_OK(module_copy->schedule().Verify());
+ EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
+ ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
+ module_copy->entry_computation()));
+ EXPECT_THAT(
+ module_copy->schedule()
+ .sequence(module_copy->entry_computation())
+ .instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 2105f7a349..f1dc08bafa 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -293,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
!LiveRangeStrictlyBefore(b, a, dataflow);
}
-HloOrderingProto HloOrdering::ToProto() const {
- HloOrderingProto proto;
- for (const auto& computation : module_->computations()) {
- const std::vector<const HloInstruction*>* sequence =
- SequentialOrder(*computation);
- if (sequence != nullptr) {
- HloOrderingProto::SequentialComputation* proto_computation =
- proto.add_sequential_computations();
- proto_computation->set_computation_name(computation->name());
- for (const HloInstruction* instruction : *sequence) {
- *proto_computation->add_instruction_names() = instruction->name();
- }
- }
- }
- return proto;
-}
-
PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
: HloOrdering(module) {}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index b21071c4b2..b0361c3f02 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -72,10 +72,6 @@ class HloOrdering {
virtual string ToString() const = 0;
- // Returns the serialized representation of this ordering.
- // Only sequential computation orders are represented.
- HloOrderingProto ToProto() const;
-
protected:
// Returns true if instruction 'a' executes before instruction 'b'.
// Precondition: 'a' and 'b' are in the same computation.
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 0f26ed4235..c54360b063 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
@@ -44,6 +45,20 @@ using absl::StrJoin;
const double kF16max = 65504;
+// Creates and returns a schedule created using the order of the instructions in
+// the HloComputation::instructions() vectors in the module.
+HloSchedule ScheduleFromInstructionOrder(const HloModule* module) {
+ HloSchedule schedule(module);
+ for (const HloComputation* computation : module->computations()) {
+ if (!computation->IsFusionComputation()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ schedule.GetOrCreateSequence(computation).push_back(instruction);
+ }
+ }
+ }
+ return schedule;
+}
+
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
@@ -366,9 +381,25 @@ bool HloParser::ParseHloModule() {
return false;
}
+ absl::optional<bool> is_scheduled;
+ std::unordered_map<string, AttrConfig> attrs;
+ attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled};
+ if (!ParseAttributes(attrs)) {
+ return false;
+ }
+
module_ = absl::make_unique<HloModule>(name, config_);
- return ParseComputations();
+ if (!ParseComputations()) {
+ return false;
+ }
+
+ if (is_scheduled.has_value() && *is_scheduled) {
+ TF_CHECK_OK(
+ module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ }
+
+ return true;
}
// computations ::= (computation)+
@@ -1248,11 +1279,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<string> custom_call_target;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
+ optional<int64> feature_group_count;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
+ attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
+ &feature_group_count};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@@ -1264,6 +1298,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (dnums.has_value()) {
instruction->set_convolution_dimension_numbers(*dnums);
}
+ if (feature_group_count.has_value()) {
+ instruction->set_feature_group_count(*feature_group_count);
+ }
break;
}
case HloOpcode::kDot: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 0dfc0a4d1c..cca50fab54 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1123,18 +1123,31 @@ ENTRY Iota {
)"
},
-// custom-call with window and dim_labels
+// custom-call with window, dim_labels and feature_group_count
{
-"CustomCallWithWindowAndDimLabels",
-R"(HloModule CustomCallWithWindowAndDimLabels
+"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
+R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
ENTRY Computation {
- ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target"
+ ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
}
)"
+ },
+// is_scheduled=true attribute
+{
+"ScheduledModule",
+R"(HloModule scheduled_module, is_scheduled=true
+
+ENTRY Sort {
+ keys = f32[1024]{0} parameter(0)
+ values = s32[1024]{0} parameter(1)
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
}
- });
+
+)"
+}
+});
// clang-format on
}
@@ -1790,5 +1803,94 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
EXPECT_EQ(convolution->feature_group_count(), 1);
}
+TEST_F(HloParserTest, IsScheduledIsFalse) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=false
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledNotPresent) {
+ const string text = R"(
+HloModule axpy_module
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_FALSE(module->has_schedule());
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrue) {
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %x = f32[2,4]{1,0} parameter(1)
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ %y = f32[2,4]{1,0} parameter(2)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(),
+ op::Multiply(), op::Parameter(), op::Add()));
+}
+
+TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
+ // As above but in with a different schedule order.
+ const string text = R"(
+HloModule axpy_module, is_scheduled=true
+
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+ %alpha = f32[] parameter(0)
+ %x = f32[2,4]{1,0} parameter(1)
+ %y = f32[2,4]{1,0} parameter(2)
+ %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
+ %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
+ ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ ASSERT_TRUE(module->has_schedule());
+ TF_ASSERT_OK(module->schedule().Verify());
+ EXPECT_EQ(module->schedule().sequences().size(), 1);
+ ASSERT_TRUE(
+ module->schedule().is_computation_scheduled(module->entry_computation()));
+ EXPECT_THAT(
+ module->schedule().sequence(module->entry_computation()).instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
+ op::Broadcast(), op::Multiply(), op::Add()));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index 3460679558..b9c0b0c4ee 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -23,11 +23,8 @@ namespace xla {
HloProto MakeHloProto(const HloModule& module,
const BufferAssignment& assignment) {
- HloOrderingProto proto_ordering =
- assignment.liveness().hlo_ordering().ToProto();
BufferAssignmentProto proto_assignment = assignment.ToProto();
HloProto proto = MakeHloProto(module);
- proto.mutable_hlo_ordering()->Swap(&proto_ordering);
proto.mutable_buffer_assignment()->Swap(&proto_assignment);
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
index a65b33bf40..3fc5dbeb02 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -21,12 +21,64 @@ limitations under the License.
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace xla {
+/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
+ const HloModule* module, const HloScheduleProto& proto) {
+ tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
+ for (const HloComputation* computation : module->computations()) {
+ id_to_computation[computation->unique_id()] = computation;
+ }
+
+ HloSchedule schedule(module);
+ for (const auto& id_sequence : proto.sequences()) {
+ int64 computation_id = id_sequence.first;
+
+ auto comp_it = id_to_computation.find(computation_id);
+ TF_RET_CHECK(comp_it != id_to_computation.end())
+ << "No computation exists in HLO module with id " << computation_id;
+ const HloComputation* computation = comp_it->second;
+
+ tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
+ for (const HloInstruction* instruction : computation->instructions()) {
+ id_to_instruction[instruction->unique_id()] = instruction;
+ }
+
+ HloInstructionSequence& sequence =
+ schedule.GetOrCreateSequence(computation);
+ for (const int64 instruction_id : id_sequence.second.instruction_ids()) {
+ auto instr_it = id_to_instruction.find(instruction_id);
+ TF_RET_CHECK(instr_it != id_to_instruction.end())
+ << "No instruction exists in HLO computation " << computation->name()
+ << " with id " << instruction_id;
+ sequence.push_back(instr_it->second);
+ }
+ }
+ TF_RETURN_IF_ERROR(schedule.Verify());
+ return std::move(schedule);
+}
+
+StatusOr<HloScheduleProto> HloSchedule::ToProto() const {
+ TF_RETURN_IF_ERROR(Verify());
+ HloScheduleProto proto;
+ for (const auto& id_sequence : sequences_) {
+ int64 computation_id = id_sequence.first;
+ const HloInstructionSequence& sequence = id_sequence.second;
+ HloScheduleProto::InstructionSequence& proto_sequence =
+ (*proto.mutable_sequences())[computation_id];
+ proto_sequence.mutable_instruction_ids()->Reserve(sequence.size());
+ for (const int64 id : sequence.ids()) {
+ proto_sequence.add_instruction_ids(id);
+ }
+ }
+ return std::move(proto);
+}
+
void HloSchedule::set_sequence(
const HloComputation* computation,
absl::Span<const HloInstruction* const> sequence) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
index 21c6988638..270fe6039f 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -21,18 +21,20 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/status.h"
namespace xla {
+class HloModule;
+
// Class representing a sequence of HLO instructions such as the sequential
// execution order of an HLO computation.
class HloInstructionSequence {
public:
HloInstructionSequence() = default;
- HloInstructionSequence(absl::Span<const HloInstruction* const> instructions) {
+ explicit HloInstructionSequence(
+ absl::Span<const HloInstruction* const> instructions) {
for (const HloInstruction* instruction : instructions) {
push_back(instruction);
}
@@ -77,7 +79,12 @@ class HloInstructionSequence {
// non-fusion computation in the HLO module.
class HloSchedule {
public:
- HloSchedule(const HloModule* module) : module_(module) {}
+ explicit HloSchedule(const HloModule* module) : module_(module) {}
+
+ // (De)Serialize an HloSchedule to/from a HloScheduleProto.
+ static StatusOr<HloSchedule> CreateFromProto(const HloModule* module,
+ const HloScheduleProto& proto);
+ StatusOr<HloScheduleProto> ToProto() const;
// Returns a reference to the sequence for the given computation.
const HloInstructionSequence& sequence(
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 6e17711f57..082bf8bffe 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -855,8 +855,7 @@ void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
? instruction.sharding().GetSubSharding(instruction.shape(), index)
: instruction.sharding();
// We propagate the sharding to the copied instruction only if it is a
- // special sharding, like tiled ones, or special devices like the
- // HostCompute module.
+ // special sharding, like tiled ones.
// Otherwise it is preferable to leave the new instruction without device,
// and let the automatic device placer to choose the best location.
auto device = sharding.UniqueDevice();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
index 7d49b8d6c2..a60643bc75 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc
@@ -75,6 +75,16 @@ void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
}
}
+void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
+ llvm::IRBuilder<>* b, llvm::Module* module) {
+ std::vector<llvm::Value*> buffer_ptrs;
+ buffer_ptrs.reserve(buffers.size());
+ absl::c_transform(
+ buffers, std::back_inserter(buffer_ptrs),
+ [](const llvm_ir::IrArray& buffer) { return buffer.GetBasePointer(); });
+ llvm_ir::EmitTuple(tuple, buffer_ptrs, b, module);
+}
+
llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
int alignment, llvm::Value* operand,
llvm::IRBuilder<>* b, llvm::Module* module) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
index 887fb61371..94340b91d8 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h
@@ -68,6 +68,11 @@ void EmitTupleSelect(const IrArray& select, const IrArray& pred,
void EmitTuple(const IrArray& tuple, absl::Span<llvm::Value* const> operands,
llvm::IRBuilder<>* b, llvm::Module* module);
+// Similar to EmitTuple above, except that the output buffers are provided in
+// the form of IrArray.
+void EmitTuple(const IrArray& tuple, absl::Span<const IrArray> buffers,
+ llvm::IRBuilder<>* b, llvm::Module* module);
+
// A tuple is an array of pointers, one for each operand. Each pointer points to
// the output buffer of its corresponding operand. A GetTupleElement instruction
// forwards the pointer to underlying tuple element buffer at the given index.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 36b8fb2644..d0bda45cf8 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -75,7 +75,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
- "//tensorflow/core:stream_executor_headers_lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 8c62adea23..57f7fed61f 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -866,10 +866,7 @@ INSTANTIATE_TEST_CASE_P(
BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
-// TODO(b/64093391) Disabled on GPU due to an assertion failure when running
-// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on
-// 2017-07-26.
-XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) {
+XLA_TEST_F(ReduceTest, OperationOnConstantAsInitValue) {
XlaBuilder builder(TestName());
XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index c20a7c8fe4..3ae31191a0 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -417,4 +417,18 @@ Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
.status();
}
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2);
+ CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2);
+ PrecisionConfig precision_config;
+ precision_config.mutable_operand_precision()->Resize(
+ 2, PrecisionConfig::DEFAULT);
+ DotDimensionNumbers dot_dimension_numbers;
+ dot_dimension_numbers.add_lhs_contracting_dimensions(1);
+ dot_dimension_numbers.add_rhs_contracting_dimensions(0);
+ return absl::make_unique<HloDotInstruction>(
+ shape, lhs, rhs, dot_dimension_numbers, precision_config);
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index 7790737c09..a260271b1b 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -24,10 +24,10 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/stream_executor/platform.h"
namespace xla {
@@ -98,6 +98,12 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
Status VerifyHloModule(HloModule* const module, bool layout_sensitive,
bool allow_mixed_precision);
+// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 of
+// the LHS with dimension 0 of the RHS with no batch dimensions.
+// Both LHS and the RHS must be of rank 2.
+std::unique_ptr<HloDotInstruction> CreateCanonicalDot(const Shape& shape,
+ HloInstruction* lhs,
+ HloInstruction* rhs);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
index 23ce1d235b..0c3ec5934e 100644
--- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
+++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
@@ -67,8 +67,8 @@ int main(int argc, char** argv) {
floats.push_back(value);
}
- absl::string_view content(absl::bit_cast<const char*>(floats.data()),
- floats.size() * sizeof(float));
+ tensorflow::StringPiece content(absl::bit_cast<const char*>(floats.data()),
+ floats.size() * sizeof(float));
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
output_file, content));
return 0;
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py
index 16eb1f0e3f..41c3424fa3 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions.py
@@ -57,8 +57,8 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'autograph_utils.dynamic_is',
- gast.IsNot: 'autograph_utils.dynamic_is_not'
+ gast.Is: 'ag__.utils.dynamic_is',
+ gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
index 8f9eee7081..409a73afba 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
@@ -47,6 +47,15 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(True, False, True)))
+ def test_ag_utils_lookup(self):
+ def test_fn(a, b):
+ return a is b or a is not b
+
+ with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
+ ) as result:
+ with self.cached_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(True, False)))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 803fde9089..a4c6fed265 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -38,9 +38,6 @@ class ApiTest(test.TestCase):
def setUp(self):
config.COMPILED_IMPORT_STATEMENTS = (
'from __future__ import print_function',
- 'from tensorflow.contrib.autograph import utils'
- ' as autograph_utils',
- 'tf = autograph_utils.fake_tf()',
)
def test_decorator_recurses(self):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
index e42f679cfe..d77c15915b 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
@@ -394,10 +394,16 @@ class AnfTransformer(transformer.Base):
# just recur.
def visit_List(self, node):
- return self._visit_strict_expression(node)
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
def visit_Tuple(self, node):
- return self._visit_strict_expression(node)
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
def transform(node, entity_info, gensym_source=None):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
index 951974820c..1ffd4bbe55 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
@@ -165,6 +165,46 @@ class AnfTransformerTest(test.TestCase):
self.assert_body_anfs_as_expected(expected_result, test_function)
+ def test_nested_multi_value_assign(self):
+
+ def test_function(a, b, c):
+ x, y = a, a + b
+ (z, y), x = (c, y + b), x + a
+ return z, (y, x)
+
+ def expected_result(a, b, c):
+ tmp_1001 = a + b
+ x, y = a, tmp_1001
+ tmp_1002 = y + b
+ tmp_1003 = (c, tmp_1002)
+ tmp_1004 = x + a
+ (z, y), x = tmp_1003, tmp_1004
+ tmp_1005 = y, x
+ tmp_1006 = z, tmp_1005
+ return tmp_1006
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_deeply_nested_multi_value_assign(self):
+
+ def test_function(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]]
+
+ def expected_result(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ tmp_1001 = b, c
+ tmp_1002 = [d, e]
+ tmp_1003 = [tmp_1001, tmp_1002]
+ tmp_1004 = f, g
+ tmp_1005 = h, i, j
+ tmp_1006 = tmp_1003, tmp_1004
+ tmp_1007 = [tmp_1005, k]
+ tmp_1008 = [tmp_1006, tmp_1007]
+ return tmp_1008
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
def test_local_definition_and_binary_compare(self):
def test_function():
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
index 2d8f922a45..e7baa244b2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
@@ -29,6 +29,11 @@ from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+# TODO(aqj): Do we need this? Do other builtins fail in similar ways
+# See b/114389775 for a related bug in pyct
+# These symbols are legal in Python, but don't appear in the namespace.
+_special_symbols = {'range': range}
+
class LiveValueResolver(transformer.Base):
"""Annotates nodes with live values."""
@@ -66,6 +71,8 @@ class LiveValueResolver(transformer.Base):
# If the symbol value is for example a primitive, then it will not
# have a name.
pass
+ elif node.id in _special_symbols:
+ anno.setanno(node, 'live_val', _special_symbols[node.id])
else:
pass
# TODO(mdan): Should we raise an error here?
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 870ce2442b..4c7a538b38 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -52,7 +52,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -94,6 +95,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -134,7 +136,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -159,7 +162,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -201,6 +205,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -224,7 +229,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -251,7 +257,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -289,6 +296,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -303,7 +311,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -329,7 +338,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
center_bias=False,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -377,6 +387,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
+
Raises:
ValueError: If learner_config is not valid.
"""
@@ -395,7 +407,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -444,7 +457,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
Args:
@@ -474,6 +488,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
def _model_fn(features, labels, mode, config):
@@ -493,7 +508,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
@@ -517,7 +533,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
label_keys=None,
logits_modifier_function=None,
center_bias=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -552,6 +569,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -576,7 +594,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 04b46c3483..a6e422847d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -81,6 +81,7 @@ def model_builder(features,
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -116,7 +117,8 @@ def model_builder(features,
logits_dimension=head.logits_dimension,
features=training_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
@@ -237,6 +239,7 @@ def ranking_model_builder(features,
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -299,7 +302,8 @@ def ranking_model_builder(features,
logits_dimension=head.logits_dimension,
features=main_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
# Logits for inference.
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index b008c6e534..c7eb2493a8 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object):
feature_columns=None,
use_core_columns=False,
output_leaf_index=False,
- output_leaf_index_modes=None):
+ output_leaf_index_modes=None,
+ num_quantiles=100):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object):
output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
dictates when leaf indices will be outputted. By default, leaf indices
are only outputted in INFER mode.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: if inputs are not valid.
@@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
+ self._num_quantiles = num_quantiles
self._max_tree_depth = variables.Variable(
initial_value=self._learner_config.constraints.max_tree_depth)
self._attempted_trees = variables.Variable(
@@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object):
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
weak_learner_type = constant_op.constant(
self._learner_config.weak_learner_type)
- epsilon = 0.01
- num_quantiles = 100
+ num_quantiles = self._num_quantiles
+ epsilon = 1.0 / num_quantiles
strategy_tensor = constant_op.constant(strategy)
with ops.device(self._get_replica_device_setter(worker_device)):
# Create handlers for dense float columns
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1ab150d74a..1056894f18 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver):
def get_master(self):
return self.master()
+ def get_job_name(self):
+ if self._shouldResolve():
+ return self._job_name
+
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 34f594f741..b9320e5fef 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -279,7 +279,9 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 9d8e955245..67242fecfe 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -428,10 +428,10 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
@parameterized.named_parameters(
- ("default", None, None),
- ("sequential_calls", 1, None),
- ("parallel_calls", 2, None),
- ("parallel_batches", None, 10),
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
)
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
@@ -505,8 +505,8 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
- ("even", False),
- ("uneven", True),
+ ("Even", False),
+ ("Uneven", True),
)
def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
@@ -663,7 +663,14 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
for _ in range(3):
sess.run(get_next)
- @parameterized.parameters(0, 5, 10, 90, 95, 99)
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
def testMapAndBatchOutOfRangeError(self, threshold):
def raising_py_fn(i):
@@ -689,18 +696,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (False, dtypes.bool),
- (-42, dtypes.int8),
- (-42, dtypes.int16),
- (-42, dtypes.int32),
- (-42, dtypes.int64),
- (42, dtypes.uint8),
- (42, dtypes.uint16),
- (42.0, dtypes.float16),
- (42.0, dtypes.float32),
- (42.0, dtypes.float64),
- (b"hello", dtypes.string),
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
)
def testMapAndBatchTypes(self, element, dtype):
def gen():
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 091eb5ce37..61567bc8d7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
+
from tensorflow.contrib.data.python.ops import map_defun
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,10 +28,10 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-
class MapDefunTest(test.TestCase):
def testMapDefunSimple(self):
@@ -146,6 +149,105 @@ class MapDefunTest(test.TestCase):
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op)
+ def testMapDefunWithUnspecifiedOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ res = x * 2 + 3
+ return (res, res + 1, res + 2)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems],
+ [dtypes.int32, dtypes.int32, dtypes.int32],
+ [None, (None,), (2,)])
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+ self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+ self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+ def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ elems = array_ops.placeholder(dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+ self.assertAllEqual(
+ sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+ def testMapDefunWithWrongOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunWithInvalidInput(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2
+
+ c = constant_op.constant(2)
+ with self.assertRaises(ValueError):
+ # Fails at graph construction time for inputs with known shapes.
+ r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+ p = array_ops.placeholder(dtypes.int32)
+ r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(r, feed_dict={p: 0})
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+ def _run(self, op, name=None, num_iters=3000):
+ with session.Session() as sess:
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op)
+ start = time.time()
+ for _ in range(num_iters):
+ sess.run(op)
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmarkDefunVsMapFn(self):
+ """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+
+ @function.Defun(dtypes.int32)
+ def defun(x):
+ return array_ops.identity(x)
+
+ def map_fn(x):
+ return array_ops.identity(x)
+
+ base = math_ops.range(100)
+ for input_size in [10, 100, 1000, 10000]:
+ num_iters = 100000 // input_size
+ map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+ map_fn_op = functional_ops.map_fn(map_fn, base)
+
+ self._run(
+ map_defun_op,
+ "benchmarkMapDefun_size_%d" % input_size,
+ num_iters=num_iters)
+ self._run(
+ map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index 586b4bee5f..6a7ef877f9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -44,22 +44,22 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
for i, fun1 in enumerate(functions):
for j, fun2 in enumerate(functions):
tests.append((
- "test_{}_{}".format(i, j),
+ "Test{}{}".format(i, j),
[fun1, fun2],
))
for k, fun3 in enumerate(functions):
tests.append((
- "test_{}_{}_{}".format(i, j, k),
+ "Test{}{}{}".format(i, j, k),
[fun1, fun2, fun3],
))
swap = lambda x, n: (n, x)
tests.append((
- "swap1",
+ "Swap1",
[lambda x: (x, 42), swap],
))
tests.append((
- "swap2",
+ "Swap2",
[lambda x: (x, 42), swap, swap],
))
return tuple(tests)
@@ -109,13 +109,13 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
for x, fun in enumerate(functions):
for y, predicate in enumerate(filters):
- tests.append(("mixed_{}_{}".format(x, y), fun, predicate))
+ tests.append(("Mixed{}{}".format(x, y), fun, predicate))
# Multi output
- tests.append(("multiOne", lambda x: (x, x),
+ tests.append(("Multi1", lambda x: (x, x),
lambda x, y: constant_op.constant(True)))
tests.append(
- ("multiTwo", lambda x: (x, 2),
+ ("Multi2", lambda x: (x, 2),
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
return tuple(tests)
@@ -172,17 +172,17 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
identity = lambda x: x
for x, predicate_1 in enumerate(filters):
for y, predicate_2 in enumerate(filters):
- tests.append(("mixed_{}_{}".format(x, y), identity,
+ tests.append(("Mixed{}{}".format(x, y), identity,
[predicate_1, predicate_2]))
for z, predicate_3 in enumerate(filters):
- tests.append(("mixed_{}_{}_{}".format(x, y, z), identity,
+ tests.append(("Mixed{}{}{}".format(x, y, z), identity,
[predicate_1, predicate_2, predicate_3]))
take_all_multiple = lambda x, y: constant_op.constant(True)
# Multi output
- tests.append(("multiOne", lambda x: (x, x),
+ tests.append(("Multi1", lambda x: (x, x),
[take_all_multiple, take_all_multiple]))
- tests.append(("multiTwo", lambda x: (x, 2), [
+ tests.append(("Multi2", lambda x: (x, 2), [
take_all_multiple,
lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
]))
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 4881f63ab9..aa89674c6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -210,6 +210,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
index ac3892fe81..243f6405a1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
@@ -27,42 +28,38 @@ from tensorflow.python.platform import test
class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
- def _build_iterator_graph(self, input_values, cycle_length, block_length):
+ def _build_iterator_graph(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
repeat_count = 2
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
repeat_count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length)
+ cycle_length, block_length, num_parallel_calls)
- def testSerializationCore(self):
+ @parameterized.named_parameters(
+ ("1", 2, 3, None),
+ ("2", 2, 3, 1),
+ ("3", 2, 3, 2),
+ ("4", 1, 3, None),
+ ("5", 1, 3, 1),
+ ("6", 2, 1, None),
+ ("7", 2, 1, 1),
+ ("8", 2, 1, 2),
+ )
+ def testSerializationCore(self, cycle_length, block_length,
+ num_parallel_calls):
input_values = np.array([4, 5, 6], dtype=np.int64)
num_outputs = np.sum(input_values) * 2
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
# pylint: disable=g-long-lambda
self.run_core_tests(
lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
+ input_values, cycle_length, block_length, num_parallel_calls),
lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length * 1),
+ input_values, cycle_length * 2, block_length, num_parallel_calls),
num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
# pylint: enable=g-long-lambda
def testSparseCore(self):
@@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest(
self.run_core_tests(_build_dataset, None, 20)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 8b2f846494..6b3e8e9f6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -32,18 +32,18 @@ from tensorflow.python.platform import test
class SlideDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDataset(self, count, window_size, window_shift, window_stride):
"""Tests a dataset that slides a window its input elements."""
@@ -96,18 +96,18 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDatasetDeprecated(self, count, window_size, stride,
window_stride):
@@ -160,10 +160,10 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (14, 0, 3, 1),
- (14, 3, 0, 1),
- (14, 3, 3, 0),
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
)
def testSlideDatasetInvalid(self, count, window_size, window_shift,
window_stride):
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 0486e2bce2..4b08ec759d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -33,8 +33,17 @@ from tensorflow.python.platform import test
class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters((1, None), (2, None), (4, None), (8, None),
- (16, None), (4, -1), (4, 0), (4, 1), (4, 4))
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", 4, -1),
+ ("7", 4, 0),
+ ("8", 4, 1),
+ ("9", 4, 4),
+ )
def testNumThreads(self, num_threads, max_intra_op_parallelism):
def get_thread_id(_):
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 33d95d6754..ff4d9b3260 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -64,15 +64,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual(xs, ys)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetFlatMap(self, structure, shape, dtype):
"""Tests windowing by chaining it with flat map.
@@ -97,15 +97,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchDense(self, structure, shape, dtype):
"""Tests batching of dense tensor windows.
@@ -135,10 +135,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchDenseDynamicShape(self, shape):
"""Tests batching of dynamically shaped dense tensor windows.
@@ -203,15 +203,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchSparse(self, structure, shape, dtype):
"""Tests batching of sparse tensor windows.
@@ -243,10 +243,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchSparseDynamicShape(self, shape):
"""Tests batching of dynamically shaped sparse tensor windows.
@@ -284,17 +284,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
]))
- @parameterized.parameters(
- (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
)
def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
padded_shape):
@@ -329,10 +330,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1], [2], [3]]), [-1]),
- (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1], [2], [3]]), [-1]),
+ ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
"""Tests padded batching of dynamically shaped dense tensor windows.
@@ -361,9 +362,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1]]), np.int32([0])),
- (np.int32([[10], [20]]), np.int32([15])),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1]]), np.int32([0])),
+ ("2", np.int32([[10], [20]]), np.int32([15])),
)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of dense tensor windows.
@@ -420,17 +421,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
)
def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
padded_shape):
@@ -463,10 +465,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1], [2], [3]]), [-1]),
- (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1], [2], [3]]), [-1]),
+ ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
padded_shape):
@@ -495,9 +497,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1]]), [0]),
- (np.int64([[10], [20]]), [15]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1]]), [0]),
+ ("2", np.int64([[10], [20]]), [15]),
)
def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of sparse tensor windows.
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py
index 54d5cd6da0..3d0d0993c9 100644
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ b/tensorflow/contrib/data/python/ops/map_defun.py
@@ -53,6 +53,4 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- if not all(s.is_fully_defined() for s in output_shapes):
- raise ValueError("All fn output shapes must be fully defined.")
return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index d39fd57294..3cee3e37a7 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -446,8 +446,7 @@ class TestWithDistributionStrategy(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- with self.assertRaisesRegexp(ValueError,
- 'expected input to have 2 dimensions'):
+ with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
# Wrong input shape
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 4fb70ec685..6ba83976fc 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -310,7 +310,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
- return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,)
+ job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker'
+ return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id)
def configure(self,
session_config=None,
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 77f62df99d..437b3d965d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -446,6 +446,7 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 7c49cd00d1..98660bb731 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import training_util
@@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator):
weight_column=None,
label_vocabulary=None,
optimizer='Adagrad',
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
input_layer_partitioner=None,
config=None):
"""Initializes a `RNNClassifier` instance.
@@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator):
string.
optimizer: An instance of `tf.Optimizer` or string specifying optimizer
type. Defaults to Adagrad optimizer.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
@@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator):
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
- n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+
def _model_fn(features, labels, mode, config):
return _rnn_model_fn(
features=features,
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 959b40371a..1aebed348d 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -713,7 +713,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+ mock_optimizer = self._mock_optimizer(expected_loss=0.559831)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -748,7 +748,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+ mock_optimizer = self._mock_optimizer(expected_loss=1.331465)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -812,20 +812,32 @@ class RNNClassifierEvaluationTest(test.TestCase):
# probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
# loss = -label * ln(p) - (1 - label) * ln(1 - p)
# = [[0.436326], [0.683335]]
+ # sum_over_batch_size = (0.436326 + 0.683335)/2
expected_metrics = {
- ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 1.119661,
- metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
- metric_keys.MetricKeys.ACCURACY: 1.0,
- metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
- metric_keys.MetricKeys.LABEL_MEAN: 0.5,
- metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ ops.GraphKeys.GLOBAL_STEP:
+ global_step,
+ metric_keys.MetricKeys.LOSS:
+ 0.559831,
+ metric_keys.MetricKeys.LOSS_MEAN:
+ 0.559831,
+ metric_keys.MetricKeys.ACCURACY:
+ 1.0,
+ metric_keys.MetricKeys.PREDICTION_MEAN:
+ 0.429262,
+ metric_keys.MetricKeys.LABEL_MEAN:
+ 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE:
+ 0.5,
# With default threshold of 0.5, the model is a perfect classifier.
- metric_keys.MetricKeys.RECALL: 1.0,
- metric_keys.MetricKeys.PRECISION: 1.0,
+ metric_keys.MetricKeys.RECALL:
+ 1.0,
+ metric_keys.MetricKeys.PRECISION:
+ 1.0,
# Positive example is scored above negative, so AUC = 1.0.
- metric_keys.MetricKeys.AUC: 1.0,
- metric_keys.MetricKeys.AUC_PR: 1.0,
+ metric_keys.MetricKeys.AUC:
+ 1.0,
+ metric_keys.MetricKeys.AUC_PR:
+ 1.0,
}
self.assertAllClose(
sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
@@ -871,9 +883,10 @@ class RNNClassifierEvaluationTest(test.TestCase):
# [0.059494, 0.572639, 0.367866]]
# loss = -1. * log(softmax[label])
# = [[2.105432], [0.557500]]
+ # sum_over_batch_size = (2.105432 + 0.557500)/2
expected_metrics = {
ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 2.662932,
+ metric_keys.MetricKeys.LOSS: 1.331465,
metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
metric_keys.MetricKeys.ACCURACY: 0.5,
}
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 0ccb4583ab..716bb87e38 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -174,7 +174,7 @@ class FusedConv2DBiasActivationOp : public OpKernel {
// Input bias is a 1-D tensor, with size matching output depth.
const Tensor& bias = context->input(kBias);
- OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
+ OP_REQUIRES_OK(context, CheckShape(bias, "bias"));
const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 418b0cf392..61185f65a9 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -403,6 +403,7 @@ py_test(
srcs = ["python/learn/estimators/dnn_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
+ tags = ["notap"],
deps = [
":learn",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 0091587bf7..f320b53d94 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -36,10 +36,10 @@ cc_library(
srcs = ["arena_planner.cc"],
hdrs = ["arena_planner.h"],
deps = [
- ":context",
":graph_info",
":memory_planner",
":simple_memory_arena",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -54,6 +54,7 @@ cc_test(
deps = [
":arena_planner",
"//tensorflow/contrib/lite/testing:util",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_googletest//:gtest",
],
@@ -63,27 +64,27 @@ cc_test(
# TODO(aselle): Resolve problems preventing C99 usage.
cc_library(
name = "context",
- srcs = ["context.c"],
hdrs = ["context.h"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "graph_info",
hdrs = ["graph_info.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "memory_planner",
hdrs = ["memory_planner.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "simple_memory_arena",
srcs = ["simple_memory_arena.cc"],
hdrs = ["simple_memory_arena.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
@@ -91,7 +92,7 @@ cc_library(
hdrs = [
"builtin_op_data.h",
],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
@@ -121,12 +122,12 @@ cc_library(
name = "framework",
srcs = [
"allocation.cc",
- "error_reporter.cc",
"graph_info.cc",
"interpreter.cc",
"model.cc",
- "op_resolver.cc",
+ "mutable_op_resolver.cc",
"optional_debug_tools.cc",
+ "stderr_reporter.cc",
] + select({
"//tensorflow:android": [
"nnapi_delegate.cc",
@@ -149,9 +150,11 @@ cc_library(
"graph_info.h",
"interpreter.h",
"model.h",
+ "mutable_op_resolver.h",
"nnapi_delegate.h",
"op_resolver.h",
"optional_debug_tools.h",
+ "stderr_reporter.h",
],
copts = tflite_copts(),
linkopts = [
@@ -164,14 +167,14 @@ cc_library(
}),
deps = [
":arena_planner",
- ":builtin_op_data",
- ":context",
":graph_info",
":memory_planner",
":schema_fbs_version",
":simple_memory_arena",
":string",
":util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:eigen_support",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
@@ -210,6 +213,8 @@ cc_test(
deps = [
":framework",
":string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
@@ -259,6 +264,8 @@ cc_test(
],
deps = [
":framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -266,9 +273,9 @@ cc_test(
# Test OpResolver.
cc_test(
- name = "op_resolver_test",
+ name = "mutable_op_resolver_test",
size = "small",
- srcs = ["op_resolver_test.cc"],
+ srcs = ["mutable_op_resolver_test.cc"],
tags = ["no_oss"],
deps = [
":framework",
@@ -277,24 +284,12 @@ cc_test(
],
)
-# Test the C extension API code.
-cc_test(
- name = "context_test",
- size = "small",
- srcs = ["context_test.cc"],
- deps = [
- ":framework",
- "//tensorflow/contrib/lite/testing:util",
- "@com_google_googletest//:gtest",
- ],
-)
-
cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
- ":context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -304,7 +299,6 @@ cc_test(
srcs = ["util_test.cc"],
tags = ["no_oss"],
deps = [
- ":context",
":util",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index 8946261814..21cb1832a7 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <cstring>
#include <utility>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
index 121f3d2646..182bc0977f 100644
--- a/tensorflow/contrib/lite/allocation.h
+++ b/tensorflow/contrib/lite/allocation.h
@@ -20,8 +20,8 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
#include "tensorflow/contrib/lite/string.h"
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index 55003cf4e9..382577045b 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
@@ -37,8 +37,8 @@ struct AllocationInfo;
// each tensor needs to be allocated and deallocated, and preallocates all the
// necessary memory (the PlanAllocations phase). It then assigns portions of
// this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may
-// share some of the buffer if a tensor B is to be allocated after another tensor
-// A has been deallocated.
+// share some of the buffer if a tensor B is to be allocated after another
+// tensor A has been deallocated.
//
// If dynamic tensors are used the planning steps can be repeated during model
// execution. Since dynamic tensors don't have sizes until after the
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 0246e7fa30..9317e2bb6e 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -49,6 +49,9 @@ def tflite_linkopts_unstripped():
Returns:
a select object with proper linkopts
"""
+
+ # In case you wonder why there's no --icf is because the gains were
+ # negligible, and created potential compatibility problems.
return select({
"//tensorflow:android": [
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
@@ -56,13 +59,7 @@ def tflite_linkopts_unstripped():
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
- "//tensorflow:darwin": [],
- "//tensorflow:ios": [],
- "//tensorflow/contrib/lite:mips": [],
- "//tensorflow/contrib/lite:mips64": [],
- "//conditions:default": [
- "-Wl,--icf=all", # Identical code folding.
- ],
+ "//conditions:default": [],
})
def tflite_jni_linkopts_unstripped():
@@ -74,17 +71,15 @@ def tflite_jni_linkopts_unstripped():
Returns:
a select object with proper linkopts
"""
+
+ # In case you wonder why there's no --icf is because the gains were
+ # negligible, and created potential compatibility problems.
return select({
"//tensorflow:android": [
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
- "//tensorflow:darwin": [],
- "//tensorflow/contrib/lite:mips": [],
- "//tensorflow/contrib/lite:mips64": [],
- "//conditions:default": [
- "-Wl,--icf=all", # Identical code folding.
- ],
+ "//conditions:default": [],
})
def tflite_linkopts():
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index aecd71910c..30901bd0fa 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -12,297 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for new location of interface definitions.
+
#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
-#include <stdint.h>
-
-#include "tensorflow/contrib/lite/context.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-// TODO(aselle): Consider using "if this then that" for testing.
-
-// Useful placeholder to put in otherwise empty structs to avoid size warnings.
-typedef struct {
- char dummy_;
-} EmptyStructPlaceholder;
-
-// Possible padding types (for convolutions)
-typedef enum {
- kTfLitePaddingUnknown = 0,
- kTfLitePaddingSame,
- kTfLitePaddingValid,
-} TfLitePadding;
-
-typedef struct {
- int width;
- int height;
-} TfLitePaddingValues;
-
-// Possible fused activation functions.
-// TODO(aselle): rename to TfLiteActivation
-typedef enum {
- kTfLiteActNone = 0,
- kTfLiteActRelu,
- kTfLiteActRelu1,
- kTfLiteActRelu6,
- kTfLiteActTanh,
- kTfLiteActSignBit,
- kTfLiteActSigmoid,
-} TfLiteFusedActivation;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int dilation_width_factor;
- int dilation_height_factor;
- TfLiteFusedActivation activation;
-} TfLiteConvParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int filter_width;
- int filter_height;
- TfLiteFusedActivation activation;
- struct {
- TfLitePaddingValues padding;
- } computed;
-} TfLitePoolParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int depth_multiplier;
- TfLiteFusedActivation activation;
-} TfLiteDepthwiseConvParams;
-
-typedef struct {
- int rank;
- TfLiteFusedActivation activation;
-} TfLiteSVDFParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteRNNParams;
-
-typedef struct {
- bool time_major;
- TfLiteFusedActivation activation;
-} TfLiteSequenceRNNParams;
-
-typedef enum {
- kTfLiteFullyConnectedWeightsFormatDefault = 0,
- kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
-} TfLiteFullyConnectedWeightsFormat;
-
-typedef struct {
- // Parameters for FullyConnected version 1 or above.
- TfLiteFusedActivation activation;
-
- // Parameters for FullyConnected version 2 or above.
- TfLiteFullyConnectedWeightsFormat weights_format;
-} TfLiteFullyConnectedParams;
-
-typedef enum {
- kTfLiteLshProjectionUnknown = 0,
- kTfLiteLshProjectionSparse = 1,
- kTfLiteLshProjectionDense = 2,
-} TfLiteLSHProjectionType;
-
-typedef struct {
- TfLiteLSHProjectionType type;
-} TfLiteLSHProjectionParams;
-
-typedef struct {
- float beta;
-} TfLiteSoftmaxParams;
-
-typedef struct {
- int axis;
- TfLiteFusedActivation activation;
-} TfLiteConcatenationParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteAddParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLiteSpaceToBatchNDParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLiteBatchToSpaceNDParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteMulParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteSubParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteDivParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteL2NormParams;
-
-typedef struct {
- int radius;
- float bias;
- float alpha;
- float beta;
-} TfLiteLocalResponseNormParams;
-
-typedef enum {
- kTfLiteLSTMFullKernel = 0,
- kTfLiteLSTMBasicKernel
-} TfLiteLSTMKernelType;
-
-typedef struct {
- // Parameters for LSTM version 1.
- TfLiteFusedActivation activation;
- float cell_clip;
- float proj_clip;
-
- // Parameters for LSTM version 2.
- // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
- TfLiteLSTMKernelType kernel_type;
-} TfLiteLSTMParams;
-
-typedef struct {
- bool align_corners;
-} TfLiteResizeBilinearParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLitePadParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLitePadV2Params;
-
-typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int shape[8];
- int num_dimensions;
-} TfLiteReshapeParams;
-
-typedef struct {
- int ngram_size;
- int max_skip_size;
- bool include_all_ngrams;
-} TfLiteSkipGramParams;
-
-typedef struct {
- int block_size;
-} TfLiteSpaceToDepthParams;
-
-typedef struct {
- TfLiteType in_data_type;
- TfLiteType out_data_type;
-} TfLiteCastParams;
-
-typedef enum {
- kTfLiteCombinerTypeSum = 0,
- kTfLiteCombinerTypeMean = 1,
- kTfLiteCombinerTypeSqrtn = 2,
-} TfLiteCombinerType;
-
-typedef struct {
- TfLiteCombinerType combiner;
-} TfLiteEmbeddingLookupSparseParams;
-
-typedef struct {
- int axis;
-} TfLiteGatherParams;
-
-typedef struct {
- EmptyStructPlaceholder placeholder_;
-} TfLiteTransposeParams;
-
-typedef struct {
- bool keep_dims;
-} TfLiteReducerParams;
-
-typedef struct {
- int num_splits;
-} TfLiteSplitParams;
-
-typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int squeeze_dims[8];
- int num_squeeze_dims;
-} TfLiteSqueezeParams;
-
-typedef struct {
- int begin_mask;
- int end_mask;
- int ellipsis_mask;
- int new_axis_mask;
- int shrink_axis_mask;
-} TfLiteStridedSliceParams;
-
-typedef struct {
- TfLiteType output_type;
-} TfLiteArgMaxParams;
-
-typedef struct {
- TfLiteType output_type;
-} TfLiteArgMinParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
-} TfLiteTransposeConvParams;
-
-typedef struct {
- bool validate_indices;
-} TfLiteSparseToDenseParams;
-
-typedef struct {
- TfLiteType out_type;
-} TfLiteShapeParams;
-
-typedef struct {
- // Parameters supported by version 1:
- float min;
- float max;
- int num_bits;
-
- // Parameters supported by version 2:
- bool narrow_range;
-} TfLiteFakeQuantParams;
-
-typedef struct {
- int values_count;
- int axis;
-} TfLitePackParams;
-
-typedef struct {
- int axis;
-} TfLiteOneHotParams;
-
-typedef struct {
- int num;
- int axis;
-} TfLiteUnpackParams;
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD
new file mode 100644
index 0000000000..663eb63cad
--- /dev/null
+++ b/tensorflow/contrib/lite/c/BUILD
@@ -0,0 +1,39 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "c_api_internal",
+ srcs = ["c_api_internal.c"],
+ hdrs = [
+ "builtin_op_data.h",
+ "c_api_internal.h",
+ ],
+ visibility = [
+ "//tensorflow/contrib/lite:__subpackages__",
+ ],
+)
+
+# Test the C extension API code.
+cc_test(
+ name = "c_api_internal_test",
+ size = "small",
+ srcs = ["c_api_internal_test.cc"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "builtin_op_data_test",
+ size = "small",
+ srcs = ["builtin_op_data_test.cc"],
+ copts = ["-Wno-unused-variable"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
new file mode 100644
index 0000000000..fa43e6a024
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -0,0 +1,298 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// TODO(aselle): Consider using "if this then that" for testing.
+
+// Possible padding types (for convolutions)
+typedef enum {
+ kTfLitePaddingUnknown = 0,
+ kTfLitePaddingSame,
+ kTfLitePaddingValid,
+} TfLitePadding;
+
+typedef struct {
+ int width;
+ int height;
+} TfLitePaddingValues;
+
+// Possible fused activation functions.
+// TODO(aselle): rename to TfLiteActivation
+typedef enum {
+ kTfLiteActNone = 0,
+ kTfLiteActRelu,
+ kTfLiteActRelu1,
+ kTfLiteActRelu6,
+ kTfLiteActTanh,
+ kTfLiteActSignBit,
+ kTfLiteActSigmoid,
+} TfLiteFusedActivation;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int dilation_width_factor;
+ int dilation_height_factor;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int filter_width;
+ int filter_height;
+ TfLiteFusedActivation activation;
+ struct {
+ TfLitePaddingValues padding;
+ } computed;
+} TfLitePoolParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int depth_multiplier;
+ TfLiteFusedActivation activation;
+} TfLiteDepthwiseConvParams;
+
+typedef struct {
+ int rank;
+ TfLiteFusedActivation activation;
+} TfLiteSVDFParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteRNNParams;
+
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+} TfLiteSequenceRNNParams;
+
+typedef enum {
+ kTfLiteFullyConnectedWeightsFormatDefault = 0,
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
+} TfLiteFullyConnectedWeightsFormat;
+
+typedef struct {
+ // Parameters for FullyConnected version 1 or above.
+ TfLiteFusedActivation activation;
+
+ // Parameters for FullyConnected version 2 or above.
+ TfLiteFullyConnectedWeightsFormat weights_format;
+} TfLiteFullyConnectedParams;
+
+typedef enum {
+ kTfLiteLshProjectionUnknown = 0,
+ kTfLiteLshProjectionSparse = 1,
+ kTfLiteLshProjectionDense = 2,
+} TfLiteLSHProjectionType;
+
+typedef struct {
+ TfLiteLSHProjectionType type;
+} TfLiteLSHProjectionParams;
+
+typedef struct {
+ float beta;
+} TfLiteSoftmaxParams;
+
+typedef struct {
+ int axis;
+ TfLiteFusedActivation activation;
+} TfLiteConcatenationParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteAddParams;
+
+typedef struct {
+} TfLiteSpaceToBatchNDParams;
+
+typedef struct {
+} TfLiteBatchToSpaceNDParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteMulParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteSubParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteDivParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteL2NormParams;
+
+typedef struct {
+ int radius;
+ float bias;
+ float alpha;
+ float beta;
+} TfLiteLocalResponseNormParams;
+
+typedef enum {
+ kTfLiteLSTMFullKernel = 0,
+ kTfLiteLSTMBasicKernel
+} TfLiteLSTMKernelType;
+
+typedef struct {
+ // Parameters for LSTM version 1.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+
+ // Parameters for LSTM version 2.
+ // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
+ TfLiteLSTMKernelType kernel_type;
+} TfLiteLSTMParams;
+
+typedef struct {
+ bool align_corners;
+} TfLiteResizeBilinearParams;
+
+typedef struct {
+} TfLitePadParams;
+
+typedef struct {
+} TfLitePadV2Params;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int shape[8];
+ int num_dimensions;
+} TfLiteReshapeParams;
+
+typedef struct {
+ int ngram_size;
+ int max_skip_size;
+ bool include_all_ngrams;
+} TfLiteSkipGramParams;
+
+typedef struct {
+ int block_size;
+} TfLiteSpaceToDepthParams;
+
+typedef struct {
+ TfLiteType in_data_type;
+ TfLiteType out_data_type;
+} TfLiteCastParams;
+
+typedef enum {
+ kTfLiteCombinerTypeSum = 0,
+ kTfLiteCombinerTypeMean = 1,
+ kTfLiteCombinerTypeSqrtn = 2,
+} TfLiteCombinerType;
+
+typedef struct {
+ TfLiteCombinerType combiner;
+} TfLiteEmbeddingLookupSparseParams;
+
+typedef struct {
+ int axis;
+} TfLiteGatherParams;
+
+typedef struct {
+} TfLiteTransposeParams;
+
+typedef struct {
+ bool keep_dims;
+} TfLiteReducerParams;
+
+typedef struct {
+ int num_splits;
+} TfLiteSplitParams;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int squeeze_dims[8];
+ int num_squeeze_dims;
+} TfLiteSqueezeParams;
+
+typedef struct {
+ int begin_mask;
+ int end_mask;
+ int ellipsis_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+} TfLiteStridedSliceParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMaxParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+} TfLiteTransposeConvParams;
+
+typedef struct {
+ bool validate_indices;
+} TfLiteSparseToDenseParams;
+
+typedef struct {
+ TfLiteType out_type;
+} TfLiteShapeParams;
+
+typedef struct {
+ // Parameters supported by version 1:
+ float min;
+ float max;
+ int num_bits;
+
+ // Parameters supported by version 2:
+ bool narrow_range;
+} TfLiteFakeQuantParams;
+
+typedef struct {
+ int values_count;
+ int axis;
+} TfLitePackParams;
+
+typedef struct {
+ int axis;
+} TfLiteOneHotParams;
+
+typedef struct {
+ int num;
+ int axis;
+} TfLiteUnpackParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
new file mode 100644
index 0000000000..4d0ba75e68
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// Builtin op data is just a set of data definitions, so the only meaningful
+// test we can run is whether we can create the structs we expect to find.
+// Testing each struct's members might be possible, but it seems unnecessary
+// until we've locked down the API. The build rule has copts set to ignore the
+// unused variable warning, since this is just a compilation test.
+TEST(IntArray, CanCompileStructs) {
+ TfLitePadding padding = kTfLitePaddingSame;
+ TfLitePaddingValues padding_values;
+ TfLiteFusedActivation fused_activation = kTfLiteActRelu;
+ TfLiteConvParams conv_params;
+ TfLitePoolParams pool_params;
+ TfLiteDepthwiseConvParams depthwise_conv_params;
+ TfLiteSVDFParams svdf_params;
+ TfLiteRNNParams rnn_params;
+ TfLiteSequenceRNNParams sequence_rnn_params;
+ TfLiteFullyConnectedWeightsFormat fully_connected_weights_format =
+ kTfLiteFullyConnectedWeightsFormatDefault;
+ TfLiteFullyConnectedParams fully_connected_params;
+ TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense;
+ TfLiteLSHProjectionParams projection_params;
+ TfLiteSoftmaxParams softmax_params;
+ TfLiteConcatenationParams concatenation_params;
+ TfLiteAddParams add_params;
+ TfLiteSpaceToBatchNDParams space_to_batch_nd_params;
+ TfLiteBatchToSpaceNDParams batch_to_space_nd_params;
+ TfLiteMulParams mul_params;
+ TfLiteSubParams sub_params;
+ TfLiteDivParams div_params;
+ TfLiteL2NormParams l2_norm_params;
+ TfLiteLocalResponseNormParams local_response_norm_params;
+ TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel;
+ TfLiteLSTMParams lstm_params;
+ TfLiteResizeBilinearParams resize_bilinear_params;
+ TfLitePadParams pad_params;
+ TfLitePadV2Params pad_v2_params;
+ TfLiteReshapeParams reshape_params;
+ TfLiteSkipGramParams skip_gram_params;
+ TfLiteSpaceToDepthParams space_to_depth_params;
+ TfLiteCastParams cast_params;
+ TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn;
+ TfLiteEmbeddingLookupSparseParams lookup_sparse_params;
+ TfLiteGatherParams gather_params;
+ TfLiteTransposeParams transpose_params;
+ TfLiteReducerParams reducer_params;
+ TfLiteSplitParams split_params;
+ TfLiteSqueezeParams squeeze_params;
+ TfLiteStridedSliceParams strided_slice_params;
+ TfLiteArgMaxParams arg_max_params;
+ TfLiteArgMinParams arg_min_params;
+ TfLiteTransposeConvParams transpose_conv_params;
+ TfLiteSparseToDenseParams sparse_to_dense_params;
+ TfLiteShapeParams shape_params;
+ TfLiteFakeQuantParams fake_quant_params;
+ TfLitePackParams pack_params;
+ TfLiteOneHotParams one_hot_params;
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/c/c_api_internal.c
index 7f2aa316f4..1846bad4b7 100644
--- a/tensorflow/contrib/lite/context.c
+++ b/tensorflow/contrib/lite/c/c_api_internal.c
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include <stdio.h>
+#include <stdlib.h>
#include <string.h>
int TfLiteIntArrayGetSizeInBytes(int size) {
@@ -76,7 +77,8 @@ void TfLiteTensorFree(TfLiteTensor* t) {
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, bool is_variable, TfLiteTensor* tensor) {
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h
new file mode 100644
index 0000000000..48df68a654
--- /dev/null
+++ b/tensorflow/contrib/lite/c/c_api_internal.h
@@ -0,0 +1,491 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file defines a C API for implementing operations in tflite.
+// These operations can be defined using c++ but the interface between
+// the interpreter and the operations are C.
+//
+// Summary of abstractions
+// TF_LITE_ENSURE - Self-sufficient error checking
+// TfLiteStatus - Status reporting
+// TfLiteIntArray - stores tensor shapes (dims),
+// TfLiteContext - allows an op to access the tensors
+// TfLiteTensor - tensor (a multidimensional array)
+// TfLiteNode - a single node or operation
+// TfLiteRegistration - the implementation of a conceptual operation.
+//
+// Some abstractions in this file are created and managed by Interpreter.
+#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+
+// The list of external context types known to TF Lite. This list exists solely
+// to avoid conflicts and to ensure ops can share the external contexts they
+// need. Access to the external contexts is controled by one of the
+// corresponding support files.
+typedef enum {
+ kTfLiteEigenContext = 0, // include eigen_support.h to use.
+ kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
+ kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
+ kTfLiteMaxExternalContexts = 3
+} TfLiteExternalContextType;
+
+// An external context is a collection of information unrelated to the TF Lite
+// framework, but useful to a subset of the ops. TF Lite knows very little
+// about about the actual contexts, but it keeps a list of them, and is able to
+// refresh them if configurations like the number of recommended threads
+// change.
+typedef struct {
+ TfLiteExternalContextType type;
+ TfLiteStatus (*Refresh)(struct TfLiteContext* context);
+} TfLiteExternalContext;
+
+// Forward declare so GetNode can use this is in Context.
+typedef struct _TfLiteRegistration TfLiteRegistration;
+typedef struct _TfLiteDelegate TfLiteDelegate;
+
+#define kOptionalTensor (-1)
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct {
+ int size;
+// gcc 6.1+ have a bug where flexible members aren't properly handled
+// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
+#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
+ __GNUC_MINOR__ >= 1
+ int data[0];
+#else
+ int data[];
+#endif
+} TfLiteIntArray;
+
+// Given the size (number of elements) in a TfLiteIntArray, calculate its size
+// in bytes.
+int TfLiteIntArrayGetSizeInBytes(int size);
+
+// Create a array of a given `size` (uninitialized entries).
+// This returns a pointer, that you must free using TfLiteIntArrayFree().
+TfLiteIntArray* TfLiteIntArrayCreate(int size);
+
+// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
+
+// Create a copy of an array passed as `src`.
+// You are expected to free memory with TfLiteIntArrayFree
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
+
+// Free memory of array `v`.
+void TfLiteIntArrayFree(TfLiteIntArray* v);
+
+// Since we must not depend on any libraries, define a minimal subset of
+// error macros while avoiding names that have pre-conceived meanings like
+// assert and check.
+
+// Check whether value is true, and if not return kTfLiteError from
+// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, msg) \
+ do { \
+ if (!(value)) { \
+ (context)->ReportError((context), __FILE__ " " msg); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+#define TF_LITE_ENSURE(context, a) \
+ do { \
+ if (!(a)) { \
+ (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
+ __LINE__, #a); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_STATUS(a) \
+ do { \
+ if ((a) != kTfLiteOk) { \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a == b` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+// `a` and `b` may be evaluated more than once, so no side effects or
+// extremely expensive computations should be done.
+#define TF_LITE_ENSURE_EQ(context, a, b) \
+ do { \
+ if ((a) != (b)) { \
+ (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
+ __LINE__, #a, #b, (a), (b)); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_OK(context, status) \
+ do { \
+ if ((status) != kTfLiteOk) { \
+ return status; \
+ } \
+ } while (0)
+
+// Single-precision complex data type compatible with the C99 definition.
+typedef struct {
+ float re, im; // real and imaginary parts, respectively.
+} TfLiteComplex64;
+
+// Types supported by tensor
+typedef enum {
+ kTfLiteNoType = 0,
+ kTfLiteFloat32 = 1,
+ kTfLiteInt32 = 2,
+ kTfLiteUInt8 = 3,
+ kTfLiteInt64 = 4,
+ kTfLiteString = 5,
+ kTfLiteBool = 6,
+ kTfLiteInt16 = 7,
+ kTfLiteComplex64 = 8,
+} TfLiteType;
+
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+// real_value = scale * (quantized_value - zero_point);
+typedef struct {
+ float scale;
+ int32_t zero_point;
+} TfLiteQuantizationParams;
+
+// A union of pointers that points to memory for a given tensor.
+typedef union {
+ int* i32;
+ int64_t* i64;
+ float* f;
+ char* raw;
+ const char* raw_const;
+ uint8_t* uint8;
+ bool* b;
+ int16_t* i16;
+ TfLiteComplex64* c64;
+} TfLitePtrUnion;
+
+// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
+// data (or data externally allocated). kTfLiteArenaRw is arena allocated
+// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+typedef enum {
+ kTfLiteMemNone = 0,
+ kTfLiteMmapRo,
+ kTfLiteArenaRw,
+ kTfLiteArenaRwPersistent,
+ kTfLiteDynamic,
+} TfLiteAllocationType;
+
+// The delegates should use zero or positive integers to represent handles.
+// -1 is reserved from unallocated status.
+typedef int TfLiteBufferHandle;
+const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
+
+// An tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct {
+ // The data type specification for data stored in `data`. This affects
+ // what member of `data` union should be used.
+ TfLiteType type;
+ // A union of data pointers. The appropriate type should be used for a typed
+ // tensor based on `type`.
+ TfLitePtrUnion data;
+ // A pointer to a structure representing the dimensionality interpretation
+ // that the buffer should have. NOTE: the product of elements of `dims`
+ // and the element datatype size should be equal to `bytes` below.
+ TfLiteIntArray* dims;
+ // Quantization information.
+ TfLiteQuantizationParams params;
+ // How memory is mapped
+ // kTfLiteMmapRo: Memory mapped read only.
+ // i.e. weights
+ // kTfLiteArenaRw: Arena allocated read write memory
+ // (i.e. temporaries, outputs).
+ TfLiteAllocationType allocation_type;
+ // The number of bytes required to store the data of this Tensor. I.e.
+ // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
+ // type is kTfLiteFloat32 and dims = {3, 2} then
+ // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+ size_t bytes;
+
+ // An opaque pointer to a tflite::MMapAllocation
+ const void* allocation;
+
+ // Null-terminated name of this tensor.
+ const char* name;
+
+ // The delegate which knows how to handle `buffer_handle`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+
+ // An integer buffer handle that can be handled by `delegate`.
+ // The value is valid only when delegate is not null.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteBufferHandle buffer_handle;
+
+ // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
+ // responsible to set data_is_stale to true.
+ // `delegate->CopyFromBufferHandle` can be called to copy the data from
+ // delegate buffer.
+ // WARNING: This is an // experimental interface that is subject to change.
+ bool data_is_stale;
+
+ // True if the tensor is a variable.
+ bool is_variable;
+} TfLiteTensor;
+
+// Free data memory of tensor `t`;
+void TfLiteTensorDataFree(TfLiteTensor* t);
+
+// Free memory of tensor `t`;
+void TfLiteTensorFree(TfLiteTensor* t);
+
+// Set all of a tensor's fields (and free any previously allocated data).
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor);
+
+// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
+// types other than kTfLiteDynamic will be ignored.
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+ // Inputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* inputs;
+
+ // Outputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* outputs;
+
+ // Temporary tensors uses during the computations. This usually contains no
+ // tensors, but ops are allowed to change that if they need scratch space of
+ // any sort.
+ TfLiteIntArray* temporaries;
+
+ // Opaque data provided by the node implementer through `Registration.init`.
+ void* user_data;
+
+ // Opaque data provided to the node if the node is a builtin. This is usually
+ // a structure defined in builtin_op_data.h
+ void* builtin_data;
+
+ // Custom initial data. This is the opaque data provided in the flatbuffer.
+ // WARNING: This is an experimental interface that is subject to change.
+ const void* custom_initial_data;
+ int custom_initial_data_size;
+
+ // The pointer to the delegate. This is non-null only when the node is
+ // created by calling `interpreter.ModifyGraphWithDelegate`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+} TfLiteNode;
+
+typedef struct TfLiteContext {
+ // Number of tensors in the context.
+ size_t tensors_size;
+
+ // The execution plan contains a list of the node indices in execution
+ // order. execution_plan->size is the current number of nodes. And,
+ // execution_plan->data[0] is the first node that needs to be run.
+ // TfLiteDelegates can traverse the current execution plan by iterating
+ // through each member of this array and using GetNodeAndRegistration() to
+ // access details about a node. i.e.
+ // TfLiteIntArray* execution_plan;
+ // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
+ // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
+ // int node_index = execution_plan->data[exec_index];
+ // TfLiteNode* node;
+ // TfLiteRegistration* reg;
+ // context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ // }
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan);
+
+ // An array of tensors in the interpreter context (of length `tensors_size`)
+ TfLiteTensor* tensors;
+
+ // opaque full context ptr (an opaque c++ data structure)
+ void* impl_;
+
+ // Request memory pointer be resized. Updates dimensions on the tensor.
+ // NOTE: ResizeTensor takes ownership of newSize.
+ TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Request that a error be reported with format string msg.
+ void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
+
+ // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
+ // non-null, the value pointed to by `first_new_tensor_index` will be set to
+ // the index of the first new tensor.
+ TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // Get a Tensor node by node_index.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
+ TfLiteNode** node,
+ TfLiteRegistration** registration);
+
+ // Replace ops with one or more stub delegate operations. This function
+ // does not take ownership of `nodes_to_replace`.
+ TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
+ struct TfLiteContext*, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+
+ // Number of threads that are recommended to subsystems like gemmlowp and
+ // eigen.
+ int recommended_num_threads;
+
+ // Access external contexts by type.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
+ TfLiteExternalContextType);
+ // Set the value of a external context. Does not take ownership of the
+ // pointer.
+ // WARNING: This is an experimental interface that is subject to change.
+ void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
+ TfLiteExternalContext*);
+} TfLiteContext;
+
+typedef struct _TfLiteRegistration {
+ // Initializes the op from serialized data.
+ // If a built-in op:
+ // `buffer` is the op's params data (TfLiteLSTMParams*).
+ // `length` is zero.
+ // If custom op:
+ // `buffer` is the op's `custom_options`.
+ // `length` is the size of the buffer.
+ //
+ // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+ // or an instance of a struct).
+ //
+ // The returned pointer will be stored with the node in the `user_data` field,
+ // accessible within prepare and invoke functions below.
+ // NOTE: if the data is already in the desired format, simply implement this
+ // function to return `nullptr` and implement the free function to be a no-op.
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+
+ // The pointer `buffer` is the data previously returned by an init invocation.
+ void (*free)(TfLiteContext* context, void* buffer);
+
+ // prepare is called when the inputs this node depends on have been resized.
+ // context->ResizeTensor() can be called to request output tensors to be
+ // resized.
+ //
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+
+ // Execute the node (should read node->inputs and output to node->outputs).
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+
+ // profiling_string is called during summarization of profiling information
+ // in order to group executions together. Providing a value here will cause a
+ // given op to appear multiple times is the profiling report. This is
+ // particularly useful for custom ops that can perform significantly
+ // different calculations depending on their `user-data`.
+ const char* (*profiling_string)(const TfLiteContext* context,
+ const TfLiteNode* node);
+
+ // Builtin codes. If this kernel refers to a builtin this is the code
+ // of the builtin. This is so we can do marshaling to other frameworks like
+ // NN API.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int32_t builtin_code;
+
+ // Custom op name. If the op is a builtin, this will be null.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ // WARNING: This is an experimental interface that is subject to change.
+ const char* custom_name;
+
+ // The version of the op.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int version;
+} TfLiteRegistration;
+
+// WARNING: This is an experimental interface that is subject to change.
+typedef struct _TfLiteDelegate {
+ // Data that delegate needs to identify itself. This data is owned by the
+ // delegate. The delegate is owned in the user code, so the delegate is
+ // responsible for doing this when it is destroyed.
+ void* data_;
+
+ // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
+ // delegate a view of the current graph through TfLiteContext*. It typically
+ // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
+ // to ask the TensorFlow lite runtime to create macro-nodes to represent
+ // delegated subgraphs of the original graph.
+ TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+
+ // Copy the data from delegate buffer handle to raw memory.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Copy the data from raw memory to delegate buffer handle.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Free the Delegate Buffer Handle. Note: This only frees the handle, but
+ // this doesn't release the underlying resource (e.g. textures). The
+ // resources are either owned by application layer or the delegate.
+ // This can be null if the delegate doesn't use its own buffer.
+ void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
+ TfLiteBufferHandle* handle);
+} TfLiteDelegate;
+
+// WARNING: This is an experimental interface that is subject to change.
+//
+// Currently, TfLiteDelegateParams has to be allocated in a way that it's
+// trivially destructable. It will be stored as `builtin_data` field in
+// `TfLiteNode` of the delegate node.
+//
+// See also the `CreateDelegateParams` function in `interpreter.cc` details.
+typedef struct {
+ TfLiteDelegate* delegate;
+ TfLiteIntArray* nodes_to_replace;
+ TfLiteIntArray* input_tensors;
+ TfLiteIntArray* output_tensors;
+} TfLiteDelegateParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc
index 20d6f69a25..af398f3207 100644
--- a/tensorflow/contrib/lite/context_test.cc
+++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc
@@ -13,16 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
// NOTE: this tests only the TfLiteIntArray part of context.
-// most of context.h is provided in the context of using it with interpreter.h
-// and interpreter.cc, so interpreter_test.cc tests context structures more
-// thoroughly.
+// most of c_api_internal.h is provided in the context of using it with
+// interpreter.h and interpreter.cc, so interpreter_test.cc tests context
+// structures more thoroughly.
TEST(IntArray, TestIntArrayCreate) {
TfLiteIntArray* a = TfLiteIntArrayCreate(0);
@@ -69,7 +68,6 @@ TEST(IntArray, TestIntArrayEqual) {
} // namespace tflite
int main(int argc, char** argv) {
- ::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index b23183b743..b86c2819b8 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -12,484 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// This file defines a C API for implementing operations in tflite.
-// These operations can be defined using c++ but the interface between
-// the interpreter and the operations are C.
-//
-// Summary of abstractions
-// TF_LITE_ENSURE - Self-sufficient error checking
-// TfLiteStatus - Status reporting
-// TfLiteIntArray - stores tensor shapes (dims),
-// TfLiteContext - allows an op to access the tensors
-// TfLiteTensor - tensor (a multidimensional array)
-// TfLiteNode - a single node or operation
-// TfLiteRegistration - the implementation of a conceptual operation.
-//
-// Some abstractions in this file are created and managed by Interpreter.
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
-#include <stdbool.h>
-#include <stdint.h>
-#include <stdlib.h>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
-
-// Forward declarations for use with dependent types.
-struct TfLiteContext;
-struct TfLiteNode;
-struct _TfLiteRegistration;
-struct _TfLiteDelegate;
-
-// The list of external context types known to TF Lite. This list exists solely
-// to avoid conflicts and to ensure ops can share the external contexts they
-// need. Access to the external contexts is controled by one of the
-// corresponding support files.
-typedef enum {
- kTfLiteEigenContext = 0, // include eigen_support.h to use.
- kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
- kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
- kTfLiteMaxExternalContexts = 3
-} TfLiteExternalContextType;
-
-// An external context is a collection of information unrelated to the TF Lite
-// framework, but useful to a subset of the ops. TF Lite knows very little
-// about about the actual contexts, but it keeps a list of them, and is able to
-// refresh them if configurations like the number of recommended threads
-// change.
-typedef struct {
- TfLiteExternalContextType type;
- TfLiteStatus (*Refresh)(struct TfLiteContext* context);
-} TfLiteExternalContext;
-
-#define kOptionalTensor (-1)
-
-// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
-// indices
-typedef struct {
- int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
- __GNUC_MINOR__ >= 1
- int data[0];
-#else
- int data[];
-#endif
-} TfLiteIntArray;
-
-// Given the size (number of elements) in a TfLiteIntArray, calculate its size
-// in bytes.
-int TfLiteIntArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteIntArrayFree().
-TfLiteIntArray* TfLiteIntArrayCreate(int size);
-
-// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
-int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
-
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteIntArrayFree
-TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
-
-// Free memory of array `v`.
-void TfLiteIntArrayFree(TfLiteIntArray* v);
-
-// Since we must not depend on any libraries, define a minimal subset of
-// error macros while avoiding names that have pre-conceived meanings like
-// assert and check.
-
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
-#define TF_LITE_ENSURE_MSG(context, value, msg) \
- do { \
- if (!(value)) { \
- (context)->ReportError((context), __FILE__ " " msg); \
- return kTfLiteError; \
- } \
- } while (0)
-
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-#define TF_LITE_ENSURE(context, a) \
- do { \
- if (!(a)) { \
- (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
- __LINE__, #a); \
- return kTfLiteError; \
- } \
- } while (0)
-
-#define TF_LITE_ENSURE_STATUS(a) \
- do { \
- if ((a) != kTfLiteOk) { \
- return kTfLiteError; \
- } \
- } while (0)
-
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-#define TF_LITE_ENSURE_EQ(context, a, b) \
- do { \
- if ((a) != (b)) { \
- (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
- __LINE__, #a, #b, (a), (b)); \
- return kTfLiteError; \
- } \
- } while (0)
-
-#define TF_LITE_ENSURE_OK(context, status) \
- do { \
- if ((status) != kTfLiteOk) { \
- return status; \
- } \
- } while (0)
-
-// Single-precision complex data type compatible with the C99 definition.
-typedef struct {
- float re, im; // real and imaginary parts, respectively.
-} TfLiteComplex64;
-
-// Types supported by tensor
-typedef enum {
- kTfLiteNoType = 0,
- kTfLiteFloat32 = 1,
- kTfLiteInt32 = 2,
- kTfLiteUInt8 = 3,
- kTfLiteInt64 = 4,
- kTfLiteString = 5,
- kTfLiteBool = 6,
- kTfLiteInt16 = 7,
- kTfLiteComplex64 = 8,
-} TfLiteType;
-
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-// real_value = scale * (quantized_value - zero_point);
-typedef struct {
- float scale;
- int32_t zero_point;
-} TfLiteQuantizationParams;
-
-// A union of pointers that points to memory for a given tensor.
-typedef union {
- int* i32;
- int64_t* i64;
- float* f;
- char* raw;
- const char* raw_const;
- uint8_t* uint8;
- bool* b;
- int16_t* i16;
- TfLiteComplex64* c64;
-} TfLitePtrUnion;
-
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
-typedef enum {
- kTfLiteMemNone = 0,
- kTfLiteMmapRo,
- kTfLiteArenaRw,
- kTfLiteArenaRwPersistent,
- kTfLiteDynamic,
-} TfLiteAllocationType;
-
-// The delegates should use zero or positive integers to represent handles.
-// -1 is reserved from unallocated status.
-typedef int TfLiteBufferHandle;
-const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
-
-// An tensor in the interpreter system which is a wrapper around a buffer of
-// data including a dimensionality (or NULL if not currently defined).
-typedef struct {
- // The data type specification for data stored in `data`. This affects
- // what member of `data` union should be used.
- TfLiteType type;
- // A union of data pointers. The appropriate type should be used for a typed
- // tensor based on `type`.
- TfLitePtrUnion data;
- // A pointer to a structure representing the dimensionality interpretation
- // that the buffer should have. NOTE: the product of elements of `dims`
- // and the element datatype size should be equal to `bytes` below.
- TfLiteIntArray* dims;
- // Quantization information.
- TfLiteQuantizationParams params;
- // How memory is mapped
- // kTfLiteMmapRo: Memory mapped read only.
- // i.e. weights
- // kTfLiteArenaRw: Arena allocated read write memory
- // (i.e. temporaries, outputs).
- TfLiteAllocationType allocation_type;
- // The number of bytes required to store the data of this Tensor. I.e.
- // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
- // type is kTfLiteFloat32 and dims = {3, 2} then
- // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
- size_t bytes;
-
- // An opaque pointer to a tflite::MMapAllocation
- const void* allocation;
-
- // Null-terminated name of this tensor.
- const char* name;
-
- // The delegate which knows how to handle `buffer_handle`.
- // WARNING: This is an experimental interface that is subject to change.
- struct _TfLiteDelegate* delegate;
-
- // An integer buffer handle that can be handled by `delegate`.
- // The value is valid only when delegate is not null.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteBufferHandle buffer_handle;
-
- // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
- // responsible to set data_is_stale to true.
- // `delegate->CopyFromBufferHandle` can be called to copy the data from
- // delegate buffer.
- // WARNING: This is an // experimental interface that is subject to change.
- bool data_is_stale;
-
- // True if the tensor is a variable.
- bool is_variable;
-} TfLiteTensor;
-
-// Free data memory of tensor `t`;
-void TfLiteTensorDataFree(TfLiteTensor* t);
-
-// Free memory of tensor `t`;
-void TfLiteTensorFree(TfLiteTensor* t);
-
-// Set all of a tensor's fields (and free any previously allocated data).
-void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
- TfLiteQuantizationParams quantization, char* buffer,
- size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, bool is_variable,
- TfLiteTensor* tensor);
-
-// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
-// types other than kTfLiteDynamic will be ignored.
-void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
-
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs and user defined data, not
-// other features like the type.
-typedef struct TfLiteNode {
- // Inputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* inputs;
-
- // Outputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* outputs;
-
- // Temporary tensors uses during the computations. This usually contains no
- // tensors, but ops are allowed to change that if they need scratch space of
- // any sort.
- TfLiteIntArray* temporaries;
-
- // Opaque data provided by the node implementer through `Registration.init`.
- void* user_data;
-
- // Opaque data provided to the node if the node is a builtin. This is usually
- // a structure defined in builtin_op_data.h
- void* builtin_data;
-
- // Custom initial data. This is the opaque data provided in the flatbuffer.
- // WARNING: This is an experimental interface that is subject to change.
- const void* custom_initial_data;
- int custom_initial_data_size;
-
- // The pointer to the delegate. This is non-null only when the node is
- // created by calling `interpreter.ModifyGraphWithDelegate`.
- // WARNING: This is an experimental interface that is subject to change.
- struct _TfLiteDelegate* delegate;
-} TfLiteNode;
-
-typedef struct TfLiteContext {
- // Number of tensors in the context.
- size_t tensors_size;
-
- // The execution plan contains a list of the node indices in execution
- // order. execution_plan->size is the current number of nodes. And,
- // execution_plan->data[0] is the first node that needs to be run.
- // TfLiteDelegates can traverse the current execution plan by iterating
- // through each member of this array and using GetNodeAndRegistration() to
- // access details about a node. i.e.
- // TfLiteIntArray* execution_plan;
- // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
- // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
- // int node_index = execution_plan->data[exec_index];
- // TfLiteNode* node;
- // TfLiteRegistration* reg;
- // context->GetNodeAndRegistration(context, node_index, &node, &reg);
- // }
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
- TfLiteIntArray** execution_plan);
-
- // An array of tensors in the interpreter context (of length `tensors_size`)
- TfLiteTensor* tensors;
-
- // opaque full context ptr (an opaque c++ data structure)
- void* impl_;
-
- // Request memory pointer be resized. Updates dimensions on the tensor.
- // NOTE: ResizeTensor takes ownership of newSize.
- TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
- TfLiteIntArray* new_size);
- // Request that a error be reported with format string msg.
- void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
-
- // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
- // non-null, the value pointed to by `first_new_tensor_index` will be set to
- // the index of the first new tensor.
- TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
- int* first_new_tensor_index);
-
- // Get a Tensor node by node_index.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteStatus (*GetNodeAndRegistration)(
- struct TfLiteContext*, int node_index, struct TfLiteNode** node,
- struct _TfLiteRegistration** registration);
-
- // Replace ops with one or more stub delegate operations. This function
- // does not take ownership of `nodes_to_replace`.
- TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
- struct TfLiteContext*, struct _TfLiteRegistration registration,
- const TfLiteIntArray* nodes_to_replace, struct _TfLiteDelegate* delegate);
-
- // Number of threads that are recommended to subsystems like gemmlowp and
- // eigen.
- int recommended_num_threads;
-
- // Access external contexts by type.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
- TfLiteExternalContextType);
- // Set the value of a external context. Does not take ownership of the
- // pointer.
- // WARNING: This is an experimental interface that is subject to change.
- void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
- TfLiteExternalContext*);
-} TfLiteContext;
-
-typedef struct _TfLiteRegistration {
- // Initializes the op from serialized data.
- // If a built-in op:
- // `buffer` is the op's params data (TfLiteLSTMParams*).
- // `length` is zero.
- // If custom op:
- // `buffer` is the op's `custom_options`.
- // `length` is the size of the buffer.
- //
- // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
- // or an instance of a struct).
- //
- // The returned pointer will be stored with the node in the `user_data` field,
- // accessible within prepare and invoke functions below.
- // NOTE: if the data is already in the desired format, simply implement this
- // function to return `nullptr` and implement the free function to be a no-op.
- void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
-
- // The pointer `buffer` is the data previously returned by an init invocation.
- void (*free)(TfLiteContext* context, void* buffer);
-
- // prepare is called when the inputs this node depends on have been resized.
- // context->ResizeTensor() can be called to request output tensors to be
- // resized.
- //
- // Returns kTfLiteOk on success.
- TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
-
- // Execute the node (should read node->inputs and output to node->outputs).
- // Returns kTfLiteOk on success.
- TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
-
- // profiling_string is called during summarization of profiling information
- // in order to group executions together. Providing a value here will cause a
- // given op to appear multiple times is the profiling report. This is
- // particularly useful for custom ops that can perform significantly
- // different calculations depending on their `user-data`.
- const char* (*profiling_string)(const TfLiteContext* context,
- const TfLiteNode* node);
-
- // Builtin codes. If this kernel refers to a builtin this is the code
- // of the builtin. This is so we can do marshaling to other frameworks like
- // NN API.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- int32_t builtin_code;
-
- // Custom op name. If the op is a builtin, this will be null.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- // WARNING: This is an experimental interface that is subject to change.
- const char* custom_name;
-
- // The version of the op.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- int version;
-} TfLiteRegistration;
-
-// WARNING: This is an experimental interface that is subject to change.
-typedef struct _TfLiteDelegate {
- // Data that delegate needs to identify itself. This data is owned by the
- // delegate. The delegate is owned in the user code, so the delegate is
- // responsible for doing this when it is destroyed.
- void* data_;
-
- // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
- // delegate a view of the current graph through TfLiteContext*. It typically
- // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
- // to ask the TensorFlow lite runtime to create macro-nodes to represent
- // delegated subgraphs of the original graph.
- TfLiteStatus (*Prepare)(struct TfLiteContext* context,
- struct _TfLiteDelegate* delegate);
-
- // Copy the data from delegate buffer handle to raw memory.
- // This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyFromBufferHandle)(struct TfLiteContext* context,
- struct _TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, size_t size);
-
- // Copy the data from raw memory to delegate buffer handle.
- // This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyToBufferHandle)(struct TfLiteContext* context,
- struct _TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, size_t size);
-
- // Free the Delegate Buffer Handle. Note: This only frees the handle, but
- // this doesn't release the underlying resource (e.g. textures). The
- // resources are either owned by application layer or the delegate.
- // This can be null if the delegate doesn't use its own buffer.
- void (*FreeBufferHandle)(struct TfLiteContext* context,
- struct _TfLiteDelegate* delegate,
- TfLiteBufferHandle* handle);
-} TfLiteDelegate;
-
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateDelegateParams` function in `interpreter.cc` details.
-typedef struct {
- TfLiteDelegate* delegate;
- TfLiteIntArray* nodes_to_replace;
- TfLiteIntArray* input_tensors;
- TfLiteIntArray* output_tensors;
-} TfLiteDelegateParams;
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h
index abe802e342..ccda4c7393 100644
--- a/tensorflow/contrib/lite/context_util.h
+++ b/tensorflow/contrib/lite/context_util.h
@@ -17,7 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/core/api/BUILD b/tensorflow/contrib/lite/core/api/BUILD
new file mode 100644
index 0000000000..e4500534f3
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/BUILD
@@ -0,0 +1,57 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+cc_library(
+ name = "api",
+ srcs = [
+ "error_reporter.cc",
+ "flatbuffer_conversions.cc",
+ "op_resolver.cc",
+ ],
+ hdrs = [
+ "error_reporter.h",
+ "flatbuffer_conversions.h",
+ "op_resolver.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "error_reporter_test",
+ size = "small",
+ srcs = ["error_reporter_test.cc"],
+ deps = [
+ ":api",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "op_resolver_test",
+ size = "small",
+ srcs = ["op_resolver_test.cc"],
+ deps = [
+ ":api",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "flatbuffer_conversions_test",
+ size = "small",
+ srcs = ["flatbuffer_conversions_test.cc"],
+ deps = [
+ ":api",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.cc b/tensorflow/contrib/lite/core/api/error_reporter.cc
new file mode 100644
index 0000000000..423f83b1a9
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include <cstdarg>
+
+namespace tflite {
+
+int ErrorReporter::Report(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.h b/tensorflow/contrib/lite/core/api/error_reporter.h
new file mode 100644
index 0000000000..a2f780b003
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+
+#include <cstdarg>
+
+namespace tflite {
+
+// A functor that reports error to supporting system. Invoked similar to
+// printf.
+//
+// Usage:
+// ErrorReporter foo;
+// foo.Report("test %d", 5);
+// or
+// va_list args;
+// foo.Report("test %d", args); // where args is va_list
+//
+// Subclass ErrorReporter to provide another reporting destination.
+// For example, if you have a GUI program, you might redirect to a buffer
+// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+ virtual ~ErrorReporter() {}
+ virtual int Report(const char* format, va_list args) = 0;
+ int Report(const char* format, ...);
+ int ReportError(void*, const char* format, ...);
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/core/api/error_reporter_test.cc b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
new file mode 100644
index 0000000000..0463eee6be
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ int Report(const char* format, va_list args) override {
+ vsnprintf(buffer_, kBufferSize, format, args);
+ return 0;
+ }
+ char* GetBuffer() { return buffer_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+};
+
+TEST(ErrorReporter, TestReport) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ reporter->Report("Error: %d", 23);
+ EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
new file mode 100644
index 0000000000..1420fbcdc6
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -0,0 +1,622 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstdlib>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+
+namespace {
+
+// Copies the contents from the flatbuffer int vector `flatbuffer` into the
+// int array `buffer`. `flat_vector` and `buffer` represent the same
+// configuration operation for a given operation.
+void FlatBufferIntVectorToArray(int max_size_of_buffer,
+ const flatbuffers::Vector<int32_t>* flat_vector,
+ int* buffer, ErrorReporter* error_reporter) {
+ if (!flat_vector) {
+ error_reporter->Report("Input array not provided for operation.\n");
+ } else {
+ int num_dimensions = flat_vector->Length();
+ if (num_dimensions > max_size_of_buffer / sizeof(int)) {
+ error_reporter->Report(
+ "Found too many dimensions in the operation's input array.\n");
+ } else {
+ for (int i = 0; i < num_dimensions; ++i) {
+ buffer[i] = flat_vector->Get(i);
+ }
+ }
+ }
+}
+
+// Allocate a structure using malloc, but make sure the structure is a POD
+// structure that doesn't require constructors to run. The reason we do this,
+// is that Interpreter's C extension part will take ownership so destructors
+// will not be run during deallocation.
+template <class T>
+T* MallocPOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(malloc(sizeof(T)));
+}
+
+} // namespace
+
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter) {
+ switch (tensor_type) {
+ case TensorType_FLOAT32:
+ *type = kTfLiteFloat32;
+ break;
+ case TensorType_INT16:
+ *type = kTfLiteInt16;
+ break;
+ case TensorType_INT32:
+ *type = kTfLiteInt32;
+ break;
+ case TensorType_UINT8:
+ *type = kTfLiteUInt8;
+ break;
+ case TensorType_INT64:
+ *type = kTfLiteInt64;
+ break;
+ case TensorType_STRING:
+ *type = kTfLiteString;
+ break;
+ case TensorType_BOOL:
+ *type = kTfLiteBool;
+ break;
+ case TensorType_COMPLEX64:
+ *type = kTfLiteComplex64;
+ break;
+ default:
+ error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
+ EnumNameTensorType(tensor_type), tensor_type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
+// need to be released by calling `free`.`
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter, void** builtin_data) {
+ auto parse_padding = [](Padding padding) {
+ switch (padding) {
+ case Padding_SAME:
+ return kTfLitePaddingSame;
+ case Padding_VALID:
+ return kTfLitePaddingValid;
+ }
+ return kTfLitePaddingUnknown;
+ };
+ auto parse_activation = [](ActivationFunctionType activation) {
+ switch (activation) {
+ case ActivationFunctionType_NONE:
+ return kTfLiteActNone;
+ case ActivationFunctionType_RELU:
+ return kTfLiteActRelu;
+ case ActivationFunctionType_RELU_N1_TO_1:
+ return kTfLiteActRelu1;
+ case ActivationFunctionType_RELU6:
+ return kTfLiteActRelu6;
+ case ActivationFunctionType_TANH:
+ return kTfLiteActTanh;
+ case ActivationFunctionType_SIGN_BIT:
+ return kTfLiteActSignBit;
+ }
+ return kTfLiteActNone;
+ };
+ auto parseLSHProjectionType = [](LSHProjectionType type) {
+ switch (type) {
+ case LSHProjectionType_SPARSE:
+ return kTfLiteLshProjectionSparse;
+ case LSHProjectionType_DENSE:
+ return kTfLiteLshProjectionDense;
+ default:
+ return kTfLiteLshProjectionUnknown;
+ }
+ };
+ auto parseCombinerType = [](CombinerType type) {
+ switch (type) {
+ case CombinerType_MEAN:
+ return kTfLiteCombinerTypeMean;
+ case CombinerType_SQRTN:
+ return kTfLiteCombinerTypeSqrtn;
+ case CombinerType_SUM:
+ default:
+ return kTfLiteCombinerTypeSum;
+ }
+ };
+
+ *builtin_data = nullptr;
+ switch (op_type) {
+ case BuiltinOperator_CONV_2D: {
+ TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+
+ params->dilation_width_factor = conv_params->dilation_w_factor();
+ params->dilation_height_factor = conv_params->dilation_h_factor();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CAST: {
+ TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ if (auto* schema_params = op->builtin_options_as_CastOptions()) {
+ auto in_status =
+ ConvertTensorType(schema_params->in_data_type(),
+ &params->in_data_type, error_reporter);
+ auto out_status =
+ ConvertTensorType(schema_params->out_data_type(),
+ &params->out_data_type, error_reporter);
+ if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
+ free(params);
+ return kTfLiteError;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LSH_PROJECTION: {
+ TfLiteLSHProjectionParams* params =
+ MallocPOD<TfLiteLSHProjectionParams>();
+ if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
+ params->type = parseLSHProjectionType(lshParams->type());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ case BuiltinOperator_MAX_POOL_2D:
+ case BuiltinOperator_L2_POOL_2D: {
+ TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
+ params->padding = parse_padding(pool_params->padding());
+ params->stride_width = pool_params->stride_w();
+ params->stride_height = pool_params->stride_h();
+ params->filter_width = pool_params->filter_width();
+ params->filter_height = pool_params->filter_height();
+ params->activation =
+ parse_activation(pool_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DEPTHWISE_CONV_2D: {
+ TfLiteDepthwiseConvParams* params =
+ MallocPOD<TfLiteDepthwiseConvParams>();
+ if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->depth_multiplier = conv_params->depth_multiplier();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SVDF: {
+ TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
+ params->rank = svdf_params->rank();
+ params->activation =
+ parse_activation(svdf_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
+ TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ if (auto* sequence_rnn_params =
+ op->builtin_options_as_SequenceRNNOptions()) {
+ params->activation =
+ parse_activation(sequence_rnn_params->fused_activation_function());
+ params->time_major = sequence_rnn_params->time_major();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RNN: {
+ TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
+ params->activation =
+ parse_activation(rnn_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
+ TfLiteEmbeddingLookupSparseParams* params =
+ MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ if (auto* embedding_params =
+ op->builtin_options_as_EmbeddingLookupSparseOptions()) {
+ params->combiner = parseCombinerType(embedding_params->combiner());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_FULLY_CONNECTED: {
+ TfLiteFullyConnectedParams* params =
+ MallocPOD<TfLiteFullyConnectedParams>();
+ if (auto* fully_connected_params =
+ op->builtin_options_as_FullyConnectedOptions()) {
+ params->activation = parse_activation(
+ fully_connected_params->fused_activation_function());
+ switch (fully_connected_params->weights_format()) {
+ case FullyConnectedOptionsWeightsFormat_DEFAULT:
+ params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
+ break;
+ case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ params->weights_format =
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
+ break;
+ default:
+ error_reporter->Report("Unhandled fully-connected weights format.");
+ return kTfLiteError;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_HASHTABLE_LOOKUP:
+ // no-op.
+ break;
+ case BuiltinOperator_SOFTMAX: {
+ TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
+ params->beta = softmax_params->beta();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CONCATENATION: {
+ TfLiteConcatenationParams* params =
+ MallocPOD<TfLiteConcatenationParams>();
+ if (auto* concatenation_params =
+ op->builtin_options_as_ConcatenationOptions()) {
+ params->activation =
+ parse_activation(concatenation_params->fused_activation_function());
+ params->axis = concatenation_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MUL: {
+ auto* params = MallocPOD<TfLiteMulParams>();
+ if (auto* schema_params = op->builtin_options_as_MulOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ADD: {
+ auto* params = MallocPOD<TfLiteAddParams>();
+ if (auto* schema_params = op->builtin_options_as_AddOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DIV: {
+ auto* params = MallocPOD<TfLiteDivParams>();
+ if (auto* schema_params = op->builtin_options_as_DivOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SUB: {
+ auto* params = MallocPOD<TfLiteSubParams>();
+ if (auto* schema_params = op->builtin_options_as_SubOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_L2_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteL2NormParams>();
+ if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_LocalResponseNormalizationOptions()) {
+ params->radius = schema_params->radius();
+ params->bias = schema_params->bias();
+ params->alpha = schema_params->alpha();
+ params->beta = schema_params->beta();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+ case BuiltinOperator_LSTM: {
+ TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
+ params->activation =
+ parse_activation(lstm_params->fused_activation_function());
+ params->cell_clip = lstm_params->cell_clip();
+ params->proj_clip = lstm_params->proj_clip();
+ switch (lstm_params->kernel_type()) {
+ case LSTMKernelType_FULL:
+ params->kernel_type = kTfLiteLSTMFullKernel;
+ break;
+ case LSTMKernelType_BASIC:
+ params->kernel_type = kTfLiteLSTMBasicKernel;
+ break;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESIZE_BILINEAR: {
+ auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_ResizeBilinearOptions()) {
+ params->align_corners = schema_params->align_corners();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESHAPE: {
+ auto* params = MallocPOD<TfLiteReshapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
+ auto* new_shape = schema_params->new_shape();
+ FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
+ params->shape, error_reporter);
+ params->num_dimensions = new_shape->Length();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SKIP_GRAM: {
+ TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
+ params->ngram_size = skip_gram_params->ngram_size();
+ params->max_skip_size = skip_gram_params->max_skip_size();
+ params->include_all_ngrams = skip_gram_params->include_all_ngrams();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPACE_TO_DEPTH: {
+ auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
+ params->block_size = schema_params->block_size();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_GATHER: {
+ TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
+ params->axis = 0;
+ if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
+ params->axis = gather_params->axis();
+ }
+
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MEAN:
+ case BuiltinOperator_REDUCE_MAX:
+ case BuiltinOperator_REDUCE_MIN:
+ case BuiltinOperator_REDUCE_PROD:
+ case BuiltinOperator_REDUCE_ANY:
+ case BuiltinOperator_SUM: {
+ auto* params = MallocPOD<TfLiteReducerParams>();
+ if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
+ params->keep_dims = schema_params->keep_dims();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPLIT: {
+ auto* params = MallocPOD<TfLiteSplitParams>();
+ if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
+ params->num_splits = schema_params->num_splits();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SQUEEZE: {
+ auto* params = MallocPOD<TfLiteSqueezeParams>();
+ if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
+ const auto& squeeze_dims = schema_params->squeeze_dims();
+ FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
+ params->squeeze_dims, error_reporter);
+ params->num_squeeze_dims = squeeze_dims->Length();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_STRIDED_SLICE: {
+ auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
+ params->begin_mask = schema_params->begin_mask();
+ params->end_mask = schema_params->end_mask();
+ params->ellipsis_mask = schema_params->ellipsis_mask();
+ params->new_axis_mask = schema_params->new_axis_mask();
+ params->shrink_axis_mask = schema_params->shrink_axis_mask();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ARG_MAX: {
+ auto* params = MallocPOD<TfLiteArgMaxParams>();
+ if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ARG_MIN: {
+ auto* params = MallocPOD<TfLiteArgMinParams>();
+ if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_TRANSPOSE_CONV: {
+ TfLiteTransposeConvParams* params =
+ MallocPOD<TfLiteTransposeConvParams>();
+ if (auto* transpose_conv_params =
+ op->builtin_options_as_TransposeConvOptions()) {
+ params->padding = parse_padding(transpose_conv_params->padding());
+ params->stride_width = transpose_conv_params->stride_w();
+ params->stride_height = transpose_conv_params->stride_h();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPARSE_TO_DENSE: {
+ TfLiteSparseToDenseParams* params =
+ MallocPOD<TfLiteSparseToDenseParams>();
+ if (auto* sparse_to_dense_params =
+ op->builtin_options_as_SparseToDenseOptions()) {
+ params->validate_indices = sparse_to_dense_params->validate_indices();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SHAPE: {
+ auto* params = MallocPOD<TfLiteShapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
+ ConvertTensorType(schema_params->out_type(), &params->out_type,
+ error_reporter);
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_PACK: {
+ TfLitePackParams* params = MallocPOD<TfLitePackParams>();
+ if (auto* pack_params = op->builtin_options_as_PackOptions()) {
+ params->values_count = pack_params->values_count();
+ params->axis = pack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DELEGATE: {
+ // TODO(ycling): Revisit when supporting saving delegated models.
+ error_reporter->Report("DELEGATE op shouldn't exist in model.");
+ return kTfLiteError;
+ }
+ case BuiltinOperator_FAKE_QUANT: {
+ auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
+ params->min = schema_params->min();
+ params->max = schema_params->max();
+ params->num_bits = schema_params->num_bits();
+ params->narrow_range = schema_params->narrow_range();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ONE_HOT: {
+ auto* params = MallocPOD<TfLiteOneHotParams>();
+ if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
+ params->axis = schema_params->axis();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_UNPACK: {
+ TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+ if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
+ params->num = unpack_params->num();
+ params->axis = unpack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+
+ // Below are the ops with no builtin_data strcture.
+ case BuiltinOperator_BATCH_TO_SPACE_ND:
+ // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
+ // ok for now, since there is no call implementation either.
+ case BuiltinOperator_CALL:
+ case BuiltinOperator_CONCAT_EMBEDDINGS:
+ case BuiltinOperator_CUSTOM:
+ case BuiltinOperator_DEQUANTIZE:
+ case BuiltinOperator_EMBEDDING_LOOKUP:
+ case BuiltinOperator_EQUAL:
+ case BuiltinOperator_EXP:
+ case BuiltinOperator_EXPAND_DIMS:
+ case BuiltinOperator_FLOOR:
+ case BuiltinOperator_GREATER:
+ case BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator_LESS:
+ case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_LOG:
+ case BuiltinOperator_LOGISTIC:
+ case BuiltinOperator_LOG_SOFTMAX:
+ case BuiltinOperator_MAXIMUM:
+ case BuiltinOperator_MINIMUM:
+ case BuiltinOperator_NEG:
+ case BuiltinOperator_NOT_EQUAL:
+ case BuiltinOperator_PAD:
+ case BuiltinOperator_PADV2:
+ case BuiltinOperator_PRELU:
+ case BuiltinOperator_RELU:
+ case BuiltinOperator_RELU6:
+ case BuiltinOperator_RELU_N1_TO_1:
+ case BuiltinOperator_RSQRT:
+ case BuiltinOperator_SELECT:
+ case BuiltinOperator_SIN:
+ case BuiltinOperator_SLICE:
+ case BuiltinOperator_SPACE_TO_BATCH_ND:
+ case BuiltinOperator_SQRT:
+ case BuiltinOperator_TANH:
+ case BuiltinOperator_TILE:
+ case BuiltinOperator_TOPK_V2:
+ case BuiltinOperator_TRANSPOSE:
+ case BuiltinOperator_POW:
+ case BuiltinOperator_LOGICAL_OR:
+ case BuiltinOperator_LOGICAL_AND:
+ case BuiltinOperator_LOGICAL_NOT:
+ case BuiltinOperator_FLOOR_DIV:
+ break;
+ }
+ return kTfLiteOk;
+} // NOLINT[readability/fn_size]
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
new file mode 100644
index 0000000000..4dec6f9cfc
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+
+// These functions transform codes and data structures that are defined in the
+// flatbuffer serialization format into in-memory values that are used by the
+// runtime API and interpreter.
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The
+// calling function has to pass in an allocator object, and this allocator
+// will be called to reserve space for the output data. If the calling
+// function's allocator reserves memory on the heap, then it's the calling
+// function's responsibility to free it.
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter, void** builtin_data);
+
+// Converts the tensor data type used in the flat buffer to the representation
+// used by the runtime.
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
new file mode 100644
index 0000000000..b12bdf43b2
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+namespace {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() : buffer_size_(0) {}
+ int Report(const char* format, va_list args) override {
+ buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+ return buffer_size_;
+ }
+ char* GetBuffer() { return buffer_; }
+ int GetBufferSize() { return buffer_size_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+ int buffer_size_;
+};
+
+} // namespace
+
+TEST(FlatbufferConversions, TestParseOpDataConv) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<void> conv_options =
+ CreateConv2DOptions(builder, Padding_SAME, 1, 2,
+ ActivationFunctionType_RELU, 3, 4)
+ .Union();
+ flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect(
+ builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options,
+ nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
+ void* output_data = nullptr;
+ EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
+ &output_data));
+ EXPECT_NE(nullptr, output_data);
+ TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
+ EXPECT_EQ(kTfLitePaddingSame, params->padding);
+ EXPECT_EQ(1, params->stride_width);
+ EXPECT_EQ(2, params->stride_height);
+ EXPECT_EQ(kTfLiteActRelu, params->activation);
+ EXPECT_EQ(3, params->dilation_width_factor);
+ EXPECT_EQ(4, params->dilation_height_factor);
+ free(output_data);
+}
+
+TEST(FlatbufferConversions, TestParseOpDataCustom) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<void> null_options;
+ flatbuffers::Offset<Operator> custom_offset = CreateOperatorDirect(
+ builder, 0, nullptr, nullptr, BuiltinOptions_NONE, null_options, nullptr,
+ CustomOptionsFormat_FLEXBUFFERS, nullptr);
+ builder.Finish(custom_offset);
+ void* custom_pointer = builder.GetBufferPointer();
+ const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer);
+ void* output_data = nullptr;
+ EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter,
+ &output_data));
+ EXPECT_EQ(nullptr, output_data);
+}
+
+TEST(FlatbufferConversions, TestConvertTensorType) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ TfLiteType type;
+ EXPECT_EQ(kTfLiteOk, ConvertTensorType(TensorType_FLOAT32, &type, reporter));
+ EXPECT_EQ(kTfLiteFloat32, type);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.cc b/tensorflow/contrib/lite/core/api/op_resolver.cc
new file mode 100644
index 0000000000..55ee924843
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+namespace tflite {
+
+TfLiteStatus GetRegistrationFromOpCode(
+ const OperatorCode* opcode, const OpResolver& op_resolver,
+ ErrorReporter* error_reporter, const TfLiteRegistration** registration) {
+ TfLiteStatus status = kTfLiteOk;
+ *registration = nullptr;
+ auto builtin_code = opcode->builtin_code();
+ int version = opcode->version();
+
+ if (builtin_code > BuiltinOperator_MAX ||
+ builtin_code < BuiltinOperator_MIN) {
+ error_reporter->Report(
+ "Op builtin_code out of range: %d. Are you using old TFLite binary "
+ "with newer model?",
+ builtin_code);
+ status = kTfLiteError;
+ } else if (builtin_code != BuiltinOperator_CUSTOM) {
+ *registration = op_resolver.FindOp(builtin_code, version);
+ if (*registration == nullptr) {
+ error_reporter->Report(
+ "Didn't find op for builtin opcode '%s' version '%d'\n",
+ EnumNameBuiltinOperator(builtin_code), version);
+ status = kTfLiteError;
+ }
+ } else if (!opcode->custom_code()) {
+ error_reporter->Report(
+ "Operator with CUSTOM builtin_code has no custom_code.\n");
+ status = kTfLiteError;
+ } else {
+ const char* name = opcode->custom_code()->c_str();
+ *registration = op_resolver.FindOp(name, version);
+ if (*registration == nullptr) {
+ error_reporter->Report(
+ "Didn't find custom op for name '%s' with version %d\n", name,
+ version);
+ status = kTfLiteError;
+ }
+ }
+ return status;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.h b/tensorflow/contrib/lite/core/api/op_resolver.h
new file mode 100644
index 0000000000..5f5e6b2736
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Abstract interface that returns TfLiteRegistrations given op codes or custom
+// op names. This is the mechanism that ops being referenced in the flatbuffer
+// model are mapped to executable function pointers (TfLiteRegistrations).
+class OpResolver {
+ public:
+ // Finds the op registration for a builtin operator by enum code.
+ virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const = 0;
+ // Finds the op registration of a custom operator by op name.
+ virtual const TfLiteRegistration* FindOp(const char* op,
+ int version) const = 0;
+ virtual ~OpResolver() {}
+};
+
+// Handles the logic for converting between an OperatorCode structure extracted
+// from a flatbuffer and information about a registered operator implementation.
+TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
+ const OpResolver& op_resolver,
+ ErrorReporter* error_reporter,
+ const TfLiteRegistration** registration);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/core/api/op_resolver_test.cc b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
new file mode 100644
index 0000000000..167463110e
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+ // Do nothing.
+ return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+ // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+class MockOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(BuiltinOperator op,
+ int version) const override {
+ if (op == BuiltinOperator_CONV_2D) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
+ if (strcmp(op, "mock_custom") == 0) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+};
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() : buffer_size_(0) {}
+ int Report(const char* format, va_list args) override {
+ buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+ return buffer_size_;
+ }
+ char* GetBuffer() { return buffer_; }
+ int GetBufferSize() { return buffer_size_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+ int buffer_size_;
+};
+
+} // namespace
+
+TEST(OpResolver, TestResolver) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+
+ const TfLiteRegistration* registration =
+ resolver->FindOp(BuiltinOperator_CONV_2D, 0);
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp(BuiltinOperator_CAST, 0);
+ EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp("mock_custom", 0);
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp("nonexistent_custom", 0);
+ EXPECT_EQ(nullptr, registration);
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeConv) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset =
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+ &registration));
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+ EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCast) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset =
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CAST, nullptr, 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+ reporter, &registration));
+ EXPECT_EQ(nullptr, registration);
+ EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCustom) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+ builder, BuiltinOperator_CUSTOM, "mock_custom", 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+ &registration));
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+ EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeNonexistentCustom) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+ builder, BuiltinOperator_CUSTOM, "nonexistent_custom", 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+ reporter, &registration));
+ EXPECT_EQ(nullptr, registration);
+ EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index b6b2357873..bf5d91899c 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -16,6 +16,7 @@ cc_library(
deps = [
":util",
"//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
@@ -54,6 +55,7 @@ cc_library(
":delegate_data",
":kernel",
":util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:util",
] + select({
@@ -104,6 +106,7 @@ tf_cc_test(
":delegate_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -117,6 +120,7 @@ cc_library(
":delegate_data",
":util",
"@flatbuffers",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:string",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -170,6 +174,7 @@ cc_library(
hdrs = ["util.h"],
deps = [
"//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
index a28329ae7d..aaaa045840 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index 6d15ba47dc..70f3c15af4 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
index b3a0ffcec1..def063309f 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
index 0ee4db1ffb..274c3c082a 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
#include "tensorflow/contrib/lite/delegates/eager/util.h"
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h
index 100672c82d..2478abccaa 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.h
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace eager {
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index ff500d18f3..930cb99cb9 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD
index 954955f24b..4e7b2948fb 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/BUILD
+++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD
@@ -13,6 +13,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
],
@@ -29,6 +30,7 @@ tf_cc_test(
deps = [
":nnapi_delegate",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index 980a1cb4a0..e3eebac4da 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
index 44cca2fd28..4852b76974 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h
index 3c5f805f12..5c20eedc25 100644
--- a/tensorflow/contrib/lite/error_reporter.h
+++ b/tensorflow/contrib/lite/error_reporter.h
@@ -12,43 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
#define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
-#include <cstdarg>
-#include "tensorflow/contrib/lite/context.h"
-
-namespace tflite {
-
-// A functor that reports error to supporting system. Invoked similar to
-// printf.
-//
-// Usage:
-// ErrorReporter foo;
-// foo.Report("test %d", 5);
-// or
-// va_list args;
-// foo.Report("test %d", args); // where args is va_list
-//
-// Subclass ErrorReporter to provide another reporting destination.
-// For example, if you have a GUI program, you might redirect to a buffer
-// that drives a GUI error log box.
-class ErrorReporter {
- public:
- virtual ~ErrorReporter();
- virtual int Report(const char* format, va_list args) = 0;
- int Report(const char* format, ...);
- int ReportError(void*, const char* format, ...);
-};
-
-// An error reporter that simplify writes the message to stderr.
-struct StderrReporter : public ErrorReporter {
- int Report(const char* format, va_list args) override;
-};
-
-// Return the default error reporter (output to stderr).
-ErrorReporter* DefaultErrorReporter();
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
#endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index 8fc07e8eb7..ea4a543252 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -78,6 +78,7 @@ cc_test(
data = ["//tensorflow/contrib/lite:testdata/add.bin"],
deps = [
":c_api",
+ "//tensorflow/contrib/lite:context",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index a4ab0e8c30..c589cf71ea 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
+#include <memory>
+
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
@@ -29,12 +31,14 @@ extern "C" {
TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
auto model = tflite::FlatBufferModel::BuildFromBuffer(
static_cast<const char*>(model_data), model_size);
- return model ? new TFL_Model{std::move(model)} : nullptr;
+ std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+ return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
}
TFL_Model* TFL_NewModelFromFile(const char* model_path) {
auto model = tflite::FlatBufferModel::BuildFromFile(model_path);
- return model ? new TFL_Model{std::move(model)} : nullptr;
+ std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+ return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
}
void TFL_DeleteModel(TFL_Model* model) { delete model; }
@@ -72,7 +76,7 @@ TFL_Interpreter* TFL_NewInterpreter(
}
}
- return new TFL_Interpreter{std::move(interpreter)};
+ return new TFL_Interpreter{model->impl, std::move(interpreter)};
}
void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
@@ -129,6 +133,8 @@ void* TFL_TensorData(const TFL_Tensor* tensor) {
return static_cast<void*>(tensor->data.raw);
}
+const char* TFL_TensorName(const TFL_Tensor* tensor) { return tensor->name; }
+
TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data,
size_t input_data_size) {
if (tensor->bytes != input_data_size) {
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index 3757349b55..b429e76870 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -93,7 +93,8 @@ typedef struct TFL_Interpreter TFL_Interpreter;
// failure.
//
// * `model` must be a valid model instance. The caller retains ownership of the
-// object, and can destroy it immediately after creating the interpreter.
+// object, and can destroy it immediately after creating the interpreter; the
+// interpreter will maintain its own reference to the underlying model data.
// * `optional_options` may be null. The caller retains ownership of the object,
// and can safely destroy it immediately after creating the interpreter.
//
@@ -145,6 +146,11 @@ TFL_CAPI_EXPORT extern int32_t TFL_InterpreterGetOutputTensorCount(
// Returns the tensor associated with the output index.
// REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor)
+//
+// NOTE: The shape and underlying data buffer for output tensors may be not
+// be available until after the output tensor has been both sized and allocated.
+// In general, best practice is to interact with the output tensor *after*
+// calling TFL_InterpreterInvoke().
TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor(
const TFL_Interpreter* interpreter, int32_t output_index);
@@ -172,12 +178,15 @@ TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor);
// Returns a pointer to the underlying data buffer.
//
-// Note: The result may be null if tensors have not yet been allocated, e.g.,
+// NOTE: The result may be null if tensors have not yet been allocated, e.g.,
// if the Tensor has just been created or resized and `TFL_AllocateTensors()`
// has yet to be called, or if the output tensor is dynamically sized and the
// interpreter hasn't been invoked.
TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor);
+// Returns the (null-terminated) name of the tensor.
+TFL_CAPI_EXPORT extern const char* TFL_TensorName(const TFL_Tensor* tensor);
+
// Copies from the provided input buffer into the tensor's buffer.
// REQUIRES: input_data_size == TFL_TensorByteSize(tensor)
TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer(
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index c5c612a4c6..60c2e4e2cd 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -24,7 +24,8 @@ limitations under the License.
// not be depended on.
struct TFL_Model {
- std::unique_ptr<tflite::FlatBufferModel> impl;
+ // Sharing is safe as FlatBufferModel is const.
+ std::shared_ptr<const tflite::FlatBufferModel> impl;
};
struct TFL_InterpreterOptions {
@@ -35,6 +36,9 @@ struct TFL_InterpreterOptions {
};
struct TFL_Interpreter {
+ // Taking a reference to the (const) model data avoids lifetime-related issues
+ // and complexity with the TFL_Model's existence.
+ std::shared_ptr<const tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> impl;
};
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
index a631dae890..649dac8d1a 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
@@ -55,6 +55,8 @@ TEST(CApiSimple, Smoke) {
EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1);
EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2);
EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2);
+ EXPECT_NE(TFL_TensorData(input_tensor), nullptr);
+ EXPECT_STREQ(TFL_TensorName(input_tensor), "input");
std::array<float, 2> input = {1.f, 3.f};
ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(),
@@ -70,6 +72,8 @@ TEST(CApiSimple, Smoke) {
EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1);
EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2);
EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2);
+ EXPECT_NE(TFL_TensorData(output_tensor), nullptr);
+ EXPECT_STREQ(TFL_TensorName(output_tensor), "output");
std::array<float, 2> output;
ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(),
diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD
index 9c06c4ebd9..4786cc62f9 100644
--- a/tensorflow/contrib/lite/experimental/kernels/BUILD
+++ b/tensorflow/contrib/lite/experimental/kernels/BUILD
@@ -53,6 +53,7 @@ cc_library(
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -61,8 +62,8 @@ cc_library(
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@flatbuffers",
],
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
index 121997dcb2..8442c4d46c 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include <vector>
#include "flatbuffers/flexbuffers.h" // flatbuffers
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h
index 77268d7aeb..8ee83827bb 100644
--- a/tensorflow/contrib/lite/graph_info.h
+++ b/tensorflow/contrib/lite/graph_info.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 5ab53f4c1d..3f8f4d198f 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -21,9 +21,9 @@ limitations under the License.
#include <cstring>
#include "tensorflow/contrib/lite/arena_planner.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/nnapi_delegate.h"
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 2b1f1819b9..f0cd178c19 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -23,10 +23,11 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 5bcf0927d8..cdede430e2 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 06f46fb923..781289ceb2 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -35,6 +35,7 @@ java_binary(
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
],
main_class = "org.tensorflow.ovic.OvicValidator",
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java",
],
@@ -47,6 +48,7 @@ android_library(
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java:tensorflowlite",
"//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
@@ -61,6 +63,7 @@ java_library(
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
javacopts = JAVACOPTS,
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java:libtensorflowlite_jni.so",
"//tensorflow/contrib/lite/java:tensorflowlite_java",
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 55ca47fed7..06b35d77c8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <stdio.h>
#include <time.h>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
@@ -124,9 +124,9 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads);
+ jclass clazz,
+ jlong handle,
+ jint num_threads);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index c020f13d9c..2f73128bdf 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
#include <jni.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#ifdef __cplusplus
extern "C" {
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index b7c5cbf207..40f28aeab4 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -66,7 +66,7 @@ cc_library(
deps = [
":op_macros",
"//tensorflow/contrib/lite:arena_planner",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:optimized",
],
)
@@ -82,7 +82,7 @@ cc_library(
copts = tflite_copts(),
deps = [
":op_macros",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@gemmlowp",
],
)
@@ -93,7 +93,7 @@ cc_library(
"activation_functor.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -113,9 +113,9 @@ cc_library(
"kernel_util.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:round",
+ "//tensorflow/contrib/lite/kernels/internal:types",
],
)
@@ -147,6 +147,15 @@ tf_cc_test(
)
cc_library(
+ name = "padding",
+ srcs = [],
+ hdrs = ["padding.h"],
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+cc_library(
name = "builtin_op_kernels",
srcs = [
"activations.cc",
@@ -216,7 +225,6 @@ cc_library(
"unpack.cc",
],
hdrs = [
- "padding.h",
],
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
visibility = ["//visibility:private"],
@@ -225,18 +233,19 @@ cc_library(
":eigen_support",
":kernel_util",
":op_macros",
- "//tensorflow/contrib/lite:builtin_op_data",
+ ":padding",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@farmhash_archive//:farmhash",
"@flatbuffers",
@@ -251,6 +260,7 @@ cc_library(
":builtin_op_kernels",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -757,8 +767,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -774,8 +784,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1044,8 +1054,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1147,8 +1157,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1164,8 +1174,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1181,8 +1191,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1198,8 +1208,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1212,8 +1222,8 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1239,8 +1249,8 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
index 41ec3cca33..e075dc7054 100644
--- a/tensorflow/contrib/lite/kernels/activation_functor.h
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <cmath>
#include <cstdlib>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index 5cdd9fc94f..b2d9b84979 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index af9b5c7013..b4393e8097 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 6e05f5a9b2..b91e348c27 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 1170d84553..44ef587244 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h"
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index c5a5c0182f..1aa27602e5 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index 4efa9d596d..fe2865dfb9 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 6b8ecdd5c3..541f320138 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index d988ef8b33..2f896c5289 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 8dd48af57f..a7972140ac 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <algorithm>
#include <complex>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 8b4d778332..4cd96348a2 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 605a20ac3e..25ea556d5a 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 3ed0cdb131..ab6bdaecaa 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 21518156b8..347515f289 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 2b0f04489a..3a08f48b00 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
index 136697f945..d2906632d7 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "flatbuffers/flexbuffers.h" // flatbuffers
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index d7420ddd8e..7945c095b1 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index b235829642..feb1543f7b 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace EigenForTFLite {
struct ThreadPoolDevice;
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index e19779ea59..04995d70dd 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cmath>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index b2dff87e62..fe33f98eb0 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -37,8 +37,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
index d3be36993c..aa75b03990 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -65,8 +65,8 @@ limitations under the License.
#include <algorithm>
#include <cmath>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc
index ce03cdfe26..673e7be90a 100644
--- a/tensorflow/contrib/lite/kernels/exp.cc
+++ b/tensorflow/contrib/lite/kernels/exp.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc
index ed33012864..fa1140b19c 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims.cc
@@ -15,8 +15,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
index 50dc860e5a..a3bc1813db 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index 0ef1a50b30..f9bc3747cb 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index f7d5f5146d..59ff77f35b 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc
index 75cf19a5a7..5d62cd2755 100644
--- a/tensorflow/contrib/lite/kernels/floor_div.cc
+++ b/tensorflow/contrib/lite/kernels/floor_div.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index eaf5a67d67..7a71fcc219 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index 2b2a9e6620..badd2de11a 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index 1d4292955c..1b48884e09 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index 37af772c68..43cd2b3055 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
#include "public/gemmlowp.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace gemm_support {
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
index f37c66acb3..c0b3c3c0c5 100644
--- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -39,8 +39,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 464163bd78..a6fd4ac2dd 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -163,7 +163,7 @@ cc_library(
":tensor_utils",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -198,7 +198,7 @@ cc_library(
":round",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -220,13 +220,15 @@ cc_library(
"optimized/eigen_spatial_convolutions.h",
"optimized/eigen_tensor_reduced_instantiations_oss.h",
"optimized/multithreaded_conv.h",
+ # FIXME(petewarden) - This should be removed, since it's a header from the
+ # :tensor dependency below.
"tensor.h",
],
deps = [
":optimized_base",
+ ":tensor",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//third_party/eigen3",
],
)
@@ -236,7 +238,7 @@ cc_test(
srcs = ["tensor_test.cc"],
tags = ["no_oss"],
deps = [
- ":reference",
+ ":tensor",
"@com_google_googletest//:gtest",
],
)
@@ -296,7 +298,7 @@ cc_library(
":strided_slice_logic",
":types",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -326,7 +328,7 @@ cc_library(
":strided_slice_logic",
":types",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -341,11 +343,27 @@ cc_library(
)
cc_library(
+ name = "tensor",
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+# Deprecated version of :tensor, kept for backwards compatibility.
+cc_library(
name = "reference",
- hdrs = ["tensor.h"],
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
deps = [
":types",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -359,7 +377,7 @@ cc_library(
],
deps = [
":round",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
],
@@ -384,7 +402,7 @@ cc_library(
":cpu_check",
":round",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
"@arm_neon_2_x86_sse",
@@ -398,7 +416,7 @@ cc_library(
hdrs = ["kernel_utils.h"],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -441,7 +459,7 @@ cc_library(
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
"//tensorflow/contrib/lite/kernels:activation_functor",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
"@gemmlowp",
] + select({
@@ -517,7 +535,7 @@ cc_test(
],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest_main",
],
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index eb4d0108bd..e67fee11b8 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -45,7 +45,7 @@ limitations under the License.
#endif
#endif
-#include "public/gemmlowp.h"
+#include "fixedpoint/fixedpoint.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index b9dd40ddf9..56e9367878 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
-#include <algorithm>
-
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index 215ad04add..b5558cce55 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
namespace kernel_utils {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 921aae1303..5fb31889fe 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -26,7 +26,7 @@ limitations under the License.
#include <tuple>
#include <type_traits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 70b6994a2b..27418178fd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 5ca1b4b76f..630a6bbf29 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 7e53dc2fa2..f87760a6c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 2a30910c3f..77e60adc18 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <string.h>
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index f5b3a84f07..714b1164ee 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index a027a47726..0abacf85e1 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3488,8 +3488,7 @@ inline void Gather(const tflite::GatherParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& coords_shape, const int32* coords_data,
const RuntimeShape& output_shape, T* output_data) {
- // TODO(b/80418076): Enable these checks when moving legacy ops to
- // legacy_reference_ops.
+ // Enable these checks when moving legacy ops to legacy_reference_ops.
//
// TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
const int input_rank = op_params.input_rank;
@@ -3808,58 +3807,110 @@ inline void Pad(const tflite::PadParams& op_params,
}
template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask, int shrink_axis_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- // Note that the axis orders are reversed for runtime ops, so the indices,
- // strides and masks must be as well too.
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- TFLITE_DCHECK_EQ(stop_indices.size(), 4);
- TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 3);
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ // Note that the output_shape is not used herein.
+ tflite::StridedSliceParams params_copy = op_params;
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ // Reverse and pad to 4 dimensions because that is what the runtime code
+ // requires (ie. all shapes must be 4D and are given backwards).
+ strided_slice::StridedSlicePadIndices(&params_copy, 4);
+
+ const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
const int stop_b =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 3, start_b);
- const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 2);
+ strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
+ const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
const int stop_h =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 2, start_h);
- const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 1);
+ strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
+ const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
const int stop_w =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 1, start_w);
- const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 0);
+ strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
+ const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
const int stop_d =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 0, start_d);
+ strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
T* out_ptr = output_data;
for (int in_b = start_b;
- !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
- in_b += strides[3]) {
+ !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
+ in_b += params_copy.strides[0]) {
for (int in_h = start_h;
- !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
- in_h += strides[2]) {
+ !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
+ in_h += params_copy.strides[1]) {
for (int in_w = start_w;
- !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
- in_w += strides[1]) {
- for (int in_d = start_d;
- !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
- in_d += strides[0]) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
+ in_w += params_copy.strides[2]) {
+ for (int in_d = start_d; !strided_slice::LoopCondition(
+ in_d, stop_d, params_copy.strides[3]);
+ in_d += params_copy.strides[3]) {
+ *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
}
}
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline uint32 LegacyReverseBits32(uint32 n) {
+ n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+ n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+ n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+ return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+ ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+ std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+ std::reverse(p->strides, p->strides + p->strides_count);
+
+ p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+ (32 - p->start_indices_count);
+ p->ellipsis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+ (32 - p->start_indices_count);
+ p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+ (32 - p->start_indices_count);
+ p->new_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+ (32 - p->start_indices_count);
+ p->shrink_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+ (32 - p->start_indices_count);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices,
+ const std::vector<int>& stop_indices,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_EQ(start_indices.size(), 4);
+ auto op_params = strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+ strides);
+ StridedSliceReverseIndices(&op_params);
+
+ StridedSlice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
index 5994fad5c7..af5db1064c 100644
--- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
+++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <limits>
#include <vector>
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-
namespace strided_slice {
// Use until std::clamp() is available from C++17.
@@ -32,15 +32,51 @@ inline int Clamp(const int v, const int lo, const int hi) {
return v;
}
+inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
+ int dim_count) {
+ // Add indices and mask bits to fully include extra dimensions
+ TFLITE_CHECK_LE(dim_count, 4);
+ TFLITE_CHECK_GE(dim_count, p->start_indices_count);
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ const int pad_count = dim_count - p->start_indices_count;
+
+ // Pad indices at start, so move arrays by pad_count.
+ for (int i = p->start_indices_count - 1; i > 0; --i) {
+ p->strides[i + pad_count] = p->strides[i];
+ p->start_indices[i + pad_count] = p->start_indices[i];
+ p->stop_indices[i + pad_count] = p->stop_indices[i];
+ }
+ for (int i = 0; i < pad_count; ++i) {
+ p->start_indices[i] = 0;
+ p->stop_indices[i] = 0;
+ p->strides[i] = 1;
+ }
+
+ // Pad masks with 0s or 1s as required.
+ p->shrink_axis_mask <<= pad_count;
+ p->ellipsis_mask <<= pad_count;
+ p->new_axis_mask <<= pad_count;
+ p->begin_mask <<= pad_count;
+ p->end_mask <<= pad_count;
+ p->begin_mask |= (1 << pad_count) - 1;
+ p->end_mask |= (1 << pad_count) - 1;
+
+ p->start_indices_count = dim_count;
+ p->stop_indices_count = dim_count;
+ p->strides_count = dim_count;
+}
+
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size - 1] that can be used to index
// directly into the data.
-template <typename IntType>
-inline int StartForAxis(int begin_mask,
- std::vector<IntType> const& start_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis) {
- // Begin with the specified index
+inline int StartForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis) {
+ const auto begin_mask = params.begin_mask;
+ const auto* start_indices = params.start_indices;
+ const auto* strides = params.strides;
+ // Begin with the specified index.
int start = start_indices[axis];
// begin_mask override
@@ -57,7 +93,7 @@ inline int StartForAxis(int begin_mask,
}
// Handle negative indices
- int axis_size = input_shape[axis];
+ int axis_size = input_shape.Dims(axis);
if (start < 0) {
start += axis_size;
}
@@ -73,11 +109,14 @@ inline int StartForAxis(int begin_mask,
// element. ie. So if you were iterating through all elements of a 1D array of
// size 4, this function would return 4 as the stop, because it is one past the
// "real" indices of 0, 1, 2 & 3.
-template <typename IntType>
-inline int StopForAxis(int end_mask, int shrink_axis_mask,
- std::vector<IntType> const& stop_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis, int start_for_axis) {
+inline int StopForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis,
+ int start_for_axis) {
+ const auto end_mask = params.end_mask;
+ const auto shrink_axis_mask = params.shrink_axis_mask;
+ const auto* stop_indices = params.stop_indices;
+ const auto* strides = params.strides;
+
// Begin with the specified index
const bool shrink_axis = shrink_axis_mask & (1 << axis);
int stop = stop_indices[axis];
@@ -103,7 +142,7 @@ inline int StopForAxis(int end_mask, int shrink_axis_mask,
}
// Handle negative indices
- const int axis_size = input_shape[axis];
+ const int axis_size = input_shape.Dims(axis);
if (stop < 0) {
stop += axis_size;
}
@@ -127,6 +166,31 @@ inline bool LoopCondition(int index, int stop, int stride) {
return stride > 0 ? index >= stop : index <= stop;
}
+inline tflite::StridedSliceParams BuildStridedSliceParams(
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
+ const std::vector<int>& strides) {
+ tflite::StridedSliceParams op_params;
+ const int dims_count = start_indices.size();
+
+ op_params.start_indices_count = dims_count;
+ op_params.stop_indices_count = dims_count;
+ op_params.strides_count = dims_count;
+ for (int i = 0; i < dims_count; ++i) {
+ op_params.start_indices[i] = start_indices[i];
+ op_params.stop_indices[i] = stop_indices[i];
+ op_params.strides[i] = strides[i];
+ }
+
+ op_params.begin_mask = begin_mask;
+ op_params.ellipsis_mask = 0;
+ op_params.end_mask = end_mask;
+ op_params.new_axis_mask = 0;
+ op_params.shrink_axis_mask = shrink_axis_mask;
+
+ return op_params;
+}
+
} // namespace strided_slice
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index ee2af5b460..13106456df 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -17,44 +17,12 @@ limitations under the License.
#include <complex>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-template <typename T>
-inline T* GetTensorData(TfLiteTensor* tensor);
-
-template <>
-inline float* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline int16_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline int32_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline int64_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline bool* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
template <>
inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr
@@ -62,39 +30,6 @@ inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
: nullptr;
}
-template <typename T>
-inline const T* GetTensorData(const TfLiteTensor* tensor);
-
-template <>
-inline const float* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline const bool* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
template <>
inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr
@@ -102,56 +37,14 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
: nullptr;
}
-inline int RemapDim(int max_dimensions, int d) {
- return max_dimensions - d - 1;
-}
-
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
return GetTensorDims(data.data(), data.size());
}
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
-}
-
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
-inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return RuntimeShape();
- }
-
- auto* dims = tensor->dims;
- return RuntimeShape(dims->size, dims->data);
-}
-
// A list of tensors in a format that can be used by kernels like split and
// concatenation.
template <typename T>
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
new file mode 100644
index 0000000000..77e22a08b4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+template <typename T>
+inline T* GetTensorData(TfLiteTensor* tensor);
+
+template <>
+inline float* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline int16_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline int32_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline int64_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline bool* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+template <typename T>
+inline const T* GetTensorData(const TfLiteTensor* tensor);
+
+template <>
+inline const float* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline const bool* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+inline int RemapDim(int max_dimensions, int d) {
+ return max_dimensions - d - 1;
+}
+
+// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
+// even if the original tensors were not 4D. We should consider rewriting them
+// to take a more generic 'shape' object.
+inline Dims<4> GetTensorDims(const int data[], const int size) {
+ Dims<4> d;
+ for (int i = 0; i < 4; ++i) {
+ int src = size - i - 1;
+ if (src >= 0) {
+ d.sizes[i] = data[src];
+ } else {
+ d.sizes[i] = 1;
+ }
+ }
+ d.strides[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
+ }
+ return d;
+}
+
+inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return Dims<4>();
+ }
+
+ auto* dims = tensor->dims;
+ return GetTensorDims(dims->data, dims->size);
+}
+
+inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return RuntimeShape();
+ }
+
+ TfLiteIntArray* dims = tensor->dims;
+ const int dims_size = dims->size;
+ const int32_t* dims_data = dims->data;
+ return RuntimeShape(dims_size, dims_data);
+}
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 1439bf8c37..b0fe5adf65 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index dad924fc28..6458af714b 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include <gmock/gmock.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index ed46cd984f..e9a5fd7a40 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -16,9 +16,10 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
#include <algorithm>
+#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index 5b3536de0c..e02d7df9ef 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index 799c1528bd..334d2a2788 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
index c71f3b4701..f770cb35d1 100644
--- a/tensorflow/contrib/lite/kernels/logical.cc
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
index 69523b02cc..9fa1c5f100 100644
--- a/tensorflow/contrib/lite/kernels/lsh_projection.cc
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -59,8 +59,8 @@ limitations under the License.
#include <limits>
#include <memory>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include <farmhash.h>
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 74dc3f25f9..aaa3ce966e 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 0308a3976a..7cb01465ee 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 306f676619..66cf147d75 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
#include "flatbuffers/flexbuffers.h" // flatbuffers
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 92d8bc8b67..e0aac8a842 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc
index 4124c05388..0ddd0644f5 100644
--- a/tensorflow/contrib/lite/kernels/neg.cc
+++ b/tensorflow/contrib/lite/kernels/neg.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc
index 9ff3dca932..910aed6f14 100644
--- a/tensorflow/contrib/lite/kernels/one_hot.cc
+++ b/tensorflow/contrib/lite/kernels/one_hot.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index cc326a7d51..4cb98fdd19 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 3bce05353d..0d939405f6 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
index 3cb55f19a9..42b6b45d3b 100644
--- a/tensorflow/contrib/lite/kernels/padding.h
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 29a5be0683..6451142391 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
index d676de5b1d..1e96cc80b1 100644
--- a/tensorflow/contrib/lite/kernels/pow.cc
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index ca83797936..d94d821e87 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <limits>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index 0296152d68..61856ab9de 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -16,8 +16,9 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 49ba0571e2..f41147b2d6 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index dafa3aebab..fb045d15f3 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 3cdb5db209..3959502d91 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc
index dbcd2ef004..66d4c9e5c1 100644
--- a/tensorflow/contrib/lite/kernels/shape.cc
+++ b/tensorflow/contrib/lite/kernels/shape.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc
index c90a15b3a2..de80a4016e 100644
--- a/tensorflow/contrib/lite/kernels/skip_gram.cc
+++ b/tensorflow/contrib/lite/kernels/skip_gram.cc
@@ -33,8 +33,8 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
index 55e16506df..ccfee41b9c 100644
--- a/tensorflow/contrib/lite/kernels/slice.cc
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 8332ae32cf..3a10d2e60c 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index 9238e879f8..64c56c017b 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index fec2a6f0d9..178568e07c 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index b144486041..719e2dc606 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc
index 09a5662fd9..080c51cd18 100644
--- a/tensorflow/contrib/lite/kernels/squeeze.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index bed2117f9a..87ffcc4110 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 77a1f59689..1be0c83f17 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 6ba7959752..9903fd5c35 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc
index 5181a8f89a..49421eb870 100644
--- a/tensorflow/contrib/lite/kernels/tile.cc
+++ b/tensorflow/contrib/lite/kernels/tile.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc
index 4f78c224e5..e73ca7b750 100644
--- a/tensorflow/contrib/lite/kernels/tile_test.cc
+++ b/tensorflow/contrib/lite/kernels/tile_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index 2dd760bbfe..6c38b6739e 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
index 2abb89b617..16106fdafe 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index 800b0563d7..95359962e0 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index a9baa5c698..6f2d98ede8 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index c678f14930..63817bd886 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 0180c2c498..744ee7c109 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
index 4998f88b41..9ff06f8331 100644
--- a/tensorflow/contrib/lite/kernels/unpack.cc
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h
index 0294ec815c..2d4707f849 100644
--- a/tensorflow/contrib/lite/memory_planner.h
+++ b/tensorflow/contrib/lite/memory_planner.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
#define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc
index fa9a3cd1d8..92934d1fd1 100644
--- a/tensorflow/contrib/lite/mmap_allocation.cc
+++ b/tensorflow/contrib/lite/mmap_allocation.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <unistd.h>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index aa410ab002..241865b3d8 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -20,8 +20,9 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/contrib/lite/model.h"
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -42,41 +43,6 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
const char* kEmptyTensorName = "";
-TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
- ErrorReporter* error_reporter) {
- switch (tensor_type) {
- case TensorType_FLOAT32:
- *type = kTfLiteFloat32;
- break;
- case TensorType_INT16:
- *type = kTfLiteInt16;
- break;
- case TensorType_INT32:
- *type = kTfLiteInt32;
- break;
- case TensorType_UINT8:
- *type = kTfLiteUInt8;
- break;
- case TensorType_INT64:
- *type = kTfLiteInt64;
- break;
- case TensorType_STRING:
- *type = kTfLiteString;
- break;
- case TensorType_BOOL:
- *type = kTfLiteBool;
- break;
- case TensorType_COMPLEX64:
- *type = kTfLiteComplex64;
- break;
- default:
- error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
- EnumNameTensorType(tensor_type), tensor_type);
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
@@ -198,39 +164,10 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
auto opcodes = model_->operator_codes();
for (const OperatorCode* opcode : *opcodes) {
const TfLiteRegistration* registration = nullptr;
- auto builtin_code = opcode->builtin_code();
- int version = opcode->version();
-
- if (builtin_code > BuiltinOperator_MAX ||
- builtin_code < BuiltinOperator_MIN) {
- error_reporter_->Report(
- "Op builtin_code out or range: %d. Are you using old TFLite binary "
- "with newer model?",
- builtin_code);
- status = kTfLiteError;
- } else if (builtin_code != BuiltinOperator_CUSTOM) {
- registration = op_resolver_.FindOp(builtin_code, version);
- if (registration == nullptr) {
- error_reporter_->Report(
- "Didn't find op for builtin opcode '%s' version '%d'\n",
- EnumNameBuiltinOperator(builtin_code), version);
- status = kTfLiteError;
- }
- } else if (!opcode->custom_code()) {
- error_reporter_->Report(
- "Operator with CUSTOM builtin_code has no custom_code.\n");
- status = kTfLiteError;
- } else {
- const char* name = opcode->custom_code()->c_str();
- registration = op_resolver_.FindOp(name, version);
- flatbuffer_op_index_to_registration_types_.push_back(
- BuiltinOperator_CUSTOM);
- if (registration == nullptr) {
- error_reporter_->Report(
- "Didn't find custom op for name '%s' with version %d\n", name,
- version);
- status = kTfLiteError;
- }
+ status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
+ &registration);
+ if (status != kTfLiteOk) {
+ return status;
}
flatbuffer_op_index_to_registration_.push_back(registration);
}
@@ -247,565 +184,6 @@ std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
return ret;
}
-// Copies the contents from the flatbuffer int vector `flatbuffer` into the
-// int array `buffer`. `flat_vector` and `buffer` represent the same
-// configuration operation for a given operation.
-void FlatBufferIntVectorToArray(int max_size_of_buffer,
- const flatbuffers::Vector<int32_t>* flat_vector,
- int* buffer, ErrorReporter* error_reporter) {
- if (!flat_vector) {
- error_reporter->Report("Input array not provided for operation.\n");
- } else {
- int num_dimensions = flat_vector->Length();
- if (num_dimensions > max_size_of_buffer / sizeof(int)) {
- error_reporter->Report(
- "Found too many dimensions in the operation's input array.\n");
- } else {
- for (int i = 0; i < num_dimensions; ++i) {
- buffer[i] = flat_vector->Get(i);
- }
- }
- }
-}
-
-// Allocate a structure using C malloc, but make sure the structure is a
-// POD structure that doesn't require constructors to run. The reason we do
-// this, is that Interpreter's C extension part will take ownership and wants
-// to use malloc() and free().
-template <class T>
-T* MallocPOD() {
- static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
- return static_cast<T*>(malloc(sizeof(T)));
-}
-
-// Parse the appropriate data out of the op.
-//
-// This handles builtin data explicitly as there are flatbuffer schemas.
-// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
-// need to be released by calling `free`.`
-// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
-TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data) {
- auto parse_padding = [](Padding padding) {
- switch (padding) {
- case Padding_SAME:
- return kTfLitePaddingSame;
- case Padding_VALID:
- return kTfLitePaddingValid;
- }
- return kTfLitePaddingUnknown;
- };
- auto parse_activation = [](ActivationFunctionType activation) {
- switch (activation) {
- case ActivationFunctionType_NONE:
- return kTfLiteActNone;
- case ActivationFunctionType_RELU:
- return kTfLiteActRelu;
- case ActivationFunctionType_RELU_N1_TO_1:
- return kTfLiteActRelu1;
- case ActivationFunctionType_RELU6:
- return kTfLiteActRelu6;
- case ActivationFunctionType_TANH:
- return kTfLiteActTanh;
- case ActivationFunctionType_SIGN_BIT:
- return kTfLiteActSignBit;
- }
- return kTfLiteActNone;
- };
- auto parseLSHProjectionType = [](LSHProjectionType type) {
- switch (type) {
- case LSHProjectionType_SPARSE:
- return kTfLiteLshProjectionSparse;
- case LSHProjectionType_DENSE:
- return kTfLiteLshProjectionDense;
- default:
- return kTfLiteLshProjectionUnknown;
- }
- };
- auto parseCombinerType = [](CombinerType type) {
- switch (type) {
- case CombinerType_MEAN:
- return kTfLiteCombinerTypeMean;
- case CombinerType_SQRTN:
- return kTfLiteCombinerTypeSqrtn;
- case CombinerType_SUM:
- default:
- return kTfLiteCombinerTypeSum;
- }
- };
-
- *builtin_data = nullptr;
- switch (op_type) {
- case BuiltinOperator_CONV_2D: {
- TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
- if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
- params->padding = parse_padding(conv_params->padding());
- params->stride_width = conv_params->stride_w();
- params->stride_height = conv_params->stride_h();
- params->activation =
- parse_activation(conv_params->fused_activation_function());
-
- params->dilation_width_factor = conv_params->dilation_w_factor();
- params->dilation_height_factor = conv_params->dilation_h_factor();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_CAST: {
- TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
- if (auto* schema_params = op->builtin_options_as_CastOptions()) {
- auto in_status =
- ConvertTensorType(schema_params->in_data_type(),
- &params->in_data_type, error_reporter);
- auto out_status =
- ConvertTensorType(schema_params->out_data_type(),
- &params->out_data_type, error_reporter);
- if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
- free(params);
- return kTfLiteError;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_LSH_PROJECTION: {
- TfLiteLSHProjectionParams* params =
- MallocPOD<TfLiteLSHProjectionParams>();
- if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
- params->type = parseLSHProjectionType(lshParams->type());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_AVERAGE_POOL_2D:
- case BuiltinOperator_MAX_POOL_2D:
- case BuiltinOperator_L2_POOL_2D: {
- TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
- if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
- params->padding = parse_padding(pool_params->padding());
- params->stride_width = pool_params->stride_w();
- params->stride_height = pool_params->stride_h();
- params->filter_width = pool_params->filter_width();
- params->filter_height = pool_params->filter_height();
- params->activation =
- parse_activation(pool_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DEPTHWISE_CONV_2D: {
- TfLiteDepthwiseConvParams* params =
- MallocPOD<TfLiteDepthwiseConvParams>();
- if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
- params->padding = parse_padding(conv_params->padding());
- params->stride_width = conv_params->stride_w();
- params->stride_height = conv_params->stride_h();
- params->depth_multiplier = conv_params->depth_multiplier();
- params->activation =
- parse_activation(conv_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SVDF: {
- TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
- if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
- params->rank = svdf_params->rank();
- params->activation =
- parse_activation(svdf_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
- if (auto* sequence_rnn_params =
- op->builtin_options_as_SequenceRNNOptions()) {
- params->activation =
- parse_activation(sequence_rnn_params->fused_activation_function());
- params->time_major = sequence_rnn_params->time_major();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RNN: {
- TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
- if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
- params->activation =
- parse_activation(rnn_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
- TfLiteEmbeddingLookupSparseParams* params =
- MallocPOD<TfLiteEmbeddingLookupSparseParams>();
- if (auto* embedding_params =
- op->builtin_options_as_EmbeddingLookupSparseOptions()) {
- params->combiner = parseCombinerType(embedding_params->combiner());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_FULLY_CONNECTED: {
- TfLiteFullyConnectedParams* params =
- MallocPOD<TfLiteFullyConnectedParams>();
- if (auto* fully_connected_params =
- op->builtin_options_as_FullyConnectedOptions()) {
- params->activation = parse_activation(
- fully_connected_params->fused_activation_function());
- switch (fully_connected_params->weights_format()) {
- case FullyConnectedOptionsWeightsFormat_DEFAULT:
- params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
- break;
- case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
- params->weights_format =
- kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
- break;
- default:
- error_reporter->Report("Unhandled fully-connected weights format.");
- return kTfLiteError;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_HASHTABLE_LOOKUP:
- // no-op.
- break;
- case BuiltinOperator_SOFTMAX: {
- TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
- if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
- params->beta = softmax_params->beta();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_CONCATENATION: {
- TfLiteConcatenationParams* params =
- MallocPOD<TfLiteConcatenationParams>();
- if (auto* concatenation_params =
- op->builtin_options_as_ConcatenationOptions()) {
- params->activation =
- parse_activation(concatenation_params->fused_activation_function());
- params->axis = concatenation_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_MUL: {
- auto* params = MallocPOD<TfLiteMulParams>();
- if (auto* schema_params = op->builtin_options_as_MulOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ADD: {
- auto* params = MallocPOD<TfLiteAddParams>();
- if (auto* schema_params = op->builtin_options_as_AddOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DIV: {
- auto* params = MallocPOD<TfLiteDivParams>();
- if (auto* schema_params = op->builtin_options_as_DivOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SUB: {
- auto* params = MallocPOD<TfLiteSubParams>();
- if (auto* schema_params = op->builtin_options_as_SubOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_L2_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteL2NormParams>();
- if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
- if (auto* schema_params =
- op->builtin_options_as_LocalResponseNormalizationOptions()) {
- params->radius = schema_params->radius();
- params->bias = schema_params->bias();
- params->alpha = schema_params->alpha();
- params->beta = schema_params->beta();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
- case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
- if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
- params->activation =
- parse_activation(lstm_params->fused_activation_function());
- params->cell_clip = lstm_params->cell_clip();
- params->proj_clip = lstm_params->proj_clip();
- switch (lstm_params->kernel_type()) {
- case LSTMKernelType_FULL:
- params->kernel_type = kTfLiteLSTMFullKernel;
- break;
- case LSTMKernelType_BASIC:
- params->kernel_type = kTfLiteLSTMBasicKernel;
- break;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RESIZE_BILINEAR: {
- auto* params = MallocPOD<TfLiteResizeBilinearParams>();
- if (auto* schema_params =
- op->builtin_options_as_ResizeBilinearOptions()) {
- params->align_corners = schema_params->align_corners();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RESHAPE: {
- auto* params = MallocPOD<TfLiteReshapeParams>();
- if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
- auto* new_shape = schema_params->new_shape();
- FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
- params->shape, error_reporter);
- params->num_dimensions = new_shape->Length();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SKIP_GRAM: {
- TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
- if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
- params->ngram_size = skip_gram_params->ngram_size();
- params->max_skip_size = skip_gram_params->max_skip_size();
- params->include_all_ngrams = skip_gram_params->include_all_ngrams();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPACE_TO_DEPTH: {
- auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
- if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
- params->block_size = schema_params->block_size();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_GATHER: {
- TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
- params->axis = 0;
- if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
- params->axis = gather_params->axis();
- }
-
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_MEAN:
- case BuiltinOperator_REDUCE_MAX:
- case BuiltinOperator_REDUCE_MIN:
- case BuiltinOperator_REDUCE_PROD:
- case BuiltinOperator_SUM:
- case BuiltinOperator_REDUCE_ANY: {
- auto* params = MallocPOD<TfLiteReducerParams>();
- if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
- params->keep_dims = schema_params->keep_dims();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPLIT: {
- auto* params = MallocPOD<TfLiteSplitParams>();
- if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
- params->num_splits = schema_params->num_splits();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SQUEEZE: {
- auto* params = MallocPOD<TfLiteSqueezeParams>();
- if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
- const auto& squeeze_dims = schema_params->squeeze_dims();
- FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
- params->squeeze_dims, error_reporter);
- params->num_squeeze_dims = squeeze_dims->Length();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_STRIDED_SLICE: {
- auto* params = MallocPOD<TfLiteStridedSliceParams>();
- if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
- params->begin_mask = schema_params->begin_mask();
- params->end_mask = schema_params->end_mask();
- params->ellipsis_mask = schema_params->ellipsis_mask();
- params->new_axis_mask = schema_params->new_axis_mask();
- params->shrink_axis_mask = schema_params->shrink_axis_mask();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ARG_MAX: {
- auto* params = MallocPOD<TfLiteArgMaxParams>();
- if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
- ConvertTensorType(schema_params->output_type(), &params->output_type,
- error_reporter);
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ARG_MIN: {
- auto* params = MallocPOD<TfLiteArgMinParams>();
- if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
- ConvertTensorType(schema_params->output_type(), &params->output_type,
- error_reporter);
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_TRANSPOSE_CONV: {
- TfLiteTransposeConvParams* params =
- MallocPOD<TfLiteTransposeConvParams>();
- if (auto* transpose_conv_params =
- op->builtin_options_as_TransposeConvOptions()) {
- params->padding = parse_padding(transpose_conv_params->padding());
- params->stride_width = transpose_conv_params->stride_w();
- params->stride_height = transpose_conv_params->stride_h();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPARSE_TO_DENSE: {
- TfLiteSparseToDenseParams* params =
- MallocPOD<TfLiteSparseToDenseParams>();
- if (auto* sparse_to_dense_params =
- op->builtin_options_as_SparseToDenseOptions()) {
- params->validate_indices = sparse_to_dense_params->validate_indices();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SHAPE: {
- auto* params = MallocPOD<TfLiteShapeParams>();
- if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
- ConvertTensorType(schema_params->out_type(), &params->out_type,
- error_reporter);
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_PACK: {
- TfLitePackParams* params = MallocPOD<TfLitePackParams>();
- if (auto* pack_params = op->builtin_options_as_PackOptions()) {
- params->values_count = pack_params->values_count();
- params->axis = pack_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DELEGATE: {
- // TODO(ycling): Revisit when supporting saving delegated models.
- error_reporter->Report("DELEGATE op shouldn't exist in model.");
- return kTfLiteError;
- }
- case BuiltinOperator_FAKE_QUANT: {
- auto* params = MallocPOD<TfLiteFakeQuantParams>();
- if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
- params->min = schema_params->min();
- params->max = schema_params->max();
- params->num_bits = schema_params->num_bits();
- params->narrow_range = schema_params->narrow_range();
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ONE_HOT: {
- auto* params = MallocPOD<TfLiteOneHotParams>();
- if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
- params->axis = schema_params->axis();
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_UNPACK: {
- TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
- if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
- params->num = unpack_params->num();
- params->axis = unpack_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
-
- // Below are the ops with no builtin_data strcture.
- case BuiltinOperator_BATCH_TO_SPACE_ND:
- // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
- // ok for now, since there is no call implementation either.
- case BuiltinOperator_CALL:
- case BuiltinOperator_CONCAT_EMBEDDINGS:
- case BuiltinOperator_CUSTOM:
- case BuiltinOperator_DEQUANTIZE:
- case BuiltinOperator_EMBEDDING_LOOKUP:
- case BuiltinOperator_EQUAL:
- case BuiltinOperator_EXP:
- case BuiltinOperator_EXPAND_DIMS:
- case BuiltinOperator_FLOOR:
- case BuiltinOperator_GREATER:
- case BuiltinOperator_GREATER_EQUAL:
- case BuiltinOperator_LESS:
- case BuiltinOperator_LESS_EQUAL:
- case BuiltinOperator_LOG:
- case BuiltinOperator_LOGISTIC:
- case BuiltinOperator_LOG_SOFTMAX:
- case BuiltinOperator_MAXIMUM:
- case BuiltinOperator_MINIMUM:
- case BuiltinOperator_NEG:
- case BuiltinOperator_NOT_EQUAL:
- case BuiltinOperator_PAD:
- case BuiltinOperator_PADV2:
- case BuiltinOperator_PRELU:
- case BuiltinOperator_RELU:
- case BuiltinOperator_RELU6:
- case BuiltinOperator_RELU_N1_TO_1:
- case BuiltinOperator_RSQRT:
- case BuiltinOperator_SELECT:
- case BuiltinOperator_SIN:
- case BuiltinOperator_SLICE:
- case BuiltinOperator_SPACE_TO_BATCH_ND:
- case BuiltinOperator_SQRT:
- case BuiltinOperator_TANH:
- case BuiltinOperator_TILE:
- case BuiltinOperator_TOPK_V2:
- case BuiltinOperator_TRANSPOSE:
- case BuiltinOperator_POW:
- case BuiltinOperator_LOGICAL_OR:
- case BuiltinOperator_LOGICAL_AND:
- case BuiltinOperator_LOGICAL_NOT:
- case BuiltinOperator_FLOOR_DIV:
- break;
- }
- return kTfLiteOk;
-}
-
} // namespace
TfLiteStatus InterpreterBuilder::ParseNodes(
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 8bc9ecd7ce..6abdfcd079 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -35,9 +35,10 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_MODEL_H_
#include <memory>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index df4f60d4ad..ec7d46af7c 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc
index f6e435e982..8ee63d2a02 100644
--- a/tensorflow/contrib/lite/op_resolver.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver.cc
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/op_resolver.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
new file mode 100644
index 0000000000..c319041e9b
--- /dev/null
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+
+// Some versions of gcc doesn't support partial specialization in class scope,
+// so these are defined in a namescope.
+namespace op_resolver_hasher {
+template <typename V>
+struct ValueHasher {
+ size_t operator()(const V& v) const { return std::hash<V>()(v); }
+};
+
+template <>
+struct ValueHasher<tflite::BuiltinOperator> {
+ size_t operator()(const tflite::BuiltinOperator& v) const {
+ return std::hash<int>()(static_cast<int>(v));
+ }
+};
+
+template <typename T>
+struct OperatorKeyHasher {
+ size_t operator()(const T& x) const {
+ size_t a = ValueHasher<typename T::first_type>()(x.first);
+ size_t b = ValueHasher<typename T::second_type>()(x.second);
+ return CombineHashes({a, b});
+ }
+};
+} // namespace op_resolver_hasher
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+// MutableOpResolver resolver;
+// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+// InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+ void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+
+ private:
+ typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
+ typedef std::pair<std::string, int> CustomOperatorKey;
+
+ std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
+ builtins_;
+ std::unordered_map<CustomOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
+ custom_ops_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
index 10b7e31972..db690eaab9 100644
--- a/tensorflow/contrib/lite/op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 484842713d..817486e898 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index 2bdb2cc5c8..22359d557e 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -16,8 +16,8 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/interpreter.h"
class ANeuralNetworksModel;
diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h
index 9d7e3f2085..e93134cbde 100644
--- a/tensorflow/contrib/lite/op_resolver.h
+++ b/tensorflow/contrib/lite/op_resolver.h
@@ -12,83 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
#define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
-#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/schema/schema_generated.h"
-#include "tensorflow/contrib/lite/util.h"
-
-namespace tflite {
-
-// Abstract interface that returns TfLiteRegistrations given op codes or custom
-// op names. This is the mechanism that ops being referenced in the flatbuffer
-// model are mapped to executable function pointers (TfLiteRegistrations).
-class OpResolver {
- public:
- // Finds the op registration for a builtin operator by enum code.
- virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
- int version) const = 0;
- // Finds the op registration of a custom operator by op name.
- virtual const TfLiteRegistration* FindOp(const char* op,
- int version) const = 0;
- virtual ~OpResolver() {}
-};
-
-// Some versions of gcc doesn't support partial specialization in class scope,
-// so these are defined in a namescope.
-namespace op_resolver_hasher {
-template <typename V>
-struct ValueHasher {
- size_t operator()(const V& v) const { return std::hash<V>()(v); }
-};
-
-template <>
-struct ValueHasher<tflite::BuiltinOperator> {
- size_t operator()(const tflite::BuiltinOperator& v) const {
- return std::hash<int>()(static_cast<int>(v));
- }
-};
-
-template <typename T>
-struct OperatorKeyHasher {
- size_t operator()(const T& x) const {
- size_t a = ValueHasher<typename T::first_type>()(x.first);
- size_t b = ValueHasher<typename T::second_type>()(x.second);
- return CombineHashes({a, b});
- }
-};
-} // namespace op_resolver_hasher
-
-// An OpResolver that is mutable, also used as the op in gen_op_registration.
-// A typical usage:
-// MutableOpResolver resolver;
-// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
-// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
-// InterpreterBuilder(model, resolver)(&interpreter);
-class MutableOpResolver : public OpResolver {
- public:
- const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
- int version) const override;
- const TfLiteRegistration* FindOp(const char* op, int version) const override;
- void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
- void AddCustom(const char* name, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
-
- private:
- typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
- typedef std::pair<std::string, int> CustomOperatorKey;
-
- std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
- op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
- builtins_;
- std::unordered_map<CustomOperatorKey, TfLiteRegistration,
- op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
- custom_ops_;
-};
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
index f738315cf2..45d0d8735e 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.h
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <list>
#include <memory>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/stderr_reporter.cc
index 646913c026..e29a6345fd 100644
--- a/tensorflow/contrib/lite/error_reporter.cc
+++ b/tensorflow/contrib/lite/stderr_reporter.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
#include <cstdarg>
#include <cstdio>
@@ -22,26 +22,6 @@ limitations under the License.
namespace tflite {
-ErrorReporter::~ErrorReporter() {}
-
-int ErrorReporter::Report(const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
-// TODO(aselle): Make the name of ReportError on context the same, so
-// we can use the ensure functions w/o a context and w/ a reporter.
-int ErrorReporter::ReportError(void*, const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
int StderrReporter::Report(const char* format, va_list args) {
#ifdef __ANDROID__
// On Android stderr is not captured for applications, only for code run from
diff --git a/tensorflow/contrib/lite/stderr_reporter.h b/tensorflow/contrib/lite/stderr_reporter.h
new file mode 100644
index 0000000000..c6f4ffbdff
--- /dev/null
+++ b/tensorflow/contrib/lite/stderr_reporter.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+
+#include <cstdarg>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+namespace tflite {
+
+// An error reporter that simplify writes the message to stderr.
+struct StderrReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override;
+};
+
+// Return the default error reporter (output to stderr).
+ErrorReporter* DefaultErrorReporter();
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index a316a40b62..b991e999b6 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
index 57f129bf5e..d24627b509 100644
--- a/tensorflow/contrib/lite/string_util.h
+++ b/tensorflow/contrib/lite/string_util.h
@@ -42,7 +42,7 @@ limitations under the License.
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc
index d53fec7512..a583a9184b 100644
--- a/tensorflow/contrib/lite/string_util_test.cc
+++ b/tensorflow/contrib/lite/string_util_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/string_util.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 0b3a97d4f5..aad1ecaeb6 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -173,7 +173,6 @@ tf_cc_test(
srcs = ["tflite_driver_test.cc"],
data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
tags = [
- "no_oss", # b/112769036
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
@@ -215,6 +214,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string",
+ "//tensorflow/contrib/lite/core/api",
],
)
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h
index 8aa639157b..925791d390 100644
--- a/tensorflow/contrib/lite/testing/util.h
+++ b/tensorflow/contrib/lite/testing/util.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <cstdio>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index a75553db84..bea90f1ce8 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -372,6 +372,7 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
+ "//tensorflow/contrib/lite/kernels/internal:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index c25be078ff..f103bb94ae 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1314,12 +1314,16 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
// Compute output shape
for (int axis = 0; axis < num_input_axes; ++axis) {
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op->begin_mask, op->end_mask, op->shrink_axis_mask,
+ op->start_indices, op->stop_indices, op->strides);
int start_index = tflite::strided_slice::StartForAxis(
- op->begin_mask, op->start_indices, op->strides,
- input_array.shape().dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
int stop_index = tflite::strided_slice::StopForAxis(
- op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides,
- input_array.shape().dims().data(), axis, start_index);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
+
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 9d8bd4fc39..8853ed87e6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
std::vector<int> src_coord(num_input_axes);
std::vector<int> stop_for_axis(num_input_axes);
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices,
+ op.stop_indices, op.strides);
+
for (int axis = 0; axis < num_input_axes; axis++) {
- int start = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
- axis);
- src_coord[axis] = start;
+ int start_index = tflite::strided_slice::StartForAxis(
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
+ src_coord[axis] = start_index;
stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
- op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
- input_shape.dims().data(), axis, start);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
}
// In order to handle any number (N) of dimensions, we copy elements one by
@@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
src_coord[axis] = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides,
- input_shape.dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_shape), axis);
carry = true;
} else {
carry = false;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index bdeb203024..5f4b8cb66a 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,6 +28,7 @@ limitations under the License.
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/include/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -139,6 +140,10 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
// - For the remaining indices [0..i0), d0[i0] == 1.
bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
+inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) {
+ return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data());
+}
+
bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
// If there is a wildcard dimension (-1), this may return a negative value.
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
index a66812fe87..98e2835b2e 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
@@ -54,6 +54,7 @@ tf_cc_test(
linkopts = common_linkopts,
linkstatic = 1,
tags = [
+ "no_oss", # b/114307765
"tflite_not_portable_android",
"tflite_not_portable_ios",
],
diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile
index e30cc1d70e..59bdb10811 100644
--- a/tensorflow/contrib/lite/tools/make/Makefile
+++ b/tensorflow/contrib/lite/tools/make/Makefile
@@ -24,6 +24,21 @@ HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32
TARGET := $(HOST_OS)
TARGET_ARCH := $(HOST_ARCH)
+INCLUDES := \
+-I. \
+-I$(MAKEFILE_DIR)/../../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../../ \
+-I$(MAKEFILE_DIR)/downloads/ \
+-I$(MAKEFILE_DIR)/downloads/eigen \
+-I$(MAKEFILE_DIR)/downloads/gemmlowp \
+-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
+-I$(MAKEFILE_DIR)/downloads/farmhash/src \
+-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
+-I$(OBJDIR)
+# This is at the end so any globally-installed frameworks like protobuf don't
+# override local versions in the source tree.
+INCLUDES += -I/usr/local/include
+
# These are the default libraries needed, but they can be added to or
# overridden by the platform-specific settings in target makefiles.
LIBS := \
@@ -44,55 +59,17 @@ ARFLAGS := -r
TARGET_TOOLCHAIN_PREFIX :=
CC_PREFIX :=
-# These target-specific makefiles should modify or replace options like
-# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
-# based on platforms or architectures should happen within these files, to
-# keep this main makefile focused on the sources and dependencies.
-include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
-
-# Where compiled objects are stored.
-GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
-OBJDIR := $(GENDIR)obj/
-BINDIR := $(GENDIR)bin/
-LIBDIR := $(GENDIR)lib/
-
-INCLUDES := \
--I. \
--I$(MAKEFILE_DIR)/../../../../../ \
--I$(MAKEFILE_DIR)/../../../../../../ \
--I$(MAKEFILE_DIR)/downloads/ \
--I$(MAKEFILE_DIR)/downloads/eigen \
--I$(MAKEFILE_DIR)/downloads/gemmlowp \
--I$(MAKEFILE_DIR)/downloads/neon_2_sse \
--I$(MAKEFILE_DIR)/downloads/farmhash/src \
--I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
--I$(OBJDIR)
-# This is at the end so any globally-installed frameworks like protobuf don't
-# override local versions in the source tree.
-INCLUDES += -I/usr/local/include
-
-CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
-CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
-AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
-
# This library is the main target for this makefile. It will contain a minimal
# runtime that can be linked in to other programs.
LIB_NAME := libtensorflow-lite.a
-LIB_PATH := $(LIBDIR)$(LIB_NAME)
-
-# A small example program that shows how to link against the library.
-MINIMAL_PATH := $(BINDIR)minimal
# Benchmark static library and binary
BENCHMARK_LIB_NAME := benchmark-lib.a
BENCHMARK_BINARY_NAME := benchmark_model
-BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
-BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+# A small example program that shows how to link against the library.
MINIMAL_SRCS := \
tensorflow/contrib/lite/examples/minimal/minimal.cc
-MINIMAL_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
# What sources we want to compile, must be kept in sync with the main Bazel
# build files.
@@ -105,7 +82,9 @@ PROFILE_SUMMARIZER_SRCS := \
CORE_CC_ALL_SRCS := \
$(wildcard tensorflow/contrib/lite/*.cc) \
-$(wildcard tensorflow/contrib/lite/*.c)
+$(wildcard tensorflow/contrib/lite/*.c) \
+$(wildcard tensorflow/contrib/lite/c/*.c) \
+$(wildcard tensorflow/contrib/lite/core/api/*.cc)
ifneq ($(BUILD_TYPE),micro)
CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/contrib/lite/kernels/*.cc) \
@@ -136,10 +115,6 @@ tensorflow/contrib/lite/nnapi_delegate.cc
endif
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
-# File names of the intermediate files target compilation generates.
-TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
-LIB_OBJS := $(TF_LITE_CC_OBJS)
# Benchmark sources
BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark
@@ -151,6 +126,40 @@ BENCHMARK_SRCS := $(filter-out \
$(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \
$(BENCHMARK_ALL_SRCS))
+# These target-specific makefiles should modify or replace options like
+# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
+# based on platforms or architectures should happen within these files, to
+# keep this main makefile focused on the sources and dependencies.
+include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
+
+ALL_SRCS := \
+ $(MINIMAL_SRCS) \
+ $(PROFILER_SRCS) \
+ $(PROFILER_SUMMARY_SRCS) \
+ $(TF_LITE_CC_SRCS) \
+ $(BENCHMARK_SRCS)
+
+# Where compiled objects are stored.
+GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
+OBJDIR := $(GENDIR)obj/
+BINDIR := $(GENDIR)bin/
+LIBDIR := $(GENDIR)lib/
+
+LIB_PATH := $(LIBDIR)$(LIB_NAME)
+BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
+BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+MINIMAL_BINARY := $(BINDIR)minimal
+
+CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
+CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
+
+MINIMAL_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
+
+LIB_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
+
BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
@@ -164,7 +173,7 @@ $(OBJDIR)%.o: %.c
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
# The target that's compiled if there's no command-line arguments.
-all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY)
+all: $(LIB_PATH) $(MINIMAL_BINARY) $(BENCHMARK_BINARY)
# The target that's compiled for micro-controllers
micro: $(LIB_PATH)
@@ -178,19 +187,18 @@ $(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
-$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH)
+$(MINIMAL_BINARY): $(MINIMAL_OBJS) $(LIB_PATH)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
- -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \
+ -o $(MINIMAL_BINARY) $(MINIMAL_OBJS) \
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
-
$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS)
benchmark_lib: $(BENCHMARK_LIB)
-$(info $(BENCHMARK_BINARY))
+
$(BENCHMARK_BINARY) : $(BENCHMARK_LIB)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
@@ -213,4 +221,4 @@ cleantarget:
$(DEPDIR)/%.d: ;
.PRECIOUS: $(DEPDIR)/%.d
--include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS)))
+-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS)))
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
index 692efb9029..b863108aa4 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -141,6 +141,7 @@ bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) {
op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
op_code == BuiltinOperator_RNN ||
+ op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD
new file mode 100644
index 0000000000..67ff1ea124
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/BUILD
@@ -0,0 +1,20 @@
+# Example Estimator model
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "mnist_tflite",
+ srcs = [
+ "dataset.py",
+ "mnist_tflite.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py
new file mode 100644
index 0000000000..ba49dfcc9b
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/dataset.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tf.data.Dataset interface to the MNIST dataset.
+
+ This is cloned from
+ https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import shutil
+import tempfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+
+def read32(bytestream):
+ """Read 4 bytes from bytestream as an unsigned 32-bit integer."""
+ dt = np.dtype(np.uint32).newbyteorder('>')
+ return np.frombuffer(bytestream.read(4), dtype=dt)[0]
+
+
+def check_image_file_header(filename):
+ """Validate that filename corresponds to images for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_images, unused
+ rows = read32(f)
+ cols = read32(f)
+ if magic != 2051:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+ if rows != 28 or cols != 28:
+ raise ValueError(
+ 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
+ (f.name, rows, cols))
+
+
+def check_labels_file_header(filename):
+ """Validate that filename corresponds to labels for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_items, unused
+ if magic != 2049:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+
+
+def download(directory, filename):
+ """Download (and unzip) a file from the MNIST dataset if not already done."""
+ filepath = os.path.join(directory, filename)
+ if tf.gfile.Exists(filepath):
+ return filepath
+ if not tf.gfile.Exists(directory):
+ tf.gfile.MakeDirs(directory)
+ # CVDF mirror of http://yann.lecun.com/exdb/mnist/
+ url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
+ _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
+ print('Downloading %s to %s' % (url, zipped_filepath))
+ urllib.request.urlretrieve(url, zipped_filepath)
+ with gzip.open(zipped_filepath, 'rb') as f_in, \
+ tf.gfile.Open(filepath, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ os.remove(zipped_filepath)
+ return filepath
+
+
+def dataset(directory, images_file, labels_file):
+ """Download and parse MNIST dataset."""
+
+ images_file = download(directory, images_file)
+ labels_file = download(directory, labels_file)
+
+ check_image_file_header(images_file)
+ check_labels_file_header(labels_file)
+
+ def decode_image(image):
+ # Normalize from [0, 255] to [0.0, 1.0]
+ image = tf.decode_raw(image, tf.uint8)
+ image = tf.cast(image, tf.float32)
+ image = tf.reshape(image, [784])
+ return image / 255.0
+
+ def decode_label(label):
+ label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
+ label = tf.reshape(label, []) # label is a scalar
+ return tf.to_int32(label)
+
+ images = tf.data.FixedLengthRecordDataset(
+ images_file, 28 * 28, header_bytes=16).map(decode_image)
+ labels = tf.data.FixedLengthRecordDataset(
+ labels_file, 1, header_bytes=8).map(decode_label)
+ return tf.data.Dataset.zip((images, labels))
+
+
+def train(directory):
+ """tf.data.Dataset object for MNIST training data."""
+ return dataset(directory, 'train-images-idx3-ubyte',
+ 'train-labels-idx1-ubyte')
+
+
+def test(directory):
+ """tf.data.Dataset object for MNIST test data."""
+ return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
new file mode 100644
index 0000000000..7b8bf5b5db
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf # pylint: disable=g-bad-import-order
+from tensorflow.contrib.lite.tutorials import dataset
+flags = tf.app.flags
+
+flags.DEFINE_string('data_dir', '/tmp/data_dir',
+ 'Directory where data is stored.')
+flags.DEFINE_string('model_file', '',
+ 'The path to the TFLite flatbuffer model file.')
+
+
+flags = flags.FLAGS
+
+
+def test_image_generator():
+ # Generates an iterator over images
+ with tf.Session() as sess:
+ input_data = dataset.test(
+ flags.data_dir).make_one_shot_iterator().get_next()
+ try:
+ while True:
+ yield sess.run(input_data)
+ except tf.errors.OutOfRangeError:
+ pass
+
+
+def run_eval(interpreter, input_image):
+ """Performs evaluation for input image over specified model.
+
+ Args:
+ interpreter: TFLite interpreter initialized with model to execute.
+ input_image: Image input to the model.
+
+ Returns:
+ output: output tensor of model being executed.
+ """
+
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # Test model on the input images.
+ input_image = np.reshape(input_image, input_details[0]['shape'])
+ interpreter.set_tensor(input_details[0]['index'], input_image)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ output = np.squeeze(output_data)
+ return output
+
+
+def main(_):
+ interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file)
+ interpreter.allocate_tensors()
+ num_correct, total = 0, 0
+ for input_data in test_image_generator():
+ output = run_eval(interpreter, input_data[0])
+ total += 1
+ if output == input_data[1]:
+ num_correct += 1
+ if total % 500 == 0:
+ print('Accuracy after %i images: %f' %
+ (total, float(num_correct) / float(total)))
+
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main)
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index f5b208afbb..6d81f844f8 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -22,7 +22,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_UTIL_H_
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
index 32bf917a59..c5c1709f1d 100644
--- a/tensorflow/contrib/lite/util_test.cc
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 22b11f1c57..7d26429f9c 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -56,6 +56,7 @@ tensorflow/core/lib/hash/hash.cc
tensorflow/core/lib/hash/crc32c.cc
tensorflow/core/lib/hash/crc32c_accelerate.cc
tensorflow/core/lib/core/threadpool.cc
+tensorflow/core/lib/core/stringpiece.cc
tensorflow/core/lib/core/status.cc
tensorflow/core/lib/core/coding.cc
tensorflow/core/lib/core/arena.cc
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 93e589907e..2e4d61d931 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -159,8 +159,10 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variables",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index f026f437dc..f55209ec49 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -25,7 +25,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -48,12 +47,7 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
may lead to different empirical results.
"""
- def _apply_sparse_shared(self,
- grad,
- var,
- indices,
- scatter_update,
- scatter_sub):
+ def _apply_sparse(self, grad, var):
beta1_power, beta2_power = self._get_beta_accumulators()
beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
@@ -65,51 +59,56 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
# \\(m := beta1 * m + (1 - beta1) * g_t\\)
m = self.get_slot(var, "m")
- m_t = scatter_update(m, indices,
- beta1_t * array_ops.gather(m, indices) +
- (1 - beta1_t) * grad)
+ m_t = state_ops.scatter_update(m, grad.indices,
+ beta1_t * array_ops.gather(m, grad.indices) +
+ (1 - beta1_t) * grad.values,
+ use_locking=self._use_locking)
# \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
v = self.get_slot(var, "v")
- v_t = scatter_update(v, indices,
- beta2_t * array_ops.gather(v, indices) +
- (1 - beta2_t) * math_ops.square(grad))
+ v_t = state_ops.scatter_update(v, grad.indices,
+ beta2_t * array_ops.gather(v, grad.indices) +
+ (1 - beta2_t) * math_ops.square(grad.values),
+ use_locking=self._use_locking)
# \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
- m_t_slice = array_ops.gather(m_t, indices)
- v_t_slice = array_ops.gather(v_t, indices)
+ m_t_slice = array_ops.gather(m_t, grad.indices)
+ v_t_slice = array_ops.gather(v_t, grad.indices)
denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t
- var_update = scatter_sub(var, indices,
- lr * m_t_slice / denominator_slice)
+ var_update = state_ops.scatter_sub(var, grad.indices,
+ lr * m_t_slice / denominator_slice,
+ use_locking=self._use_locking)
return control_flow_ops.group(var_update, m_t, v_t)
- def _apply_sparse(self, grad, var):
- return self._apply_sparse_shared(
- grad.values, var, grad.indices,
- self._scatter_update,
- self._scatter_sub)
-
def _resource_apply_sparse(self, grad, var, indices):
- return self._apply_sparse_shared(
- grad, var, indices,
- self._resource_scatter_update,
- self._resource_scatter_sub)
-
- # Utility functions for updating resource or non-resource variables.
- def _scatter_update(self, x, i, v):
- return state_ops.scatter_update(
- x, i, v, use_locking=self._use_locking)
-
- def _scatter_sub(self, x, i, v):
- return state_ops.scatter_sub(
- x, i, v, use_locking=self._use_locking)
-
- def _resource_scatter_update(self, x, i, v):
- update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v)
- with ops.control_dependencies([update_op]):
- return x.value()
-
- def _resource_scatter_sub(self, x, i, v):
- sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v)
- with ops.control_dependencies([sub_op]):
- return x.value()
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
+ lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+ beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+ beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+ epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+
+ # \\(m := beta1 * m + (1 - beta1) * g_t\\)
+ m = self.get_slot(var, "m")
+ m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
+ m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
+ indices,
+ m_t_slice)
+
+ # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
+ v = self.get_slot(var, "v")
+ v_t_slice = (beta2_t * array_ops.gather(v, indices) +
+ (1 - beta2_t) * math_ops.square(grad))
+ v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
+ indices,
+ v_t_slice)
+
+ # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
+ var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
+ var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
+ indices,
+ var_slice)
+
+ return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index d3e9e89502..f08ffaa36f 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -19,12 +19,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.opt.python.training import lazy_adam_optimizer
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -50,9 +53,10 @@ def adam_update_numpy(param,
return param_t, m_t, v_t
-class AdamOptimizerTest(test.TestCase):
+class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
- def doTestSparse(self, use_resource=False):
+ @parameterized.parameters([False, True])
+ def testSparse(self, use_resource):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
# Initialize variables for numpy implementation.
@@ -68,6 +72,7 @@ class AdamOptimizerTest(test.TestCase):
else:
var0 = variables.Variable(var0_np)
var1 = variables.Variable(var1_np)
+
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np),
@@ -99,18 +104,17 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
- def testSparse(self):
- self.doTestSparse(use_resource=False)
-
- def testResourceSparse(self):
- self.doTestSparse(use_resource=True)
-
- def testSparseDevicePlacement(self):
+ @parameterized.parameters([False, True])
+ def testSparseDevicePlacement(self, use_resource):
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(force_gpu=test.is_gpu_available()):
# If a GPU is available, tests that all optimizer ops can be placed on
# it (i.e. they have GPU kernels).
- var = variables.Variable([[1.0], [2.0]])
+ if use_resource:
+ var = resource_variable_ops.ResourceVariable([[1.0], [2.0]])
+ else:
+ var = variables.Variable([[1.0], [2.0]])
+
indices = constant_op.constant([0, 1], dtype=index_dtype)
gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0)
@@ -118,13 +122,21 @@ class AdamOptimizerTest(test.TestCase):
variables.global_variables_initializer().run()
minimize_op.run()
- def testSparseRepeatedIndices(self):
+ @parameterized.parameters([False, True])
+ def testSparseRepeatedIndices(self, use_resource):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
- repeated_index_update_var = variables.Variable(
- [[1.0], [2.0]], dtype=dtype)
- aggregated_update_var = variables.Variable(
- [[1.0], [2.0]], dtype=dtype)
+ if use_resource:
+ repeated_index_update_var = resource_variable_ops.ResourceVariable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = resource_variable_ops.ResourceVariable(
+ [[1.0], [2.0]], dtype=dtype)
+ else:
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+
grad_repeated_index = ops.IndexedSlices(
constant_op.constant(
[0.1, 0.1], shape=[2, 1], dtype=dtype),
@@ -150,6 +162,204 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllClose(aggregated_update_var.eval(),
repeated_index_update_var.eval())
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ learning_rate = lambda: 0.001
+ beta1 = lambda: 0.9
+ beta2 = lambda: 0.999
+ epsilon = lambda: 1e-8
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ beta1 = beta1()
+ beta2 = beta2()
+ epsilon = epsilon()
+
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertIsNotNone(beta1_power)
+ self.assertIsNotNone(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
+
+ if not context.executing_eagerly():
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/Adam:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testBasic(self):
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testTwoSessions(self):
+ optimizer = lazy_adam_optimizer.LazyAdamOptimizer()
+
+ with context.eager_mode():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ g = ops.Graph()
+ with g.as_default():
+ with self.session(graph=g):
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ gg = ops.Graph()
+ with gg.as_default():
+ with self.session(graph=gg):
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+ # If the optimizer saves any state not keyed by graph the following line
+ # fails.
+ optimizer.apply_gradients([(grads0, var0)])
+
+ def testSlotsUniqueEager(self):
+ with context.eager_mode():
+ v1 = resource_variable_ops.ResourceVariable(1.)
+ v2 = resource_variable_ops.ResourceVariable(1.)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(1.)
+ opt.minimize(lambda: v1 + v2)
+ # There should be two non-slot variables, and two unique slot variables
+ # for v1 and v2 respectively.
+ self.assertEqual(6, len(set(opt.variables())))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 499fec4ffa..c59f667f6a 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -22,6 +22,7 @@ py_test(
":common",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
@@ -89,7 +90,6 @@ py_library(
":common",
":graph_matcher",
":input_to_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
@@ -171,7 +171,6 @@ py_library(
":graph_matcher",
":input_to_ops",
":quant_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index bf648e158e..b27117dd48 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix):
return s[len(prefix):]
else:
return s
+
+
+def RerouteTensor(t0, t1, can_modify=None):
+ """Reroute the end of the tensor t0 to the ends of the tensor t1.
+
+ Args:
+ t0: a tf.Tensor.
+ t1: a tf.Tensor.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+
+ Returns:
+ The number of individual modifications made by the function.
+ """
+ nb_update_inputs = 0
+ consumers = t1.consumers()
+ if can_modify is not None:
+ consumers = [c for c in consumers if c in can_modify]
+ consumers_indices = {}
+ for c in consumers:
+ consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1]
+ for c in consumers:
+ for i in consumers_indices[c]:
+ c._update_input(i, t0) # pylint: disable=protected-access
+ nb_update_inputs += 1
+ return nb_update_inputs
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 06c62f2d26..2b26302f8a 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -62,6 +64,29 @@ class CommonTest(test_util.TensorFlowTestCase):
_, step_val = sess.run([b, quantization_step_tensor])
self.assertEqual(step_val, 2)
+ def testRerouteTensor(self):
+ a = constant_op.constant(1, name='a')
+ b = constant_op.constant(2, name='b')
+ c = constant_op.constant(3, name='c')
+ d = constant_op.constant(4, name='d')
+
+ add_ac = math_ops.add(a, c)
+ add_ad = math_ops.add(a, d)
+
+ # Ensure that before rerouting the inputs are what we think.
+ self._CheckOpHasInputs(add_ac.op, [a, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ # references to tensor a should be replaced with b for all ops in
+ # can_modify. This means add_ac will be changed but add_ad will not.
+ common.RerouteTensor(b, a, can_modify=[add_ac.op])
+ self._CheckOpHasInputs(add_ac.op, [b, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ def _CheckOpHasInputs(self, op, inputs):
+ for i in inputs:
+ self.assertIn(i, op.inputs)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index d9f179bee4..2971b28f45 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -134,8 +133,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
- nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
- match.output_tensor)
+ nodes_modified_count = common.RerouteTensor(bias_add_tensor,
+ match.output_tensor)
if nodes_modified_count == 0:
raise ValueError('Folding batch norms failed, %s had no outputs.' %
match.output_tensor.name)
@@ -370,8 +369,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: match.bn_decay_mean_tensor,
name='freeze_moving_mean')
- graph_editor.reroute_ts(
- [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+ common.RerouteTensor(
+ bn_decay_mean_out,
+ match.bn_decay_mean_tensor,
can_modify=bn_decay_mean_consumers)
bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
@@ -380,8 +380,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: bn_decay_zero,
lambda: match.bn_decay_var_tensor,
name='freeze_moving_var')
- graph_editor.reroute_ts(
- [bn_decay_var_out], [match.bn_decay_var_tensor],
+ common.RerouteTensor(
+ bn_decay_var_out,
+ match.bn_decay_var_tensor,
can_modify=bn_decay_var_consumers)
correction_recip = utils.smart_cond(
@@ -486,9 +487,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
activation = common.GetEndpointActivationOp(graph, bn)
if activation:
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[activation])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[activation])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % activation.name)
continue
@@ -497,9 +497,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
# operations instead of Relu* above.
add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[add_bypass])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 2ddbd73ea6..e88db0acd5 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -592,8 +591,8 @@ def _InsertQuantOp(context,
name=name_prefix + '/delayed_quant')
if consumers:
- tensors_modified_count = graph_editor.reroute_ts(
- [quant], [inputs], can_modify=consumers)
+ tensors_modified_count = common.RerouteTensor(
+ quant, inputs, can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 5874245d58..4e67d80558 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -212,6 +212,7 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
+ tags = ["noasan"],
)
tf_custom_op_library(
@@ -279,7 +280,10 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
- tags = ["no_oss"],
+ tags = [
+ "no_oss",
+ "noasan",
+ ],
)
tf_cc_test(
@@ -287,6 +291,7 @@ tf_cc_test(
size = "small",
srcs = ["ops/gru_ops_test.cc"],
data = [":python/ops/_gru_ops.so"],
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@@ -306,6 +311,7 @@ tf_cc_test(
size = "small",
srcs = ["ops/lstm_ops_test.cc"],
data = [":python/ops/_lstm_ops.so"],
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index f74c95f962..06c481672c 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -97,10 +97,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -2448,10 +2448,10 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -2802,9 +2802,11 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
Training of Deep Neural Networks
The default LSTM implementation based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
+
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The class uses optional peephole connections, optional cell clipping
and an optional projection layer.
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index db970deff5..0042d37acd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -134,19 +134,19 @@ def _get_default_head(params, weights_name, output_type, name=None):
weight_column=weights_name,
label_dimension=params.num_outputs,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
if params.num_classes == 2:
return core_head_lib.binary_classification_head(
weight_column=weights_name,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
return core_head_lib.multi_class_head(
n_classes=params.num_classes,
weight_column=weights_name,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
def get_model_fn(params,
graph_builder_class,
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 537d94b797..3c0456dc2f 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -33,6 +33,7 @@
@@shard
@@batch_parallel
@@rewrite
+@@outside_compilation
@@CrossShardOptimizer
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 08e0465b71..d8c3872363 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -258,6 +258,8 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
def set_weights(self, weights):
+ # TODO(power): Figure out whether we really need this given there is no
+ # caller for this API yet.
self._opt.set_weights()
def get_weights(self):
@@ -282,9 +284,9 @@ def _valid_name(tensor_name):
def _replicated_optimizer(opt):
"""Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
- if tpu_function.get_tpu_context().number_of_shards == 1:
- return opt
-
+ # Always wrap `opt` with CrossShardOptimizer, even if we are running on a
+ # single core. This ensures Keras properly tracks and initializes optimizer
+ # variables.
if isinstance(opt, keras_optimizers.TFOptimizer):
return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
else:
@@ -1420,7 +1422,7 @@ class KerasTPUModel(models.Model):
y,
sample_weights,
batch_size)
- self._pipeline_fit_loop(
+ return self._pipeline_fit_loop(
x,
y,
sample_weights=sample_weights,
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 1e21cc5252..c1f90c3963 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -652,13 +652,28 @@ def split_compile_and_replicate(computation,
# TODO(phawkins): consider removing this code. It will
# be less confusing to clients if they knowingly choose to use resource
# variables.
+ # Partitioned variables is not supported (b/112311320).
+ def custom_getter(getter, name, *args, **kwargs):
+ partitioner = kwargs["partitioner"]
+ if partitioner is None:
+ return getter(name, *args, **kwargs)
+ else:
+ raise ValueError(
+ "Partitioned variables are not supported on TPU. Got "
+ "`partitioner` that is {}.".format(partitioner))
+
vscope = variable_scope.get_variable_scope()
+
saved_use_resource = vscope.use_resource
+ saved_custom_getter = vscope.custom_getter
+
vscope.set_use_resource(True)
+ vscope.set_custom_getter(custom_getter)
outputs = computation(*computation_inputs)
vscope.set_use_resource(saved_use_resource)
+ vscope.set_custom_getter(saved_custom_getter)
# If the computation returns `None`, make it an empty tuple.
if outputs is None:
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index ad3dce1784..d4951b156c 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -63,7 +63,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
}
CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
- string key(std::move(parsed.FullKey().ToString()));
+ string key(parsed.FullKey());
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
Device* dst_dev;
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index c06fea130f..79ad3b8e54 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -702,6 +702,21 @@ cc_library(
)
cc_library(
+ name = "feature_util",
+ srcs = ["example/feature_util.cc"],
+ hdrs = [
+ "example/feature_util.h",
+ "platform/types.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":core_stringpiece",
+ ":platform_protobuf",
+ ":protos_all_cc",
+ ],
+)
+
+cc_library(
name = "abi",
srcs = ["platform/abi.cc"],
hdrs = ["platform/abi.h"],
@@ -1339,6 +1354,7 @@ cc_library(
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
"//tensorflow/core/kernels:mkl_softmax_op",
+ "//tensorflow/core/kernels:mkl_transpose_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:mkl_aggregate_ops",
]) + if_cuda([
@@ -3712,6 +3728,7 @@ tf_cc_test_mkl(
":core_cpu_internal",
":framework",
":framework_internal",
+ ":lib",
":test",
":test_main",
":testlib",
diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
new file mode 100644
index 0000000000..27bc4013c3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ParallelInterleaveDatasetV2"
+ visibility: HIDDEN
+ attr {
+ name: "f"
+ description: <<END
+A function mapping elements of `input_dataset`, concatenated with
+`other_arguments`, to a Dataset variant that contains elements matching
+`output_types` and `output_shapes`.
+END
+ }
+ summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
index 8cef243aee..30fd97a0d7 100644
--- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
@@ -9,7 +9,7 @@ END
in_arg {
name: "pattern"
description: <<END
-A 1-D string tensor of the regular expression to match the input.
+A scalar string tensor containing the regular expression to match the input.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
index 35f55fe106..d33a36ce06 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
index 70a07d9b4c..afdc39da96 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
index b2e3eece38..026b5b3991 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
index 7bac02e23d..a168eed87f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
index a73306a892..876b860824 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
new file mode 100644
index 0000000000..6d9d9908ca
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "StaticRegexFullMatch"
+ in_arg {
+ name: "input"
+ description: <<END
+A string tensor of the text to be processed.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A bool tensor with the same shape as `input`.
+END
+ }
+ attr {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ summary: "Check if the input matches the regex pattern."
+ description: <<END
+The input is a string tensor of any shape. The pattern is the
+regular expression to be matched with every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
index 907c6d2022..7a60e4387a 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
@@ -3,15 +3,14 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
-END
+A tensor whose shape is a prefix of `data.shape`.END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the maximum along segments of a tensor."
@@ -24,13 +23,16 @@ This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the maximum such that:
-\\(output_i = \max_j data_j\\) where max is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+that `segment_ids[j...] == i`.
If the maximum is empty for a given segment ID `i`, it outputs the smallest
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::lowest()`.
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
+
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
</div>
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
index 37dd973b23..7e139ddf4d 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
@@ -3,15 +3,15 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the minimum along segments of a tensor."
@@ -24,11 +24,14 @@ This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the minimum such that:
-\\(output_i = \min_j data_j\\) where min is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+that `segment_ids[j...] == i`.
If the minimum is empty for a given segment ID `i`, it outputs the largest
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::max()`.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
index efbc023705..9c8ea3b620 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
@@ -3,15 +3,15 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the product along segments of a tensor."
@@ -25,9 +25,12 @@ This operator is similar to the unsorted segment sum operator found
Instead of computing the sum over segments, it computes the product of all
entries belonging to a segment such that:
-\\(output_i = \prod_j data_j\\) where the product is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
+`j...` such that `segment_ids[j...] == i`.
If there is no entry for a given segment ID `i`, it outputs 1.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
index a8874950eb..7e5d9265c2 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
@@ -21,7 +21,7 @@ Read
for an explanation of segments.
Computes a tensor such that
-\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
need not be sorted and need not cover all values in the full
range of valid values.
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 46bb8d92f8..1c9b69721d 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -615,11 +615,14 @@ void PruneFunctionBody(Graph* g) {
std::unordered_set<const Node*> nodes;
for (auto n : g->nodes()) {
// NOTE(mrry): "_Retval" nodes are stateful, and so will be added
- // to the seed set of `nodes`.
+ // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
+ // specifically exclude them as seeds, to avoid unconditionally executing
+ // unused argument nodes (e.g. in a function like `lambda x, y: y`).
// TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
// still needed. It would be preferable to prune entire loops and/or
// conditionals if they are not used in the graph.
- if (n->IsControlFlow() || n->op_def().is_stateful()) {
+ if (n->IsControlFlow() ||
+ (n->op_def().is_stateful() && n->type_string() != kArgOp)) {
nodes.insert(n);
}
}
@@ -925,29 +928,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
}
DCHECK(run_opts.runner != nullptr);
- Executor::Args* exec_args = new Executor::Args;
+ Executor::Args exec_args;
// Inherit the step_id from the caller.
- exec_args->step_id = run_opts.step_id;
- exec_args->rendezvous = run_opts.rendezvous;
- exec_args->stats_collector = run_opts.stats_collector;
- exec_args->cancellation_manager = run_opts.cancellation_manager;
- exec_args->collective_executor = run_opts.collective_executor;
- exec_args->step_container = run_opts.step_container;
- exec_args->runner = *run_opts.runner;
- exec_args->call_frame = frame;
-
- item->exec->RunAsync(
- // Executor args
- *exec_args,
- // Done callback.
- std::bind(
- [item, frame, exec_args](DoneCallback done,
- // Start unbound arguments.
- const Status& status) {
- delete exec_args;
- done(status);
- },
- std::move(done), std::placeholders::_1));
+ exec_args.step_id = run_opts.step_id;
+ exec_args.rendezvous = run_opts.rendezvous;
+ exec_args.stats_collector = run_opts.stats_collector;
+ exec_args.cancellation_manager = run_opts.cancellation_manager;
+ exec_args.collective_executor = run_opts.collective_executor;
+ exec_args.step_container = run_opts.step_container;
+ exec_args.runner = *run_opts.runner;
+ exec_args.call_frame = frame;
+
+ item->exec->RunAsync(exec_args, std::move(done));
}
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 120f480198..7bab9be9a6 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
// Name
"SquareAndAddOneWithStatefulNodes",
// Args
- {"x: int32"},
+ {"x: int32", "y: float32"},
// Return values
- {"y: int32"},
+ {"z: int32"},
// Attrs
{},
// Nodes
@@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
"RandomUniform",
{"shape"},
{{"T", T}, {"dtype", DT_FLOAT}}},
- // y = Add<T>(a, o)
- {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
+ // z = Add<T>(a, o)
+ {{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
Init({stateful_func});
auto x = test::AsTensor<int32>({1, 2, 3, 4});
- Tensor y;
+ auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
+ Tensor z;
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(
@@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
StepStatsCollector stats_collector(&stats);
FunctionLibraryRuntime::Options opts;
opts.stats_collector = &stats_collector;
- TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
TF_CHECK_OK(flr0_->ReleaseHandle(handle));
TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
- {x}, {&y}));
- test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17}));
+ {x, y}, {&z}));
+ test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
stats_collector.FinalizeAndSwap(&stats);
- // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute.
+ // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
+ // execute.
std::set<string> expected_node_names(
- {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"});
+ {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
std::set<string> executed_node_names;
for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
executed_node_names.insert(node_stats.node_name());
diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h
index 39215efa35..e1b163074f 100644
--- a/tensorflow/core/common_runtime/tracing_device.h
+++ b/tensorflow/core/common_runtime/tracing_device.h
@@ -35,8 +35,11 @@ class TracingDevice : public Device {
: Device(env, attributes) {}
void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
+ const tracing::TraceCollector* trace_collector =
+ tracing::GetTraceCollector();
if (TF_PREDICT_FALSE(
- tracing::GetTraceCollector() ||
+ (trace_collector &&
+ trace_collector->IsEnabled(op_kernel->IsExpensive())) ||
tracing::GetEventCollector(tracing::EventCategory::kCompute))) {
const string& op_name = op_kernel->name();
tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 38863db1cc..6994dec3b5 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -693,6 +693,7 @@ uint64 DebugFileIO::diskBytesUsed = 0;
mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED);
bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
+ mutex_lock l(bytes_mu);
if (globalDiskBytesLimit == 0) {
const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
if (env_tfdbg_disk_bytes_limit == nullptr ||
@@ -707,7 +708,6 @@ bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
if (bytes == 0) {
return true;
}
- mutex_lock l(bytes_mu);
if (diskBytesUsed + bytes < globalDiskBytesLimit) {
diskBytesUsed += bytes;
return true;
diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
index 21c21723d0..74bd39cb61 100644
--- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h
+++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
+#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -24,27 +25,26 @@ namespace data {
// See below macro for usage details.
class WhitelistedStatefulOpRegistry {
public:
- Status Add(StringPiece op_name) {
- op_names_.insert(op_name);
+ Status Add(string op_name) {
+ op_names_.insert(std::move(op_name));
return Status::OK();
}
- bool Contains(StringPiece op_name) {
- return op_names_.find(op_name) != op_names_.end();
- }
+ bool Contains(const string& op_name) { return op_names_.count(op_name); }
static WhitelistedStatefulOpRegistry* Global() {
- static WhitelistedStatefulOpRegistry* reg =
- new WhitelistedStatefulOpRegistry;
+ static auto* reg = new WhitelistedStatefulOpRegistry;
return reg;
}
private:
- WhitelistedStatefulOpRegistry() {}
- WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy);
+ WhitelistedStatefulOpRegistry() = default;
+ WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) =
+ delete;
WhitelistedStatefulOpRegistry operator=(
- WhitelistedStatefulOpRegistry const& copy);
- std::set<StringPiece> op_names_;
+ WhitelistedStatefulOpRegistry const& copy) = delete;
+
+ std::unordered_set<string> op_names_;
};
} // namespace data
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 6710ff9df3..d24e7e8ee4 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -429,18 +429,22 @@ class SymbolicShapeRefiner {
// perform shape inference on the function body.
//
// Propagate shape information of final function body node
- // to function node `node`.
+ // to function node `function_node`.
//
- // In the event of an error, UpdateNode will simply set `node`'s
+ // In the event of an error, UpdateNode will simply set `function_node`'s
// output shape to be Unknown.
- Status UpdateFunction(const NodeDef* node) {
- auto it = fun_to_grappler_function_item_.find(node->op());
+ Status UpdateFunction(const NodeDef* function_node) {
+ auto it = fun_to_grappler_function_item_.find(function_node->op());
if (it == fun_to_grappler_function_item_.end()) {
return errors::InvalidArgument(
- node->op(), " was not previously added to SymbolicShapeRefiner.");
+ function_node->op(),
+ " was not previously added to SymbolicShapeRefiner.");
}
- GrapplerFunctionItem& grappler_function_item = it->second;
+ // Copy (not reference) so that changes we make here (e.g., replacing
+ // Placeholder with Const) don't affect one in
+ // fun_to_grappler_function_item_.
+ GrapplerFunctionItem grappler_function_item = it->second;
GraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes.
@@ -453,7 +457,7 @@ class SymbolicShapeRefiner {
"supported.");
}
NodeDef* fun_node = gv.GetNode(fun_input.input_name);
- const string& input = node->input(i);
+ const string& input = function_node->input(i);
const string& node_name = NodeName(input);
if (IsControlInput(input)) {
@@ -478,16 +482,35 @@ class SymbolicShapeRefiner {
TensorShapeProto proto;
const auto& handle = input_inference_context->output(output_port_num);
input_inference_context->ShapeHandleToProto(handle, &proto);
+ // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
+ for (int i = 0; i < proto.dim_size(); i++) {
+ if (proto.dim(i).size() < -1) {
+ proto.mutable_dim(i)->set_size(-1);
+ }
+ }
*attr_output_shape.mutable_shape() = proto;
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
}
+ // Replace input Placeholders with Consts, if values are known. Note that
+ // we don't check exceptions here as it's done in the above loop.
+ for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
+ const string& input = function_node->input(i);
+ const string& node_name = NodeName(input);
+ NodeDef* input_node = graph_.GetNode(node_name);
+ // TODO(dyoon): also use Const when output_tensors_as_shape is available.
+ if (IsConstant(*input_node)) {
+ TF_CHECK_OK(
+ ReplaceInputWithConst(*input_node, i, &grappler_function_item));
+ }
+ }
+
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true));
// Add return nodes for output shapes.
- auto ic = GetContext(node);
+ auto ic = GetContext(function_node);
int output = 0;
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_tensors.size() > 1) {
@@ -505,8 +528,9 @@ class SymbolicShapeRefiner {
const NodeDef* retnode = gv.GetNode(node_name);
if (retnode == nullptr) {
- return errors::FailedPrecondition("Unable to find return node ",
- node_name, " for ", node->name());
+ return errors::FailedPrecondition(
+ "Unable to find return function_node ", node_name, " for ",
+ function_node->name());
}
auto output_properties = gp.GetOutputProperties(retnode->name());
@@ -671,11 +695,13 @@ class SymbolicShapeRefiner {
// true, as the updates to the call node will have changed, even if it's
// the same function being called twice with the same input shapes.
// Example: simple_function.pbtxt
- if (UpdateFunction(node).ok()) {
+ auto s = UpdateFunction(node);
+ if (s.ok()) {
return Status::OK();
} else {
VLOG(1) << "UpdateFunction failed for " << node->op()
- << ". Defaulting to ShapeUnknown.";
+ << ". Defaulting to ShapeUnknown.\n"
+ << s.ToString();
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 8938b7c32e..3ec68a4e59 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -785,7 +785,58 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
-TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) {
+TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
+ FunctionDefLibrary library;
+ // This function is simply
+ // out = Fill(shape, value), but
+ // Fill requires values in the shape input, not just shape of it, to infer
+ // output shape; hence, func
+ *library.add_function() = FunctionDefHelper::Create(
+ // Name
+ "MyFillFunc",
+ // Inputs
+ {"shape: int32", "value: float"},
+ // Outputs
+ {"out: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"a"},
+ "Fill",
+ {"shape", "value"},
+ {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
+ },
+ // Returns
+ {{"out", "a:output:0"}});
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+ Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
+ Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+ auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+ s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto _value = tensorflow::ops::AsNodeOut(s, value);
+ TF_CHECK_OK(
+ builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFillFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
+ EXPECT_FALSE(out_prop0.shape().unknown_rank());
+ EXPECT_EQ(4, out_prop0.shape().dim_size());
+ EXPECT_EQ(1, out_prop0.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop0.shape().dim(1).size());
+ EXPECT_EQ(3, out_prop0.shape().dim(2).size());
+ EXPECT_EQ(4, out_prop0.shape().dim(3).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
// Create graph with a function that takes a scalar value so that we use
// Placeholder with scalar as for input to the function shape inference.
// Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
@@ -818,7 +869,7 @@ TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) {
// MyFunc output shouldn't be unknown rank.
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically(false));
+ TF_CHECK_OK(properties.InferStatically(true));
const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 65947ddce5..11ce121cba 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1121,11 +1121,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
- // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
- if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
- }
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 5fd34efeb1..a5fd33d28b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -156,7 +156,7 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
- return Status::OK();
+ return InitializeCustomGraphOptimizers(optimizers);
}
Status MetaOptimizer::InitializeOptimizersByName(
@@ -180,6 +180,11 @@ Status MetaOptimizer::InitializeOptimizersByName(
VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
+ return InitializeCustomGraphOptimizers(optimizers);
+}
+
+Status MetaOptimizer::InitializeCustomGraphOptimizers(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
optimizer_config.name());
@@ -208,7 +213,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
}
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
- if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) {
+ if (cfg_.optimizers().empty()) {
TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
} else {
TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 151a54cbdf..831c5e37c0 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -52,6 +52,9 @@ class MetaOptimizer : public GraphOptimizer {
// Initialize active optimizers from RewriterConfig optimizer names.
Status InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Initialize active optimizers from RewriterConfig.custom_optimizers.
+ Status InitializeCustomGraphOptimizers(
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
// multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 9a03c7dfef..e74e0f7501 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -64,6 +64,13 @@ bool TestOptimizer::optimized_;
REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
+class TestGraphOptimizer : public TestOptimizer {
+ public:
+ string name() const override { return "test_graph_optimizer"; }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -83,6 +90,27 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ TestGraphOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizer");
+ auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+ customGraphOptimizer->set_name("TestGraphOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+ EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
@@ -98,6 +126,24 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TF_EXPECT_OK(status);
}
+TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+ customGraphOptimizer->set_name("TestGraphOptimizer");
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
using test::function::NDef;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 25063ac823..972fb9efa9 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -643,14 +643,7 @@ cc_library(
":split_v_op",
":strided_slice_op",
":tile_ops",
- ] + if_mkl(
- [
- ":mkl_transpose_op",
- ],
- [
- ":transpose_op",
- ],
- ) + [
+ ":transpose_op",
":unique_op",
":unpack_op",
":unravel_index_op",
@@ -893,24 +886,13 @@ tf_kernel_library(
deps = ARRAY_DEPS,
)
-if_mkl(
- [tf_mkl_kernel_library(
- name = "mkl_transpose_op",
- srcs = [
- "mkl_transpose_op.cc",
- "transpose_op.cc",
- ],
- hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS + mkl_deps(),
- )],
- [tf_kernel_library(
- name = "transpose_op",
- srcs = [
- "transpose_op.cc",
- ],
- hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS,
- )],
+tf_kernel_library(
+ name = "transpose_op",
+ srcs = [
+ "transpose_op.cc",
+ ],
+ hdrs = ["transpose_op.h"],
+ deps = ARRAY_DEPS + if_mkl([":mkl_transpose_op"]),
)
tf_kernel_library(
@@ -6351,6 +6333,15 @@ tf_mkl_kernel_library(
deps = NN_DEPS + mkl_deps() + [":cwise_op"],
)
+tf_mkl_kernel_library(
+ name = "mkl_transpose_op",
+ srcs = [
+ "mkl_transpose_op.cc",
+ ],
+ hdrs = ["transpose_op.h"],
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
# NOTE(lespeholt): This rule is deprecated, please use:
# tensorflow/core/util/batch_util.h
cc_library(
diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h
index a7836896c7..390db8fe5a 100644
--- a/tensorflow/core/kernels/conditional_accumulator.h
+++ b/tensorflow/core/kernels/conditional_accumulator.h
@@ -51,9 +51,11 @@ class ConditionalAccumulator
// dtype: The datatype of the gradients to be accumulated.
// shape: The shape of the accumulated gradients.
// name: A name to use for the ConditionalAccumulator.
+ // reduction_type: The reduction type, i.e., MEAN or SUM
ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
- const string& name)
- : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
+ const string& name, const string& reduction_type)
+ : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
+ reduction_type) {}
~ConditionalAccumulator() override{};
protected:
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc
index 90593c56b8..292cf0cd64 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_base.cc
@@ -14,12 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/conditional_accumulator_base.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
ConditionalAccumulatorBase::ConditionalAccumulatorBase(
- const DataType& dtype, const PartialTensorShape& shape, const string& name)
- : dtype_(dtype), shape_(shape), name_(name) {
+ const DataType& dtype, const PartialTensorShape& shape, const string& name,
+ const string& reduction_type)
+ : dtype_(dtype),
+ shape_(shape),
+ name_(name),
+ reduction_type_(reduction_type) {
counter_ = 0;
current_global_step_ = 0;
}
@@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
current_global_step_++;
// Average the accumulated gradient
- DivideAccumGradByCounter(ctx);
+ if (reduction_type_ == "MEAN") {
+ DivideAccumGradByCounter(ctx);
+ }
// Set output for accumulated gradient tensor
bool successful_set_output = SetOutput(ctx);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h
index b7b7482a00..4a5ec6f0fb 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base.h
@@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
// name: A name to use for the ConditionalAccumulator.
ConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name);
+ const string& name, const string& reduction_type);
typedef AsyncOpKernel::DoneCallback DoneCallback;
@@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
const DataType dtype_;
const PartialTensorShape shape_;
const string name_;
+ const string reduction_type_;
mutex mu_;
int counter_ GUARDED_BY(mu_);
int64 current_global_step_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h
index 012a0dcc12..ca24d690f8 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base_op.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h
@@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
&accumulator_handle_, nullptr));
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("reduction_type", &reduction_type_));
}
void Compute(OpKernelContext* ctx) override {
@@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
DataType dtype_;
PartialTensorShape shape_;
ContainerInfo cinfo_;
+ string reduction_type_;
private:
Status SetAccumulatorHandle(OpKernelContext* ctx)
diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc
index e13bf8a4c6..52ac51a9b6 100644
--- a/tensorflow/core/kernels/conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_op.cc
@@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
ConditionalAccumulator<Device, T>* accumulator =
- new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name());
+ new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
+ reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 306486b96a..af301e2b42 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -28,9 +28,7 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit MapDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
@@ -186,7 +184,6 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList func_;
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 3c562fc7f3..b87d61ee44 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/batch_util.h"
#include "tensorflow/core/util/reffed_status_callback.h"
@@ -60,26 +62,43 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
+ Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
+ // Validates inputs and gets the size of their leading dimension.
+ *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != *batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+ return Status::OK();
+ }
+
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size = ctx->input(0).dim_size(0);
+ int64 batch_size;
+ OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done);
+
// Inputs
auto* args = new std::vector<Tensor>;
auto* arg_shapes = new std::vector<TensorShape>;
+
+ // Create a copy because every `Compute` may have different output shapes.
+ auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_);
arg_shapes->reserve(ctx->num_inputs());
args->reserve(ctx->num_inputs());
+ auto* mu = new mutex;
+
for (size_t i = 0; i < ctx->num_inputs(); ++i) {
args->push_back(ctx->input(i));
arg_shapes->push_back(ctx->input(i).shape());
arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- OP_REQUIRES_ASYNC(
- ctx, batch_size == ctx->input(i).dim_size(0),
- errors::InvalidArgument(
- "All inputs must have the same dimension 0. Input ", i,
- " has leading dimension ", ctx->input(i).dim_size(0),
- ", while all previous inputs have leading dimension ", batch_size,
- "."),
- done);
}
// Outputs
@@ -87,10 +106,14 @@ class MapDefunOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
for (size_t i = 0; i < output_types().size(); ++i) {
- Tensor* out = nullptr;
- TensorShape output_shape = output_shapes_.at(i);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done);
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, batch_size);
+ OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out),
+ done);
+ }
}
SetRunOptions(ctx, &opts_, false);
@@ -98,15 +121,19 @@ class MapDefunOp : public AsyncOpKernel {
// Run loop
StatusCallback callback = std::bind(
[](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes, OpOutputList* output,
- DoneCallback& done, const Status& status) {
+ std::vector<TensorShape>* arg_shapes,
+ std::vector<PartialTensorShape>* output_shapes, OpOutputList* output,
+ mutex* mu, DoneCallback& done, const Status& status) {
delete args;
delete arg_shapes;
delete output;
+ delete output_shapes;
+ delete mu;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1);
+ ctx, args, arg_shapes, output_shapes, output, mu, std::move(done),
+ std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
@@ -114,9 +141,11 @@ class MapDefunOp : public AsyncOpKernel {
// Start from i = 1 because refcounted is initialized with refcount = 1
refcounted->Ref();
}
+
for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame =
- new MapFunctionCallFrame(*args, *arg_shapes, output, this, i);
+ auto* call_frame = new MapFunctionCallFrame(
+ *args, *arg_shapes, output_shapes, mu, output, this, i,
+ static_cast<size_t>(batch_size));
CancellationManager* c_mgr = new CancellationManager;
opts_.cancellation_manager = c_mgr;
ctx->function_library()->Run(
@@ -133,18 +162,23 @@ class MapDefunOp : public AsyncOpKernel {
private:
FunctionLibraryRuntime::Handle func_handle_;
FunctionLibraryRuntime::Options opts_;
- std::vector<TensorShape> output_shapes_;
+ std::vector<PartialTensorShape> output_shapes_;
class MapFunctionCallFrame : public CallFrameInterface {
public:
MapFunctionCallFrame(const std::vector<Tensor>& args,
const std::vector<TensorShape>& arg_shapes,
- OpOutputList* output, OpKernel* kernel, size_t iter)
+ std::vector<PartialTensorShape>* output_shapes,
+ mutex* output_shapes_mutex, OpOutputList* output,
+ OpKernel* kernel, size_t iter, size_t batch_size)
: args_(args),
arg_shapes_(arg_shapes),
+ output_shapes_(output_shapes),
+ output_shapes_mutex_(output_shapes_mutex),
output_(output),
kernel_(kernel),
- iter_(iter) {}
+ iter_(iter),
+ batch_size_(batch_size) {}
~MapFunctionCallFrame() override {}
@@ -182,15 +216,37 @@ class MapDefunOp : public AsyncOpKernel {
"output: ",
index);
}
+ { // Locking scope
+ mutex_lock l(*output_shapes_mutex_);
+ if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) {
+ return errors::InvalidArgument(
+ "Mismatch in function retval shape, ", val.shape(),
+ ", and expected output shape,",
+ output_shapes_->at(index).DebugString(), ".");
+ }
+ if (!output_shapes_->at(index).IsFullyDefined()) {
+ // Given val, we have new information about the output shape at
+ // this index. Store the shape and allocate the output accordingly.
+ output_shapes_->at(index) = val.shape();
+
+ Tensor* out = nullptr;
+ TensorShape actual_shape = val.shape();
+ actual_shape.InsertDim(0, batch_size_);
+ TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out));
+ }
+ }
return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
}
private:
const std::vector<Tensor>& args_;
const std::vector<TensorShape>& arg_shapes_;
+ std::vector<PartialTensorShape>* output_shapes_;
+ mutex* output_shapes_mutex_;
OpOutputList* output_;
const OpKernel* kernel_;
const size_t iter_;
+ const size_t batch_size_;
};
};
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index f8287cf0e3..640f1565b7 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <deque>
+#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -21,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
@@ -34,8 +36,7 @@ namespace {
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
@@ -125,6 +126,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
+
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
@@ -1058,7 +1060,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList interleave_func_;
@@ -1067,6 +1068,593 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
+// The motivation for creating an alternative implementation of parallel
+// interleave is to decouple the degree of parallelism from the cycle length.
+// This makes it possible to change the degree of parallelism (e.g. through
+// auto-tuning) without changing the cycle length (which would change the order
+// in which elements are produced).
+//
+// Furthermore, this class favors modularity over extended functionality. In
+// particular, it refrains from implementing configurable buffering of output
+// elements and prefetching of input iterators, relying on other parts of
+// tf.data to provide this functionality if necessary.
+//
+// The above design choices were made with automated optimizations in mind,
+// isolating the degree of parallelism as the single tunable knob of this
+// implementation.
+class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
+ public:
+ explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ OpInputList inputs;
+ OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
+
+ int64 cycle_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "cycle_length", &cycle_length));
+ OP_REQUIRES(ctx, cycle_length > 0,
+ errors::InvalidArgument("`cycle_length` must be > 0"));
+
+ int64 block_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "block_length", &block_length));
+ OP_REQUIRES(ctx, block_length > 0,
+ errors::InvalidArgument("`block_length` must be > 0"));
+
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+ OP_REQUIRES(
+ ctx, num_parallel_calls <= cycle_length,
+ errors::InvalidArgument(
+ "num_parallel_calls must less than or equal to cycle_length."));
+
+ // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`.
+ std::vector<Tensor> other_arguments;
+ other_arguments.reserve(inputs.size());
+ for (const Tensor& t : inputs) {
+ other_arguments.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(
+ interleave_func_, std::move(other_arguments), &captured_func));
+
+ *output = new Dataset(ctx, input, interleave_func_,
+ std::move(captured_func), cycle_length, block_length,
+ num_parallel_calls, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
+ int64 block_length, int64 num_parallel_calls,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ interleave_func_(func),
+ captured_func_(std::move(captured_func)),
+ cycle_length_(cycle_length),
+ block_length_(block_length),
+ num_parallel_calls_(num_parallel_calls),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::ParallelInterleaveV2")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "ParallelInterleaveDatasetV2Op::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+ Node* cycle_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
+ Node* block_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
+ Node* num_parallel_calls_node;
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
+ for (const Tensor& t : captured_func_->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments.emplace_back(node);
+ other_arguments_types.emplace_back(t.dtype());
+ }
+ AttrValue f;
+ b->BuildAttrValue(interleave_func_, &f);
+ AttrValue other_arguments_types_attr;
+ b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {{0, input_node},
+ {2, cycle_length_node},
+ {3, block_length_node},
+ {4, num_parallel_calls_node}},
+ {{1, other_arguments}},
+ {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ args_list_(params.dataset->cycle_length_),
+ current_elements_(params.dataset->cycle_length_),
+ element_in_use_(params.dataset->cycle_length_, false),
+ thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "parallel_interleave",
+ dataset()->cycle_length_ /* num_threads */,
+ false /* low_latency_hint */)) {}
+
+ ~Iterator() override {
+ mutex_lock l(mu_);
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_.notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ std::shared_ptr<InvocationResult> result;
+ do {
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (invocation_results_.empty() &&
+ (!end_of_input_ || num_open_ > 0)) {
+ cond_var_.wait(l);
+ }
+ if (!invocation_results_.empty()) {
+ std::swap(result, invocation_results_.front());
+ invocation_results_.pop_front();
+ } else {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+ cond_var_.notify_all();
+ result->notification.WaitForNotification();
+ } while (result->skip);
+
+ if (result->status.ok()) {
+ *out_tensors = std::move(result->return_values);
+ }
+ *end_of_sequence = false;
+ return result->status;
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ CHECK_EQ(num_calls_, 0);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name("invocation_results.size"), invocation_results_.size()));
+ for (size_t i = 0; i < invocation_results_.size(); i++) {
+ std::shared_ptr<InvocationResult> result = invocation_results_[i];
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ result->return_values.size()));
+ for (size_t j = 0; j < result->return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result->return_values[j]));
+ }
+ if (result->skip) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")),
+ ""));
+ }
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cycle_index"), cycle_index_));
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("end_of_input"), ""));
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("num_open"), num_open_));
+ TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ int64 invocation_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name("invocation_results.size"), &invocation_results_size));
+ for (size_t i = 0; i < invocation_results_size; i++) {
+ std::shared_ptr<InvocationResult> result(new InvocationResult());
+ invocation_results_.push_back(result);
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ size_t num_return_values;
+ {
+ int64 size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ &size));
+ num_return_values = static_cast<size_t>(size);
+ if (num_return_values != size) {
+ return errors::InvalidArgument(strings::StrCat(
+ full_name(
+ strings::StrCat("invocation_results[", i, "].size")),
+ ": ", size, " is not a valid value of type size_t."));
+ }
+ }
+ result->return_values.reserve(num_return_values);
+ for (size_t j = 0; j < num_return_values; j++) {
+ result->return_values.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ &result->return_values.back()));
+ }
+ result->skip = reader->Contains(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")));
+ result->notification.Notify();
+ }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
+ if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("num_open"), &num_open_));
+ TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
+ return Status::OK();
+ }
+
+ private:
+ struct InvocationResult {
+ Notification notification; // used for coordination with the consumer
+ Status status; // the invocation status
+ std::vector<Tensor> return_values; // the invocation result values
+ bool skip; // if set the result should be skipped
+ };
+
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ [this, new_ctx]() { RunnerThread(new_ctx); }));
+ }
+ }
+
+ // Fetches up to `results.size()` outputs from the cycle element at
+ // position `cycle_index`.
+ //
+ // If end of input is encountered, the `skip` field of the invocation
+ // result is used to identify results that should be skipped.
+ void FetchOutputs(
+ const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
+ const std::vector<std::shared_ptr<InvocationResult>>& results)
+ LOCKS_EXCLUDED(mu_) {
+ bool end_of_input = false;
+ for (auto& result : results) {
+ if (!end_of_input) {
+ result->status = current_elements_[cycle_index]->GetNext(
+ ctx.get(), &result->return_values, &end_of_input);
+ }
+ if (end_of_input) {
+ result->skip = true;
+ }
+ result->notification.Notify();
+ if (!result->status.ok()) {
+ break;
+ }
+ }
+
+ // Release the ownership of the cycle element iterator, closing the
+ // iterator if end of input was encountered.
+ {
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
+ }
+ }
+ cond_var_.notify_all();
+ }
+
+ int64 MaxInvocationResults() {
+ return dataset()->cycle_length_ * dataset()->block_length_;
+ }
+
+ // Method responsible for 1) creating iterators out of input elements, 2)
+ // determining the order in which elements are fetched from the iterators,
+ // and 3) scheduling the fetching of the elements to a threadpool.
+ //
+ // This method runs in the `runner_thread` background thread.
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
+ (element_in_use_[cycle_index_] ||
+ num_calls_ >= dataset()->num_parallel_calls_ ||
+ invocation_results_.size() >= MaxInvocationResults())) {
+ cond_var_.wait(l);
+ }
+
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
+
+ while (!element_in_use_[cycle_index_] &&
+ (!end_of_input_ || num_open_ > 0) &&
+ num_calls_ < dataset()->num_parallel_calls_ &&
+ invocation_results_.size() < MaxInvocationResults()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ ++num_open_;
+ }
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
+ }
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
+ }
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ }
+ }
+ cond_var_.notify_all();
+ }
+ }
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].error_message"));
+ }
+
+ Status WriteCurrentElements(IteratorStateWriter* writer)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (current_elements_[idx]) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ args_list_[idx].size()));
+ for (int i = 0; i < args_list_[idx].size(); i++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ args_list_[idx][i]));
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status ReadCurrentElements(IteratorContext* ctx,
+ IteratorStateReader* reader)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (reader->Contains(
+ full_name(strings::StrCat("args_size[", idx, "]")))) {
+ int64 args_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ &args_size));
+ args_list_[idx].resize(args_size);
+ for (int i = 0; i < args_size; i++) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ &args_list_[idx][i]));
+ }
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
+ ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
+ prefix(), &current_elements_[idx]));
+ TF_RETURN_IF_ERROR(
+ RestoreInput(ctx, reader, current_elements_[idx]));
+ } else {
+ current_elements_[idx].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads.
+ mutex mu_;
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads. In particular, the runner thread should only
+ // schedule new calls when the number of in-flight calls is less than the
+ // user specified level of parallelism, there are slots available in the
+ // `invocation_results_` buffer, the current cycle element is not in use,
+ // and there are elements left to be fetched.
+ condition_variable cond_var_;
+
+ // Iterator for input elements.
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+
+ // Identifies current cycle element.
+ int64 cycle_index_ = 0;
+
+ // Arguments for creating an iterator for cycle elements.
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+
+ // Iterators for the current cycle elements. Concurrent access is
+ // protected by `element_in_use_`.
+ std::vector<std::unique_ptr<IteratorBase>> current_elements_;
+
+ // Identifies cycle elements that are in use by worker threads.
+ std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+
+ // Buffer for storing the invocation results.
+ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+ GUARDED_BY(mu_);
+
+ // Identifies whether end of input has been reached.
+ bool end_of_input_ GUARDED_BY(mu_) = false;
+
+ // Identifies the number of open iterators.
+ int64 num_open_ GUARDED_BY(mu_) = 0;
+
+ // Identifies the number of outstanding calls.
+ int64 num_calls_ GUARDED_BY(mu_) = 0;
+
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+
+ // Identifies whether background activity should be cancelled.
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ };
+
+ const DatasetBase* const input_;
+ const NameAttrList interleave_func_;
+ const std::unique_ptr<CapturedFunction> captured_func_;
+ const int64 cycle_length_;
+ const int64 block_length_;
+ const int64 num_parallel_calls_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList interleave_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
+ ParallelInterleaveDatasetV2Op);
+
} // namespace
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index ac5ed286ee..a0cb179eb8 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -33,11 +33,12 @@ namespace {
class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelMapDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
}
protected:
@@ -60,10 +61,12 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ func_, std::move(other_arguments),
+ use_inter_op_parallelism_, &captured_func));
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
- output_shapes_, std::move(captured_func));
+ output_shapes_, use_inter_op_parallelism_,
+ std::move(captured_func));
}
private:
@@ -73,6 +76,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func, int32 num_parallel_calls,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction> captured_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
@@ -80,6 +84,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
num_parallel_calls_(num_parallel_calls),
output_types_(output_types),
output_shapes_(output_shapes),
+ use_inter_op_parallelism_(use_inter_op_parallelism),
captured_func_(std::move(captured_func)) {
input_->Ref();
}
@@ -92,12 +97,27 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- auto map_func = [this](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- };
+ ParallelMapIteratorFunction map_func;
+ if (use_inter_op_parallelism_) {
+ map_func = [this](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done));
+ };
+ } else {
+ map_func = [this](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())(std::bind(
+ [this, ctx, result](std::vector<Tensor>& input_element,
+ StatusCallback& done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done));
+ },
+ std::move(input_element), std::move(done)));
+ };
+ }
return NewParallelMapIterator(
{this, strings::StrCat(prefix, "::ParallelMap")}, input_,
@@ -167,12 +187,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const int32 num_parallel_calls_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
NameAttrList func_;
};
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index a7a2935195..baf448e572 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -209,6 +209,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
if (s.ok()) {
*out_tensors = std::move(buffer_.front().value);
}
+ auto_tuner_.RecordConsumption(buffer_.size());
buffer_.pop_front();
*end_of_sequence = false;
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index b01db91720..fb2a4cc8ef 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -247,8 +247,8 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- T* merged_base = &merged_flat(0, 0);
- const T* data_base = &data_flat(0, 0);
+ T* merged_base = merged_flat.data();
+ const T* data_base = data_flat.data();
for (int i = 0; i < indices_vec.size(); i++) {
int32 index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
index 3b34f650b6..ec949ddc84 100644
--- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -48,8 +48,10 @@ void SpatialConvolution(int iters, int num_threads,
benchmark.SpatialConvolution(input_dims, filter_dims);
- auto output_size = input_dims.TotalSize();
- auto flops = output_size * (input_depth * filter_height * filter_width);
+ auto num_computed_elements =
+ (input_dims.TotalSize() / input_depth) * filter_count;
+ auto flops =
+ num_computed_elements * (input_depth * filter_height * filter_width);
::tensorflow::testing::ItemsProcessed(flops * iters);
}
@@ -75,8 +77,9 @@ void SpatialConvolutionBackwardInput(int iters, int num_threads,
benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims);
- auto output_size = input_dims.TotalSize();
- auto flops = output_size * (input_depth * filter_height * filter_width);
+ auto num_computed_elements = input_dims.TotalSize();
+ auto flops =
+ num_computed_elements * (input_depth * filter_height * filter_width);
::tensorflow::testing::ItemsProcessed(flops * iters);
}
@@ -102,8 +105,9 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads,
benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims);
- auto filter_size = filter_dims.TotalSize();
- auto flops = filter_size * (input_batches * input_height * input_width);
+ auto num_computed_elements = filter_dims.TotalSize();
+ auto flops =
+ num_computed_elements * (input_batches * input_height * input_width);
::tensorflow::testing::ItemsProcessed(flops * iters);
}
@@ -266,8 +270,9 @@ void CuboidConvolution(int iters, int num_threads,
benchmark.CuboidConvolution(input_dims, filter_dims);
- auto output_size = input_dims.TotalSize();
- auto flops = output_size *
+ auto num_computed_elements =
+ (input_dims.TotalSize() / input_depth) * filter_count;
+ auto flops = num_computed_elements *
(input_depth * filter_height * filter_width * filter_planes);
::tensorflow::testing::ItemsProcessed(flops * iters);
}
@@ -295,8 +300,8 @@ void CuboidConvolutionBackwardInput(int iters, int num_threads,
benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims);
- auto output_size = input_dims.TotalSize();
- auto flops = output_size *
+ auto num_computed_elements = input_dims.TotalSize();
+ auto flops = num_computed_elements *
(input_depth * filter_height * filter_width * filter_planes);
::tensorflow::testing::ItemsProcessed(flops * iters);
}
@@ -324,9 +329,9 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads,
benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims);
- auto filter_size = filter_dims.TotalSize();
- auto flops =
- filter_size * (input_batches * input_height * input_width * input_planes);
+ auto num_computed_elements = filter_dims.TotalSize();
+ auto flops = num_computed_elements *
+ (input_batches * input_height * input_width * input_planes);
::tensorflow::testing::ItemsProcessed(flops * iters);
}
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 2e8d9c623c..a495758861 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
@@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
Tensor* keys;
@@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
private:
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
std::unordered_map<K, V> table_ GUARDED_BY(mu_);
};
@@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
@@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
int64 value_dim = value_shape_.dim_size(0);
@@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface {
private:
TensorShape value_shape_;
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
typedef gtl::InlinedVector<V, 4> ValueArray;
std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
@@ -335,7 +333,7 @@ class MutableDenseHashTable final : public LookupInterface {
}
size_t size() const override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return num_entries_;
}
@@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface {
auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
const auto default_flat = default_value.flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
const auto key_buckets_matrix =
key_buckets_.AccessTensor(ctx)->template matrix<K>();
const auto value_buckets_matrix =
@@ -451,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
@@ -493,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface {
TensorShape value_shape() const override { return value_shape_; }
int64 MemoryUsed() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
}
diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc
index bdc3b5778f..dd89597369 100644
--- a/tensorflow/core/kernels/map_stage_op.cc
+++ b/tensorflow/core/kernels/map_stage_op.cc
@@ -410,8 +410,9 @@ class StagingMap : public ResourceBase {
copy_or_move_tensors(&it->second, *key, *indices, tuple));
// Remove entry if all the values have been consumed
- if (!std::any_of(it->second.begin(), it->second.end(),
- std::mem_fn(&OptionalTensor::has_value))) {
+ if (!std::any_of(
+ it->second.begin(), it->second.end(),
+ [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
map_.erase(it);
}
@@ -444,8 +445,9 @@ class StagingMap : public ResourceBase {
*key = it->first;
// Remove entry if all the values have been consumed
- if (!std::any_of(it->second.begin(), it->second.end(),
- std::mem_fn(&OptionalTensor::has_value))) {
+ if (!std::any_of(
+ it->second.begin(), it->second.end(),
+ [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
map_.erase(it);
}
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 9b10c3f3d6..184e0cb003 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -1083,7 +1083,7 @@ class MklConvOp : public OpKernel {
#endif
// Register 2D operations
-#define REGISTER_MKL_CPU(T) \
+#define REGISTER_MKL_CPU_2D(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
@@ -1100,16 +1100,16 @@ class MklConvOp : public OpKernel {
.Label(mkl_op_registry::kMklOpLabel), \
MklDummyOp<CPUDevice, T>);
-TF_CALL_float(REGISTER_MKL_CPU);
+TF_CALL_float(REGISTER_MKL_CPU_2D);
// Register 3D operations
-#define REGISTER_MKL_CPU(T) \
+#define REGISTER_MKL_CPU_3D(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklConvOp<CPUDevice, T, false>);
-TF_CALL_float(REGISTER_MKL_CPU);
+TF_CALL_float(REGISTER_MKL_CPU_3D);
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index ec6d241e17..5398e6113f 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -34,11 +34,11 @@ using mkldnn::prop_kind;
template <typename T>
void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
- if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg &&
- fwdParams.alg_kind != pooling_avg_include_padding &&
- fwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(fwdParams.alg_kind == pooling_max ||
+ fwdParams.alg_kind == pooling_avg ||
+ fwdParams.alg_kind == pooling_avg_include_padding ||
+ fwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = fwdParams.alg_kind;
// create memory desc
@@ -102,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data);
}
context_.fwd_stream->submit(context_.fwd_primitives);
@@ -111,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
@@ -120,11 +120,11 @@ template class MklPoolingFwdPrimitive<float>;
template <typename T>
void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
- if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg &&
- bwdParams.alg_kind != pooling_avg_include_padding &&
- bwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(bwdParams.alg_kind == pooling_max ||
+ bwdParams.alg_kind == pooling_avg ||
+ bwdParams.alg_kind == pooling_avg_include_padding ||
+ bwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = bwdParams.alg_kind;
// check whether it is 2d or 3d
@@ -190,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
}
@@ -199,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
context_.diff_dst_mem->set_data_handle(DummyData);
context_.diff_src_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index f4cfc48af5..84385356e1 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -40,7 +40,6 @@ using mkldnn::memory;
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
#endif
-#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 04d8a1bdeb..cfab529662 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -88,6 +88,7 @@ class MklSoftmaxOp : public OpKernel {
break;
default:
OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1"));
+ return;
}
// Create softmax memory for src, dst: both are defined in mkl_util.h,
// they are wrapper
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 5d9257e20b..81ce6d6e95 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -75,28 +75,28 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
}
// Return intersection-over-union overlap between boxes i and j
-static inline float IOUGreaterThanThreshold(
- typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
- float iou_threshold) {
- const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
- const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
- const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
- const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3));
- const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2));
- const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3));
- const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2));
- const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3));
- const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
- const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
- if (area_i <= 0 || area_j <= 0) return 0.0;
- const float intersection_ymin = std::max<float>(ymin_i, ymin_j);
- const float intersection_xmin = std::max<float>(xmin_i, xmin_j);
- const float intersection_ymax = std::min<float>(ymax_i, ymax_j);
- const float intersection_xmax = std::min<float>(xmax_i, xmax_j);
- const float intersection_area =
- std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
- std::max<float>(intersection_xmax - intersection_xmin, 0.0);
- const float iou = intersection_area / (area_i + area_j - intersection_area);
+template <typename T>
+static inline bool IOUGreaterThanThreshold(
+ typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) {
+ const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2));
+ const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3));
+ const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2));
+ const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3));
+ const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2));
+ const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3));
+ const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2));
+ const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3));
+ const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
+ const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
+ if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0;
+ const T intersection_ymin = std::max<T>(ymin_i, ymin_j);
+ const T intersection_xmin = std::max<T>(xmin_i, xmin_j);
+ const T intersection_ymax = std::min<T>(ymax_i, ymax_j);
+ const T intersection_xmax = std::min<T>(xmax_i, xmax_j);
+ const T intersection_area =
+ std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) *
+ std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0));
+ const T iou = intersection_area / (area_i + area_j - intersection_area);
return iou > iou_threshold;
}
@@ -106,11 +106,13 @@ static inline bool OverlapsGreaterThanThreshold(
return overlaps(i, j) > overlap_threshold;
}
+template <typename T>
static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
const Tensor& boxes, float threshold) {
- typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
- return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1,
- std::placeholders::_2, threshold);
+ typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
+ return std::bind(&IOUGreaterThanThreshold<T>, boxes_data,
+ std::placeholders::_1, std::placeholders::_2,
+ static_cast<T>(threshold));
}
static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
@@ -121,6 +123,7 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
std::placeholders::_1, std::placeholders::_2, threshold);
}
+template <typename T>
void DoNonMaxSuppressionOp(
OpKernelContext* context, const Tensor& scores, int num_boxes,
const Tensor& max_output_size, const float score_threshold,
@@ -128,13 +131,13 @@ void DoNonMaxSuppressionOp(
bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
const int output_size = max_output_size.scalar<int>()();
- std::vector<float> scores_data(num_boxes);
- std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
+ std::vector<T> scores_data(num_boxes);
+ std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
// Data structure for selection candidate in NMS.
struct Candidate {
int box_index;
- float score;
+ T score;
};
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
@@ -143,13 +146,13 @@ void DoNonMaxSuppressionOp(
std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
candidate_priority_queue(cmp);
for (int i = 0; i < scores_data.size(); ++i) {
- if (scores_data[i] > score_threshold) {
+ if (static_cast<float>(scores_data[i]) > score_threshold) {
candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
}
}
std::vector<int> selected;
- std::vector<float> selected_scores;
+ std::vector<T> selected_scores;
Candidate next_candidate;
while (selected.size() < output_size && !candidate_priority_queue.empty()) {
@@ -176,7 +179,7 @@ void DoNonMaxSuppressionOp(
int num_valid_outputs = selected.size();
if (pad_to_max_output_size) {
selected.resize(output_size, 0);
- selected_scores.resize(output_size, 0);
+ selected_scores.resize(output_size, static_cast<T>(0));
}
if (ptr_num_valid_outputs) {
*ptr_num_valid_outputs = num_valid_outputs;
@@ -221,18 +224,19 @@ class NonMaxSuppressionOp : public OpKernel {
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_);
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_);
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
private:
float iou_threshold_;
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV2Op : public OpKernel {
public:
explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
@@ -264,11 +268,12 @@ class NonMaxSuppressionV2Op : public OpKernel {
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
@@ -325,7 +330,7 @@ class NonMaxSuppressionV3V4Base : public OpKernel {
float score_threshold_val_;
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
public:
explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
@@ -334,14 +339,14 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
protected:
void DoComputeAndPostProcess(OpKernelContext* context) override {
auto suppress_check_fn =
- CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
- DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
- score_threshold_val_, suppress_check_fn);
+ DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn);
}
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
public:
explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
@@ -353,12 +358,12 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
protected:
void DoComputeAndPostProcess(OpKernelContext* context) override {
auto suppress_check_fn =
- CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
int num_valid_outputs;
- DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
- score_threshold_val_, suppress_check_fn,
- pad_to_max_output_size_, &num_valid_outputs);
+ DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn,
+ pad_to_max_output_size_, &num_valid_outputs);
// Allocate scalar output tensor for number of indices computed.
Tensor* num_outputs_t = nullptr;
@@ -413,22 +418,37 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel {
auto suppress_check_fn =
CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
NonMaxSuppressionOp<CPUDevice>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
- NonMaxSuppressionV2Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
- NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
- NonMaxSuppressionV4Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 876a1704c7..7bb403290d 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h"
@@ -104,13 +105,6 @@ class PartitionedCallOp : public AsyncOpKernel {
for (auto d : lib->device_mgr()->ListDevices()) {
device_set.AddDevice(d);
}
- Placer placer(graph.get(), &device_set);
- OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
-
- std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
- OP_REQUIRES_OK_ASYNC(
- ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
- done);
// The FunctionLibraryRuntime's library cannot be mutated from within
// an OpKernel, so functions are instantiated in an overlay library.
@@ -124,6 +118,47 @@ class PartitionedCallOp : public AsyncOpKernel {
new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition());
overlay_libs_.emplace(lib, overlay_lib);
+ GraphOptimizationPassOptions optimization_options;
+ // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or
+ // make it possible to specify the relevant options via attributes.
+ SessionOptions session_options;
+ session_options.env = ctx->env();
+ optimization_options.session_options = &session_options;
+ optimization_options.graph = &graph;
+ optimization_options.flib_def = overlay_lib;
+ optimization_options.device_set = &device_set;
+ Placer placer(graph.get(), &device_set);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
+ done);
+ OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PLACEMENT, optimization_options),
+ done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_REWRITE_FOR_EXEC,
+ optimization_options),
+ done);
+
+ std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
+ done);
+ optimization_options.graph = nullptr;
+ optimization_options.device_set = nullptr;
+ optimization_options.partition_graphs = &subgraphs;
+ OP_REQUIRES_OK_ASYNC(ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PARTITIONING,
+ optimization_options),
+ done);
+
auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
for (const auto& pair : subgraphs) {
// TODO(akshayka): Fail gracefully if the set of devices corresponds
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
index 5863a2c8e4..7edaaad8f7 100644
--- a/tensorflow/core/kernels/regex_full_match_op.cc
+++ b/tensorflow/core/kernels/regex_full_match_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
RegexFullMatchOp);
+class StaticRegexFullMatchOp : public OpKernel {
+ public:
+ explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string pattern;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+ re_ = MakeUnique<RE2>(pattern);
+ OP_REQUIRES(ctx, re_->ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", re_->error()));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<bool>();
+ for (size_t i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
+ }
+ }
+
+ private:
+ std::unique_ptr<RE2> re_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
+ StaticRegexFullMatchOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h
index 11149c4d16..a4453bd7ab 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator.h
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h
@@ -50,10 +50,10 @@ class SparseConditionalAccumulator
public:
SparseConditionalAccumulator(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
+ const string& name, const string& reduction_type)
: TypedConditionalAccumulatorBase<
std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
- dtype, shape, name) {
+ dtype, shape, name, reduction_type) {
accum_idx_vec_ = nullptr;
count_element_ = nullptr;
accum_val_ = nullptr;
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
index 80bc1f1934..1e542a26a7 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
@@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
SparseConditionalAccumulator<Device, T>* accumulator =
- new SparseConditionalAccumulator<Device, T>(dtype_, shape_,
- cinfo_.name());
+ new SparseConditionalAccumulator<Device, T>(
+ dtype_, shape_, cinfo_.name(), reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
index 9dedb618f9..ca341e511e 100644
--- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
@@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
public:
TypedConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
- : ConditionalAccumulatorBase(dtype, shape, name) {}
+ const string& name,
+ const string& reduction_type)
+ : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
/**
* Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
new file mode 100644
index 0000000000..4c488066e4
--- /dev/null
+++ b/tensorflow/core/lib/core/stringpiece.cc
@@ -0,0 +1,54 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+#include <algorithm>
+#include <iostream>
+
+namespace tensorflow {
+
+std::ostream& operator<<(std::ostream& o, StringPiece piece) {
+ o.write(piece.data(), piece.size());
+ return o;
+}
+
+size_t StringPiece::find(char c, size_t pos) const {
+ if (pos >= size_) {
+ return npos;
+ }
+ const char* result =
+ reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos));
+ return result != nullptr ? result - data_ : npos;
+}
+
+// Search range is [0..pos] inclusive. If pos == npos, search everything.
+size_t StringPiece::rfind(char c, size_t pos) const {
+ if (size_ == 0) return npos;
+ for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) {
+ if (*p == c) {
+ return p - data_;
+ }
+ }
+ return npos;
+}
+
+StringPiece StringPiece::substr(size_t pos, size_t n) const {
+ if (pos > size_) pos = size_;
+ if (n > size_ - pos) n = size_ - pos;
+ return StringPiece(data_ + pos, n);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index e7b17c9b36..02dded42c1 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -31,13 +31,124 @@ limitations under the License.
#include <string.h>
#include <iosfwd>
#include <string>
-#include "absl/strings/string_view.h"
+#include <type_traits>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-// Deprecated: please use absl::string_view directly.
-using StringPiece = absl::string_view;
+class StringPiece {
+ public:
+ typedef size_t size_type;
+
+ // Create an empty slice.
+ StringPiece() : data_(nullptr), size_(0) {}
+
+ // Create a slice that refers to d[0,n-1].
+ StringPiece(const char* d, size_t n) : data_(d), size_(n) {}
+
+ // Create a slice that refers to the contents of "s"
+ StringPiece(const string& s) : data_(s.data()), size_(s.size()) {}
+
+ // Create a slice that refers to s[0,strlen(s)-1]
+ StringPiece(const char* s) : data_(s), size_(strlen(s)) {}
+
+ // Return a pointer to the beginning of the referenced data
+ const char* data() const { return data_; }
+
+ // Return the length (in bytes) of the referenced data
+ size_t size() const { return size_; }
+
+ // Return true iff the length of the referenced data is zero
+ bool empty() const { return size_ == 0; }
+
+ typedef const char* const_iterator;
+ typedef const char* iterator;
+ iterator begin() const { return data_; }
+ iterator end() const { return data_ + size_; }
+
+ static const size_t npos = size_type(-1);
+
+ // Return the ith byte in the referenced data.
+ // REQUIRES: n < size()
+ char operator[](size_t n) const {
+ assert(n < size());
+ return data_[n];
+ }
+
+ // Drop the first "n" bytes from this slice.
+ void remove_prefix(size_t n) {
+ assert(n <= size());
+ data_ += n;
+ size_ -= n;
+ }
+
+ void remove_suffix(size_t n) {
+ assert(size_ >= n);
+ size_ -= n;
+ }
+
+ size_t find(char c, size_t pos = 0) const;
+ size_t rfind(char c, size_t pos = npos) const;
+
+ StringPiece substr(size_t pos, size_t n = npos) const;
+
+ // Three-way comparison. Returns value:
+ // < 0 iff "*this" < "b",
+ // == 0 iff "*this" == "b",
+ // > 0 iff "*this" > "b"
+ int compare(StringPiece b) const;
+
+ // Converts to various kinds of strings, including `std::basic_string`.
+ template <typename S>
+ explicit operator S() const {
+ static_assert(
+ std::is_same<char, typename S::value_type>::value,
+ "Type mismatch: S must be a string with character type char.");
+ static_assert(
+ std::is_same<std::char_traits<char>, typename S::traits_type>::value,
+ "Type mismatch: S must be a string with traits type "
+ "std::char_traits<char>.");
+ if (!data()) return {};
+ return S(data(), size());
+ }
+
+ private:
+ const char* data_;
+ size_t size_;
+
+ // Intentionally copyable
+};
+
+inline bool operator==(StringPiece x, StringPiece y) {
+ return ((x.size() == y.size()) &&
+ (memcmp(x.data(), y.data(), x.size()) == 0));
+}
+
+inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
+
+inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; }
+inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; }
+inline bool operator<=(StringPiece x, StringPiece y) {
+ return x.compare(y) <= 0;
+}
+inline bool operator>=(StringPiece x, StringPiece y) {
+ return x.compare(y) >= 0;
+}
+
+inline int StringPiece::compare(StringPiece b) const {
+ const size_t min_len = (size_ < b.size_) ? size_ : b.size_;
+ int r = memcmp(data_, b.data_, min_len);
+ if (r == 0) {
+ if (size_ < b.size_)
+ r = -1;
+ else if (size_ > b.size_)
+ r = +1;
+ }
+ return r;
+}
+
+// allow StringPiece to be logged
+extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece);
} // namespace tensorflow
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index 2f6afa5487..6a2bf66d12 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -41,7 +41,7 @@ class RecordWriterOptions {
// Options specific to zlib compression.
#if !defined(IS_SLIM_BUILD)
- ZlibCompressionOptions zlib_options;
+ tensorflow::io::ZlibCompressionOptions zlib_options;
#endif // IS_SLIM_BUILD
};
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index a620f59447..351b6f5de3 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -124,9 +124,6 @@ class AlphaNum {
AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit)
: piece_(str) {}
- template <typename A>
- AlphaNum(const std::basic_string<char, std::char_traits<char>, A> &str)
- : piece_(str) {} // NOLINT(runtime/explicit)
StringPiece::size_type size() const { return piece_.size(); }
const char *data() const { return piece_.data(); }
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 9836f784ab..c32d6f84f5 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -13070,6 +13070,71 @@ op {
is_stateful: true
}
op {
+ name: "ConditionalAccumulator"
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Conj"
input_arg {
name: "input"
@@ -37080,6 +37145,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -37161,6 +37274,53 @@ op {
}
}
op {
+ name: "ParallelMapDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ParameterizedTruncatedNormal"
input_arg {
name: "shape"
@@ -64543,6 +64703,71 @@ op {
is_stateful: true
}
op {
+ name: "SparseConditionalAccumulator"
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "SparseCross"
input_arg {
name: "indices"
@@ -69336,6 +69561,21 @@ op {
}
}
op {
+ name: "StaticRegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+}
+op {
name: "StaticRegexReplace"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index eed0bce174..ffab8ad661 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
@@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 1a5ad8f421..9d2b3af51d 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -210,6 +210,7 @@ REGISTER_OP("ParallelMapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MapAndBatchDataset")
@@ -326,6 +327,19 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ParallelInterleaveDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("num_parallel_calls: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("GroupByReducerDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
@@ -867,7 +881,7 @@ REGISTER_OP("MapDefun")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
- std::vector<TensorShape> output_shapes;
+ std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
@@ -877,6 +891,10 @@ REGISTER_OP("MapDefun")
int64 dim_zero = -1;
for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ if (c->Rank(c->input(i)) == 0) {
+ return errors::InvalidArgument(
+ "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ }
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
if (dim_zero == -1) {
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 11ca0bd259..5427275284 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -683,11 +683,12 @@ REGISTER_OP("NonMaxSuppression")
});
REGISTER_OP("NonMaxSuppressionV2")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Output("selected_indices: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
// Get inputs and validate ranks.
ShapeHandle boxes;
@@ -711,22 +712,24 @@ REGISTER_OP("NonMaxSuppressionV2")
});
REGISTER_OP("NonMaxSuppressionV3")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.SetShapeFn(NMSShapeFn);
REGISTER_OP("NonMaxSuppressionV4")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
.Output("valid_outputs: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.Attr("pad_to_max_output_size: bool = false")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(NMSShapeFn(c));
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 28b25fdeae..aeb03c5952 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -5592,6 +5592,19 @@ op {
s: ""
}
}
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -18199,6 +18212,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -18237,6 +18298,13 @@ op {
has_minimum: true
minimum: 1
}
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "ParameterizedTruncatedNormal"
@@ -29617,6 +29685,19 @@ op {
s: ""
}
}
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -32115,6 +32196,21 @@ op {
}
}
op {
+ name: "StaticRegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+}
+op {
name: "StaticRegexReplace"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 7aa1e71809..ef8b15dc8a 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -56,6 +56,12 @@ REGISTER_OP("RegexFullMatch")
return Status::OK();
});
+REGISTER_OP("StaticRegexFullMatch")
+ .Input("input: string")
+ .Attr("pattern: string")
+ .Output("output: bool")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index ccddf1eafc..0389149469 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -321,6 +321,11 @@ class DeviceTracerImpl : public DeviceTracer,
return nullptr;
}
+ bool IsEnabled(bool is_expensive) const override {
+ // We don't do anything with 'Activities' so we are never 'enabled'.
+ return false;
+ }
+
protected:
// This callback is used exclusively by CUPTIManager.
friend class CUPTIManager;
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index e5851f1dfe..9974bbbb4e 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -155,6 +155,10 @@ class TraceCollector {
StringPiece name_part1, StringPiece name_part2,
bool is_expensive) const = 0;
+ // Returns true if this activity handle tracking is enabled for an op of the
+ // given expensiveness.
+ virtual bool IsEnabled(bool is_expensive) const = 0;
+
protected:
static string ConcatenateNames(StringPiece first, StringPiece second);
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 5ebd409b15..e755c37039 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3401,56 +3401,59 @@ func BoostedTreesCenterBias(scope *Scope, tree_ensemble_handle tf.Output, mean_g
return op.Output(0)
}
-// Computes the mean along sparse segments of a tensor.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Runs multiple additive regression ensemble predictors on input instances and
//
-// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
-// dimension, selecting a subset of dimension 0, specified by `indices`.
+// computes the update to cached logits. It is designed to be used during training.
+// It traverses the trees starting from cached tree id and cached node id and
+// calculates the updates to be pushed to the cache.
//
// Arguments:
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
+// tree of prediction.
+// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
+// node of prediction.
+// bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+// logits_dimension: scalar, dimension of the logits, to be used for partial logits
+// shape.
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+// Returns Rank 2 Tensor containing logits update (with respect to cached
+// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
+func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"logits_dimension": logits_dimension}
opspec := tf.OpSpec{
- Type: "SparseSegmentMean",
+ Type: "BoostedTreesTrainingPredict",
Input: []tf.Input{
- data, indices, segment_ids,
+ tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1), op.Output(2)
}
-// Pop the element at the top of the stack.
+// Serializes the tree ensemble to a proto.
//
// Arguments:
-// handle: The handle to a stack.
-// elem_type: The type of the elem that is popped.
+// tree_ensemble_handle: Handle to the tree ensemble.
//
-// Returns The tensor that is popped from the top of the stack.
-func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) {
+// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble.
+func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"elem_type": elem_type}
opspec := tf.OpSpec{
- Type: "StackPopV2",
+ Type: "BoostedTreesSerializeEnsemble",
Input: []tf.Input{
- handle,
+ tree_ensemble_handle,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1)
}
// Computes the sum along sparse segments of a tensor.
@@ -8159,47 +8162,6 @@ func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...
return op.Output(0)
}
-// RandomPoissonAttr is an optional argument to RandomPoisson.
-type RandomPoissonAttr func(optionalAttr)
-
-// RandomPoissonSeed sets the optional seed attribute to value.
-// If not specified, defaults to 0
-func RandomPoissonSeed(value int64) RandomPoissonAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomPoissonSeed2 sets the optional seed2 attribute to value.
-// If not specified, defaults to 0
-func RandomPoissonSeed2(value int64) RandomPoissonAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Use RandomPoissonV2 instead.
-//
-// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2
-func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomPoisson",
- Input: []tf.Input{
- shape, rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns the element-wise sum of a list of tensors.
//
// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
@@ -8348,6 +8310,377 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or
return op.Output(0)
}
+// Returns the truth value of (x > y) element-wise.
+//
+// *NOTE*: `Greater` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Greater",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
+type ResourceSparseApplyRMSPropAttr func(optionalAttr)
+
+// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, ms, and mom tensors is protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the RMSProp algorithm.
+//
+// Note that in dense implementation of this algorithm, ms and mom will
+// update even if the grad is zero, but in this sparse implementation, ms
+// and mom will not update in iterations during which the grad is zero.
+//
+// mean_square = decay * mean_square + (1-decay) * gradient ** 2
+// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+//
+// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+// var <- var - mom
+//
+// Arguments:
+// var_: Should be from a Variable().
+// ms: Should be from a Variable().
+// mom: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// rho: Decay rate. Must be a scalar.
+//
+// epsilon: Ridge term. Must be a scalar.
+// grad: The gradient.
+// indices: A vector of indices into the first dimension of var, ms and mom.
+//
+// Returns the created operation.
+func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceSparseApplyRMSProp",
+ Input: []tf.Input{
+ var_, ms, mom, lr, rho, momentum, epsilon, grad, indices,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
+type SampleDistortedBoundingBoxAttr func(optionalAttr)
+
+// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to non-zero, the random number
+// generator is seeded by the given `seed`. Otherwise, it is seeded by a random
+// seed.
+// If not specified, defaults to 0
+func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
+//
+// value: The cropped area of the image must contain at least this
+// fraction of any bounding box supplied. The value of this parameter should be
+// non-negative. In the case of 0, the cropped area does not need to overlap
+// any of the bounding boxes supplied.
+// If not specified, defaults to 0.1
+func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["min_object_covered"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
+//
+// value: The cropped area of the image must have an aspect ratio =
+// width / height within this range.
+// If not specified, defaults to <f:0.75 f:1.33 >
+func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["aspect_ratio_range"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
+//
+// value: The cropped area of the image must contain a fraction of the
+// supplied image within this range.
+// If not specified, defaults to <f:0.05 f:1 >
+func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["area_range"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
+//
+// value: Number of attempts at generating a cropped region of the image
+// of the specified constraints. After `max_attempts` failures, return the entire
+// image.
+// If not specified, defaults to 100
+func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["max_attempts"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
+//
+// value: Controls behavior if no bounding boxes supplied.
+// If true, assume an implicit bounding box covering the whole input. If false,
+// raise an error.
+// If not specified, defaults to false
+func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["use_image_if_no_bounding_boxes"] = value
+ }
+}
+
+// Generate a single randomly distorted bounding box for an image.
+//
+// Bounding box annotations are often supplied in addition to ground-truth labels
+// in image recognition or object localization tasks. A common technique for
+// training such a system is to randomly distort an image while preserving
+// its content, i.e. *data augmentation*. This Op outputs a randomly distorted
+// localization of an object, i.e. bounding box, given an `image_size`,
+// `bounding_boxes` and a series of constraints.
+//
+// The output of this Op is a single bounding box that may be used to crop the
+// original image. The output is returned as 3 tensors: `begin`, `size` and
+// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
+// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
+// what the bounding box looks like.
+//
+// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
+// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
+// height of the underlying image.
+//
+// For example,
+//
+// ```python
+// # Generate a single distorted bounding box.
+// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
+// tf.shape(image),
+// bounding_boxes=bounding_boxes)
+//
+// # Draw the bounding box in an image summary.
+// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
+// bbox_for_draw)
+// tf.summary.image('images_with_box', image_with_box)
+//
+// # Employ the bounding box to distort the image.
+// distorted_image = tf.slice(image, begin, size)
+// ```
+//
+// Note that if no bounding box information is available, setting
+// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
+// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
+// false and no bounding boxes are supplied, an error is raised.
+//
+// Arguments:
+// image_size: 1-D, containing `[height, width, channels]`.
+// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes
+// associated with the image.
+//
+// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
+// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to
+// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box.
+// Provide as input to `tf.image.draw_bounding_boxes`.
+func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SampleDistortedBoundingBox",
+ Input: []tf.Input{
+ image_size, bounding_boxes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Computes sigmoid of `x` element-wise.
+//
+// Specifically, `y = 1 / (1 + exp(-x))`.
+func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Sigmoid",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
+type FusedBatchNormAttr func(optionalAttr)
+
+// FusedBatchNormEpsilon sets the optional epsilon attribute to value.
+//
+// value: A small float number added to the variance of x.
+// If not specified, defaults to 0.0001
+func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["epsilon"] = value
+ }
+}
+
+// FusedBatchNormDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format for x and y. Either "NHWC" (default) or "NCHW".
+// If not specified, defaults to "NHWC"
+func FusedBatchNormDataFormat(value string) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// FusedBatchNormIsTraining sets the optional is_training attribute to value.
+//
+// value: A bool value to indicate the operation is for training (default)
+// or inference.
+// If not specified, defaults to true
+func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr {
+ return func(m optionalAttr) {
+ m["is_training"] = value
+ }
+}
+
+// Batch normalization.
+//
+// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
+// The size of 1D Tensors matches the dimension C of the 4D Tensors.
+//
+// Arguments:
+// x: A 4D Tensor for input data.
+// scale: A 1D Tensor for scaling factor, to scale the normalized x.
+// offset: A 1D Tensor for offset, to shift to the normalized x.
+// mean: A 1D Tensor for population mean. Used for inference only;
+// must be empty for training.
+// variance: A 1D Tensor for population variance. Used for inference only;
+// must be empty for training.
+//
+// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow
+// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by
+// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused
+// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance
+// in the cuDNN case), to be reused in the gradient computation.
+func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "FusedBatchNorm",
+ Input: []tf.Input{
+ x, scale, offset, mean, variance,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
+// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
+type RandomStandardNormalAttr func(optionalAttr)
+
+// RandomStandardNormalSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomStandardNormalSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random values from a normal distribution.
+//
+// The generated values will have mean 0 and standard deviation 1.
+//
+// Arguments:
+// shape: The shape of the output tensor.
+// dtype: The type of the output.
+//
+// Returns A tensor of the specified shape filled with random normal values.
+func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomStandardNormal",
+ Input: []tf.Input{
+ shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
type ResourceApplyFtrlAttr func(optionalAttr)
@@ -12357,235 +12690,6 @@ func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.
return values
}
-// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
-type ResourceSparseApplyRMSPropAttr func(optionalAttr)
-
-// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var, ms, and mom tensors is protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the RMSProp algorithm.
-//
-// Note that in dense implementation of this algorithm, ms and mom will
-// update even if the grad is zero, but in this sparse implementation, ms
-// and mom will not update in iterations during which the grad is zero.
-//
-// mean_square = decay * mean_square + (1-decay) * gradient ** 2
-// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
-//
-// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
-// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-// var <- var - mom
-//
-// Arguments:
-// var_: Should be from a Variable().
-// ms: Should be from a Variable().
-// mom: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// rho: Decay rate. Must be a scalar.
-//
-// epsilon: Ridge term. Must be a scalar.
-// grad: The gradient.
-// indices: A vector of indices into the first dimension of var, ms and mom.
-//
-// Returns the created operation.
-func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceSparseApplyRMSProp",
- Input: []tf.Input{
- var_, ms, mom, lr, rho, momentum, epsilon, grad, indices,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// Returns the truth value of (x > y) element-wise.
-//
-// *NOTE*: `Greater` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Greater",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
-type SampleDistortedBoundingBoxAttr func(optionalAttr)
-
-// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to non-zero, the random number
-// generator is seeded by the given `seed`. Otherwise, it is seeded by a random
-// seed.
-// If not specified, defaults to 0
-func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
-//
-// value: The cropped area of the image must contain at least this
-// fraction of any bounding box supplied. The value of this parameter should be
-// non-negative. In the case of 0, the cropped area does not need to overlap
-// any of the bounding boxes supplied.
-// If not specified, defaults to 0.1
-func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["min_object_covered"] = value
- }
-}
-
-// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
-//
-// value: The cropped area of the image must have an aspect ratio =
-// width / height within this range.
-// If not specified, defaults to <f:0.75 f:1.33 >
-func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["aspect_ratio_range"] = value
- }
-}
-
-// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
-//
-// value: The cropped area of the image must contain a fraction of the
-// supplied image within this range.
-// If not specified, defaults to <f:0.05 f:1 >
-func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["area_range"] = value
- }
-}
-
-// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
-//
-// value: Number of attempts at generating a cropped region of the image
-// of the specified constraints. After `max_attempts` failures, return the entire
-// image.
-// If not specified, defaults to 100
-func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["max_attempts"] = value
- }
-}
-
-// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
-//
-// value: Controls behavior if no bounding boxes supplied.
-// If true, assume an implicit bounding box covering the whole input. If false,
-// raise an error.
-// If not specified, defaults to false
-func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["use_image_if_no_bounding_boxes"] = value
- }
-}
-
-// Generate a single randomly distorted bounding box for an image.
-//
-// Bounding box annotations are often supplied in addition to ground-truth labels
-// in image recognition or object localization tasks. A common technique for
-// training such a system is to randomly distort an image while preserving
-// its content, i.e. *data augmentation*. This Op outputs a randomly distorted
-// localization of an object, i.e. bounding box, given an `image_size`,
-// `bounding_boxes` and a series of constraints.
-//
-// The output of this Op is a single bounding box that may be used to crop the
-// original image. The output is returned as 3 tensors: `begin`, `size` and
-// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
-// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
-// what the bounding box looks like.
-//
-// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
-// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
-// height of the underlying image.
-//
-// For example,
-//
-// ```python
-// # Generate a single distorted bounding box.
-// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
-// tf.shape(image),
-// bounding_boxes=bounding_boxes)
-//
-// # Draw the bounding box in an image summary.
-// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
-// bbox_for_draw)
-// tf.summary.image('images_with_box', image_with_box)
-//
-// # Employ the bounding box to distort the image.
-// distorted_image = tf.slice(image, begin, size)
-// ```
-//
-// Note that if no bounding box information is available, setting
-// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
-// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
-// false and no bounding boxes are supplied, an error is raised.
-//
-// Arguments:
-// image_size: 1-D, containing `[height, width, channels]`.
-// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes
-// associated with the image.
-//
-// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
-// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to
-// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box.
-// Provide as input to `tf.image.draw_bounding_boxes`.
-func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "SampleDistortedBoundingBox",
- Input: []tf.Input{
- image_size, bounding_boxes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// LRNAttr is an optional argument to LRN.
type LRNAttr func(optionalAttr)
@@ -14396,6 +14500,47 @@ func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
+// RandomPoissonAttr is an optional argument to RandomPoisson.
+type RandomPoissonAttr func(optionalAttr)
+
+// RandomPoissonSeed sets the optional seed attribute to value.
+// If not specified, defaults to 0
+func RandomPoissonSeed(value int64) RandomPoissonAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomPoissonSeed2 sets the optional seed2 attribute to value.
+// If not specified, defaults to 0
+func RandomPoissonSeed2(value int64) RandomPoissonAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Use RandomPoissonV2 instead.
+//
+// DEPRECATED at GraphDef version 25: Replaced by RandomPoissonV2
+func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomPoisson",
+ Input: []tf.Input{
+ shape, rate,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
type LogUniformCandidateSamplerAttr func(optionalAttr)
@@ -16136,148 +16281,6 @@ func ResourceScatterMul(scope *Scope, resource tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
-// Computes sigmoid of `x` element-wise.
-//
-// Specifically, `y = 1 / (1 + exp(-x))`.
-func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Sigmoid",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
-type FusedBatchNormAttr func(optionalAttr)
-
-// FusedBatchNormEpsilon sets the optional epsilon attribute to value.
-//
-// value: A small float number added to the variance of x.
-// If not specified, defaults to 0.0001
-func FusedBatchNormEpsilon(value float32) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["epsilon"] = value
- }
-}
-
-// FusedBatchNormDataFormat sets the optional data_format attribute to value.
-//
-// value: The data format for x and y. Either "NHWC" (default) or "NCHW".
-// If not specified, defaults to "NHWC"
-func FusedBatchNormDataFormat(value string) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// FusedBatchNormIsTraining sets the optional is_training attribute to value.
-//
-// value: A bool value to indicate the operation is for training (default)
-// or inference.
-// If not specified, defaults to true
-func FusedBatchNormIsTraining(value bool) FusedBatchNormAttr {
- return func(m optionalAttr) {
- m["is_training"] = value
- }
-}
-
-// Batch normalization.
-//
-// Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW".
-// The size of 1D Tensors matches the dimension C of the 4D Tensors.
-//
-// Arguments:
-// x: A 4D Tensor for input data.
-// scale: A 1D Tensor for scaling factor, to scale the normalized x.
-// offset: A 1D Tensor for offset, to shift to the normalized x.
-// mean: A 1D Tensor for population mean. Used for inference only;
-// must be empty for training.
-// variance: A 1D Tensor for population variance. Used for inference only;
-// must be empty for training.
-//
-// Returns A 4D Tensor for output data.A 1D Tensor for the computed batch mean, to be used by TensorFlow
-// to compute the running mean.A 1D Tensor for the computed batch variance, to be used by
-// TensorFlow to compute the running variance.A 1D Tensor for the computed batch mean, to be reused
-// in the gradient computation.A 1D Tensor for the computed batch variance (inverted variance
-// in the cuDNN case), to be reused in the gradient computation.
-func FusedBatchNorm(scope *Scope, x tf.Output, scale tf.Output, offset tf.Output, mean tf.Output, variance tf.Output, optional ...FusedBatchNormAttr) (y tf.Output, batch_mean tf.Output, batch_variance tf.Output, reserve_space_1 tf.Output, reserve_space_2 tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "FusedBatchNorm",
- Input: []tf.Input{
- x, scale, offset, mean, variance,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
-// RandomStandardNormalAttr is an optional argument to RandomStandardNormal.
-type RandomStandardNormalAttr func(optionalAttr)
-
-// RandomStandardNormalSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomStandardNormalSeed(value int64) RandomStandardNormalAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomStandardNormalSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomStandardNormalSeed2(value int64) RandomStandardNormalAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random values from a normal distribution.
-//
-// The generated values will have mean 0 and standard deviation 1.
-//
-// Arguments:
-// shape: The shape of the output tensor.
-// dtype: The type of the output.
-//
-// Returns A tensor of the specified shape filled with random normal values.
-func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...RandomStandardNormalAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomStandardNormal",
- Input: []tf.Input{
- shape,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Component-wise divides a SparseTensor by a dense Tensor.
//
// *Limitation*: this Op only broadcasts the dense side to the sparse side, but not
@@ -17427,26 +17430,6 @@ func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (i
return op.Output(0)
}
-// Serializes the tree ensemble to a proto.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.Serialized proto of the ensemble.
-func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, tree_ensemble_serialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesSerializeEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// StageSizeAttr is an optional argument to StageSize.
type StageSizeAttr func(optionalAttr)
@@ -20376,6 +20359,58 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
+// Computes the mean along sparse segments of a tensor.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
+// dimension, selecting a subset of dimension 0, specified by `indices`.
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMean",
+ Input: []tf.Input{
+ data, indices, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Pop the element at the top of the stack.
+//
+// Arguments:
+// handle: The handle to a stack.
+// elem_type: The type of the elem that is popped.
+//
+// Returns The tensor that is popped from the top of the stack.
+func StackPopV2(scope *Scope, handle tf.Output, elem_type tf.DataType) (elem tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"elem_type": elem_type}
+ opspec := tf.OpSpec{
+ Type: "StackPopV2",
+ Input: []tf.Input{
+ handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes hyperbolic cosine of x element-wise.
func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -31743,54 +31778,6 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true
return op.Output(0), op.Output(1), op.Output(2)
}
-// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
-type WholeFileReaderV2Attr func(optionalAttr)
-
-// WholeFileReaderV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this reader is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this reader is named in the given bucket
-// with this shared_name. Otherwise, the node name is used instead.
-// If not specified, defaults to ""
-func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A Reader that outputs the entire contents of a file as a value.
-//
-// To use, enqueue filenames in a Queue. The output of ReaderRead will
-// be a filename (key) and the contents of that file (value).
-//
-// Returns The handle to reference the Reader.
-func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "WholeFileReaderV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Transforms a tf.Example proto (as a string) into typed tensors.
//
// Arguments:
@@ -31861,60 +31848,73 @@ func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.
return sparse_indices, sparse_values, sparse_shapes, dense_values
}
-// Deserializes a serialized tree ensemble config and replaces current tree
+// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
+type WholeFileReaderV2Attr func(optionalAttr)
+
+// WholeFileReaderV2Container sets the optional container attribute to value.
//
-// ensemble.
+// value: If non-empty, this reader is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-// stamp_token: Token to use as the new value of the resource stamp.
-// tree_ensemble_serialized: Serialized proto of the ensemble.
+// value: If non-empty, this reader is named in the given bucket
+// with this shared_name. Otherwise, the node name is used instead.
+// If not specified, defaults to ""
+func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A Reader that outputs the entire contents of a file as a value.
//
-// Returns the created operation.
-func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
+// To use, enqueue filenames in a Queue. The output of ReaderRead will
+// be a filename (key) and the contents of that file (value).
+//
+// Returns The handle to reference the Reader.
+func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "BoostedTreesDeserializeEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
- },
+ Type: "WholeFileReaderV2",
+
+ Attrs: attrs,
}
- return scope.AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Runs multiple additive regression ensemble predictors on input instances and
+// Deserializes a serialized tree ensemble config and replaces current tree
//
-// computes the update to cached logits. It is designed to be used during training.
-// It traverses the trees starting from cached tree id and cached node id and
-// calculates the updates to be pushed to the cache.
+// ensemble.
//
// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+// stamp_token: Token to use as the new value of the resource stamp.
+// tree_ensemble_serialized: Serialized proto of the ensemble.
//
-// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
-// tree of prediction.
-// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
-// node of prediction.
-// bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-// logits_dimension: scalar, dimension of the logits, to be used for partial logits
-// shape.
-//
-// Returns Rank 2 Tensor containing logits update (with respect to cached
-// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
-func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
+// Returns the created operation.
+func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"logits_dimension": logits_dimension}
opspec := tf.OpSpec{
- Type: "BoostedTreesTrainingPredict",
+ Type: "BoostedTreesDeserializeEnsemble",
Input: []tf.Input{
- tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
+ tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
},
- Attrs: attrs,
}
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
+ return scope.AddOperation(opspec)
}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index ba9c6a2320..19729813a1 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -78,6 +78,7 @@ py_library(
"//tensorflow:__pkg__",
"//tensorflow/python/tools:__pkg__",
"//tensorflow/python/tools/api/generator:__pkg__",
+ "//tensorflow/tools/api/tests:__pkg__",
],
deps = [
":array_ops",
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 586f4c6936..7a3fc27592 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 5)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 7)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 23c98247bf..631b87a718 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -137,6 +137,8 @@ tf_py_test(
size = "small",
srcs = ["interleave_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -154,6 +156,7 @@ tf_py_test(
size = "small",
srcs = ["map_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index 7dbf7268d7..a35cee594a 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -19,8 +19,10 @@ from __future__ import print_function
import itertools
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
@@ -28,7 +30,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase):
+class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
@@ -97,84 +99,85 @@ class InterleaveDatasetTest(test.TestCase):
expected_elements, self._interleave(input_lists, 7, 2)):
self.assertEqual(expected, produced)
- def testInterleaveDataset(self):
- input_values = array_ops.placeholder(dtypes.int64, shape=[None])
- cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
- block_length = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_count = 2
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_values)
- .repeat(repeat_count)
- .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- next_element = iterator.get_next()
+ @parameterized.named_parameters(
+ ("1", np.int64([4, 5, 6]), 1, 3, None),
+ ("2", np.int64([4, 5, 6]), 1, 3, 1),
+ ("3", np.int64([4, 5, 6]), 2, 1, None),
+ ("4", np.int64([4, 5, 6]), 2, 1, 1),
+ ("5", np.int64([4, 5, 6]), 2, 1, 2),
+ ("6", np.int64([4, 5, 6]), 2, 3, None),
+ ("7", np.int64([4, 5, 6]), 2, 3, 1),
+ ("8", np.int64([4, 5, 6]), 2, 3, 2),
+ ("9", np.int64([4, 5, 6]), 7, 2, None),
+ ("10", np.int64([4, 5, 6]), 7, 2, 1),
+ ("11", np.int64([4, 5, 6]), 7, 2, 3),
+ ("12", np.int64([4, 5, 6]), 7, 2, 5),
+ ("13", np.int64([4, 5, 6]), 7, 2, 7),
+ ("14", np.int64([]), 2, 3, None),
+ ("15", np.int64([0, 0, 0]), 2, 3, None),
+ ("16", np.int64([4, 0, 6]), 2, 3, None),
+ ("17", np.int64([4, 0, 6]), 2, 3, 1),
+ ("18", np.int64([4, 0, 6]), 2, 3, 2),
+ )
+ def testInterleaveDataset(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
+ count = 2
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
+ count).interleave(
+ lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
+ cycle_length, block_length, num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ def repeat(values, count):
+ result = []
+ for value in values:
+ result.append([value] * value)
+ return result * count
with self.test_session() as sess:
- # Cycle length 1 acts like `Dataset.flat_map()`.
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 1, block_length: 3})
-
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3):
- self.assertEqual(expected_element, sess.run(next_element))
-
- # Cycle length > 1.
- # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5,
- # 6, 5, 6, 5, 6, 5, 6, 5]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 2, block_length: 1})
for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Cycle length > 1 and block length > 1.
- # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5,
- # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 2, block_length: 3})
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Cycle length > len(input_values) * repeat_count.
- # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4,
- # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
- sess.run(init_op, feed_dict={input_values: [4, 5, 6],
- cycle_length: 7, block_length: 2})
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2):
- self.assertEqual(expected_element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Empty input.
- sess.run(init_op, feed_dict={input_values: [],
- cycle_length: 2, block_length: 3})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ repeat(input_values, count), cycle_length, block_length):
+ self.assertEqual(expected_element, sess.run(get_next))
+
+ for _ in range(2):
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None),
+ ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1),
+ ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None),
+ ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1),
+ ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2),
+ ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None),
+ ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1),
+ ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2),
+ ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None),
+ ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1),
+ ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3),
+ ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5),
+ ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7),
+ )
+ def testInterleaveErrorDataset(self,
+ input_values,
+ cycle_length,
+ block_length,
+ num_parallel_calls):
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+ lambda x: array_ops.check_numerics(x, "message")).interleave(
+ dataset_ops.Dataset.from_tensors, cycle_length, block_length,
+ num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
- # Non-empty input leading to empty output.
- sess.run(init_op, feed_dict={input_values: [0, 0, 0],
- cycle_length: 2, block_length: 3})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- # Mixture of non-empty and empty interleaved datasets.
- sess.run(init_op, feed_dict={input_values: [4, 0, 6],
- cycle_length: 2, block_length: 3})
- for expected_element in self._interleave(
- [[4] * 4, [], [6] * 6] * repeat_count, 2, 3):
- self.assertEqual(expected_element, sess.run(next_element))
+ with self.test_session() as sess:
+ for value in input_values:
+ if np.isnan(value):
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+ else:
+ self.assertEqual(value, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ sess.run(get_next)
def testSparse(self):
@@ -201,20 +204,6 @@ class InterleaveDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testEmptyInput(self):
- iterator = (
- dataset_ops.Dataset.from_tensor_slices([])
- .repeat(None)
- .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index df2c9b170a..fde785be6e 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -22,6 +22,7 @@ import threading
import time
import warnings
+from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import attr_value_pb2
@@ -46,7 +47,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test.TestCase, parameterized.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
@@ -705,6 +706,35 @@ class MapDatasetTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
sess.run(iterator.initializer)
+# pylint: disable=g-long-lambda
+ @parameterized.named_parameters(
+ ("Map", lambda dataset, func:
+ dataset_ops.MapDataset(dataset, func, use_inter_op_parallelism=False)),
+ ("ParallelMap", lambda dataset, func:
+ dataset_ops.ParallelMapDataset(dataset, func, num_parallel_calls=1,
+ use_inter_op_parallelism=False)),
+ )
+ def testNoInterOpParallelism(self, make_dataset_fn):
+ dataset = dataset_ops.Dataset.from_tensors(0)
+
+ def _get_tid():
+ return np.int64(threading.current_thread().ident)
+
+ def _map_fn(_):
+ tids = []
+ for _ in range(10):
+ tids.append(script_ops.py_func(_get_tid, [], dtypes.int64))
+ return tids
+
+ dataset = make_dataset_fn(dataset, _map_fn)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ tids = sess.run(get_next)
+ self.assertTrue(all(tids[0] == tid for tid in tids))
+# pylint: enable=g-long-lambda
+
class MapDatasetBenchmark(test.Benchmark):
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6205ee392e..c985e00dd1 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1019,7 +1019,11 @@ class Dataset(object):
"""
return FlatMapDataset(self, map_func)
- def interleave(self, map_func, cycle_length, block_length=1):
+ def interleave(self,
+ map_func,
+ cycle_length,
+ block_length=1,
+ num_parallel_calls=None):
"""Maps `map_func` across this dataset, and interleaves the results.
For example, you can use `Dataset.interleave()` to process many input files
@@ -1082,11 +1086,19 @@ class Dataset(object):
processed concurrently.
block_length: The number of consecutive elements to produce from each
input element before cycling to another input element.
+ num_parallel_calls: (Optional.) If specified, the implementation creates
+ a threadpool, which is used to fetch inputs from cycle elements
+ asynchronously and in parallel. The default behavior is to fetch inputs
+ from cycle elements synchronously with no parallelism.
Returns:
Dataset: A `Dataset`.
"""
- return InterleaveDataset(self, map_func, cycle_length, block_length)
+ if num_parallel_calls is None:
+ return InterleaveDataset(self, map_func, cycle_length, block_length)
+ else:
+ return ParallelInterleaveDataset(self, map_func, cycle_length,
+ block_length, num_parallel_calls)
def filter(self, predicate):
"""Filters this dataset according to `predicate`.
@@ -2245,9 +2257,14 @@ class MapDataset(Dataset):
class ParallelMapDataset(MapDataset):
"""A `Dataset` that maps a function over elements in its input in parallel."""
- def __init__(self, input_dataset, map_func, num_parallel_calls):
+ def __init__(self,
+ input_dataset,
+ map_func,
+ num_parallel_calls,
+ use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(ParallelMapDataset, self).__init__(input_dataset, map_func)
+ super(ParallelMapDataset, self).__init__(input_dataset, map_func,
+ use_inter_op_parallelism)
self._num_parallel_calls = ops.convert_to_tensor(
num_parallel_calls, dtype=dtypes.int32, name="num_parallel_calls")
@@ -2260,6 +2277,7 @@ class ParallelMapDataset(MapDataset):
self._map_func.captured_inputs,
f=self._map_func,
num_parallel_calls=self._num_parallel_calls,
+ use_inter_op_parallelism=self._use_inter_op_parallelism,
**flat_structure(self))
# pylint: enable=protected-access
@@ -2330,6 +2348,36 @@ class InterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
+class ParallelInterleaveDataset(FlatMapDataset):
+ """A `Dataset` that maps a function over its input and interleaves the result.
+
+ """
+
+ def __init__(self, input_dataset, map_func, cycle_length, block_length,
+ num_parallel_calls):
+ """See `Dataset.interleave()` for details."""
+ super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func)
+ self._cycle_length = ops.convert_to_tensor(
+ cycle_length, dtype=dtypes.int64, name="cycle_length")
+ self._block_length = ops.convert_to_tensor(
+ block_length, dtype=dtypes.int64, name="block_length")
+ self._num_parallel_calls = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parallel_interleave_dataset_v2(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._map_func.captured_inputs, # pylint: disable=protected-access
+ self._cycle_length,
+ self._block_length,
+ self._num_parallel_calls,
+ f=self._map_func, # pylint: disable=protected-access
+ **flat_structure(self))
+
+ def _transformation_name(self):
+ return "Dataset.interleave()"
+
+
class FilterDataset(Dataset):
"""A `Dataset` that filters its input according to a predicate function."""
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index 9d621fcd30..3a5d1f0adf 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -96,37 +96,12 @@ def _yield_value(iterable):
yield value
-def is_sequence(seq):
- """Returns a true if `seq` is a Sequence or dict (except strings/lists).
+# See the swig file (../../util/util.i) for documentation.
+is_sequence = _pywrap_tensorflow.IsSequenceForData
- NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
- which *does* treat a Python list as a sequence. For ergonomic
- reasons, `tf.data` users would prefer to treat lists as
- implicit `tf.Tensor` objects, and dicts as (nested) sequences.
- Args:
- seq: an input sequence.
-
- Returns:
- True if the sequence is a not a string or list and is a
- collections.Sequence.
- """
- return _pywrap_tensorflow.IsSequenceForData(seq)
-
-
-def flatten(nest):
- """Returns a flat sequence from a given nested structure.
-
- If `nest` is not a sequence, this returns a single-element list: `[nest]`.
-
- Args:
- nest: an arbitrarily nested structure or a scalar object.
- Note, numpy arrays are considered scalars.
-
- Returns:
- A Python list, the flattened version of the input.
- """
- return _pywrap_tensorflow.FlattenForData(nest)
+# See the swig file (../../util/util.i) for documentation.
+flatten = _pywrap_tensorflow.FlattenForData
def assert_same_structure(nest1, nest2, check_types=True):
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 9891068056..be392c7a0f 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -216,9 +216,7 @@ def implicit_val_and_grad(f):
"function was being computed.")
sources = [v.handle for v in variables]
- grad = imperative_grad.imperative_grad(_default_vspace,
- this_tape,
- nest.flatten(end_node),
+ grad = imperative_grad.imperative_grad(this_tape, nest.flatten(end_node),
sources)
return end_node, list(zip(grad, variables))
@@ -537,8 +535,8 @@ def make_vjp(f, params=None, persistent=True):
if dy is not None:
dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)]
return imperative_grad.imperative_grad(
- _default_vspace, this_tape, nest.flatten(result), sources,
- output_gradients=dy)
+ this_tape, nest.flatten(result), sources, output_gradients=dy)
+
return result, vjp
return decorated
@@ -631,9 +629,9 @@ def _ones(shape, dtype):
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- tensor_id=ops.tensor_id,
zeros=_zeros,
ones=_ones)
+pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):
@@ -695,19 +693,57 @@ class GradientTape(object):
del g # Drop the reference to the tape
```
+ By default GradientTape will automatically watch any trainable variables that
+ are accessed inside the context. If you want fine grained control over which
+ variables are watched you can disable automatic tracking by passing
+ `watch_accessed_variables=False` to the tape constructor:
+
+ ```python
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(variable_a)
+ y = variable_a ** 2 # Gradients will be available for `variable_a`.
+ z = variable_b ** 3 # No gradients will be avaialble since `variable_b` is
+ # not being watched.
+ ```
+
+ Note that when using models you should ensure that your variables exist when
+ using `watch_accessed_variables=False`. Otherwise it's quite easy to make your
+ first iteration not have any gradients:
+
+ ```python
+ a = tf.keras.layers.Dense(32)
+ b = tf.keras.layers.Dense(32)
+
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(a.variables) # Since `a.build` has not been called at this point
+ # `a.variables` will return an empty list and the
+ # tape will not be watching anything.
+ result = b(a(inputs))
+ tape.gradient(result, a.variables) # The result of this computation will be
+ # a list of `None`s since a's variables
+ # are not being watched.
+ ```
+
Note that only tensors with real or complex dtypes are differentiable.
"""
- def __init__(self, persistent=False):
+ def __init__(self, persistent=False, watch_accessed_variables=True):
"""Creates a new GradientTape.
Args:
persistent: Boolean controlling whether a persistent gradient tape
is created. False by default, which means at most one call can
be made to the gradient() method on this object.
+ watch_accessed_variables: Boolean controlling whether the tape will
+ automatically `watch` any (trainable) variables accessed while the tape
+ is active. Defaults to True meaning gradients can be requested from any
+ result computed in the tape derived from reading a trainable `Variable`.
+ If False users must explicitly `watch` any `Variable`s they want to
+ request gradients from.
"""
self._tape = None
self._persistent = persistent
+ self._watch_accessed_variables = watch_accessed_variables
self._recording = False
context.context().start_step()
@@ -721,15 +757,15 @@ class GradientTape(object):
if self._recording:
self._pop_tape()
- def _push_tape(self, existing_tape=False):
+ def _push_tape(self):
if self._recording:
raise ValueError("Tape is already recording.")
- if existing_tape:
- if self._tape is None:
- raise ValueError("There is no existing tape.")
- tape.push_tape(self._tape)
+ if self._tape is None:
+ self._tape = tape.push_new_tape(
+ persistent=self._persistent,
+ watch_accessed_variables=self._watch_accessed_variables)
else:
- self._tape = tape.push_new_tape(persistent=self._persistent)
+ tape.push_tape(self._tape)
self._recording = True
def _pop_tape(self):
@@ -748,7 +784,13 @@ class GradientTape(object):
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
- tape.watch(self._tape, _handle_or_self(t))
+ if hasattr(t, "handle"):
+ # There are many variable-like objects, all of them currently have
+ # `handle` attribute that points to a tensor. If this changes, internals
+ # of watch_variable need to change as well.
+ tape.watch_variable(self._tape, t)
+ else:
+ tape.watch(self._tape, t)
@tf_contextlib.contextmanager
def stop_recording(self):
@@ -780,7 +822,7 @@ class GradientTape(object):
try:
yield
finally:
- self._push_tape(existing_tape=True)
+ self._push_tape()
def reset(self):
"""Clears all information stored in this tape.
@@ -814,6 +856,7 @@ class GradientTape(object):
```
"""
self._pop_tape()
+ self._tape = None
self._push_tape()
def watched_variables(self):
@@ -865,7 +908,9 @@ class GradientTape(object):
for x in nest.flatten(output_gradients)]
flat_grad = imperative_grad.imperative_grad(
- _default_vspace, self._tape, nest.flatten(target), flat_sources,
+ self._tape,
+ nest.flatten(target),
+ flat_sources,
output_gradients=output_gradients)
if not self._persistent:
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 6673178ee7..f938ed5df8 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -474,6 +474,18 @@ class BackpropTest(test.TestCase):
self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
@test_util.assert_no_new_tensors
+ def testGradientTapeReEnterContext(self):
+ g = backprop.GradientTape()
+ with g:
+ x = constant_op.constant(3.0)
+ g.watch(x)
+ y = 2*x
+ with g:
+ z = 2*y
+ grad = g.gradient(target=z, sources=[x])
+ self.assertEqual(self.evaluate(grad), [4.0])
+
+ @test_util.assert_no_new_tensors
@test_util.run_in_graph_and_eager_modes
def testGradientTapeRepeatedSource(self):
with backprop.GradientTape(persistent=False) as g:
@@ -956,6 +968,60 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(grad1, grad2)
+ @test_util.run_in_graph_and_eager_modes
+ def testSelectivelyWatchVariables(self):
+ x1 = resource_variable_ops.ResourceVariable(1.0)
+ x2 = resource_variable_ops.ResourceVariable(1.0)
+ with backprop.GradientTape(watch_accessed_variables=False) as tape:
+ tape.watch(x2)
+ y = x1**2
+ z = x2**3
+ self.assertTupleEqual(tape.watched_variables(), (x2,))
+ dy, dz = tape.gradient([y, z], [x1, x2])
+ self.evaluate([x1.initializer, x2.initializer])
+ self.assertIsNone(dy)
+ self.assertEqual(self.evaluate(dz), 3.0)
+
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDifferentiatingScalarCache(self):
+ # In the following test, if x2 = x1 (i.e the objects are the exact same),
+ # then y is essentially, 2*x1, and dy/dx1 = 2.
+ # When we had a pure scalar cache in eager, this would be the case. This
+ # test prevents us from going back to that case.
+ with backprop.GradientTape(persistent=False) as g:
+ x1 = constant_op.constant(3.0)
+ x2 = constant_op.constant(3.0)
+ g.watch(x1)
+ g.watch(x2)
+ y = x1 + x2
+ grad = g.gradient(target=y, sources=[x1])
+ self.assertEqual(self.evaluate(grad), [1.0])
+
+ def testVariablesAndConstantsProduceTheSameGradients(self):
+
+ # In the following test, differentiating [y, z] against [a, b] gives:
+ # (dy/da + dz/da, dy/db + dz/db).
+ # If a and b are the same constant, dz/da will not be 0 (which it should
+ # be).
+ # This is solved by using variable since doing a read_value on a tensor will
+ # produce a new tensor and corresponding TensorHandle, and not reuse the
+ # same tensor (which would happen if we are using a cache and reusing
+ # EagerTensor objects).
+ def get_grads(a, b):
+ with backprop.GradientTape() as tape:
+ tape.watch([a, b])
+ y = a**3
+ z = b**2
+ return tape.gradient([y, z], [a, b])
+
+ gradients_constants = get_grads(
+ constant_op.constant(2.0), constant_op.constant(2.0))
+ gradients_variables = get_grads(
+ resource_variable_ops.ResourceVariable(2.0),
+ resource_variable_ops.ResourceVariable(2.0))
+ self.assertAllEqual(gradients_constants, gradients_variables)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index a2e8422671..3bdaf0b214 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -175,6 +175,11 @@ class MicroBenchmarks(test.Benchmark):
self._run(func, 30000)
+ def benchmark_create_constant(self):
+ func = lambda: constant_op.constant(3.0)
+
+ self._run(func, 30000)
+
def benchmark_create_float_tensor_from_list_CPU(self):
self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index d56c1457e0..03f12139f6 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -519,7 +519,7 @@ class Function(object):
for v in self._func_graph.variables:
if v.trainable:
- tape.watch_variable(v)
+ tape.variable_accessed(v)
captures = self._resolve_captured_inputs()
tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 3c79099d87..37a9957cea 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
-from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -616,7 +615,6 @@ class FunctionTest(test.TestCase):
@function.defun
def g(x):
- tape.watch_variable(x)
y = math_ops.add(x, three)
f(y)
@@ -630,7 +628,6 @@ class FunctionTest(test.TestCase):
return math_ops.add(x, three)
def g(x):
- tape.watch_variable(three)
return f(x)
g = backprop.implicit_grad(g)(constant_op.constant(1.0))[0][0]
@@ -1427,14 +1424,14 @@ class FunctionTest(test.TestCase):
grad_t, = backprop.gradients_function(sq, [0])(t)
self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
- with backprop.GradientTape(persistent=True) as gtape:
- gtape.watch(t)
+ with backprop.GradientTape(persistent=True) as tape:
+ tape.watch(t)
one = matmul(t, b=t, transpose_a=True)
two = matmul(b=t, a=t, transpose_a=True)
three = matmul(a=t, b=t, transpose_a=True)
for output in [one, two, three]:
- self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]])
+ self.assertAllEqual(tape.gradient(output, t), [[6, 6], [14, 14]])
def testGradientInFunctionWithKeywordArguments(self):
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 000152855d..5f027d107c 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -24,12 +24,10 @@ from tensorflow.python import pywrap_tensorflow
VSpace = collections.namedtuple(
- "VSpace",
- ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"])
+ "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
def imperative_grad(
- vspace,
tape,
target,
sources,
@@ -41,7 +39,6 @@ def imperative_grad(
gradients for all sources.
Args:
- vspace: the vector space in which to differentiate.
tape: the gradient tape which stores the trace.
target: either a Tensor or list of Tensors to be differentiated.
sources: list of Tensors for which we want gradients
@@ -60,4 +57,7 @@ def imperative_grad(
computation of target.
"""
return pywrap_tensorflow.TFE_Py_TapeGradient(
- tape._tape, vspace, target, sources, output_gradients) # pylint: disable=protected-access
+ tape._tape, # pylint: disable=protected-access
+ target,
+ sources,
+ output_gradients)
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 86fbd24d68..f34ce6af79 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
+#include "structmember.h" // NOLINT // For PyMemberDef
+
// forward declare
struct EagerTensor;
@@ -325,12 +327,36 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
PyObject* context = nullptr;
PyObject* device = nullptr;
PyObject* dtype = Py_None;
- const char* kwlist[] = {"value", "context", "device", "dtype", nullptr};
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O",
+ PyObject* other_value = nullptr;
+ const char* kwlist[] = {"value", "context", "device",
+ "dtype", "other_value", nullptr};
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO",
const_cast<char**>(kwlist), &value, &context,
- &device, &dtype)) {
+ &device, &dtype, &other_value)) {
return -1;
}
+
+ if (other_value != nullptr) {
+ if (!EagerTensor_CheckExact(other_value)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting an EagerTensor for other_value, got ",
+ Py_TYPE(other_value)->tp_name)
+ .c_str());
+
+ return -1;
+ }
+ EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value);
+ self->handle =
+ TFE_TensorHandleCopySharingTensor(other->handle, self->status);
+
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ return -1;
+ }
+
+ return 0;
+ }
+
// Extract dtype
int desired_dtype = -1;
if (dtype != Py_None) {
@@ -619,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = {
{nullptr} /* Sentinel */
};
+#if PY_MAJOR_VERSION < 3
+// Only used for Python2 since Python3 seems to set the __dict__ correctly.
+static PyMemberDef EagerTensor_members[] = {
+ {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
+ READONLY},
+ {nullptr},
+};
+#endif
+
static PyMethodDef EagerTensor_methods[] = {
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
PyDoc_STR("_numpy")},
@@ -693,7 +728,7 @@ static PyTypeObject _EagerTensorType = {
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
EagerTensor_methods, /* tp_methods */
- nullptr, /* tp_members */
+ EagerTensor_members, /* tp_members */
EagerTensor_getseters, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
@@ -829,7 +864,7 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
}
EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
#else
- _EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
+ _EagerTensorType.tp_base = base_class_type;
if (PyType_Ready(&_EagerTensorType) < 0) {
if (PyErr_Occurred()) return nullptr;
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 16f8c3c917..f1b4042ec9 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -59,6 +59,10 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e);
// This function is not thread-safe.
PyObject* TFE_Py_RegisterResourceVariableType(PyObject* e);
+// Registers e as the VSpace to use.
+// `vspace` must be a imperative_grad.py:VSpace named tuple.
+PyObject* TFE_Py_RegisterVSpace(PyObject* e);
+
// Registers e as the Exception to be raised when the conditions of
// TFE_Py_FastPathExecute_C have not been met. When this exception is set, it
// is a signal to the calling code that it should fall back to the safer (and
@@ -124,9 +128,10 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class);
// To unset the profiler, pass Py_None as the value of `profiler`.
PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler);
-// Creates a new tape and adds it to the active set. `persistent` must be a
-// PyBool_Type, i.e either Py_True or Py_False
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent);
+// Creates a new tape and adds it to the active set. `persistent` and
+// `watch_accessed_variables` must be `PyBool_Type` (`Py_True` or `Py_False`).
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables);
// Removes the passed tape from the set of active tapes.
void TFE_Py_TapeSetRemove(PyObject* tape);
@@ -158,18 +163,20 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
PyObject* input_tensor_ids,
PyObject* backward_function);
+// Notifies all tapes that a variable has been accessed.
+void TFE_Py_TapeVariableAccessed(PyObject* variable);
+
// Watches the given variable object on the given tape.
-void TFE_Py_TapeSetWatchVariable(PyObject* variable);
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable);
// Computes a gradient based on information recorded on the tape.`tape` must
-// have been produced by TFE_Py_NewTape. `vspace` must be a
-// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
+// have been produced by TFE_Py_NewTape. `target` and `sources` must be python
// lists of Tensor objects. `output_gradients` is either None or a python list
// of either Tensor or None, and if not None should have the same length as
// target.
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
- PyObject* target, PyObject* sources,
- PyObject* output_gradients, TF_Status* status);
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+ PyObject* sources, PyObject* output_gradients,
+ TF_Status* status);
// Execute a tensorflow operation assuming that all provided inputs are
// correctly formatted (i.e. EagerTensors). If it doesn't find EagerTensors,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 0a33a04dcb..1ed814258b 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -892,9 +892,10 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
public:
- explicit GradientTape(bool persistent)
+ explicit GradientTape(bool persistent, bool watch_accessed_variables)
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent) {}
+ persistent),
+ watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
for (const IdAndVariable& v : watched_variables_) {
@@ -902,6 +903,12 @@ class GradientTape
}
}
+ void VariableAccessed(PyObject* v) {
+ if (watch_accessed_variables_) {
+ WatchVariable(v);
+ }
+ }
+
void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
@@ -951,6 +958,7 @@ class GradientTape
}
};
+ bool watch_accessed_variables_;
tensorflow::mutex watched_variables_mu_;
std::set<IdAndVariable, CompareById> watched_variables_
GUARDED_BY(watched_variables_mu_);
@@ -1056,11 +1064,13 @@ void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
-PyObject* TFE_Py_TapeSetNew(PyObject* persistent) {
+PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
+ PyObject* watch_accessed_variables) {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
- tape->tape = new GradientTape(persistent == Py_True);
+ tape->tape = new GradientTape(persistent == Py_True,
+ watch_accessed_variables == Py_True);
Py_INCREF(tape);
GetTapeSet()->insert(reinterpret_cast<TFE_Py_Tape*>(tape));
return reinterpret_cast<PyObject*>(tape);
@@ -1233,13 +1243,20 @@ std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
return list;
}
-void TFE_Py_TapeSetWatchVariable(PyObject* variable) {
+void TFE_Py_TapeVariableAccessed(PyObject* variable) {
if (*ThreadTapeIsStopped()) {
return;
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- tape->tape->WatchVariable(variable);
+ tape->tape->VariableAccessed(variable);
+ }
+}
+
+void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
+ if (*ThreadTapeIsStopped()) {
+ return;
}
+ reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
}
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
@@ -1348,7 +1365,9 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
class PyVSpace
: public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
tensorflow::Status Initialize() {
num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
@@ -1376,6 +1395,8 @@ class PyVSpace
Py_XDECREF(aggregate_fn_);
Py_XDECREF(zeros_);
Py_XDECREF(ones_);
+
+ Py_DECREF(py_vspace_);
}
tensorflow::int64 NumElements(PyObject* tensor) const final {
@@ -1491,6 +1512,22 @@ class PyVSpace
PyObject* zeros_;
PyObject* ones_;
};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
@@ -1507,9 +1544,9 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
return list;
}
-PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
- PyObject* target, PyObject* sources,
- PyObject* output_gradients, TF_Status* status) {
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
+ PyObject* sources, PyObject* output_gradients,
+ TF_Status* status) {
TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
if (!tape_obj->tape->IsPersistent()) {
auto* tape_set = GetTapeSet();
@@ -1524,10 +1561,6 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
return nullptr;
}
}
- PyVSpace c_vspace(vspace);
- if (!c_vspace.Initialize().ok()) {
- return nullptr;
- }
std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
if (PyErr_Occurred()) {
@@ -1551,7 +1584,7 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
}
std::vector<PyObject*> result;
status->status = tape_obj->tape->ComputeGradient(
- c_vspace, target_vec, sources_vec, outgrad_vec, &result);
+ *py_vspace, target_vec, sources_vec, outgrad_vec, &result);
if (!status->status.ok()) {
if (PyErr_Occurred()) {
// Do not propagate the erroneous status as that would swallow the
@@ -1893,14 +1926,14 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
Py_RETURN_NONE;
}
-void MaybeWatchVariable(PyObject* input) {
+void MaybeNotifyVariableAccessed(PyObject* input) {
DCHECK(CheckResourceVariable(input));
DCHECK(PyObject_HasAttrString(input, "_trainable"));
tensorflow::Safe_PyObjectPtr trainable(
PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
- TFE_Py_TapeSetWatchVariable(input);
+ TFE_Py_TapeVariableAccessed(input);
}
bool CastTensor(const FastPathOpExecInfo& op_exec_info,
@@ -1931,7 +1964,7 @@ bool CastTensor(const FastPathOpExecInfo& op_exec_info,
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
TF_Status* status) {
- MaybeWatchVariable(input);
+ MaybeNotifyVariableAccessed(input);
TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 6eb62afec4..399d90223c 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -33,9 +33,10 @@ class Tape(object):
return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape)
-def push_new_tape(persistent=False):
+def push_new_tape(persistent=False, watch_accessed_variables=True):
"""Pushes a new tape onto the tape stack."""
- tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent)
+ tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent,
+ watch_accessed_variables)
return Tape(tape)
@@ -49,13 +50,14 @@ def watch(tape, tensor):
pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access
-def watch_variable(variable):
- """Marks this variable to be watched by all tapes in the stack.
+def watch_variable(tape, variable):
+ """Marks this variable to be watched by the given tape."""
+ pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, variable) # pylint: disable=protected-access
- Args:
- variable: variable to be watched.
- """
- pywrap_tensorflow.TFE_Py_TapeSetWatchVariable(variable)
+
+def variable_accessed(variable):
+ """Notifies all tapes in the stack that a variable has been accessed."""
+ pywrap_tensorflow.TFE_Py_TapeVariableAccessed(variable)
def pop_tape(tape):
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 32742a9b96..344a9b25bd 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
def _create_tensor(value, device=None, dtype=None):
@@ -333,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
"but tensor at index 2 has rank 0"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testTensorDir(self):
+ t = array_ops.zeros(1)
+ t.test_attr = "Test"
+
+ instance_dir = dir(t)
+ type_dir = dir(ops.EagerTensor)
+
+ # Monkey patched attributes should show up in dir(t)
+ self.assertIn("test_attr", instance_dir)
+ instance_dir.remove("test_attr")
+ self.assertEqual(instance_dir, type_dir)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index cf8e18b216..00da335fef 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -687,6 +687,7 @@ py_test(
"manual", # b/112769036, b/113907597
"no_oss", # b/112769036, b/113907597
"no_windows",
+ "noasan", # b/114304340
"nomsan",
"notsan", # b/67510291
],
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index d104c961d3..19f18015e4 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -1000,8 +1000,11 @@ class BoostedTreesClassifier(estimator.Estimator):
bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+ # Need to see a large portion of the data before we can build a layer, for
+ # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
classifier = estimator.BoostedTreesClassifier(
feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_batches_per_layer=n_batches_per_layer,
n_trees=100,
... <some other params>
)
@@ -1024,7 +1027,8 @@ class BoostedTreesClassifier(estimator.Estimator):
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
n_batches_per_layer: the number of batches to collect statistics per
- layer.
+ layer. The total number of batches is total number of data divided by
+ batch size.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
@@ -1138,8 +1142,11 @@ class BoostedTreesRegressor(estimator.Estimator):
bucketized_feature_2 = bucketized_column(
numeric_column('feature_2'), BUCKET_BOUNDARIES_2)
+ # Need to see a large portion of the data before we can build a layer, for
+ # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE
regressor = estimator.BoostedTreesRegressor(
feature_columns=[bucketized_feature_1, bucketized_feature_2],
+ n_batches_per_layer=n_batches_per_layer,
n_trees=100,
... <some other params>
)
@@ -1162,7 +1169,8 @@ class BoostedTreesRegressor(estimator.Estimator):
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
n_batches_per_layer: the number of batches to collect statistics per
- layer.
+ layer. The total number of batches is total number of data divided by
+ batch size.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 290c4604ce..7e5a0c80a7 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -26,20 +26,23 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import keras as keras_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
-from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.optimizers import SGD
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import rmsprop
from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
try:
@@ -90,6 +93,15 @@ def simple_subclassed_model():
return SimpleModel()
+def gen_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=False):
+ def input_fn():
+ ds = dataset_ops.Dataset.from_tensor_slices((x, y) if y is not None else x)
+ if shuffle:
+ ds = ds.shuffle(1000)
+ return ds.repeat(num_epochs).batch(batch_size)
+ return input_fn
+
+
def get_resource_for_simple_model(model_type='sequential',
is_evaluate=False,):
if model_type == 'sequential':
@@ -117,19 +129,19 @@ def get_resource_for_simple_model(model_type='sequential',
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
- train_input_fn = numpy_io.numpy_input_fn(
+ train_input_fn = gen_input_fn(
x=randomize_io_type(x_train, input_name),
y=randomize_io_type(y_train, output_name),
shuffle=False,
num_epochs=None,
batch_size=16)
- evaluate_input_fn = numpy_io.numpy_input_fn(
+ evaluate_input_fn = gen_input_fn(
x=randomize_io_type(x_test, input_name),
y=randomize_io_type(y_test, output_name),
num_epochs=1, shuffle=False)
- predict_input_fn = numpy_io.numpy_input_fn(
+ predict_input_fn = gen_input_fn(
x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False)
inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn
@@ -203,7 +215,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
before_eval_results = est_keras.evaluate(
@@ -228,7 +240,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
metrics=['mse', keras.metrics.categorical_accuracy])
my_hook = MyHook()
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
before_eval_results = est_keras.evaluate(
@@ -252,7 +264,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
my_hook = MyHook()
- with self.test_session():
+ with self.cached_session():
keras_model.fit(x_train, y_train, epochs=1)
keras_est = keras_lib.model_to_estimator(
@@ -274,7 +286,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model,
config=self._config)
@@ -297,7 +309,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
@@ -316,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
# Create state
keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE),
np.random.random((10, _NUM_CLASS)))
@@ -343,7 +355,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(
model_type='functional', is_evaluate=True)
- with self.test_session():
+ with self.cached_session():
metrics = [
'binary_accuracy', 'binary_crossentropy', 'categorical_accuracy',
'categorical_crossentropy', 'cosine_proximity', 'hinge',
@@ -357,7 +369,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.fit(x_train, y_train, epochs=1)
keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32)
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_eval = keras_est.evaluate(input_fn=eval_input_fn)
@@ -385,7 +397,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, _), _, pred_input_fn = get_resource_for_simple_model(
model_type='sequential', is_evaluate=False)
- with self.test_session():
+ with self.cached_session():
keras_model.compile(
loss='categorical_crossentropy',
optimizer='adam',
@@ -393,7 +405,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.fit(x_train, y_train, epochs=1)
keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
est_pred = [
@@ -439,7 +451,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
output_dict = {'dense_2': c_test, 'dense_3': d_test}
return input_dict, output_dict
- with self.test_session():
+ with self.cached_session():
model = multi_inputs_multi_outputs_model()
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
@@ -456,7 +468,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
x_test, _), _, pred_input_fn = get_resource_for_simple_model(
model_type='functional', is_evaluate=False)
- with self.test_session():
+ with self.cached_session():
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -466,7 +478,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
fname = os.path.join(self._base_dir, 'keras_model.h5')
keras.models.save_model(keras_model, fname)
- with self.test_session():
+ with self.cached_session():
keras_est = keras_lib.model_to_estimator(
keras_model_path=fname, config=self._config)
est_pred = [
@@ -479,19 +491,19 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'Either'):
keras_lib.model_to_estimator()
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not both'):
keras_lib.model_to_estimator(
keras_model=keras_model,
keras_model_path=tempfile.mkdtemp(dir=self._base_dir))
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'compiled'):
keras_lib.model_to_estimator(keras_model=keras_model)
- with self.test_session():
+ with self.cached_session():
keras_model = simple_sequential_model()
with self.assertRaisesRegexp(ValueError, 'not a local path'):
keras_lib.model_to_estimator(
@@ -516,10 +528,10 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
model = simple_functional_model()
model.compile(
loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(KeyError,
'Difference: .*invalid_input_name'):
est_keras.train(input_fn=invald_input_name_input_fn, steps=100)
@@ -547,20 +559,20 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
y_train = keras.utils.to_categorical(y_train, 2)
input_name = keras_model.input_names[0]
output_name = keras_model.output_names[0]
- train_input_fn = numpy_io.numpy_input_fn(
+ train_input_fn = gen_input_fn(
x=randomize_io_type(x_train, input_name),
y=randomize_io_type(y_train, output_name),
shuffle=False,
num_epochs=None,
batch_size=16)
with self.assertRaisesRegexp(ValueError, 'relu6'):
- with self.test_session():
+ with self.cached_session():
est = keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
est.train(input_fn=train_input_fn, steps=1)
- with self.test_session():
+ with self.cached_session():
est = keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir),
@@ -586,7 +598,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
}
})
with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
- with self.test_session():
+ with self.cached_session():
keras_lib.model_to_estimator(
keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
@@ -602,7 +614,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
self._config._session_config = sess_config
- with self.test_session():
+ with self.cached_session():
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
self.assertEqual(
@@ -618,7 +630,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, model_dir=self._base_dir,
config=run_config_lib.RunConfig())
@@ -629,7 +641,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
self.assertEqual(self._base_dir, est_keras._config.model_dir)
self.assertEqual(self._base_dir, est_keras._model_dir)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, model_dir=self._base_dir,
config=None)
@@ -648,7 +660,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model,
@@ -663,7 +675,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
optimizer='rmsprop',
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
'constructor and `RunConfig`'):
keras_lib.model_to_estimator(
@@ -676,7 +688,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
metrics=['mse', keras.metrics.categorical_accuracy])
- with self.test_session():
+ with self.cached_session():
keras_model.train_on_batch(
np.random.random((10,) + _INPUT_SIZE),
np.random.random((10, _NUM_CLASS)))
@@ -690,6 +702,32 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
+ def assert_increasing_global_step(self, optimizer):
+ keras_model, _, _, train_input_fn, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=optimizer,
+ metrics=['mse', keras.metrics.categorical_accuracy])
+ with self.cached_session() as sess:
+ keras_model_fn = keras_lib._create_keras_model_fn(keras_model)
+ global_step = training_util.create_global_step()
+ features, labels = train_input_fn().make_one_shot_iterator().get_next()
+ spec = keras_model_fn(features, labels, mode=model_fn_lib.ModeKeys.TRAIN)
+
+ sess.run(variables.global_variables_initializer())
+ sess.run(variables.local_variables_initializer())
+
+ self.assertEqual(global_step.eval(), 0) # Sanity check
+ sess.run(spec.train_op)
+ self.assertEqual(global_step.eval(), 1)
+
+ def test_model_fn_increments_global_step_tf_optimizer(self):
+ self.assert_increasing_global_step(rmsprop.RMSPropOptimizer(1e-3))
+
+ def test_model_fn_increments_global_step_keras_optimizer(self):
+ self.assert_increasing_global_step('rmsprop')
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 1017d4ba47..ac53a84eef 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -12,6 +12,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":feature_column",
+ ":feature_column_v2",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index aa66ed77e9..28c5c82d2c 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -385,6 +385,10 @@ class FeatureLayer(Layer):
'You can wrap a categorical column with an '
'embedding_column or indicator_column. Given: {}'.format(column))
+ @property
+ def _is_feature_layer(self):
+ return True
+
def build(self, _):
for column in sorted(self._feature_columns, key=lambda x: x.name):
if isinstance(column, SharedEmbeddingColumn):
@@ -409,7 +413,13 @@ class FeatureLayer(Layer):
A `Tensor` which represents input layer of a model. Its shape
is (batch_size, first_layer_dimension) and its dtype is `float32`.
first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: If features are not a dictionary.
"""
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
transformation_cache = FeatureTransformationCache(features)
output_tensors = []
ordered_columns = []
@@ -431,6 +441,12 @@ class FeatureLayer(Layer):
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
+ def compute_output_shape(self, input_shape):
+ total_elements = 0
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ total_elements += column.variable_shape.num_elements()
+ return (input_shape[0], total_elements)
+
def linear_model(features,
feature_columns,
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 6b343ecf3e..58168e0f9e 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -2786,6 +2786,21 @@ class FeatureLayerTest(test.TestCase):
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+ def test_compute_output_shape(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2', shape=4)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [5., 6.]],
+ 'price2': [[3., 4., 5., 6.], [7., 8., 9., 10.]]
+ }
+ feature_layer = FeatureLayer([price1, price2])
+ self.assertEqual((None, 6), feature_layer.compute_output_shape((None,)))
+ net = feature_layer(features)
+ with _initialized_session():
+ self.assertAllClose(
+ [[1., 2., 3., 4., 5., 6.], [5., 6., 7., 8., 9., 10.]], net.eval())
+
def test_raises_if_shape_mismatch(self):
price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index eca34ac26e..4b2706d4cf 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -105,7 +105,8 @@ def convert_to_eager_tensor(value, ctx, dtype=None):
scalar_cache = ctx.scalar_cache()
tensor = scalar_cache.get(cache_key, None)
if tensor is not None:
- return tensor
+ return ops.EagerTensor(
+ value, context=handle, device=device, dtype=dtype, other_value=tensor)
t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype)
scalar_cache[cache_key] = t
return t
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 11b681d544..3c2a736fb9 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -606,8 +606,8 @@ class TensorShape(object):
slice.
Raises:
- ValueError: If `key` is a slice, and any of its elements are negative, or
- if `self` is completely unknown and the step is set.
+ ValueError: If `key` is a slice and `self` is completely unknown and
+ the step is set.
"""
if self._dims is not None:
if isinstance(key, slice):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 0925598e33..4bece9e25e 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f):
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
- collection_sizes_before = {
- collection: len(ops.get_collection(collection))
- for collection in ops.get_default_graph().collections
- }
+ if ops.has_default_graph():
+ collection_sizes_before = {
+ collection: len(ops.get_collection(collection))
+ for collection in ops.get_default_graph().collections
+ }
for _ in range(3):
f(self, **kwargs)
# Note that gc.get_objects misses anything that isn't subject to garbage
# collection (C types). Collections are a common source of leaks, so we
# test for collection sizes explicitly.
- for collection_key in ops.get_default_graph().collections:
- collection = ops.get_collection(collection_key)
- size_before = collection_sizes_before.get(collection_key, 0)
- if len(collection) > size_before:
- raise AssertionError(
- ("Collection %s increased in size from "
- "%d to %d (current items %s).") % (collection_key, size_before,
- len(collection), collection))
- # Make sure our collection checks don't show up as leaked memory by
- # removing references to temporary variables.
- del collection
- del collection_key
- del size_before
- del collection_sizes_before
+ if ops.has_default_graph():
+ for collection_key in ops.get_default_graph().collections:
+ collection = ops.get_collection(collection_key)
+ size_before = collection_sizes_before.get(collection_key, 0)
+ if len(collection) > size_before:
+ raise AssertionError(
+ ("Collection %s increased in size from "
+ "%d to %d (current items %s).") %
+ (collection_key, size_before, len(collection), collection))
+ # Make sure our collection checks don't show up as leaked memory by
+ # removing references to temporary variables.
+ del collection
+ del collection_key
+ del size_before
+ del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 7246341519..290e182a79 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -700,6 +700,20 @@ py_test(
)
py_test(
+ name = "feature_columns_integration_test",
+ size = "small",
+ srcs = ["engine/feature_columns_integration_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/feature_column:feature_column_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "training_eager_test",
size = "medium",
srcs = ["engine/training_eager_test.py"],
diff --git a/tensorflow/python/keras/engine/feature_columns_integration_test.py b/tensorflow/python/keras/engine/feature_columns_integration_test.py
new file mode 100644
index 0000000000..e0478ee357
--- /dev/null
+++ b/tensorflow/python/keras/engine/feature_columns_integration_test.py
@@ -0,0 +1,237 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests specific to Feature Columns integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.framework import test_util as tf_test_util
+from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
+
+
+class TestDNNModel(keras.models.Model):
+
+ def __init__(self, feature_columns, units, name=None, **kwargs):
+ super(TestDNNModel, self).__init__(name=name, **kwargs)
+ self._input_layer = fc.FeatureLayer(feature_columns, name='input_layer')
+ self._dense_layer = keras.layers.Dense(units, name='dense_layer')
+
+ def call(self, features):
+ net = self._input_layer(features)
+ net = self._dense_layer(net)
+ return net
+
+
+class FeatureColumnsIntegrationTest(test.TestCase):
+ """Most Sequential model API tests are covered in `training_test.py`.
+
+ """
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_sequential_model(self):
+ columns = [fc.numeric_column('a')]
+ model = keras.models.Sequential([
+ fc.FeatureLayer(columns),
+ keras.layers.Dense(64, activation='relu'),
+ keras.layers.Dense(20, activation='softmax')
+ ])
+ model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ x = {'a': np.random.random((10, 1))}
+ y = np.random.randint(20, size=(10, 1))
+ y = keras.utils.to_categorical(y, num_classes=20)
+ model.fit(x, y, epochs=1, batch_size=5)
+ model.fit(x, y, epochs=1, batch_size=5)
+ model.evaluate(x, y, batch_size=5)
+ model.predict(x, batch_size=5)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_sequential_model_with_ds_input(self):
+ columns = [fc.numeric_column('a')]
+ model = keras.models.Sequential([
+ fc.FeatureLayer(columns),
+ keras.layers.Dense(64, activation='relu'),
+ keras.layers.Dense(20, activation='softmax')
+ ])
+ model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ y = np.random.randint(20, size=(100, 1))
+ y = keras.utils.to_categorical(y, num_classes=20)
+ x = {'a': np.random.random((100, 1))}
+ ds1 = dataset_ops.Dataset.from_tensor_slices(x)
+ ds2 = dataset_ops.Dataset.from_tensor_slices(y)
+ ds = dataset_ops.Dataset.zip((ds1, ds2)).batch(5)
+ model.fit(ds, steps_per_epoch=1)
+ model.fit(ds, steps_per_epoch=1)
+ model.evaluate(ds, steps=1)
+ model.predict(ds, steps=1)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_subclassed_model_with_feature_columns(self):
+ col_a = fc.numeric_column('a')
+ col_b = fc.numeric_column('b')
+
+ dnn_model = TestDNNModel([col_a, col_b], 20)
+
+ dnn_model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ x = {'a': np.random.random((10, 1)), 'b': np.random.random((10, 1))}
+ y = np.random.randint(20, size=(10, 1))
+ y = keras.utils.to_categorical(y, num_classes=20)
+ dnn_model.fit(x=x, y=y, epochs=1, batch_size=5)
+ dnn_model.fit(x=x, y=y, epochs=1, batch_size=5)
+ dnn_model.evaluate(x=x, y=y, batch_size=5)
+ dnn_model.predict(x=x, batch_size=5)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_subclassed_model_with_feature_columns_with_ds_input(self):
+ col_a = fc.numeric_column('a')
+ col_b = fc.numeric_column('b')
+
+ dnn_model = TestDNNModel([col_a, col_b], 20)
+
+ dnn_model.compile(
+ optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.001),
+ loss='categorical_crossentropy',
+ metrics=['accuracy'])
+
+ y = np.random.randint(20, size=(100, 1))
+ y = keras.utils.to_categorical(y, num_classes=20)
+ x = {'a': np.random.random((100, 1)), 'b': np.random.random((100, 1))}
+ ds1 = dataset_ops.Dataset.from_tensor_slices(x)
+ ds2 = dataset_ops.Dataset.from_tensor_slices(y)
+ ds = dataset_ops.Dataset.zip((ds1, ds2)).batch(5)
+ dnn_model.fit(ds, steps_per_epoch=1)
+ dnn_model.fit(ds, steps_per_epoch=1)
+ dnn_model.evaluate(ds, steps=1)
+ dnn_model.predict(ds, steps=1)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def DISABLED_test_function_model_feature_layer_input(self):
+ col_a = fc.numeric_column('a')
+ col_b = fc.numeric_column('b')
+
+ feature_layer = fc.FeatureLayer([col_a, col_b], name='fc')
+ dense = keras.layers.Dense(4)
+
+ # This seems problematic.... We probably need something for FeatureLayer
+ # the way Input is for InputLayer.
+ output = dense(feature_layer)
+
+ model = keras.models.Model([feature_layer], [output])
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ model.compile(
+ optimizer,
+ loss,
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+ loss_weights=loss_weights)
+
+ data = ({'a': np.arange(10), 'b': np.arange(10)}, np.arange(10, 20))
+ print(model.fit(*data, epochs=1))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def DISABLED_test_function_model_multiple_feature_layer_inputs(self):
+ col_a = fc.numeric_column('a')
+ col_b = fc.numeric_column('b')
+ col_c = fc.numeric_column('c')
+
+ fc1 = fc.FeatureLayer([col_a, col_b], name='fc1')
+ fc2 = fc.FeatureLayer([col_b, col_c], name='fc2')
+ dense = keras.layers.Dense(4)
+
+ # This seems problematic.... We probably need something for FeatureLayer
+ # the way Input is for InputLayer.
+ output = dense(fc1) + dense(fc2)
+
+ model = keras.models.Model([fc1, fc2], [output])
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ loss_weights = [1., 0.5]
+ model.compile(
+ optimizer,
+ loss,
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+ loss_weights=loss_weights)
+
+ data_list = ([{
+ 'a': np.arange(10),
+ 'b': np.arange(10)
+ }, {
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }], np.arange(10, 100))
+ print(model.fit(*data_list, epochs=1))
+
+ data_bloated_list = ([{
+ 'a': np.arange(10),
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }, {
+ 'a': np.arange(10),
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }], np.arange(10, 100))
+ print(model.fit(*data_bloated_list, epochs=1))
+
+ data_dict = ({
+ 'fc1': {
+ 'a': np.arange(10),
+ 'b': np.arange(10)
+ },
+ 'fc2': {
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }
+ }, np.arange(10, 100))
+ print(model.fit(*data_dict, epochs=1))
+
+ data_bloated_dict = ({
+ 'fc1': {
+ 'a': np.arange(10),
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ },
+ 'fc2': {
+ 'a': np.arange(10),
+ 'b': np.arange(10),
+ 'c': np.arange(10)
+ }
+ }, np.arange(10, 100))
+ print(model.fit(*data_bloated_dict, epochs=1))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 966b446f22..d224dfffdd 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -862,7 +863,8 @@ class Model(Network):
Fraction of the training data to be used as validation data.
Returns:
- A tuple of 3 lists: input arrays, target arrays, sample-weight arrays.
+ A tuple of 3: inputs (arrays or dicts, depending on whether `x` was a dict
+ or not), target arrays, sample-weight arrays.
If the model's input and targets are symbolic, these lists are empty
(since the model takes no user-provided data, instead the data comes
from the symbolic inputs/targets).
@@ -928,11 +930,16 @@ class Model(Network):
'Make sure that your dataset can generate '
'required number of samples.')
- if not isinstance(next_element, (list, tuple)) or len(next_element) != 2:
- raise ValueError('Please provide model inputs as a list or tuple of 2 '
- 'elements: input and target pair. '
- 'Received %s' % next_element)
- x, y = next_element
+ if (not isinstance(next_element, (list, tuple)) or
+ len(next_element) not in [2, 3]):
+ raise ValueError(
+ 'Please provide model inputs as a list or tuple of 2 or 3'
+ 'elements: (input, target) or (input, target, sample_weights)'
+ 'Received %s' % next_element)
+ if len(next_element) == 2:
+ x, y = next_element
+ else:
+ x, y, sample_weight = next_element
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
class_weight, batch_size)
return x, y, sample_weights
@@ -948,6 +955,7 @@ class Model(Network):
all_inputs = []
is_build_called = False
is_compile_called = False
+ dict_inputs = False
if not self.inputs:
# We need to use `x` to set the model inputs.
# We type-check that `x` and `y` are either single arrays
@@ -959,7 +967,9 @@ class Model(Network):
'array or a list of arrays. You passed: x=' + str(x))
all_inputs += list(x)
elif isinstance(x, dict):
- raise ValueError('Please do not pass a dictionary as model inputs.')
+ dict_inputs = True
+ keys = sorted(x.keys())
+ all_inputs = [x[k] for k in keys]
else:
if not isinstance(x, np.ndarray) and not tensor_util.is_tensor(x):
raise ValueError('Please provide as model inputs either a single '
@@ -972,6 +982,8 @@ class Model(Network):
if not self.inputs:
is_build_called = True
self._set_inputs(x)
+ else:
+ dict_inputs = isinstance(self.inputs, dict)
if y is not None:
if not self.optimizer:
@@ -1124,6 +1136,10 @@ class Model(Network):
'a number of samples that can be '
'divided by the batch size. Found: ' +
str(x[0].shape[0]) + ' samples')
+
+ # If dictionary inputs were provided, we return a dictionary as well.
+ if dict_inputs:
+ x = dict(zip(feed_input_names, x))
return x, y, sample_weights
@checkpointable.no_automatic_dependency_tracking
@@ -1146,6 +1162,9 @@ class Model(Network):
training: Boolean or None. Only relevant in symbolic mode. Specifies
whether to build the model's graph in inference mode (False), training
mode (True), or using the Keras learning phase (None).
+ Raises:
+ ValueError: If dict inputs are passed to a Sequential Model where the
+ first layer isn't FeatureLayer.
"""
call_convention = getattr(
self,
@@ -1162,6 +1181,14 @@ class Model(Network):
if tensor_util.is_tensor(inputs):
input_shape = (None,) + tuple(inputs.get_shape().as_list()[1:])
self.build(input_shape=input_shape)
+ elif isinstance(inputs, dict):
+ # We assert that the first layer is a FeatureLayer.
+ if not training_utils.is_feature_layer(self.layers[0]):
+ raise ValueError('Passing a dictionary input to a Sequential Model '
+ 'which doesnt have FeatureLayer as the first layer '
+ 'is an error')
+ input_shape = (None,)
+ self.build(input_shape=input_shape)
else:
input_shape = (None,) + inputs.shape[1:]
self.build(input_shape=input_shape)
@@ -1189,36 +1216,22 @@ class Model(Network):
assert context.executing_eagerly()
if self.inputs:
raise ValueError('Model inputs are already set.')
+
# On-the-fly setting of model inputs/outputs as DeferredTensors,
# to keep track of number of inputs and outputs and their ndim.
- if isinstance(inputs, (list, tuple)):
- if tensor_util.is_tensor(inputs[0]):
- dummy_output_values = self.call(
- training_utils.cast_if_floating_dtype(inputs))
- else:
- dummy_output_values = self.call(
- [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs])
- dummy_input_values = list(inputs)
- else:
- if tensor_util.is_tensor(inputs):
- dummy_output_values = self.call(
- training_utils.cast_if_floating_dtype(inputs))
- else:
- dummy_output_values = self.call(
- ops.convert_to_tensor(inputs, dtype=K.floatx()))
- dummy_input_values = [inputs]
- if isinstance(dummy_output_values, (list, tuple)):
- dummy_output_values = list(dummy_output_values)
- else:
- dummy_output_values = [dummy_output_values]
+ model_inputs = training_utils.ModelInputs(inputs)
+ dummy_input_values = model_inputs.get_input_values()
+ dummy_output_values = self.call(dummy_input_values)
+
+ self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.input_names = model_inputs.get_input_names()
+
+ dummy_output_values = nest.flatten(dummy_output_values)
self.outputs = [
- base_layer.DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_output_values]
- self.inputs = [
- base_layer.DeferredTensor(shape=(None for _ in v.shape),
- dtype=v.dtype) for v in dummy_input_values]
- self.input_names = [
- 'input_%d' % (i + 1) for i in range(len(dummy_input_values))]
+ base_layer.DeferredTensor(shape=(None
+ for _ in v.shape), dtype=v.dtype)
+ for v in dummy_output_values
+ ]
self.output_names = [
'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
self.built = True
@@ -1248,58 +1261,29 @@ class Model(Network):
# On-the-fly setting of symbolic model inputs (either by using the tensor
# provided, or by creating a placeholder if Numpy data was provided).
- self.inputs = []
- self.input_names = []
+ model_inputs = training_utils.ModelInputs(inputs)
+ dummy_input_values = model_inputs.get_symbolic_inputs()
+ self.inputs = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.input_names = model_inputs.get_input_names()
+
self._feed_inputs = []
self._feed_input_names = []
self._feed_input_shapes = []
- if isinstance(inputs, (list, tuple)):
- inputs = list(inputs)
- else:
- inputs = [inputs]
-
- for i, v in enumerate(inputs):
- name = 'input_%d' % (i + 1)
- self.input_names.append(name)
- if isinstance(v, list):
- v = np.asarray(v)
- if v.ndim == 1:
- v = np.expand_dims(v, 1)
- if isinstance(v, (np.ndarray)):
- # We fix the placeholder shape except the batch size.
- # This is suboptimal, but it is the best we can do with the info
- # we have. The user should call `model._set_inputs(placeholders)`
- # to specify custom placeholders if the need arises.
- shape = (None,) + v.shape[1:]
- placeholder = K.placeholder(shape=shape, name=name)
- self.inputs.append(placeholder)
- self._feed_inputs.append(placeholder)
- self._feed_input_names.append(name)
- self._feed_input_shapes.append(shape)
- else:
- # Assumed tensor - TODO(fchollet) additional type check?
- self.inputs.append(v)
- if K.is_placeholder(v):
- self._feed_inputs.append(v)
- self._feed_input_names.append(name)
- self._feed_input_shapes.append(K.int_shape(v))
+
+ for k, v in model_inputs.as_dict():
+ if K.is_placeholder(v):
+ self._feed_inputs.append(v)
+ self._feed_input_names.append(k)
+ self._feed_input_shapes.append(K.int_shape(v))
if outputs is None:
# Obtain symbolic outputs by calling the model.
- if len(self.inputs) == 1:
- if self._expects_training_arg:
- outputs = self.call(self.inputs[0], training=training)
- else:
- outputs = self.call(self.inputs[0])
+ if self._expects_training_arg:
+ outputs = self.call(dummy_input_values, training=training)
else:
- if self._expects_training_arg:
- outputs = self.call(self.inputs, training=training)
- else:
- outputs = self.call(self.inputs)
- if isinstance(outputs, (list, tuple)):
- outputs = list(outputs)
- else:
- outputs = [outputs]
+ outputs = self.call(dummy_input_values)
+
+ outputs = nest.flatten(outputs)
self.outputs = outputs
self.output_names = [
'output_%d' % (i + 1) for i in range(len(self.outputs))]
@@ -1331,7 +1315,8 @@ class Model(Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- - A `tf.data` dataset or a dataset iterator.
+ - A `tf.data` dataset or a dataset iterator. Should return a tuple
+ of either (inputs, targets) or (inputs, targets, sample_weights).
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
@@ -1396,7 +1381,8 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator, instead
+ provide the sample_weights as the third element of `x`.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index e2c458c65f..95b864bef0 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -55,7 +55,7 @@ def fit_loop(model,
Arguments:
model: Keras Model instance.
- inputs: List of input arrays.
+ inputs: Either a list of arrays or a dictionary.
targets: List of target arrays.
sample_weights: Optional list of sample weight arrays.
batch_size: Integer batch size or None if unknown.
@@ -88,6 +88,7 @@ def fit_loop(model,
sample_weights = sample_weights or []
val_sample_weights = val_sample_weights or []
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + targets + sample_weights + [1]
else:
@@ -262,6 +263,7 @@ def predict_loop(model, inputs, batch_size=32, verbose=0, steps=None):
model._make_predict_function()
f = model.predict_function
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + [0]
else:
@@ -368,6 +370,7 @@ def test_loop(model,
f = model.test_function
sample_weights = sample_weights or []
+ inputs = training_utils.ModelInputs(inputs).as_list()
if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
ins = inputs + targets + sample_weights + [0]
else:
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index e440e02bfb..939732cd67 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -70,7 +70,8 @@ def fit_loop(
# TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged.
if current_strategy.__class__.__name__ == 'TPUStrategy':
return _experimental_fit_loop(
- model, iterator, epochs, initial_epoch, steps_per_epoch)
+ model, iterator, epochs, verbose, callbacks, initial_epoch,
+ steps_per_epoch)
clone_model_on_towers(
model, current_strategy, make_callback_model=True)
@@ -201,6 +202,8 @@ def _experimental_fit_loop(
model,
iterator,
epochs=100,
+ verbose=1,
+ callbacks=None,
initial_epoch=0,
steps_per_epoch=None):
"""fit function when using TPU DistributionStrategy for training.
@@ -209,6 +212,8 @@ def _experimental_fit_loop(
model: Keras Model instance.
iterator: Iterator that returns inputs and targets
epochs: Number of times to iterate over the data
+ verbose: Verbosity mode, 0, 1 or 2
+ callbacks: List of callbacks to be called during training
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
steps_per_epoch: Total number of steps (batches of samples)
@@ -225,7 +230,6 @@ def _experimental_fit_loop(
# TODO(priyag): Add validation that shapes are fully defined for TPU case.
- # TODO(priyag, sourabhbajaj): This should be moved into a callback instead.
K.get_session().run(current_strategy.initialize())
def _per_device_train_function(model):
@@ -298,19 +302,35 @@ def _experimental_fit_loop(
assert steps_per_epoch is not None
- # TODO(priyag, sourabhbajaj): Add callbacks support.
+ # TODO(sourabhbajaj): Convert this into a proper validation function
+ if callbacks:
+ raise NotImplementedError(
+ 'Callbacks are not supported with TPUStrategy right now.')
+
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=False,
+ val_inputs=None,
+ val_targets=None,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
+ # TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
+ # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
# TODO(priyag, sourabhbajaj): Add validation.
+ callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
- for step_index in range(
- 0, steps_per_epoch, current_strategy.steps_per_run):
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
+ # TODO(sourabhbajaj): Add the size parameter in batch_logs once callbacks
+ # are fixed as we need to replace size with a combination of steps_per_run
+ # and batch_size
+ batch_logs = {'batch': step_index}
+ callbacks.on_batch_begin(step_index, batch_logs)
try:
- _, outs = K.get_session().run([train_op, output_tensors])
- # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper
- # summaries through callbacks.
- print('Epoch: {}, step_index: {}, loss: {}'.format(
- epoch, step_index, outs['loss']))
- for label, out in outs.items():
- print(label, ': ', out)
+ _, outputs = K.get_session().run([train_op, output_tensors])
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your dataset '
@@ -319,6 +339,16 @@ def _experimental_fit_loop(
steps_per_epoch * epochs)
break
+ batch_logs.update(outputs)
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callbacks.model.stop_training:
+ break
+ callbacks.on_train_end()
+
# Copy the weights back from the replicated model to the original model.
with current_strategy.scope():
updated_weights = current_strategy.unwrap(
@@ -326,8 +356,7 @@ def _experimental_fit_loop(
model.set_weights(updated_weights)
K.get_session().run(current_strategy.finalize())
-
- # TODO(priyag, sourabhbajaj): Return history.
+ return model.history
def test_loop(model, iterator, verbose=0, steps=None):
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 1e377149b6..939a7f2356 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -67,7 +67,8 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
Arguments:
model: The model on which metrics are being calculated.
- inputs: List of input arrays.
+ inputs: Either a dictionary of inputs to the model or a list of input
+ arrays.
targets: List of target arrays.
sample_weights: Optional list of sample weight arrays.
training: Whether the model should be run in inference or training mode.
@@ -82,7 +83,7 @@ def _model_loss(model, inputs, targets, sample_weights=None, training=False):
kwargs = {}
if model._expects_training_arg:
kwargs['training'] = training
- if len(inputs) == 1:
+ if len(inputs) == 1 and not isinstance(inputs, dict):
inputs = inputs[0]
if model._compute_output_and_mask_jointly:
@@ -369,6 +370,8 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
# Get current step size.
if isinstance(x, list):
step_size = x[0].get_shape().as_list()[0]
+ elif isinstance(x, dict):
+ step_size = list(x.values())[0].get_shape().as_list()[0]
else:
step_size = x.get_shape().as_list()[0]
@@ -417,11 +420,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
"""
assert isinstance(inputs, iterator_ops.EagerIterator)
if not isinstance(inputs.output_shapes,
- (list, tuple)) or len(inputs.output_shapes) > 2:
+ (list, tuple)) or len(inputs.output_shapes) > 3:
raise ValueError(
- 'Please provide data as a list or tuple of 1 or 2 elements '
- ' - input or input and target pair. Received %s. We do not use the '
- '`target` value here.' % inputs.output_shapes)
+ 'Please provide data as a list or tuple of 1, 2, or 3 elements '
+ ' - `(input)`, or `(input, target)`, or `(input, target,'
+ 'sample_weights)`. Received %s. We do not use the `target` or'
+ '`sample_weights` value here.' % inputs.output_shapes)
outs = []
if verbose == 1:
progbar = generic_utils.Progbar(target=steps)
@@ -444,10 +448,13 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
x, _, _ = model._standardize_user_data(x)
x = training_utils.cast_if_floating_dtype(x)
+ if isinstance(x, list) and len(x) == 1:
+ x = x[0]
+
if model._expects_training_arg:
- batch_outs = model.call(x[0] if len(x) == 1 else x, training=False)
+ batch_outs = model.call(x, training=False)
else:
- batch_outs = model.call(x[0] if len(x) == 1 else x)
+ batch_outs = model.call(x)
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index bf5c7fd7f8..1d0d113e40 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -481,8 +481,8 @@ class LossWeightingTest(test.TestCase):
num_hidden=10, num_classes=num_classes, input_dim=input_dim)
model.compile(
loss='categorical_crossentropy',
- metrics=['acc'],
- weighted_metrics=['mae'],
+ metrics=['acc', metrics_module.CategoricalAccuracy()],
+ weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
optimizer=RMSPropOptimizer(learning_rate=learning_rate))
np.random.seed(1337)
@@ -536,6 +536,25 @@ class LossWeightingTest(test.TestCase):
self.assertLess(score[0], ref_score[0])
@tf_test_util.run_in_graph_and_eager_modes
+ def test_sequential_model_fails_with_dict_inputs(self):
+ num_classes = 5
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=num_classes)
+ model.compile(
+ RMSPropOptimizer(learning_rate=0.001),
+ metrics=['acc'],
+ weighted_metrics=['mae'],
+ loss='categorical_crossentropy')
+
+ x = {'dense_input': np.random.random((10, 1))}
+ y = np.random.randint(num_classes, size=(10, 1))
+
+ with self.assertRaisesRegexp(
+ ValueError, 'Passing a dictionary input to a Sequential Model which '
+ 'doesnt have FeatureLayer as the first layer is an error'):
+ model.fit(x, y, batch_size=5, epochs=1)
+
+ @tf_test_util.run_in_graph_and_eager_modes
def test_sample_weights(self):
num_classes = 5
batch_size = 5
@@ -550,8 +569,8 @@ class LossWeightingTest(test.TestCase):
num_hidden=10, num_classes=num_classes, input_dim=input_dim)
model.compile(
RMSPropOptimizer(learning_rate=learning_rate),
- metrics=['acc'],
- weighted_metrics=['mae'],
+ metrics=['acc', metrics_module.CategoricalAccuracy()],
+ weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
loss='categorical_crossentropy')
np.random.seed(43)
@@ -679,8 +698,8 @@ class LossWeightingTest(test.TestCase):
model.compile(
RMSPropOptimizer(learning_rate=learning_rate),
loss='binary_crossentropy',
- metrics=['acc'],
- weighted_metrics=['mae'],
+ metrics=['acc', metrics_module.CategoricalAccuracy()],
+ weighted_metrics=['mae', metrics_module.CategoricalAccuracy()],
sample_weight_mode='temporal')
model.fit(
@@ -2097,6 +2116,43 @@ class TestTrainingWithDataset(test.TestCase):
'you should specify the `steps` argument'):
model.predict(dataset, verbose=0)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_dataset_with_sample_weights(self):
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ metrics = ['mae', metrics_module.CategoricalAccuracy()]
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((10, 3), np.float32)
+ targets = np.zeros((10, 4), np.float32)
+ sample_weights = np.ones((10), np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets,
+ sample_weights))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+ model.train_on_batch(dataset)
+ model.predict_on_batch(dataset)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_dataset_with_sparse_labels(self):
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ loss = 'sparse_categorical_crossentropy'
+ model.compile(optimizer, loss)
+
+ inputs = np.zeros((10, 3))
+ targets = np.random.randint(0, 4, size=10, dtype=np.int32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+
def test_dataset_input_shape_validation(self):
with self.test_session():
model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
@@ -2108,8 +2164,10 @@ class TestTrainingWithDataset(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- with self.assertRaisesRegexp(ValueError,
- r'expected (.*?) to have 2 dimensions'):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)'
+ ):
model.train_on_batch(dataset)
# Wrong input shape
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index f94697c913..898e9223cb 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -22,18 +22,22 @@ import copy
import math
import numpy as np
+import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.util import nest
def _map_nested(data, func):
@@ -210,10 +214,11 @@ def check_num_samples(ins,
def standardize_single_array(x):
if x is None:
return None
- elif tensor_util.is_tensor(x):
- return x
- elif x.ndim == 1:
- x = np.expand_dims(x, 1)
+ if x.shape is not None and len(x.shape) == 1:
+ if tensor_util.is_tensor(x):
+ return array_ops.expand_dims(x, axis=1)
+ else:
+ return np.expand_dims(x, 1)
return x
@@ -245,7 +250,8 @@ def standardize_input_data(data,
ValueError: in case of improperly formatted user-provided data.
"""
if not names:
- if data is not None and hasattr(data, '__len__') and len(data):
+ if (data is not None and hasattr(data, '__len__') and len(data) and
+ not isinstance(data, dict)):
raise ValueError('Error when checking model ' + exception_prefix + ': '
'expected no data, but got:', data)
return []
@@ -341,7 +347,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type):
Raises:
ValueError: In case of invalid user-provided argument.
"""
- if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test
+ if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test
return [None for _ in output_names]
if len(output_names) == 1:
if isinstance(x_weight, list) and len(x_weight) == 1:
@@ -675,7 +681,8 @@ def standardize_weights(y,
'Expected sample_weight with rank '
'less than or equal to ' + str(len(y.shape)))
- if y.shape[:sample_weight.ndim] != sample_weight.shape:
+ if (not tensor_util.is_tensor(sample_weight) and
+ y.shape[:sample_weight.ndim] != sample_weight.shape):
raise ValueError(
'Found a sample_weight array with shape ' + str(sample_weight.shape) +
' for an input with shape ' + str(y.shape) + '. '
@@ -717,6 +724,8 @@ def has_symbolic_tensors(ls):
def has_tensors(ls):
if isinstance(ls, (list, tuple)):
return any(tensor_util.is_tensor(v) for v in ls)
+ if isinstance(ls, dict):
+ return any(tensor_util.is_tensor(v) for _, v in six.iteritems(ls))
return tensor_util.is_tensor(ls)
@@ -777,7 +786,9 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
'Received: %s' % (x, y))
if sample_weight is not None:
raise ValueError('`sample_weight` argument is not supported when input '
- '`x` is a dataset or a dataset iterator. '
+ '`x` is a dataset or a dataset iterator. Instead, you'
+ 'can provide sample_weight as the third element of your'
+ 'dataset, i.e. (inputs, targets, sample_weight). '
'Received: x=%s, sample_weight=%s' % (x, sample_weight))
if validation_split is not None and validation_split != 0.0:
raise ValueError(
@@ -825,6 +836,12 @@ def check_steps_argument(input_data, steps, steps_name):
return False
+def cast_single_tensor(x):
+ if tensor_util.is_tensor(x) and x.dtype.is_floating:
+ return math_ops.cast(x, dtype=K.floatx())
+ return x
+
+
def cast_if_floating_dtype(x):
"""Casts the given data tensors to the default floating point type.
@@ -842,13 +859,7 @@ def cast_if_floating_dtype(x):
raise RuntimeError(
'Please provide tensors for casting, got: {x}'.format(x=x))
- if isinstance(x, (list, tuple)):
- return [
- math_ops.cast(val, dtype=K.floatx())
- if tensor_util.is_tensor(val) and val.dtype.is_floating else val
- for val in x
- ]
- return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x
+ return nest.map_structure(cast_single_tensor, x)
def get_output_sample_weight_and_mode(skip_target_weighing_indices,
@@ -929,3 +940,103 @@ def prepare_sample_weights(output_names, sample_weight_mode,
sample_weights.append(weight)
sample_weight_modes.append(mode)
return sample_weights, sample_weight_modes
+
+
+# TODO(rohanj): This is a hack to get around not depending on feature_column and
+# create a cyclical dependency. Figure out a cleaner solution
+def is_feature_layer(layer):
+ """Returns whether `layer` is a FeatureLayer or not."""
+ return getattr(layer, '_is_feature_layer', False)
+
+
+class ModelInputs(object):
+ """Encapsulates model inputs.
+
+ Allows for transforming model inputs while keeping the same structure.
+ """
+
+ def __init__(self, inputs):
+ self._inputs = inputs
+ self._is_dict = isinstance(self._inputs, dict)
+ self._is_single_input = not isinstance(self._inputs, (list, tuple, dict))
+ self._flattened_inputs = []
+ self._input_names = []
+ if isinstance(self._inputs, dict):
+ for k in sorted(self._inputs.keys()):
+ self._flattened_inputs.append(self._inputs[k])
+ self._input_names.append(k)
+ else:
+ self._flattened_inputs = nest.flatten(self._inputs)
+ self._input_names = [
+ 'input_%d' % (i + 1) for i in range(len(self._flattened_inputs))
+ ]
+ assert len(self._input_names) == len(self._flattened_inputs)
+
+ def get_input_names(self):
+ """Returns keys to name inputs by.
+
+ In case inputs provided were a list, tuple or single entry, we make up a
+ key 'input_%d'. For dictionary case, we return a sorted list of keys.
+ """
+ return self._input_names
+
+ def _get(self, return_single_as_list=False):
+ """Returns provided inputs, potentially transformed.
+
+ Inputs are returned in the same format they were provided i.e. lists
+ are returned as lists, single entries as single entries (unless
+ `return_single_as_list` is true), dictionaries as dictionaries.
+
+ Args:
+ return_single_as_list: Returns a list of size 1 for single entry case.
+ """
+ if self._is_dict:
+ return dict(zip(self._input_names, self._flattened_inputs))
+ if self._is_single_input and not return_single_as_list:
+ return self._flattened_inputs[0]
+ return self._flattened_inputs
+
+ def get_input_values(self):
+ """Returns input values passed in."""
+ if context.executing_eagerly():
+ for i in range(len(self._flattened_inputs)):
+ v = self._flattened_inputs[i]
+ if tensor_util.is_tensor(v):
+ v = cast_single_tensor(v)
+ else:
+ v = ops.convert_to_tensor(v, dtype=K.floatx())
+ self._flattened_inputs[i] = v
+ return self._get(return_single_as_list=False)
+
+ def get_symbolic_inputs(self, return_single_as_list=False):
+ """Returns inputs to be set as self.inputs for a model."""
+ for i in range(len(self._flattened_inputs)):
+ k = self._input_names[i]
+ v = self._flattened_inputs[i]
+ if context.executing_eagerly():
+ v = base_layer.DeferredTensor(
+ shape=(None for _ in v.shape), dtype=v.dtype)
+ else:
+ if isinstance(v, list):
+ v = np.asarray(v)
+ if v.ndim == 1:
+ v = np.expand_dims(v, 1)
+ if isinstance(v, (np.ndarray)):
+ # We fix the placeholder shape except the batch size.
+ # This is suboptimal, but it is the best we can do with the info
+ # we have. The user should call `model._set_inputs(placeholders)`
+ # to specify custom placeholders if the need arises.
+ shape = (None,) + v.shape[1:]
+ v = K.placeholder(shape=shape, name=k)
+ self._flattened_inputs[i] = v
+
+ return self._get(return_single_as_list)
+
+ def as_dict(self):
+ """An iterable over a dictionary version of inputs."""
+ for i in range(len(self._flattened_inputs)):
+ yield self._input_names[i], self._flattened_inputs[i]
+
+ def as_list(self):
+ """Returning the inputs as a list."""
+ return self._flattened_inputs
diff --git a/tensorflow/python/keras/engine/training_utils_test.py b/tensorflow/python/keras/engine/training_utils_test.py
index 297a1ae494..e777cb6db3 100644
--- a/tensorflow/python/keras/engine/training_utils_test.py
+++ b/tensorflow/python/keras/engine/training_utils_test.py
@@ -20,8 +20,11 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.platform import test
@@ -146,5 +149,91 @@ class TrainingUtilTest(test.TestCase):
self.assertEquals(any_true, False)
+class ModelInputsTest(test.TestCase):
+
+ def test_single_thing(self):
+ a = np.ones(10)
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['input_1'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertAllEqual(np.ones(10), vals)
+ self.assertFalse(tensor_util.is_tensor(vals))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(tensor_util.is_tensor(vals))
+ vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.assertEquals(1, len(vals))
+ self.assertTrue(tensor_util.is_tensor(vals[0]))
+
+ def test_single_thing_eager(self):
+ with context.eager_mode():
+ a = np.ones(10)
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['input_1'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertAllEqual(np.ones(10), vals)
+ self.assertTrue(tensor_util.is_tensor(vals))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(isinstance(vals, base_layer.DeferredTensor))
+ vals = model_inputs.get_symbolic_inputs(return_single_as_list=True)
+ self.assertEquals(1, len(vals))
+ self.assertTrue(isinstance(vals[0], base_layer.DeferredTensor))
+
+ def test_list(self):
+ a = [np.ones(10), np.ones(20)]
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['input_1', 'input_2'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertEqual(2, len(vals))
+ self.assertAllEqual(np.ones(10), vals[0])
+ self.assertAllEqual(np.ones(20), vals[1])
+ self.assertFalse(tensor_util.is_tensor(vals[0]))
+ self.assertFalse(tensor_util.is_tensor(vals[1]))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(tensor_util.is_tensor(vals[0]))
+ self.assertTrue(tensor_util.is_tensor(vals[1]))
+
+ def test_list_eager(self):
+ with context.eager_mode():
+ a = [np.ones(10), np.ones(20)]
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['input_1', 'input_2'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertEqual(2, len(vals))
+ self.assertAllEqual(np.ones(10), vals[0])
+ self.assertAllEqual(np.ones(20), vals[1])
+ self.assertTrue(tensor_util.is_tensor(vals[0]))
+ self.assertTrue(tensor_util.is_tensor(vals[1]))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(isinstance(vals[0], base_layer.DeferredTensor))
+ self.assertTrue(isinstance(vals[1], base_layer.DeferredTensor))
+
+ def test_dict(self):
+ a = {'b': np.ones(10), 'a': np.ones(20)}
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['a', 'b'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertAllEqual(np.ones(20), vals['a'])
+ self.assertAllEqual(np.ones(10), vals['b'])
+ self.assertFalse(tensor_util.is_tensor(vals['a']))
+ self.assertFalse(tensor_util.is_tensor(vals['b']))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(tensor_util.is_tensor(vals['a']))
+ self.assertTrue(tensor_util.is_tensor(vals['b']))
+
+ def test_dict_eager(self):
+ with context.eager_mode():
+ a = {'b': np.ones(10), 'a': np.ones(20)}
+ model_inputs = training_utils.ModelInputs(a)
+ self.assertEquals(['a', 'b'], model_inputs.get_input_names())
+ vals = model_inputs.get_input_values()
+ self.assertAllEqual(np.ones(20), vals['a'])
+ self.assertAllEqual(np.ones(10), vals['b'])
+ self.assertTrue(tensor_util.is_tensor(vals['a']))
+ self.assertTrue(tensor_util.is_tensor(vals['b']))
+ vals = model_inputs.get_symbolic_inputs()
+ self.assertTrue(isinstance(vals['a'], base_layer.DeferredTensor))
+ self.assertTrue(isinstance(vals['b'], base_layer.DeferredTensor))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 81c760b1f6..473d8cd95b 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -22,7 +22,10 @@ from __future__ import print_function
from abc import ABCMeta
from abc import abstractmethod
+import functools
+import sys
import types
+import weakref
import six
from tensorflow.python.eager import context
@@ -137,6 +140,21 @@ def result_wrapper(result_fn):
return tf_decorator.make_decorator(result_fn, decorated)
+def weakmethod(method):
+ """Creates a weak reference to the bound method."""
+
+ cls = method.im_class
+ func = method.im_func
+ instance_ref = weakref.ref(method.im_self)
+
+ @functools.wraps(method)
+ def inner(*args, **kwargs):
+ return func.__get__(instance_ref(), cls)(*args, **kwargs)
+
+ del method
+ return inner
+
+
def safe_div(numerator, denominator):
"""Divides two tensors element-wise, returning 0 if the denominator is <= 0.
@@ -318,14 +336,27 @@ class Metric(Layer):
def __new__(cls, *args, **kwargs):
obj = super(Metric, cls).__new__(cls)
- # TODO(psv): Fix reference cycle issue here.
-
- # Converting update_state_fn() into a graph function, so that
- # we can return a single op that performs all of the variable updates.
- defuned_update_state_fn = function.defun(obj.update_state)
- obj.update_state = types.MethodType(
- update_state_wrapper(defuned_update_state_fn), obj)
- obj.result = types.MethodType(result_wrapper(obj.result), obj)
+
+ if sys.version_info < (3,):
+ # Wrap methods in `weakmethod` function to remove binding and create a
+ # weak reference. This is to remove reference cycle that is created here.
+ # This is not an issue in python versions > 3.
+ if context.executing_eagerly():
+ update_state = weakmethod(obj.update_state)
+ else:
+ update_state = function.defun(obj.update_state)
+ obj.update_state = weakmethod(
+ types.MethodType(update_state_wrapper(update_state), obj))
+ result = weakmethod(obj.result)
+ obj.result = weakmethod(types.MethodType(result_wrapper(result), obj))
+ else:
+ # Converting update_state_fn() into a graph function, so that
+ # we can return a single op that performs all of the variable updates.
+ defuned_update_state_fn = function.defun(obj.update_state)
+ obj.update_state = types.MethodType(
+ update_state_wrapper(defuned_update_state_fn), obj)
+ obj.result = types.MethodType(result_wrapper(obj.result), obj)
+
return obj
def __call__(self, *args, **kwargs):
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 779c08c42d..4195ea18ad 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -212,7 +212,7 @@ class KerasMetricsTest(test.TestCase):
self.assertAllClose(
val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
- @test_util.run_in_graph_and_eager_modes
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def test_mean(self):
m = metrics.Mean(name='my_mean')
@@ -394,7 +394,7 @@ class KerasMetricsTest(test.TestCase):
self.assertTrue(acc_obj.stateful)
self.assertEqual(len(acc_obj.variables), 2)
self.assertEqual(acc_obj.dtype, dtypes.float32)
- self.evaluate(variables.global_variables_initializer())
+ self.evaluate(variables.variables_initializer(acc_obj.variables))
# verify that correct value is returned
update_op = acc_obj.update_state([[0, 0, 1], [0, 1, 0]],
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index c3b7301eba..f0733a9105 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -414,10 +414,10 @@ def clone_and_build_model(
this argument must be set to `True` (default `False`). To restore the
original model, use the function
`in_place_subclassed_model_state_restoration(model)`.
- optimizer_iterations: An iterations variable to pass to the optimizer if
- the model uses a TFOptimizer, and if the clone is compiled. This is used
- when a Keras model is cloned into an Estimator model function, because
- Estimators create their own global step variable.
+ optimizer_iterations: An iterations variable that will be incremented by the
+ optimizer if the clone is compiled. This argument is used when a Keras
+ model is cloned into an Estimator model function, because Estimators
+ create their own global step variable.
Returns:
Clone of the model.
@@ -458,6 +458,8 @@ def clone_and_build_model(
else:
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
+ if optimizer_iterations is not None:
+ optimizer.iterations = optimizer_iterations
clone.compile(
optimizer,
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 1d0f56f3c8..c550caeb80 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -25,7 +25,9 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import backend as K
from tensorflow.python.keras import metrics
from tensorflow.python.keras import models
from tensorflow.python.ops import random_ops
@@ -51,7 +53,7 @@ class TestModel(keras.Model):
class TestModelCloning(test.TestCase):
def test_clone_sequential_model(self):
- with self.test_session():
+ with self.cached_session():
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
@@ -64,7 +66,7 @@ class TestModelCloning(test.TestCase):
# Everything should work in a new session.
keras.backend.clear_session()
- with self.test_session():
+ with self.cached_session():
# With placeholder creation
new_model = keras.models.clone_model(model)
# update ops from batch norm needs to be included
@@ -89,7 +91,7 @@ class TestModelCloning(test.TestCase):
new_model.train_on_batch(None, val_out)
def test_clone_functional_model(self):
- with self.test_session():
+ with self.cached_session():
val_a = np.random.random((10, 4))
val_b = np.random.random((10, 4))
val_out = np.random.random((10, 4))
@@ -110,7 +112,7 @@ class TestModelCloning(test.TestCase):
# Everything should work in a new session.
keras.backend.clear_session()
- with self.test_session():
+ with self.cached_session():
# With placeholder creation
new_model = keras.models.clone_model(model)
self.assertEquals(len(new_model.get_updates_for(new_model.inputs)), 2)
@@ -137,7 +139,7 @@ class TestModelCloning(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_clone_functional_model_with_masking(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([[[1], [1]], [[0], [0]]])
inputs = keras.Input((2, 1))
outputs = keras.layers.Masking(mask_value=0)(inputs)
@@ -238,7 +240,7 @@ class TestModelDeepCopy(test.TestCase):
class TestCloneAndBuildModel(test.TestCase):
def test_clone_and_build_non_compiled_model(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.random((10, 4))
out = np.random.random((10, 4))
@@ -251,7 +253,7 @@ class TestCloneAndBuildModel(test.TestCase):
# Everything should work in a new session.
keras.backend.clear_session()
- with self.test_session():
+ with self.cached_session():
# With placeholder creation
new_model = models.clone_and_build_model(model, compile_clone=True)
with self.assertRaisesRegexp(RuntimeError, 'must compile'):
@@ -289,7 +291,7 @@ class TestCloneAndBuildModel(test.TestCase):
# Everything should work in a new session.
keras.backend.clear_session()
- with self.test_session():
+ with self.cached_session():
# With placeholder creation
new_model = models.clone_and_build_model(
model, compile_clone=True, in_place_reset=is_subclassed)
@@ -316,7 +318,7 @@ class TestCloneAndBuildModel(test.TestCase):
new_model.evaluate(inp, out)
def test_clone_and_build_compiled_sequential_model(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(4, input_shape=(4,)))
model.add(keras.layers.BatchNormalization())
@@ -328,7 +330,7 @@ class TestCloneAndBuildModel(test.TestCase):
self._clone_and_build_test_helper(model)
def test_clone_and_build_functional_model(self):
- with self.test_session():
+ with self.cached_session():
input_a = keras.Input(shape=(4,))
dense_1 = keras.layers.Dense(4,)
dense_2 = keras.layers.Dense(4,)
@@ -358,12 +360,42 @@ class TestCloneAndBuildModel(test.TestCase):
out = self.layer2(out)
return out
- with self.test_session():
+ with self.cached_session():
model = SubclassedModel()
model.compile('rmsprop', 'mse',
metrics=['acc', metrics.categorical_accuracy])
self._clone_and_build_test_helper(model, True)
+ def assert_optimizer_iterations_increases(self, optimizer):
+ with self.cached_session():
+ input_a = keras.Input(shape=(4,))
+ dense_1 = keras.layers.Dense(4,)
+ dense_2 = keras.layers.Dense(4,)
+
+ x_a = dense_1(input_a)
+ x_a = keras.layers.Dropout(0.5)(x_a)
+ x_a = keras.layers.BatchNormalization()(x_a)
+ x_a = dense_2(x_a)
+ model = keras.models.Model(input_a, x_a)
+ model.compile(optimizer, 'mse',
+ metrics=['acc', metrics.categorical_accuracy])
+
+ global_step = keras.backend.variable(123, dtype=dtypes.int64)
+ clone_model = models.clone_and_build_model(
+ model, compile_clone=True, optimizer_iterations=global_step)
+
+ inp = np.random.random((10, 4))
+ out = np.random.random((10, 4))
+ clone_model.train_on_batch(inp, out)
+
+ self.assertEqual(K.eval(global_step), 124)
+
+ def test_replace_tf_optimizer_iterations_variable(self):
+ self.assert_optimizer_iterations_increases(adam.AdamOptimizer(0.01))
+
+ def test_replace_keras_optimizer_iterations_variable(self):
+ self.assert_optimizer_iterations_increases('adam')
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3026c7755a..0403211d92 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -622,6 +622,7 @@ cuda_py_test(
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
],
+ tags = ["notap"],
)
cuda_py_test(
@@ -779,6 +780,7 @@ tf_py_test(
size = "small",
srcs = ["regex_full_match_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
@@ -1634,6 +1636,7 @@ cuda_py_test(
srcs = ["functional_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index 400d38b936..de52a70cc0 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.platform import test
@@ -158,13 +159,19 @@ class ClipTest(test.TestCase):
ans = clip_ops.clip_by_norm(x, clip_norm)
tf_ans = ans.eval()
- clip_tensor = constant_op.constant(4.0)
ans = clip_ops.clip_by_norm(x, clip_norm)
tf_ans_tensor = ans.eval()
self.assertAllClose(np_ans, tf_ans)
self.assertAllClose(np_ans, tf_ans_tensor)
+ def testClipByNormGradientZeros(self):
+ with self.test_session(use_gpu=True):
+ x = array_ops.zeros([3])
+ b = clip_ops.clip_by_norm(x, 1.)
+ grad, = gradients_impl.gradients(b, x)
+ self.assertAllEqual(grad.eval(), [1., 1., 1.])
+
def testClipByNormBadShape(self):
with self.test_session(use_gpu=True):
x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3, 1])
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 7570523495..86802664d1 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -42,14 +42,22 @@ class ConditionalAccumulatorTest(test.TestCase):
with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
+ def testConstructorWithInvalidArg(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32, name="Q", reduction_type="Invalid")
+
def testConstructorWithShape(self):
with ops.Graph().as_default():
q = data_flow_ops.ConditionalAccumulator(
@@ -57,7 +65,8 @@ class ConditionalAccumulatorTest(test.TestCase):
name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'ConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 }
@@ -67,6 +76,7 @@ class ConditionalAccumulatorTest(test.TestCase):
} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
@@ -237,12 +247,11 @@ class ConditionalAccumulatorTest(test.TestCase):
extract_t.op.run()
self.assertEqual(q.num_accumulated().eval(), 0)
- def testAccumulatorTakeGrad(self):
+ def testAccumulatorTakeGradMean(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0]
- elems_ave = sum(elems) / len(elems)
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
takeg_t = q.take_grad(1)
@@ -251,7 +260,7 @@ class ConditionalAccumulatorTest(test.TestCase):
accum_op.run()
val = takeg_t.eval()
- self.assertEqual(elems_ave, val)
+ self.assertEqual(15.0, val)
accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
takeg_t = q.take_grad(constant_op.constant(1))
@@ -260,7 +269,42 @@ class ConditionalAccumulatorTest(test.TestCase):
accum_op.run()
val = takeg_t.eval()
- self.assertEqual(elems_ave, val)
+ self.assertEqual(15.0, val)
+
+ def testAccumulatorTakeGradSum(self):
+ with self.test_session():
+ q = data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="SUM")
+ elems = [10.0, 20.0]
+
+ accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(30.0, val)
+
+ accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
+ takeg_t = q.take_grad(constant_op.constant(1))
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(30.0, val)
+
+ def testAccumulatorTakeGradInvalidReductionType(self):
+ with self.assertRaises(ValueError):
+ data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="Invalid")
def testAccumulatorInvalidTakeGrad(self):
with self.test_session():
@@ -277,7 +321,7 @@ class ConditionalAccumulatorTest(test.TestCase):
with self.assertRaises(errors_impl.InvalidArgumentError):
takeg_t.eval()
- def testAccumulatorRepeatedTakeGrad(self):
+ def testAccumulatorRepeatedTakeGradMean(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -304,6 +348,36 @@ class ConditionalAccumulatorTest(test.TestCase):
val = takeg_t.eval()
self.assertEqual(elems_ave + 0.0, val)
+ def testAccumulatorRepeatedTakeGradSum(self):
+ with self.test_session():
+ q = data_flow_ops.ConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([1]),
+ reduction_type="SUM")
+
+ elems = [10.0, 20.0]
+ elems_sum = 30.0
+ accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(elems_sum, val)
+
+ elems = [20.0, 30.0]
+ elems_sum = 50.0
+ accum_ops = [q.apply_grad((x,), local_step=1) for x in elems]
+ takeg_t = q.take_grad(1)
+
+ for accum_op in accum_ops:
+ accum_op.run()
+
+ val = takeg_t.eval()
+ self.assertEqual(elems_sum, val)
+
def testAccumulatorIncrementGlobalStep(self):
with self.test_session():
q = data_flow_ops.ConditionalAccumulator(
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index c4d4ce780b..49b9569e2b 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -104,6 +104,27 @@ class DynamicStitchTestBase(object):
# Dimension 0 is max(flatten(indices))+1.
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
+ def testZeroSizeTensor(self):
+ with self.test_session(use_gpu=True):
+ indices = [
+ constant_op.constant([0, 4, 7]),
+ constant_op.constant([1, 6]),
+ constant_op.constant([2, 3, 5]),
+ array_ops.zeros([0], dtype=dtypes.int32)
+ ]
+ data = [
+ constant_op.constant([[0, 1], [40, 41], [70, 71]]),
+ constant_op.constant([[10, 11], [60, 61]]),
+ constant_op.constant([[20, 21], [30, 31], [50, 51]]),
+ array_ops.zeros([0, 2], dtype=dtypes.int32)
+ ]
+ stitched_t = self.stitch_op(indices, data)
+ stitched_val = stitched_t.eval()
+ self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
+ [50, 51], [60, 61], [70, 71]], stitched_val)
+ # Dimension 0 is max(flatten(indices))+1.
+ self.assertEqual([8, 2], stitched_t.get_shape().as_list())
+
def testHigherRank(self):
with self.test_session(use_gpu=True) as sess:
indices = [
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 3ddb5e06c9..e39daf1371 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import iterator_ops
@@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
+ def testWhileLowering(self):
+
+ def Run(n, fetch_by_name):
+ for use_gpu in (True, False):
+ with ops.Graph().as_default() as g:
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Cond(n, unused_x):
+ return n > 0
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(n, x):
+ return n - 1, x + n
+
+ # outputs: [0, n*(n+1)/2]
+ outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while")
+
+ # `outputs` is the list of output tensors of the While op. We
+ # arbitrarily choose the 0th tensor to get the While op and set the
+ # lowering attribute on it.
+ outputs[0].op._set_attr("_lower_using_switch_merge",
+ attr_value_pb2.AttrValue(b=True))
+ if not fetch_by_name:
+ fetch = outputs[1]
+ else:
+ fetch = "my_while:1"
+ with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ return sess.run(fetch)
+
+ self.assertAllEqual(Run(20., False), 210.)
+ self.assertAllEqual(Run(20., True), 210.)
+ self.assertAllEqual(Run(100., False), 5050.)
+ self.assertAllEqual(Run(100., True), 5050.)
+
def testWhileError(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
index 5daae1b79b..7bd8c3ca27 100644
--- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -18,37 +18,77 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class RegexFullMatchOpTest(test.TestCase):
+@parameterized.parameters(
+ (gen_string_ops.regex_full_match),
+ (gen_string_ops.static_regex_full_match))
+class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
- def testRegexFullMatch(self):
+ def testRegexFullMatch(self, op):
values = ["abaaba", "abcdabcde"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
- matched = string_ops.regex_full_match(input_vector, "a.*a").eval()
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([True, False], matched)
- def testEmptyMatch(self):
+ def testRegexFullMatchTwoDims(self, op):
+ values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
+ with self.test_session():
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "a.*a").eval()
+ self.assertAllEqual([[True, False], [True, False]], matched)
+
+ def testEmptyMatch(self, op):
values = ["abc", "1"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
- matched = string_ops.regex_full_match(input_vector, "").eval()
+ input_tensor = constant_op.constant(values, dtypes.string)
+ matched = op(input_tensor, "").eval()
self.assertAllEqual([False, False], matched)
- def testInvalidPattern(self):
+ def testInvalidPattern(self, op):
values = ["abc", "1"]
with self.test_session():
- input_vector = constant_op.constant(values, dtypes.string)
+ input_tensor = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
- matched = string_ops.regex_full_match(input_vector, invalid_pattern)
+ matched = op(input_tensor, invalid_pattern)
with self.assertRaisesOpError("Invalid pattern"):
matched.eval()
+class RegexFullMatchOpTest(test.TestCase):
+
+ def testRegexFullMatchDelegation(self):
+ with compat.forward_compatibility_horizon(2018, 11, 1):
+ with self.test_session():
+ input_tensor = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]"
+ op = string_ops.regex_full_match(input_tensor, pattern)
+ self.assertTrue(op.name.startswith("RegexFullMatch"), op.name)
+
+ pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
+ op_tensor = string_ops.regex_full_match(input_tensor, pattern_tensor)
+ self.assertTrue(op_tensor.name.startswith("RegexFullMatch"), op.name)
+
+ def testStaticRegexFullMatchDelegation(self):
+ with compat.forward_compatibility_horizon(2018, 11, 20):
+ with self.test_session():
+ input_tensor = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]*"
+ op = string_ops.regex_full_match(input_tensor, pattern)
+ self.assertTrue(op.name.startswith("StaticRegexFullMatch"), op.name)
+
+ pattern_tensor = constant_op.constant("[a-z]*", dtypes.string)
+ op_vec = string_ops.regex_full_match(input_tensor, pattern_tensor)
+ self.assertTrue(op_vec.name.startswith("RegexFullMatch"), op.name)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index d749843410..3bb5e899fe 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -61,14 +61,22 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q")
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'SparseConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { unknown_rank: true} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
+ def testConstructorWithInvalidArg(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", reduction_type="Invalid")
+
def testConstructorWithShape(self):
with ops.Graph().as_default():
q = data_flow_ops.SparseConditionalAccumulator(
@@ -76,7 +84,8 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
name="Q",
shape=tensor_shape.TensorShape([1, 5, 2, 8]))
self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor))
- self.assertProtoEquals("""
+ self.assertProtoEquals(
+ """
name:'Q' op:'SparseConditionalAccumulator'
attr { key: 'dtype' value { type: DT_FLOAT } }
attr { key: 'shape' value { shape { dim {size: 1 }
@@ -86,6 +95,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
} } }
attr { key: 'container' value { s: '' } }
attr { key: 'shared_name' value { s: '' } }
+ attr { key: 'reduction_type' value {s: 'MEAN'} }
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
@@ -164,7 +174,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
result = sess.run(accums[i].take_indexed_slices_grad(1))
self._assertEqual_indexedslices(expected_tensors[i], result)
- def testAccumulatorTakeGrad(self):
+ def testAccumulatorTakeGradMean(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=())
@@ -180,9 +190,34 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
takeg_t = q.take_indexed_slices_grad(1)
val = sess.run(takeg_t)
- self.assertAllEqual(val.indices, [0, 1, 2])
- self.assertAllEqual(val.values, [[0.5, 0.5], [0, 2], [3, 0]])
- self.assertAllEqual(val.dense_shape, [-1, 2])
+ self.assertAllEqual([0, 1, 2], val.indices)
+ self.assertAllEqual([[0.5, 0.5], [0, 2], [3, 0]], val.values)
+ self.assertAllEqual([-1, 2], val.dense_shape)
+
+ def testAccumulatorTakeGradSum(self):
+ with self.test_session() as sess:
+ q = data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
+
+ grad_indexed_slices = ops.IndexedSlices(
+ indices=[0, 1], values=np.array([[1, 0], [0, 2]]).astype(np.float32))
+ accum_op = q.apply_indexed_slices_grad(grad_indexed_slices)
+ accum_op.run()
+ accum_op = q.apply_grad([0, 2],
+ np.array([[0, 1], [3, 0]]).astype(np.float32),
+ [3, 2])
+ accum_op.run()
+
+ takeg_t = q.take_indexed_slices_grad(1)
+ val = sess.run(takeg_t)
+ self.assertAllEqual([0, 1, 2], val.indices)
+ self.assertAllEqual([[1, 1], [0, 2], [3, 0]], val.values)
+ self.assertAllEqual([-1, 2], val.dense_shape)
+
+ def testAccumulatorTakeGradInvalidReductionType(self):
+ with self.assertRaises(ValueError):
+ data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid")
def testAccumulatorRepeatedTakeGrad(self):
with self.test_session() as sess:
@@ -222,7 +257,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]])
self.assertAllEqual(val.dense_shape, [-1, 2])
- def testParallelApplyGrad(self):
+ def testParallelApplyGradMean(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@@ -253,6 +288,40 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
val, sess)
+ def testParallelApplyGradSum(self):
+ with self.test_session() as sess:
+ q = data_flow_ops.SparseConditionalAccumulator(
+ dtypes_lib.float32,
+ name="Q",
+ shape=tensor_shape.TensorShape([2, 2]),
+ reduction_type="SUM")
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ accum_ops = []
+ for x in elems:
+ x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32))
+ accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
+ takeg_t = q.take_indexed_slices_grad(1)
+
+ def apply_indexed_slices_grad(accum_op):
+ sess.run(accum_op)
+
+ threads = [
+ self.checkedThread(target=apply_indexed_slices_grad, args=(o,))
+ for o in accum_ops
+ ]
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ val = sess.run(takeg_t)
+
+ expected_val = 550.0
+ self._assertEqual_nparray(
+ np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
+ val, sess)
+
def testParallelTakeGrad(self):
with self.test_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
diff --git a/tensorflow/python/lib/io/py_record_reader.cc b/tensorflow/python/lib/io/py_record_reader.cc
index 9500fc6a7c..07ce071845 100644
--- a/tensorflow/python/lib/io/py_record_reader.cc
+++ b/tensorflow/python/lib/io/py_record_reader.cc
@@ -30,6 +30,8 @@ namespace io {
PyRecordReader::PyRecordReader() {}
+// NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
+// RecordReaderOptions, if this changes the API can be updated at that time.
PyRecordReader* PyRecordReader::New(const string& filename, uint64 start_offset,
const string& compression_type_string,
TF_Status* out_status) {
diff --git a/tensorflow/python/lib/io/py_record_writer.cc b/tensorflow/python/lib/io/py_record_writer.cc
index e4e5268b0f..faf20df868 100644
--- a/tensorflow/python/lib/io/py_record_writer.cc
+++ b/tensorflow/python/lib/io/py_record_writer.cc
@@ -28,7 +28,7 @@ namespace io {
PyRecordWriter::PyRecordWriter() {}
PyRecordWriter* PyRecordWriter::New(const string& filename,
- const string& compression_type_string,
+ const io::RecordWriterOptions& options,
TF_Status* out_status) {
std::unique_ptr<WritableFile> file;
Status s = Env::Default()->NewWritableFile(filename, &file);
@@ -38,10 +38,6 @@ PyRecordWriter* PyRecordWriter::New(const string& filename,
}
PyRecordWriter* writer = new PyRecordWriter;
writer->file_ = std::move(file);
-
- RecordWriterOptions options =
- RecordWriterOptions::CreateRecordWriterOptions(compression_type_string);
-
writer->writer_.reset(new RecordWriter(writer->file_.get(), options));
return writer;
}
diff --git a/tensorflow/python/lib/io/py_record_writer.h b/tensorflow/python/lib/io/py_record_writer.h
index 61a4960ee6..9b0792c6db 100644
--- a/tensorflow/python/lib/io/py_record_writer.h
+++ b/tensorflow/python/lib/io/py_record_writer.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -36,10 +37,8 @@ class RecordWriter;
// by multiple threads.
class PyRecordWriter {
public:
- // TODO(vrv): make this take a shared proto to configure
- // the compression options.
static PyRecordWriter* New(const string& filename,
- const string& compression_type_string,
+ const io::RecordWriterOptions& compression_options,
TF_Status* out_status);
~PyRecordWriter();
diff --git a/tensorflow/python/lib/io/py_record_writer.i b/tensorflow/python/lib/io/py_record_writer.i
index 3181c9afce..b2c2bda5dd 100644
--- a/tensorflow/python/lib/io/py_record_writer.i
+++ b/tensorflow/python/lib/io/py_record_writer.i
@@ -18,6 +18,11 @@ limitations under the License.
%include "tensorflow/python/platform/base.i"
%include "tensorflow/python/lib/core/strings.i"
+// Define int8_t explicitly instead of including "stdint.i", since "stdint.h"
+// and "stdint.i" disagree on the definition of int64_t.
+typedef signed char int8;
+%{ typedef signed char int8; %}
+
%feature("except") tensorflow::io::PyRecordWriter::New {
// Let other threads run while we write
Py_BEGIN_ALLOW_THREADS
@@ -26,6 +31,7 @@ limitations under the License.
}
%newobject tensorflow::io::PyRecordWriter::New;
+%newobject tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
%feature("except") tensorflow::io::PyRecordWriter::WriteRecord {
// Let other threads run while we write
@@ -35,6 +41,8 @@ limitations under the License.
}
%{
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/python/lib/io/py_record_writer.h"
%}
@@ -48,7 +56,21 @@ limitations under the License.
%unignore tensorflow::io::PyRecordWriter::Flush;
%unignore tensorflow::io::PyRecordWriter::Close;
%unignore tensorflow::io::PyRecordWriter::New;
+%unignore tensorflow::io::ZlibCompressionOptions;
+%unignore tensorflow::io::ZlibCompressionOptions::flush_mode;
+%unignore tensorflow::io::ZlibCompressionOptions::input_buffer_size;
+%unignore tensorflow::io::ZlibCompressionOptions::output_buffer_size;
+%unignore tensorflow::io::ZlibCompressionOptions::window_bits;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_level;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_method;
+%unignore tensorflow::io::ZlibCompressionOptions::mem_level;
+%unignore tensorflow::io::ZlibCompressionOptions::compression_strategy;
+%unignore tensorflow::io::RecordWriterOptions;
+%unignore tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
+%unignore tensorflow::io::RecordWriterOptions::zlib_options;
+%include "tensorflow/core/lib/io/record_writer.h"
+%include "tensorflow/core/lib/io/zlib_compression_options.h"
%include "tensorflow/python/lib/io/py_record_writer.h"
%unignoreall
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 2b3e986f6b..cce71a2bab 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -33,8 +33,6 @@ class TFRecordCompressionType(object):
GZIP = 2
-# NOTE(vrv): This will eventually be converted into a proto. to match
-# the interface used by the C++ RecordWriter.
@tf_export("python_io.TFRecordOptions")
class TFRecordOptions(object):
"""Options used for manipulating TFRecord files."""
@@ -44,14 +42,105 @@ class TFRecordOptions(object):
TFRecordCompressionType.NONE: ""
}
- def __init__(self, compression_type):
+ def __init__(self,
+ compression_type=None,
+ flush_mode=None,
+ input_buffer_size=None,
+ output_buffer_size=None,
+ window_bits=None,
+ compression_level=None,
+ compression_method=None,
+ mem_level=None,
+ compression_strategy=None):
+ # pylint: disable=line-too-long
+ """Creates a `TFRecordOptions` instance.
+
+ Options only effect TFRecordWriter when compression_type is not `None`.
+ Documentation, details, and defaults can be found in
+ [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
+ and in the [zlib manual](http://www.zlib.net/manual.html).
+ Leaving an option as `None` allows C++ to set a reasonable default.
+
+ Args:
+ compression_type: `TFRecordCompressionType` or `None`.
+ flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
+ input_buffer_size: int or `None`.
+ output_buffer_size: int or `None`.
+ window_bits: int or `None`.
+ compression_level: 0 to 9, or `None`.
+ compression_method: compression method or `None`.
+ mem_level: 1 to 9, or `None`.
+ compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
+
+ Returns:
+ A `TFRecordOptions` object.
+
+ Raises:
+ ValueError: If compression_type is invalid.
+ """
+ # pylint: enable=line-too-long
+ # Check compression_type is valid, but for backwards compatibility don't
+ # immediately convert to a string.
+ self.get_compression_type_string(compression_type)
self.compression_type = compression_type
+ self.flush_mode = flush_mode
+ self.input_buffer_size = input_buffer_size
+ self.output_buffer_size = output_buffer_size
+ self.window_bits = window_bits
+ self.compression_level = compression_level
+ self.compression_method = compression_method
+ self.mem_level = mem_level
+ self.compression_strategy = compression_strategy
@classmethod
def get_compression_type_string(cls, options):
+ """Convert various option types to a unified string.
+
+ Args:
+ options: `TFRecordOption`, `TFRecordCompressionType`, or string.
+
+ Returns:
+ Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
+
+ Raises:
+ ValueError: If compression_type is invalid.
+ """
if not options:
return ""
- return cls.compression_type_map[options.compression_type]
+ elif isinstance(options, TFRecordOptions):
+ return cls.get_compression_type_string(options.compression_type)
+ elif isinstance(options, TFRecordCompressionType):
+ return cls.compression_type_map[options]
+ elif options in TFRecordOptions.compression_type_map:
+ return cls.compression_type_map[options]
+ elif options in TFRecordOptions.compression_type_map.values():
+ return options
+ else:
+ raise ValueError('Not a valid compression_type: "{}"'.format(options))
+
+ def _as_record_writer_options(self):
+ """Convert to RecordWriterOptions for use with PyRecordWriter."""
+ options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions(
+ compat.as_bytes(
+ self.get_compression_type_string(self.compression_type)))
+
+ if self.flush_mode is not None:
+ options.zlib_options.flush_mode = self.flush_mode
+ if self.input_buffer_size is not None:
+ options.zlib_options.input_buffer_size = self.input_buffer_size
+ if self.output_buffer_size is not None:
+ options.zlib_options.output_buffer_size = self.output_buffer_size
+ if self.window_bits is not None:
+ options.zlib_options.window_bits = self.window_bits
+ if self.compression_level is not None:
+ options.zlib_options.compression_level = self.compression_level
+ if self.compression_method is not None:
+ options.zlib_options.compression_method = self.compression_method
+ if self.mem_level is not None:
+ options.zlib_options.mem_level = self.mem_level
+ if self.compression_strategy is not None:
+ options.zlib_options.compression_strategy = self.compression_strategy
+ return options
@tf_export("python_io.tf_record_iterator")
@@ -100,16 +189,21 @@ class TFRecordWriter(object):
Args:
path: The path to the TFRecords file.
- options: (optional) A TFRecordOptions object.
+ options: (optional) String specifying compression type,
+ `TFRecordCompressionType`, or `TFRecordOptions` object.
Raises:
IOError: If `path` cannot be opened for writing.
+ ValueError: If valid compression_type can't be determined from `options`.
"""
- compression_type = TFRecordOptions.get_compression_type_string(options)
+ if not isinstance(options, TFRecordOptions):
+ options = TFRecordOptions(compression_type=options)
with errors.raise_exception_on_not_ok_status() as status:
+ # pylint: disable=protected-access
self._writer = pywrap_tensorflow.PyRecordWriter_New(
- compat.as_bytes(path), compat.as_bytes(compression_type), status)
+ compat.as_bytes(path), options._as_record_writer_options(), status)
+ # pylint: enable=protected-access
def __enter__(self):
"""Enter a `with` block."""
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
index b853b64ae4..def8fe23e5 100644
--- a/tensorflow/python/lib/io/tf_record_test.py
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import gzip
import os
+import random
+import string
import zlib
import six
@@ -131,9 +133,6 @@ class TFCompressionTestCase(test.TestCase):
class TFRecordWriterTest(TFCompressionTestCase):
- def setUp(self):
- super(TFRecordWriterTest, self).setUp()
-
def _AssertFilesEqual(self, a, b, equal):
for an, bn in zip(a, b):
with open(an, "rb") as af, open(bn, "rb") as bf:
@@ -142,6 +141,37 @@ class TFRecordWriterTest(TFCompressionTestCase):
else:
self.assertNotEqual(af.read(), bf.read())
+ def _CompressionSizeDelta(self, records, options_a, options_b):
+ """Validate compression with options_a and options_b and return size delta.
+
+ Compress records with options_a and options_b. Uncompress both compressed
+ files and assert that the contents match the original records. Finally
+ calculate how much smaller the file compressed with options_a was than the
+ file compressed with options_b.
+
+ Args:
+ records: The records to compress
+ options_a: First set of options to compress with, the baseline for size.
+ options_b: Second set of options to compress with.
+
+ Returns:
+ The difference in file size when using options_a vs options_b. A positive
+ value means options_a was a better compression than options_b. A negative
+ value means options_b had better compression than options_a.
+
+ """
+
+ fn_a = self._WriteRecordsToFile(records, "tfrecord_a", options=options_a)
+ test_a = list(tf_record.tf_record_iterator(fn_a, options=options_a))
+ self.assertEqual(records, test_a, options_a)
+
+ fn_b = self._WriteRecordsToFile(records, "tfrecord_b", options=options_b)
+ test_b = list(tf_record.tf_record_iterator(fn_b, options=options_b))
+ self.assertEqual(records, test_b, options_b)
+
+ # Negative number => better compression.
+ return os.path.getsize(fn_a) - os.path.getsize(fn_b)
+
def testWriteReadZLibFiles(self):
# Write uncompressed then compress manually.
options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
@@ -188,6 +218,76 @@ class TFRecordWriterTest(TFCompressionTestCase):
]
self._AssertFilesEqual(uncompressed_files, files, True)
+ def testNoCompressionType(self):
+ self.assertEqual(
+ "",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions()))
+
+ self.assertEqual(
+ "",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions("")))
+
+ with self.assertRaises(ValueError):
+ tf_record.TFRecordOptions(5)
+
+ with self.assertRaises(ValueError):
+ tf_record.TFRecordOptions("BZ2")
+
+ def testZlibCompressionType(self):
+ zlib_t = tf_record.TFRecordCompressionType.ZLIB
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions("ZLIB")))
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions(zlib_t)))
+
+ self.assertEqual(
+ "ZLIB",
+ tf_record.TFRecordOptions.get_compression_type_string(
+ tf_record.TFRecordOptions(tf_record.TFRecordOptions(zlib_t))))
+
+ def testCompressionOptions(self):
+ # Create record with mix of random and repeated data to test compression on.
+ rnd = random.Random(123)
+ random_record = compat.as_bytes(
+ "".join(rnd.choice(string.digits) for _ in range(10000)))
+ repeated_record = compat.as_bytes(_TEXT)
+ for _ in range(10000):
+ start_i = rnd.randint(0, len(_TEXT))
+ length = rnd.randint(10, 200)
+ repeated_record += _TEXT[start_i:start_i + length]
+ records = [random_record, repeated_record, random_record]
+
+ tests = [
+ ("compression_level", 2, -1), # Lower compression is worse.
+ ("compression_level", 6, 0), # Default compression_level is equal.
+ ("flush_mode", zlib.Z_FULL_FLUSH, 1), # A few less bytes.
+ ("flush_mode", zlib.Z_NO_FLUSH, 0), # NO_FLUSH is the default.
+ ("input_buffer_size", 4096, 0), # Increases time not size.
+ ("output_buffer_size", 4096, 0), # Increases time not size.
+ ("window_bits", 8, -1), # Smaller than default window increases size.
+ ("compression_strategy", zlib.Z_HUFFMAN_ONLY, -1), # Worse.
+ ("compression_strategy", zlib.Z_FILTERED, -1), # Worse.
+ ]
+
+ compression_type = tf_record.TFRecordCompressionType.ZLIB
+ options_a = tf_record.TFRecordOptions(compression_type)
+ for prop, value, delta_sign in tests:
+ options_b = tf_record.TFRecordOptions(
+ compression_type=compression_type, **{prop: value})
+ delta = self._CompressionSizeDelta(records, options_a, options_b)
+ self.assertTrue(
+ delta == 0 if delta_sign == 0 else delta // delta_sign > 0,
+ "Setting {} = {}, file was {} smaller didn't match sign of {}".format(
+ prop, value, delta, delta_sign))
+
class TFRecordWriterZlibTest(TFCompressionTestCase):
@@ -318,6 +418,7 @@ class TFRecordIteratorTest(TFCompressionTestCase):
for _ in tf_record.tf_record_iterator(fn_truncated):
pass
+
class TFRecordWriterCloseAndFlushTests(test.TestCase):
def setUp(self, compression_type=TFRecordCompressionType.NONE):
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 21ccbc6c33..c8b883350d 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1275,7 +1275,7 @@ unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__
def split(value, num_or_size_splits, axis=0, num=None, name="split"):
"""Splits a tensor into sub tensors.
- If `num_or_size_splits` is an integer type, `num_split`, then splits `value`
+ If `num_or_size_splits` is an integer type, then `value` is split
along dimension `axis` into `num_split` smaller tensors.
Requires that `num_split` evenly divides `value.shape[axis]`.
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 78b395a6c1..29468431b3 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -144,7 +144,11 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
t = ops.convert_to_tensor(t, name="t")
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
- l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keepdims=True))
+ l2sum = math_ops.reduce_sum(t * t, axes, keepdims=True)
+ pred = l2sum > 0
+ # Two-tap tf.where trick to bypass NaN gradients
+ l2sum_safe = array_ops.where(pred, l2sum, array_ops.ones_like(l2sum))
+ l2norm = array_ops.where(pred, math_ops.sqrt(l2sum_safe), l2sum)
intermediate = t * clip_norm
# Assert that the shape is compatible with the initial shape,
# to prevent unintentional broadcasting.
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 7af2ca56be..69c0fcbbee 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -1229,7 +1229,8 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
dtype,
shape=None,
shared_name=None,
- name="conditional_accumulator"):
+ name="conditional_accumulator",
+ reduction_type="MEAN"):
"""Creates a new ConditionalAccumulator.
Args:
@@ -1238,9 +1239,14 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
shared_name: Optional. If non-empty, this accumulator will be shared under
the given name across multiple sessions.
name: Optional name for the accumulator.
+ reduction_type: Reduction type to use when taking the gradient.
"""
accumulator_ref = gen_data_flow_ops.conditional_accumulator(
- dtype=dtype, shape=shape, shared_name=shared_name, name=name)
+ dtype=dtype,
+ shape=shape,
+ shared_name=shared_name,
+ name=name,
+ reduction_type=reduction_type)
super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
def apply_grad(self, grad, local_step=0, name=None):
@@ -1312,15 +1318,21 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
shared_name: Optional. If non-empty, this accumulator will be shared under
the given name across multiple sessions.
name: Optional name for the accumulator.
+ reduction_type: Reduction type to use when taking the gradient.
"""
def __init__(self,
dtype,
shape=None,
shared_name=None,
- name="sparse_conditional_accumulator"):
+ name="sparse_conditional_accumulator",
+ reduction_type="MEAN"):
accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
- dtype=dtype, shape=shape, shared_name=shared_name, name=name)
+ dtype=dtype,
+ shape=shape,
+ shared_name=shared_name,
+ name=name,
+ reduction_type=reduction_type)
super(SparseConditionalAccumulator, self).__init__(dtype, shape,
accumulator_ref)
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 12356944f8..de260f3140 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -330,6 +330,8 @@ def _random_flip(image, flip_index, seed, scope_name):
lambda: image,
name=scope
)
+ if isinstance(result, tuple):
+ result = result[0] # TODO(b/111124878) remove this logic (CondV2).
return fix_image_flip_shape(image, result)
elif shape.ndims == 4:
uniform_random = random_ops.random_uniform(
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index f7502c4018..795e6bbc3e 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -3657,6 +3657,47 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
scores = constant_op.constant([0.9])
image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
+ def testDataTypes(self):
+ # Test case for GitHub issue 20199.
+ boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
+ [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
+ scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3]
+ max_output_size_np = 3
+ iou_threshold_np = 0.5
+ # Note: There are multiple versions of non_max_suppression v2, v3, v4.
+ # gen_image_ops.non_max_suppression_v2:
+ for dtype in [np.float16, np.float32]:
+ with self.test_session():
+ boxes = constant_op.constant(boxes_np, dtype=dtype)
+ scores = constant_op.constant(scores_np, dtype=dtype)
+ max_output_size = constant_op.constant(max_output_size_np)
+ iou_threshold = constant_op.constant(iou_threshold_np)
+ selected_indices = gen_image_ops.non_max_suppression_v2(
+ boxes, scores, max_output_size, iou_threshold).eval()
+ self.assertAllClose(selected_indices, [3, 0, 5])
+ # image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
+ for dtype in [np.float16, np.float32]:
+ with self.test_session():
+ boxes = constant_op.constant(boxes_np, dtype=dtype)
+ scores = constant_op.constant(scores_np, dtype=dtype)
+ max_output_size = constant_op.constant(max_output_size_np)
+ iou_threshold = constant_op.constant(iou_threshold_np)
+ selected_indices = image_ops.non_max_suppression(
+ boxes, scores, max_output_size, iou_threshold).eval()
+ self.assertAllClose(selected_indices, [3, 0, 5])
+ # gen_image_ops.non_max_suppression_v4.
+ score_threshold = float('-inf')
+ for dtype in [np.float16, np.float32]:
+ with self.test_session():
+ boxes = constant_op.constant(boxes_np, dtype=dtype)
+ scores = constant_op.constant(scores_np, dtype=dtype)
+ max_output_size = constant_op.constant(max_output_size_np)
+ iou_threshold = constant_op.constant(iou_threshold_np)
+ selected_indices, _ = gen_image_ops.non_max_suppression_v4(
+ boxes, scores, max_output_size, iou_threshold, score_threshold)
+ selected_indices = selected_indices.eval()
+ self.assertAllClose(selected_indices, [3, 0, 5])
+
class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 9b0ab00c7a..33e7a5533b 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2571,7 +2571,7 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
@tf_export("unsorted_segment_mean")
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
- r""" Computes the mean along segments of a tensor.
+ r"""Computes the mean along segments of a tensor.
Read [the section on
segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
@@ -2582,17 +2582,26 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
Instead of computing the sum over segments, it computes the mean of all
entries belonging to a segment such that:
- \\(output_i = 1/N_i \sum data_j\\) where the sum is over `j` such
- that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
- of id \\i\\.
+ \\(output_i = 1/N_i \sum_{j...} data[j...]\\) where the sum is over tuples
+ `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the number of
+ occurrences of id \\i\\.
If there is no entry for a given segment ID `i`, it outputs 0.
- segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
- first dimension.
+ If the given segment ID `i` is negative, the value is dropped and will not
+ be added to the sum of the segment.
- output: Has same shape as data, except for dimension 0 which
- has size `num_segments`.
+ Args:
+ data: A `Tensor` with floating point or complex dtype.
+ segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
+ num_segments: An integer scalar `Tensor`. The number of distinct
+ segment IDs.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has same shape as data, except for the first `segment_ids.rank`
+ dimensions, which are replaced with a single dimension which has size
+ `num_segments`.
"""
with ops.name_scope(name, "UnsortedSegmentMean"):
data = ops.convert_to_tensor(data)
@@ -2615,20 +2624,29 @@ def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
Additionally to computing the sum over segments, it divides the results by
sqrt(N).
- \\(output_i = 1/sqrt(N_i) \sum data_j\\) where the sum is over `j` such
- that `segment_ids[j] == i` with \\N_i\\ being the number of occurrences
- of id \\i\\.
+ \\(output_i = 1/sqrt(N_i) \sum_{j...} data[j...]\\) where the sum is over
+ tuples `j...` such that `segment_ids[j...] == i` with \\N_i\\ being the
+ number of occurrences of id \\i\\.
If there is no entry for a given segment ID `i`, it outputs 0.
Note that this op only supports floating point and complex dtypes,
due to tf.sqrt only supporting these types.
- segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
- first dimension.
+ If the given segment ID `i` is negative, the value is dropped and will not
+ be added to the sum of the segment.
- output: Has same shape as data, except for dimension 0 which
- has size `num_segments`.
+ Args:
+ data: A `Tensor` with floating point or complex dtype.
+ segment_ids: An integer tensor whose shape is a prefix of `data.shape`.
+ num_segments: An integer scalar `Tensor`. The number of distinct
+ segment IDs.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has same shape as data, except for the first `segment_ids.rank`
+ dimensions, which are replaced with a single dimension which has size
+ `num_segments`.
"""
with ops.name_scope(name, "UnsortedSegmentSqrtN"):
data = ops.convert_to_tensor(data)
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 4800352ac2..55c2eb5fa4 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -750,7 +750,7 @@ class ResourceVariable(variables.RefVariable):
def _read_variable_op(self):
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
result = gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
if not context.executing_eagerly():
@@ -781,7 +781,7 @@ class ResourceVariable(variables.RefVariable):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value)
@@ -949,12 +949,12 @@ class ResourceVariable(variables.RefVariable):
def _lazy_read(self, op):
if self.trainable:
- tape.watch_variable(self)
+ tape.variable_accessed(self)
return _UnreadVariable(
handle=self._handle, dtype=self.dtype, shape=self._shape,
in_graph_mode=self._in_graph_mode,
deleter=self._handle_deleter if not self._in_graph_mode else None,
- parent_op=op, parent_name=self._handle_name, unique_id=self._unique_id)
+ parent_op=op, unique_id=self._unique_id)
def assign(self, value, use_locking=None, name=None, read_value=True):
"""Assigns a new value to this variable.
@@ -1293,8 +1293,7 @@ class _UnreadVariable(ResourceVariable):
"""
def __init__(self, handle, dtype, # pylint: disable=super-init-not-called
- shape, in_graph_mode, deleter, parent_op, parent_name,
- unique_id):
+ shape, in_graph_mode, deleter, parent_op, unique_id):
# We do not call super init on purpose.
self._trainable = False
self._save_slice_info = None
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index fa13568596..c11c9ccaae 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -428,7 +428,7 @@ class BasicRNNCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ % str(input_shape))
input_depth = inputs_shape[-1]
self._kernel = self.add_variable(
@@ -525,7 +525,7 @@ class GRUCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ % str(input_shape))
input_depth = inputs_shape[-1]
self._gate_kernel = self.add_variable(
@@ -705,7 +705,7 @@ class BasicLSTMCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ % str(input_shape))
input_depth = inputs_shape[-1]
h_depth = self._num_units
@@ -783,10 +783,10 @@ class LSTMCell(LayerRNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -908,7 +908,7 @@ class LSTMCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % inputs_shape)
+ % str(input_shape))
input_depth = inputs_shape[-1]
h_depth = self._num_units if self._num_proj is None else self._num_proj
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index c832ba4e2a..29fefbe3a5 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -41,12 +41,41 @@ from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
+
+# pylint: disable=redefined-builtin
+def regex_full_match(input, pattern, name=None):
+ r"""Match elements of `input` with regex `pattern`.
+
+ Args:
+ input: string `Tensor`, the source strings to process.
+ pattern: string or scalar string `Tensor`, regular expression to use,
+ see more details at https://github.com/google/re2/wiki/Syntax
+ name: Name of the op.
+
+ Returns:
+ bool `Tensor` of the same shape as `input` with match results.
+ """
+ # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
+ if not compat.forward_compatible(2018, 11, 10):
+ return gen_string_ops.regex_full_match(
+ input=input, pattern=pattern, name=name)
+ if isinstance(pattern, util_compat.bytes_or_text_types):
+ # When `pattern` is static through the life of the op we can
+ # use a version which performs the expensive regex compilation once at
+ # creation time.
+ return gen_string_ops.static_regex_full_match(
+ input=input, pattern=pattern, name=name)
+ return gen_string_ops.regex_full_match(
+ input=input, pattern=pattern, name=name)
+
+regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
+
# Expose regex_full_match in strings namespace
tf_export("strings.regex_full_match")(regex_full_match)
def regex_replace(source, pattern, rewrite, replace_global=True):
- r"""Replace elements of `source` matching regex `pattern with `rewrite`.
+ r"""Replace elements of `source` matching regex `pattern` with `rewrite`.
Args:
source: string `Tensor`, the source strings to process.
@@ -128,6 +157,7 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv
shape.set_shape([2])
return sparse_tensor.SparseTensor(indices, values, shape)
+
@tf_export("strings.split")
def string_split_v2(source, sep=None, maxsplit=-1):
"""Split elements of `source` based on `sep` into a `SparseTensor`.
@@ -170,7 +200,7 @@ def string_split_v2(source, sep=None, maxsplit=-1):
second column corresponds to the index of the split component in this row.
"""
if sep is None:
- sep = ''
+ sep = ""
sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
source = ops.convert_to_tensor(source, dtype=dtypes.string)
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index a31861ae40..be8f425481 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -52,9 +52,10 @@ limitations under the License.
%rename("%s") TFE_Py_TapeSetShouldRecord;
%rename("%s") TFE_Py_TapeSetDeleteTrace;
%rename("%s") TFE_Py_TapeSetRecordOperation;
-%rename("%s") TFE_Py_TapeSetWatchVariable;
%rename("%s") TFE_Py_TapeGradient;
+%rename("%s") TFE_Py_TapeVariableAccessed;
%rename("%s") TFE_Py_TapeWatch;
+%rename("%s") TFE_Py_TapeWatchVariable;
%rename("%s") TFE_Py_TapeWatchedVariables;
%rename("%s") TFE_NewContextOptions;
%rename("%s") TFE_ContextOptionsSetConfig;
@@ -65,6 +66,7 @@ limitations under the License.
%rename("%s") TFE_Py_TensorShapeOnDevice;
%rename("%s") TFE_ContextStartStep;
%rename("%s") TFE_ContextEndStep;
+%rename("%s") TFE_Py_RegisterVSpace;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 7a37eda5ea..c9bc33e218 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -225,6 +225,7 @@ py_library(
":signature_constants",
":utils",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index f8ad788f77..37f927f381 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -21,9 +21,7 @@ from __future__ import print_function
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.util.tf_export import tf_export
@@ -316,80 +314,3 @@ def _is_valid_classification_signature(signature_def):
return True
-
-def _get_shapes_from_tensor_info_dict(tensor_info_dict):
- """Returns a map of keys to TensorShape objects.
-
- Args:
- tensor_info_dict: map with TensorInfo proto as values.
-
- Returns:
- Map with corresponding TensorShape objects as values.
- """
- return {
- key: tensor_shape.TensorShape(tensor_info.tensor_shape)
- for key, tensor_info in tensor_info_dict.items()
- }
-
-
-def _get_types_from_tensor_info_dict(tensor_info_dict):
- """Returns a map of keys to DType objects.
-
- Args:
- tensor_info_dict: map with TensorInfo proto as values.
-
- Returns:
- Map with corresponding DType objects as values.
- """
- return {
- key: dtypes.DType(tensor_info.dtype)
- for key, tensor_info in tensor_info_dict.items()
- }
-
-
-def get_signature_def_input_shapes(signature):
- """Returns map of parameter names to their shapes.
-
- Args:
- signature: SignatureDef proto.
-
- Returns:
- Map from string to TensorShape objects.
- """
- return _get_shapes_from_tensor_info_dict(signature.inputs)
-
-
-def get_signature_def_input_types(signature):
- """Returns map of output names to their types.
-
- Args:
- signature: SignatureDef proto.
-
- Returns:
- Map from string to DType objects.
- """
- return _get_types_from_tensor_info_dict(signature.inputs)
-
-
-def get_signature_def_output_shapes(signature):
- """Returns map of output names to their shapes.
-
- Args:
- signature: SignatureDef proto.
-
- Returns:
- Map from string to TensorShape objects.
- """
- return _get_shapes_from_tensor_info_dict(signature.outputs)
-
-
-def get_signature_def_output_types(signature):
- """Returns map of output names to their types.
-
- Args:
- signature: SignatureDef proto.
-
- Returns:
- Map from string to DType objects.
- """
- return _get_types_from_tensor_info_dict(signature.outputs)
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
index ebc5450633..18c55d8d33 100644
--- a/tensorflow/python/saved_model/signature_def_utils_test.py
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -275,44 +275,6 @@ class SignatureDefUtilsTest(test.TestCase):
self.assertEqual(method_name, signature_def.method_name)
self.assertEqual(3, len(signature_def.outputs))
- def testGetShapeAndTypes(self):
- inputs = {
- "input-1": constant_op.constant(["a", "b"]),
- "input-2": array_ops.placeholder(dtypes.float32, [10, 11]),
- }
- outputs = {
- "output-1": array_ops.placeholder(dtypes.float32, [10, 32]),
- "output-2": constant_op.constant([["b"]]),
- }
- signature_def = _make_signature(inputs, outputs)
- self.assertEqual(
- signature_def_utils_impl.get_signature_def_input_shapes(signature_def),
- {"input-1": [2], "input-2": [10, 11]})
- self.assertEqual(
- signature_def_utils_impl.get_signature_def_output_shapes(signature_def),
- {"output-1": [10, 32], "output-2": [1, 1]})
- self.assertEqual(
- signature_def_utils_impl.get_signature_def_input_types(signature_def),
- {"input-1": dtypes.string, "input-2": dtypes.float32})
- self.assertEqual(
- signature_def_utils_impl.get_signature_def_output_types(signature_def),
- {"output-1": dtypes.float32, "output-2": dtypes.string})
-
- def testGetNonFullySpecifiedShapes(self):
- outputs = {
- "output-1": array_ops.placeholder(dtypes.float32, [None, 10, None]),
- "output-2": array_ops.sparse_placeholder(dtypes.float32),
- }
- signature_def = _make_signature({}, outputs)
- shapes = signature_def_utils_impl.get_signature_def_output_shapes(
- signature_def)
- self.assertEqual(len(shapes), 2)
- # Must compare shapes with as_list() since 2 equivalent non-fully defined
- # shapes are not equal to each other.
- self.assertEqual(shapes["output-1"].as_list(), [None, 10, None])
- # Must compare `dims` since its an unknown shape.
- self.assertEqual(shapes["output-2"].dims, None)
-
def _assertValidSignature(self, inputs, outputs, method_name):
signature_def = signature_def_utils_impl.build_signature_def(
inputs, outputs, method_name)
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 01d43e09d1..1c1a1a54cd 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -137,6 +137,7 @@ py_test(
size = "small",
srcs = ["strip_unused_test.py"],
srcs_version = "PY2AND3",
+ tags = ["notap"],
deps = [
":strip_unused_lib",
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/tools/api/generator/api_gen.bzl b/tensorflow/python/tools/api/generator/api_gen.bzl
index 2810d83bd2..271cf2afaf 100644
--- a/tensorflow/python/tools/api/generator/api_gen.bzl
+++ b/tensorflow/python/tools/api/generator/api_gen.bzl
@@ -12,10 +12,15 @@ ESTIMATOR_API_INIT_FILES = [
# END GENERATED ESTIMATOR FILES
]
+def get_compat_files(
+ file_paths,
+ compat_api_version):
+ """Prepends compat/v<compat_api_version> to file_paths."""
+ return ["compat/v%d/%s" % (compat_api_version, f) for f in file_paths]
+
def gen_api_init_files(
name,
output_files = TENSORFLOW_API_INIT_FILES,
- compat_output_files = {},
root_init_template = None,
srcs = [],
api_name = "tensorflow",
@@ -23,7 +28,8 @@ def gen_api_init_files(
compat_api_versions = [],
package = "tensorflow.python",
package_dep = "//tensorflow/python:no_contrib",
- output_package = "tensorflow"):
+ output_package = "tensorflow",
+ output_dir = ""):
"""Creates API directory structure and __init__.py files.
Creates a genrule that generates a directory structure with __init__.py
@@ -37,8 +43,6 @@ def gen_api_init_files(
tf_export. For e.g. if an op is decorated with
@tf_export('module1.module2', 'module3'). Then, output_files should
include module1/module2/__init__.py and module3/__init__.py.
- compat_output_files: Dictionary mapping each compat_api_version to the
- set of __init__.py file paths that should be generated for that version.
root_init_template: Python init file that should be used as template for
root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
template will be replaced with root imports collected by this genrule.
@@ -53,14 +57,16 @@ def gen_api_init_files(
process
package_dep: Python library target containing your package.
output_package: Package where generated API will be added to.
+ output_dir: Subdirectory to output API to.
+ If non-empty, must end with '/'.
"""
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
- api_gen_binary_target = "create_" + package + "_api"
+ api_gen_binary_target = ("create_" + package + "_api_%d") % api_version
native.py_binary(
- name = "create_" + package + "_api",
+ name = api_gen_binary_target,
srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
main = "//tensorflow/python/tools/api/generator:create_python_api.py",
srcs_version = "PY2AND3",
@@ -72,14 +78,9 @@ def gen_api_init_files(
],
)
- all_output_files = list(output_files)
+ all_output_files = ["%s%s" % (output_dir, f) for f in output_files]
compat_api_version_flags = ""
for compat_api_version in compat_api_versions:
- compat_files = compat_output_files.get(compat_api_version, [])
- all_output_files.extend([
- "compat/v%d/%s" % (compat_api_version, f)
- for f in compat_files
- ])
compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version
native.genrule(
@@ -87,12 +88,15 @@ def gen_api_init_files(
outs = all_output_files,
cmd = (
"$(location :" + api_gen_binary_target + ") " +
- root_init_template_flag + " --apidir=$(@D) --apiname=" +
- api_name + " --apiversion=" + str(api_version) +
+ root_init_template_flag + " --apidir=$(@D)" + output_dir +
+ " --apiname=" + api_name + " --apiversion=" + str(api_version) +
compat_api_version_flags + " --package=" + package +
" --output_package=" + output_package + " $(OUTS)"
),
srcs = srcs,
tools = [":" + api_gen_binary_target],
- visibility = ["//tensorflow:__pkg__"],
+ visibility = [
+ "//tensorflow:__pkg__",
+ "//tensorflow/tools/api/tests:__pkg__",
+ ],
)
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 6716c79f87..c5289564fe 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -546,7 +546,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
input_examples = preprocess_input_examples_arg_string(input_examples_str)
for input_tensor_key, (filename, variable_name) in inputs.items():
- data = np.load(file_io.FileIO(filename, mode='r'))
+ data = np.load(file_io.FileIO(filename, mode='rb'))
# When a variable_name key is specified for the input file
if variable_name:
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 76625624e4..3bd4bd75bd 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -1025,7 +1025,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
def before_run(self, run_context):
self._request_summary = (
- self._next_step is None or
+ self._next_step is not None and
self._timer.should_trigger_for_step(self._next_step))
requests = {"global_step": self._global_step_tensor}
opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1035,6 +1035,10 @@ class ProfilerHook(session_run_hook.SessionRunHook):
def after_run(self, run_context, run_values):
stale_global_step = run_values.results["global_step"]
+ if self._next_step is None:
+ # Update the timer so that it does not activate until N steps or seconds
+ # have passed.
+ self._timer.update_last_triggered_step(stale_global_step)
global_step = stale_global_step + 1
if self._request_summary:
global_step = run_context.session.run(self._global_step_tensor)
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index b49a871a56..fe8a3e9062 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -1454,52 +1454,50 @@ class ProfilerHookTest(test.TestCase):
with self.assertRaises(ValueError):
basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)
- def test_save_secs_saves_in_first_step(self):
+ def test_save_secs_does_not_save_in_first_step(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
sess.run(self.train_op)
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
@test.mock.patch.object(time, 'time')
def test_save_secs_saves_periodically(self, mock_time):
# Pick a fixed start time.
- current_time = 1484863632.320497
+ current_time = 1484863632.
with self.graph.as_default():
mock_time.return_value = current_time
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
- sess.run(self.train_op) # Saved.
- self.assertEqual(1, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
# Simulate 2.5 seconds of sleep.
mock_time.return_value = current_time + 2.5
sess.run(self.train_op) # Saved.
+ self.assertEqual(1, self._count_timeline_files())
# Pretend some small amount of time has passed.
- mock_time.return_value = current_time + 0.1
+ mock_time.return_value = current_time + 2.6
sess.run(self.train_op) # Not saved.
# Edge test just before we should save the timeline.
- mock_time.return_value = current_time + 1.9
+ mock_time.return_value = current_time + 4.4
sess.run(self.train_op) # Not saved.
- self.assertEqual(2, self._count_timeline_files())
+ self.assertEqual(1, self._count_timeline_files())
mock_time.return_value = current_time + 4.5
sess.run(self.train_op) # Saved.
- self.assertEqual(3, self._count_timeline_files())
+ self.assertEqual(2, self._count_timeline_files())
- def test_save_steps_saves_in_first_step(self):
+ def test_save_steps_does_not_save_in_first_step(self):
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=1, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
- sess.run(self.train_op) # Saved.
sess.run(self.train_op) # Not saved.
- self.assertEqual(1, self._count_timeline_files())
+ self.assertEqual(0, self._count_timeline_files())
def test_save_steps_saves_periodically(self):
with self.graph.as_default():
@@ -1507,6 +1505,8 @@ class ProfilerHookTest(test.TestCase):
save_steps=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
self.assertEqual(0, self._count_timeline_files())
+ sess.run(self.train_op) # Not saved.
+ self.assertEqual(0, self._count_timeline_files())
sess.run(self.train_op) # Saved.
self.assertEqual(1, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
@@ -1515,20 +1515,19 @@ class ProfilerHookTest(test.TestCase):
self.assertEqual(2, self._count_timeline_files())
sess.run(self.train_op) # Not saved.
self.assertEqual(2, self._count_timeline_files())
- sess.run(self.train_op) # Saved.
- self.assertEqual(3, self._count_timeline_files())
- def test_run_metadata_saves_in_first_step(self):
+ def test_run_metadata_saves(self):
writer_cache.FileWriterCache.clear()
fake_summary_writer.FakeSummaryWriter.install()
fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
with self.graph.as_default():
hook = basic_session_run_hooks.ProfilerHook(
- save_secs=2, output_dir=self.output_dir)
+ save_steps=1, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
+ sess.run(self.train_op) # Not saved.
sess.run(self.train_op) # Saved.
self.assertEqual(
- list(fake_writer._added_run_metadata.keys()), ['step_1'])
+ list(fake_writer._added_run_metadata.keys()), ['step_2'])
fake_summary_writer.FakeSummaryWriter.uninstall()
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 9189d8f3e8..095a90ddd4 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -17,11 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import functools
import json
import weakref
+import six
+
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor):
return self._checkpoint_position
-class PythonStringStateSaveable(saveable_object.SaveableObject):
+class NoRestoreSaveable(saveable_object.SaveableObject):
+ """Embeds a tensor in a checkpoint with no restore ops."""
+
+ def __init__(self, tensor, name, dtype=None):
+ spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
+ super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ return control_flow_ops.no_op()
+
+
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateSaveable(saveable_object.SaveableObject):
+ """An interface for saving/restoring volatile Python state."""
+
+ @abc.abstractmethod
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed.
+
+ Returns:
+ A dictionary mapping `Tensor`s to current Python state.
+ """
+ pass
+
+ @abc.abstractmethod
+ def freeze(self):
+ """Create a new `SaveableObject` which freezes current state as a constant.
+
+ Used when executing eagerly to embed the current state as a constant, or
+ when creating a static tf.train.Saver with the frozen current Python state.
+
+ Returns:
+ A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
+ no Python state associated with it).
+ """
+ pass
+
+
+class PythonStringStateSaveable(PythonStateSaveable):
"""Saves Python state in a checkpoint."""
def __init__(self, name, state_callback, restore_callback=None):
@@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject):
restore_callback: A function taking a Python string, used to restore
state. Optional; defaults to doing nothing.
"""
+ self._state_callback = state_callback
self._restore_callback = restore_callback
- if context.executing_eagerly():
- self._save_string = (
- lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
- else:
+ with ops.device("/cpu:0"):
self._save_string = constant_op.constant("", dtype=dtypes.string)
- self.feed_dict_additions = (
- lambda: {self._save_string: state_callback()})
spec = saveable_object.SaveSpec(
self._save_string, "", name, dtype=dtypes.string)
super(PythonStringStateSaveable, self).__init__(
self._save_string, [spec], name)
+ def feed_dict_additions(self):
+ """When running a graph, indicates fresh state to feed."""
+ return {self._save_string: self._state_callback()}
+
+ def freeze(self):
+ """Create a frozen `SaveableObject` which saves the current state."""
+ return NoRestoreSaveable(
+ tensor=self._state_callback,
+ dtype=dtypes.string,
+ name=self.name)
+
def python_restore(self, restored_strings):
"""Called to restore Python state."""
if self._restore_callback:
@@ -309,7 +357,7 @@ class _CheckpointPosition(object):
if self._checkpoint.saveable_object_cache is not None:
self._checkpoint.saveable_object_cache.setdefault(
self.checkpointable, {})[serialized_tensor.name] = [saveable]
- if isinstance(saveable, PythonStringStateSaveable):
+ if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable)
else:
named_saveables[serialized_tensor.checkpoint_key] = saveable
@@ -819,7 +867,7 @@ class CheckpointableBase(object):
def _state_callback():
dereferenced_self = weak_self()
if dereferenced_self:
- return json.dumps(self,
+ return json.dumps(dereferenced_self,
default=serialization.get_json_type,
sort_keys=True).encode("utf8")
else:
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 13dddd37ac..56c4043d9d 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
@@ -557,7 +556,14 @@ def _serialize_checkpointables(
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
named_saveables = []
- feed_additions = {}
+ if saveables_cache is None:
+ # No SaveableObject caching. Either we're executing eagerly, or building a
+ # static save which is specialized to the current Python state.
+ feed_additions = None
+ else:
+ # If we are caching SaveableObjects, we need to build up a feed_dict with
+ # functions computing volatile Python state to be saved with the checkpoint.
+ feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
@@ -616,18 +622,25 @@ def _serialize_checkpointables(
for saveable in saveables:
if hasattr(saveable, "full_name"):
attribute.full_name = saveable.full_name
- saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
- if saveable_feed_dict_fn is not None:
- saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
- for new_feed_key in saveable_feed_dict.keys():
- if new_feed_key in feed_additions:
- raise AssertionError(
- ("The object %s tried to feed a value for the Tensor %s "
- "when saving, but another object is already feeding a "
- "value.")
- % (checkpointable, new_feed_key))
- feed_additions.update(saveable_feed_dict)
- named_saveables.extend(saveables)
+ if isinstance(saveable, base.PythonStateSaveable):
+ if feed_additions is None:
+ assert saveables_cache is None
+ # If we're not caching saveables, then we're either executing
+ # eagerly or building a static save/restore (e.g. for a
+ # SavedModel). In either case, we should embed the current Python
+ # state in the graph rather than relying on a feed dict.
+ saveable = saveable.freeze()
+ else:
+ saveable_feed_dict = saveable.feed_dict_additions()
+ for new_feed_key in saveable_feed_dict.keys():
+ if new_feed_key in feed_additions:
+ raise AssertionError(
+ ("The object %s tried to feed a value for the Tensor %s "
+ "when saving, but another object is already feeding a "
+ "value.")
+ % (checkpointable, new_feed_key))
+ feed_additions.update(saveable_feed_dict)
+ named_saveables.append(saveable)
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
@@ -827,16 +840,6 @@ def capture_dependencies(template):
yield
-class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
-
- def __init__(self, tensor, name):
- spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name)
- super(_NoRestoreSaveable, self).__init__(tensor, [spec], name)
-
- def restore(self, restored_tensors, restored_shapes):
- return control_flow_ops.no_op()
-
-
class _LoadStatus(object):
"""Abstract base for load status callbacks."""
@@ -1241,6 +1244,78 @@ class CheckpointableSaver(object):
else:
return self._root_checkpointable_ref
+ def _gather_saveables(
+ self, object_graph_tensor=None, saveable_object_cache=None):
+ """Wraps _serialize_object_graph to include the object graph proto."""
+ assert ((object_graph_tensor is None and saveable_object_cache is None)
+ or (object_graph_tensor is not None
+ and saveable_object_cache is not None))
+ (named_saveable_objects, graph_proto,
+ feed_additions) = _serialize_object_graph(
+ self._root_checkpointable,
+ saveables_cache=saveable_object_cache)
+ if object_graph_tensor is None:
+ with ops.device("/cpu:0"):
+ object_graph_tensor = constant_op.constant(
+ graph_proto.SerializeToString(), dtype=dtypes.string)
+ else:
+ feed_additions.update(
+ {object_graph_tensor: graph_proto.SerializeToString()})
+ assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
+ named_saveable_objects.append(
+ base.NoRestoreSaveable(
+ tensor=object_graph_tensor,
+ name=base.OBJECT_GRAPH_PROTO_KEY))
+ return named_saveable_objects, graph_proto, feed_additions
+
+ def freeze(self):
+ """Creates a `tf.train.Saver` with the current object graph frozen."""
+ named_saveable_objects, _, _ = self._gather_saveables(
+ object_graph_tensor=None, saveable_object_cache=None)
+ return saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+
+ def _prepare_save(self,
+ object_graph_tensor=None,
+ saveable_object_cache=None):
+ """Create or retrieve save ops.
+
+ When graph building, `saveable_object_cache` will typically be non-`None`,
+ meaning that existing `SaveableObject`s are re-used across calls to
+ `_prepare_save` even if the object graph has grown. This avoids
+ unnecessarily re-creating save ops.
+
+ Args:
+ object_graph_tensor: A `Tensor` to which the current object graph will be
+ fed.
+ saveable_object_cache: A dictionary; if specified, used to cache
+ `SaveableObject`s.
+
+ Returns:
+ A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
+ to feed when running save ops. The feed dict contains the current object
+ graph and any Python state to be saved in the checkpoint.
+ """
+ (named_saveable_objects, graph_proto,
+ feed_additions) = self._gather_saveables(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=saveable_object_cache)
+ if (self._last_save_object_graph != graph_proto
+ # When executing eagerly, we need to re-create SaveableObjects each time
+ # save() is called so they pick up new Tensors passed to their
+ # constructors. That means the Saver needs to be copied with a new
+ # var_list.
+ or context.executing_eagerly()):
+ if self._last_save_object_graph is not None:
+ self._last_save_saver = _copy_saver_with_new_var_list(
+ old_saver=self._last_save_saver,
+ new_var_list=named_saveable_objects)
+ else:
+ self._last_save_saver = saver_lib.Saver(
+ var_list=named_saveable_objects, max_to_keep=None)
+ self._last_save_object_graph = graph_proto
+ return self._last_save_saver, feed_additions
+
def save(self, file_prefix, checkpoint_number=None, session=None):
"""Save a training checkpoint.
@@ -1263,44 +1338,29 @@ class CheckpointableSaver(object):
Returns:
The full path to the checkpoint.
"""
- named_variables, graph_proto, feed_additions = _serialize_object_graph(
- self._root_checkpointable,
- saveables_cache=self._saveable_object_cache)
- if not context.executing_eagerly():
- if session is None:
- session = ops.get_default_session()
+ feed_additions = {}
+ graph_building = not context.executing_eagerly()
+ if graph_building:
if self._object_graph_feed_tensor is None:
with ops.device("/cpu:0"):
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
- feed_additions.update(
- {object_graph_tensor: graph_proto.SerializeToString()})
else:
+ object_graph_tensor = None
+
+ saver, new_feed_additions = self._prepare_save(
+ object_graph_tensor=object_graph_tensor,
+ saveable_object_cache=self._saveable_object_cache)
+ if new_feed_additions:
+ feed_additions.update(new_feed_additions)
+ if not graph_building:
session = None
- with ops.device("/cpu:0"):
- object_graph_tensor = constant_op.constant(
- graph_proto.SerializeToString(), dtype=dtypes.string)
- assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
- named_variables.append(
- _NoRestoreSaveable(
- tensor=object_graph_tensor,
- name=base.OBJECT_GRAPH_PROTO_KEY))
- if (self._last_save_object_graph != graph_proto
- # When executing eagerly, we need to re-create SaveableObjects each time
- # save() is called so they pick up new Tensors passed to their
- # constructors. That means the Saver needs to be copied with a new
- # var_list.
- or context.executing_eagerly()):
- if self._last_save_object_graph is not None:
- self._last_save_saver = _copy_saver_with_new_var_list(
- old_saver=self._last_save_saver, new_var_list=named_variables)
- else:
- self._last_save_saver = saver_lib.Saver(
- var_list=named_variables, max_to_keep=None)
- self._last_save_object_graph = graph_proto
+ elif session is None:
+ session = ops.get_default_session()
+
with ops.device("/cpu:0"):
- save_path = self._last_save_saver.save(
+ save_path = saver.save(
sess=_SessionWithFeedDictAdditions(
session=session, feed_additions=feed_additions),
save_path=file_prefix,
@@ -1422,6 +1482,30 @@ class CheckpointableSaver(object):
return load_status
+def frozen_saver(root_checkpointable):
+ """Creates a static `tf.train.Saver` from a checkpointable object.
+
+ The returned `Saver` saves object-based checkpoints, but these checkpoints
+ will no longer reflect structural changes to the object graph, only changes to
+ the values of `Variable`s added as dependencies of the root object before
+ `freeze` was called.
+
+ `restore` works on the returned `Saver`, but requires that the object graph of
+ the checkpoint being loaded exactly matches the object graph when `freeze` was
+ called. This is in contrast the object-based restore performed by
+ `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
+ object graph and the current Python object graph.
+
+ Args:
+ root_checkpointable: A checkpointable object to save.
+
+ Returns:
+ A `tf.train.Saver` which saves object-based checkpoints for the object graph
+ frozen at the time `frozen_saver` was called.
+ """
+ return CheckpointableSaver(root_checkpointable).freeze()
+
+
@tf_export("train.Checkpoint")
class Checkpoint(tracking.Checkpointable):
"""Groups checkpointable objects, saving and restoring them.
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index bef4bf2a16..0d32d21426 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -560,6 +560,46 @@ class CheckpointingTests(test.TestCase):
self.evaluate(root.save_counter))
@test_util.run_in_graph_and_eager_modes
+ def testFreezing(self):
+ with self.cached_session(use_gpu=True) as session:
+ # Save an object-based checkpoint using a frozen saver
+ directory = self.get_temp_dir()
+ prefix = os.path.join(directory, "ckpt")
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ checkpoint = checkpointable_utils.Checkpoint(v=v)
+ self.evaluate(v.assign(3))
+ # Create the save counter so assert_consumed doesn't complain about it not
+ # existing in the checkpoint on restore.
+ self.evaluate(checkpoint.save_counter.assign(12))
+ saver = checkpointable_utils.frozen_saver(checkpoint)
+ save_path = saver.save(session, prefix)
+ self.evaluate(v.assign(10))
+ # Use the frozen saver to restore the same object graph
+ saver.restore(session, save_path)
+ self.assertEqual(3, self.evaluate(v))
+
+ # Restore using another frozen saver on an identical object graph
+ del v, checkpoint, saver
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ checkpoint = checkpointable_utils.Checkpoint(v=v)
+ saver = checkpointable_utils.frozen_saver(checkpoint)
+ saver.restore(session, save_path)
+ self.assertEqual(3, self.evaluate(v))
+
+ # Restore as an object-based checkpoint
+ del v, checkpoint, saver
+ checkpoint = checkpointable_utils.Checkpoint()
+ status = checkpoint.restore(save_path)
+ v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
+ if context.executing_eagerly():
+ self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+ self.assertEqual(0, self.evaluate(v))
+ checkpoint.v = v
+ status.assert_consumed().run_restore_ops()
+ self.assertEqual(3, self.evaluate(v))
+ self.assertEqual(12, self.evaluate(checkpoint.save_counter))
+
+ @test_util.run_in_graph_and_eager_modes
def testCustomNumbering(self):
directory = self.get_temp_dir()
prefix = os.path.join(directory, "ckpt")
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 6d336ac39d..104a615636 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -104,9 +104,36 @@ Raises:
%unignore tensorflow::swig::Flatten;
%noexception tensorflow::swig::Flatten;
+%feature("docstring") tensorflow::swig::IsSequenceForData
+"""Returns a true if `seq` is a Sequence or dict (except strings/lists).
+
+NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
+which *does* treat a Python list as a sequence. For ergonomic
+reasons, `tf.data` users would prefer to treat lists as
+implicit `tf.Tensor` objects, and dicts as (nested) sequences.
+
+Args:
+ seq: an input sequence.
+
+Returns:
+ True if the sequence is a not a string or list and is a
+ collections.Sequence.
+"""
%unignore tensorflow::swig::IsSequenceForData;
%noexception tensorflow::swig::IsSequenceForData;
+%feature("docstring") tensorflow::swig::FlattenForData
+"""Returns a flat sequence from a given nested structure.
+
+If `nest` is not a sequence, this returns a single-element list: `[nest]`.
+
+Args:
+ nest: an arbitrarily nested structure or a scalar object.
+ Note, numpy arrays are considered scalars.
+
+Returns:
+ A Python list, the flattened version of the input.
+"""
%unignore tensorflow::swig::FlattenForData;
%noexception tensorflow::swig::FlattenForData;
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 207f22c931..3c533c7f99 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -3275,6 +3275,26 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
"This configuration potentially produces incorrect results.");
}());
+ // Zero out the result buffer for strided conv backward filter for NHWC
+ // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not
+ // zeroed.
+ //
+ // This wrong result caused by the bug is very flaky. It needs to be run for
+ // up to 20 times to produce a mismatch.
+ //
+ // TODO(timshen): add a nvbugs link.
+ if (CUDNN_VERSION >= 7100 &&
+ algorithm_config.algorithm().algo_id() ==
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 &&
+ cudnn_type == CUDNN_DATA_HALF &&
+ input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+ filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
+ output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+ (convolution_descriptor.vertical_filter_stride() > 1 ||
+ convolution_descriptor.horizontal_filter_stride() > 1)) {
+ stream->ThenMemZero(backward_filter_data, backward_filter_data->size());
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
cudnn.handle(),
/*alpha=*/alpha,
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
index d23b3bd0ca..15e0ab76b6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
index cbf655498c..2f4257a66a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
name: "gradient"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
index 2260279ad2..39ff336c4f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 834f0954d5..87745420ee 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 4d854a4cee..6dd46365b0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index 601f095a60..35b7105eba 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 587829a4c0..8ae370af98 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
index 0853716023..614ba42d3e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.python_io.-t-f-record-options.pbtxt
@@ -8,7 +8,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "get_compression_type_string"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
index d23b3bd0ca..15e0ab76b6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
index cbf655498c..2f4257a66a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'persistent\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
name: "gradient"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
index 2260279ad2..39ff336c4f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt
@@ -17,7 +17,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], "
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
}
member_method {
name: "apply_grad"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 834f0954d5..87745420ee 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -60,7 +60,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 4d854a4cee..6dd46365b0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 601f095a60..35b7105eba 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 587829a4c0..8ae370af98 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -61,7 +61,7 @@ tf_class {
}
member_method {
name: "interleave"
- argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "list_files"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
index 0853716023..614ba42d3e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.python_io.-t-f-record-options.pbtxt
@@ -8,7 +8,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'compression_type\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "get_compression_type_string"
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 8764409e4d..4efa4a9651 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -15,7 +15,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
py_test(
name = "api_compatibility_test",
- srcs = ["api_compatibility_test.py"],
+ srcs = [
+ "api_compatibility_test.py",
+ "//tensorflow:tf_python_api_gen_v2",
+ ],
data = [
"//tensorflow/tools/api/golden:api_golden_v1",
"//tensorflow/tools/api/golden:api_golden_v2",
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 43d19bc99c..99bed5714f 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,6 +34,7 @@ import sys
import unittest
import tensorflow as tf
+from tensorflow._api import v2 as tf_v2
from google.protobuf import message
from google.protobuf import text_format
@@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v1, visitor)
+ traverse.traverse(tf_v2.compat.v1, visitor)
def testNoSubclassOfMessageV2(self):
if not hasattr(tf.compat, 'v2'):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v2, visitor)
+ traverse.traverse(tf_v2, visitor)
def _checkBackwardsCompatibility(
self, root, golden_file_pattern, api_version,
@@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase):
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV1(self):
- if not hasattr(tf.compat, 'v1'):
- return
api_version = 1
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v1, golden_file_pattern, api_version)
+ tf_v2.compat.v1, golden_file_pattern, api_version)
@unittest.skipUnless(
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV2(self):
- if not hasattr(tf.compat, 'v2'):
- return
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v2, golden_file_pattern, api_version)
+ tf_v2, golden_file_pattern, api_version,
+ additional_private_map={'tf.compat': ['v1']})
if __name__ == '__main__':
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu b/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu
new file mode 100644
index 0000000000..08dc026328
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.gcc.gpu
@@ -0,0 +1,43 @@
+# To push a new version, run:
+# $ docker build -f Dockerfile.rbe.gcc.gpu \
+# --tag "gcr.io/asci-toolchain/nosla-nvidia-gcc" .
+# $ docker push gcr.io/asci-toolchain/nosla-nvidia-gcc
+FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
+
+LABEL maintainer="Manuel Klimek <klimek@google.com>"
+
+# TODO(b/110903506): Fix the nvidia docker image by providing a link to the
+# SONAME of libcuda.so. Alternatively, consider using gold or lld which do not
+# run into the same problem - that will only work once the tensorflow build does
+# not link to libcuda from generators anymore.
+# https://github.com/NVIDIA/nvidia-docker/issues/775
+RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
+
+# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find
+# libnccl is resolved, delete this block.
+RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \
+ && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2
+
+# TODO(b/110903506): Fix tensorflow to not require the use of LD_LIBRARY_PATH.
+# The stubs/libcuda.so is not meant to used at runtime. The correct way to
+# pass the path to bfd-ld is to pass -Wl,-rpath-link=/usr/local/cuda/lib64/stubs
+# to all binaries transitively depending on libcuda. Optimally the tensorflow
+# build would do that internally.
+ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa && \
+ add-apt-repository -y ppa:george-edison55/cmake-3.x
+RUN /install/install_deb_packages.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_golang.sh
+
+# Install nccl2.
+RUN apt-get update && apt-get install -y \
+ libnccl2 \
+ libnccl-dev \
+ && rm -rf /var/lib/apt-lists/*
+
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index 1d7d9df72f..c8472102cb 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -86,7 +86,7 @@
# When set, overrides TF_BUILD_IS_OPT and TF_BUILD_MAVX
# options, as this will replace the two.
# TF_SKIP_CONTRIB_TESTS:
-# If set to any non-empty or non-0 value, will skipp running
+# If set to any non-empty or non-0 value, will skip running
# contrib tests.
# TF_NIGHTLY:
# If this run is being used to build the tf_nightly pip
@@ -131,7 +131,13 @@ BAZEL_CMD="bazel test"
BAZEL_BUILD_ONLY_CMD="bazel build"
BAZEL_CLEAN_CMD="bazel clean"
-DEFAULT_BAZEL_CONFIGS=""
+# Default flags:
+# --test_summary=detailed: Tell us more about which targets are being built
+# --keep_going: Don't stop at the first failure; tell us all the failures
+# --build_tests_only: Don't build targets depended on by tests if the test is
+# disabled. Also saves some compilation time. Otherwise,
+# tries to build everything.
+DEFAULT_BAZEL_CONFIGS="--test_summary=detailed --build_tests_only --keep_going"
PIP_CMD="${CI_BUILD_DIR}/builds/pip.sh"
PIP_TEST_TUTORIALS_FLAG="--test_tutorials"
@@ -148,9 +154,7 @@ EXTRA_PARAMS=""
BAZEL_TARGET="//tensorflow/... -//tensorflow/compiler/..."
if [[ -n "$TF_SKIP_CONTRIB_TESTS" ]]; then
- BAZEL_TARGET="$BAZEL_TARGET -//tensorflow/contrib/..."
-else
- BAZEL_TARGET="${BAZEL_TARGET} //tensorflow/contrib/lite/..."
+ BAZEL_TARGET="${BAZEL_TARGET} -//tensorflow/contrib/..."
fi
TUT_TEST_DATA_DIR="/tmp/tf_tutorial_test_data"
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index af478eded4..a9ae715c6a 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -119,6 +119,8 @@ pip2 install keras_applications==1.0.5 --no-deps
pip3 install keras_applications==1.0.5 --no-deps
pip2 install keras_preprocessing==1.0.3 --no-deps
pip3 install keras_preprocessing==1.0.3 --no-deps
+pip2 install --upgrade h5py==2.8.0
+pip3 install --upgrade h5py==2.8.0
# Install last working version of setuptools.
pip2 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 93ea0c3db6..37e6b51f66 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -87,6 +87,7 @@ pip3.5 install --upgrade setuptools==39.1.0
# Keras
pip3.5 install keras_applications==1.0.5
pip3.5 install keras_preprocessing==1.0.3
+pip3.5 install --upgrade h5py==2.8.0
# Install last working version of setuptools.
pip3.5 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index 7a9eef7c64..7520ff74cb 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -99,6 +99,7 @@ pip3 install --upgrade termcolor
# Install last working version of setuptools.
pip3 install --upgrade setuptools==39.1.0
+pip3 install --upgrade h5py==2.8.0
# Keras
pip3 install keras_applications==1.0.5
diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
index 333a89d3f5..c18f0d6e69 100644
--- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
@@ -53,7 +53,7 @@ export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH"
# Setting default values to CUDA related environment variables
export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0}
-export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0}
+export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7}
export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7}
export CUDA_TOOLKIT_PATH=${CUDA_TOOLKIT_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"}
export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"}
diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md
index c484c162cb..d64db35afb 100644
--- a/tensorflow/tools/dockerfiles/README.md
+++ b/tensorflow/tools/dockerfiles/README.md
@@ -2,8 +2,8 @@
This directory houses TensorFlow's Dockerfiles. **DO NOT EDIT THE DOCKERFILES
MANUALLY!** They are maintained by `assembler.py`, which builds Dockerfiles from
-the files in `partials/` and the rules in `spec.yml`. See [the Maintaining
-section](#maintaining) for more information.
+the files in `partials/` and the rules in `spec.yml`. See [the Contributing
+section](#contributing) for more information.
## Building
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 483921fc2f..1cd9cb7ca9 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -36,23 +36,6 @@ from tensorflow.tools.docs import pretty_docs
from tensorflow.tools.docs import py_guide_parser
-def _is_free_function(py_object, full_name, index):
- """Check if input is a free function (and not a class- or static method)."""
- if not tf_inspect.isfunction(py_object):
- return False
-
- # Static methods are functions to tf_inspect (in 2.7), so check if the parent
- # is a class. If there is no parent, it's not a function.
- if '.' not in full_name:
- return False
-
- parent_name = full_name.rsplit('.', 1)[0]
- if tf_inspect.isclass(index[parent_name]):
- return False
-
- return True
-
-
def write_docs(output_dir,
parser_config,
yaml_toc,
@@ -109,7 +92,7 @@ def write_docs(output_dir,
# Methods and some routines are documented only as part of their class.
if not (tf_inspect.ismodule(py_object) or tf_inspect.isclass(py_object) or
- _is_free_function(py_object, full_name, parser_config.index)):
+ parser.is_free_function(py_object, full_name, parser_config.index)):
continue
sitepath = os.path.join('api_docs/python',
@@ -548,6 +531,13 @@ class DocGenerator(object):
help='The path from the site-root to api_docs'
'directory for this project')
+ self.argument_parser.add_argument(
+ '--api_cache_out_path',
+ type=str,
+ default=None,
+ help='Path to store a json-serialized api-index, so links can be '
+ 'inserted into docs without rebuilding the api_docs')
+
def add_output_dir_argument(self):
self.argument_parser.add_argument(
'--output_dir',
@@ -648,6 +638,9 @@ class DocGenerator(object):
visitor = self.run_extraction()
reference_resolver = self.make_reference_resolver(visitor, doc_index)
+ if getattr(flags, 'api_cache_out_path', None):
+ reference_resolver.to_json_file(flags.api_cache_out_path)
+
# Build the guide_index for the api_docs back links.
root_title = getattr(flags, 'root_title', 'TensorFlow')
guide_index = _build_guide_index(
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index 549056c6c4..a6159fa692 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -35,6 +35,28 @@ from tensorflow.python.util import tf_inspect
from tensorflow.tools.docs import doc_controls
+def is_free_function(py_object, full_name, index):
+ """Check if input is a free function (and not a class- or static method).
+
+ Args:
+ py_object: The the object in question.
+ full_name: The full name of the object, like `tf.module.symbol`.
+ index: The {full_name:py_object} dictionary for the public API.
+
+ Returns:
+ True if the obeject is a stand-alone function, and not part of a class
+ definition.
+ """
+ if not tf_inspect.isfunction(py_object):
+ return False
+
+ parent_name = full_name.rsplit('.', 1)[0]
+ if tf_inspect.isclass(index[parent_name]):
+ return False
+
+ return True
+
+
# A regular expression capturing a python identifier.
IDENTIFIER_RE = r'[a-zA-Z_]\w*'
@@ -74,7 +96,7 @@ class _Errors(object):
return self._errors == other._errors # pylint: disable=protected-access
-def documentation_path(full_name):
+def documentation_path(full_name, is_fragment=False):
"""Returns the file path for the documentation for the given API symbol.
Given the fully qualified name of a library symbol, compute the path to which
@@ -84,12 +106,22 @@ def documentation_path(full_name):
Args:
full_name: Fully qualified name of a library symbol.
-
+ is_fragment: If `False` produce a direct markdown link (`tf.a.b.c` -->
+ `tf/a/b/c.md`). If `True` produce fragment link, `tf.a.b.c` -->
+ `tf/a/b.md#c`
Returns:
The file path to which to write the documentation for `full_name`.
"""
- dirs = full_name.split('.')
- return os.path.join(*dirs) + '.md'
+ parts = full_name.split('.')
+ if is_fragment:
+ parts, fragment = parts[:-1], parts[-1]
+
+ result = os.path.join(*parts) + '.md'
+
+ if is_fragment:
+ result = result + '#' + fragment
+
+ return result
def _get_raw_docstring(py_object):
@@ -136,8 +168,7 @@ class ReferenceResolver(object):
doc.
"""
- def __init__(self, duplicate_of, doc_index, is_class, is_module,
- py_module_names):
+ def __init__(self, duplicate_of, doc_index, is_fragment, py_module_names):
"""Initializes a Reference Resolver.
Args:
@@ -145,15 +176,15 @@ class ReferenceResolver(object):
symbols.
doc_index: A `dict` mapping symbol name strings to objects with `url`
and `title` fields. Used to resolve @{$doc} references in docstrings.
- is_class: A map from full names to bool for each symbol.
- is_module: A map from full names to bool for each symbol.
+ is_fragment: A map from full names to bool for each symbol. If True the
+ object lives at a page fragment `tf.a.b.c` --> `tf/a/b#c`. If False
+ object has a page to itself: `tf.a.b.c` --> `tf/a/b/c`.
py_module_names: A list of string names of Python modules.
"""
self._duplicate_of = duplicate_of
self._doc_index = doc_index
- self._is_class = is_class
- self._is_module = is_module
- self._all_names = set(is_class.keys())
+ self._is_fragment = is_fragment
+ self._all_names = set(is_fragment.keys())
self._py_module_names = py_module_names
self.current_doc_full_name = None
@@ -180,21 +211,18 @@ class ReferenceResolver(object):
Returns:
an instance of `ReferenceResolver` ()
"""
- is_class = {
- name: tf_inspect.isclass(visitor.index[name])
- for name, obj in visitor.index.items()
- }
+ is_fragment = {}
+ for name, obj in visitor.index.items():
+ has_page = (
+ tf_inspect.isclass(obj) or tf_inspect.ismodule(obj) or
+ is_free_function(obj, name, visitor.index))
- is_module = {
- name: tf_inspect.ismodule(visitor.index[name])
- for name, obj in visitor.index.items()
- }
+ is_fragment[name] = not has_page
return cls(
duplicate_of=visitor.duplicate_of,
doc_index=doc_index,
- is_class=is_class,
- is_module=is_module,
+ is_fragment=is_fragment,
**kwargs)
@classmethod
@@ -210,6 +238,10 @@ class ReferenceResolver(object):
Args:
filepath: The file path to write the json to.
"""
+ try:
+ os.makedirs(os.path.dirname(filepath))
+ except OSError:
+ pass
json_dict = {}
for key, value in self.__dict__.items():
# Drop these two fields. `_doc_index` is not serializable. `_all_names` is
@@ -223,7 +255,7 @@ class ReferenceResolver(object):
json_dict[key.lstrip('_')] = value
with open(filepath, 'w') as f:
- json.dump(json_dict, f)
+ json.dump(json_dict, f, indent=2, sort_keys=True)
def replace_references(self, string, relative_path_to_root):
"""Replace "@{symbol}" references with links to symbol's documentation page.
@@ -339,19 +371,7 @@ class ReferenceResolver(object):
raise TFDocsError(
'Cannot make link to "%s": Not in index.' % master_name)
- # If this is a member of a class, link to the class page with an anchor.
- ref_path = None
- if not (self._is_class[master_name] or self._is_module[master_name]):
- idents = master_name.split('.')
- if len(idents) > 1:
- class_name = '.'.join(idents[:-1])
- assert class_name in self._all_names
- if self._is_class[class_name]:
- ref_path = documentation_path(class_name) + '#%s' % idents[-1]
-
- if not ref_path:
- ref_path = documentation_path(master_name)
-
+ ref_path = documentation_path(master_name, self._is_fragment[master_name])
return os.path.join(relative_path_to_root, ref_path)
def _one_ref(self, match, relative_path_to_root):
diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py
index 71e96afa10..8a41796fb9 100644
--- a/tensorflow/tools/docs/parser_test.py
+++ b/tensorflow/tools/docs/parser_test.py
@@ -28,6 +28,12 @@ from tensorflow.python.util import tf_inspect
from tensorflow.tools.docs import doc_controls
from tensorflow.tools.docs import parser
+# The test needs a real module. `types.ModuleType()` doesn't work, as the result
+# is a `builtin` module. Using "parser" here is arbitraty. The tests don't
+# depend on the module contents. At this point in the process the public api
+# has already been extracted.
+test_module = parser
+
def test_function(unused_arg, unused_kwarg='default'):
"""Docstring for test function."""
@@ -334,15 +340,16 @@ class ParserTest(googletest.TestCase):
self.assertEqual('my_method', page_info.methods[0].short_name)
def test_docs_for_module(self):
- # Get the current module.
- module = sys.modules[__name__]
index = {
- 'TestModule': module,
- 'TestModule.test_function': test_function,
+ 'TestModule':
+ test_module,
+ 'TestModule.test_function':
+ test_function,
'TestModule.test_function_with_args_kwargs':
- test_function_with_args_kwargs,
- 'TestModule.TestClass': TestClass,
+ test_function_with_args_kwargs,
+ 'TestModule.TestClass':
+ TestClass,
}
visitor = DummyVisitor(index=index, duplicate_of={})
@@ -365,11 +372,13 @@ class ParserTest(googletest.TestCase):
base_dir='/')
page_info = parser.docs_for_object(
- full_name='TestModule', py_object=module, parser_config=parser_config)
+ full_name='TestModule',
+ py_object=test_module,
+ parser_config=parser_config)
# Make sure the brief docstring is present
- self.assertEqual(tf_inspect.getdoc(module).split('\n')[0],
- page_info.doc.brief)
+ self.assertEqual(
+ tf_inspect.getdoc(test_module).split('\n')[0], page_info.doc.brief)
# Make sure that the members are there
funcs = {f_info.obj for f_info in page_info.functions}
@@ -378,8 +387,9 @@ class ParserTest(googletest.TestCase):
classes = {cls_info.obj for cls_info in page_info.classes}
self.assertEqual({TestClass}, classes)
- # Make sure this file is contained as the definition location.
- self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path)
+ # Make sure the module's file is contained as the definition location.
+ self.assertEqual(
+ os.path.relpath(test_module.__file__, '/'), page_info.defined_in.path)
def test_docs_for_function(self):
index = {
@@ -495,6 +505,7 @@ class ParserTest(googletest.TestCase):
duplicate_of = {'tf.third': 'tf.fourth'}
index = {
+ 'tf': test_module,
'tf.fancy': test_function_with_fancy_docstring,
'tf.reference': HasOneMember,
'tf.reference.foo': HasOneMember.foo,
@@ -521,20 +532,18 @@ class ParserTest(googletest.TestCase):
'NumPy has nothing as awesome as this function.\n')
def test_generate_index(self):
- module = sys.modules[__name__]
index = {
- 'TestModule': module,
- 'test_function': test_function,
- 'TestModule.test_function': test_function,
- 'TestModule.TestClass': TestClass,
- 'TestModule.TestClass.a_method': TestClass.a_method,
- 'TestModule.TestClass.a_property': TestClass.a_property,
- 'TestModule.TestClass.ChildClass': TestClass.ChildClass,
- }
- duplicate_of = {
- 'TestModule.test_function': 'test_function'
+ 'tf': test_module,
+ 'tf.TestModule': test_module,
+ 'tf.test_function': test_function,
+ 'tf.TestModule.test_function': test_function,
+ 'tf.TestModule.TestClass': TestClass,
+ 'tf.TestModule.TestClass.a_method': TestClass.a_method,
+ 'tf.TestModule.TestClass.a_property': TestClass.a_property,
+ 'tf.TestModule.TestClass.ChildClass': TestClass.ChildClass,
}
+ duplicate_of = {'tf.TestModule.test_function': 'tf.test_function'}
visitor = DummyVisitor(index=index, duplicate_of=duplicate_of)
@@ -553,7 +562,7 @@ class ParserTest(googletest.TestCase):
self.assertIn('TestModule.test_function', docs)
# Leading backtick to make sure it's included top-level.
# This depends on formatting, but should be stable.
- self.assertIn('<code>test_function', docs)
+ self.assertIn('<code>tf.test_function', docs)
def test_argspec_for_functools_partial(self):
# pylint: disable=unused-argument
@@ -665,22 +674,18 @@ class ParserTest(googletest.TestCase):
duplicate_of = {'AClass': ['AClass2']}
doc_index = {'doc': you_cant_serialize_this}
- is_class = {
+ is_fragment = {
'tf': False,
- 'tf.AClass': True,
- 'tf.AClass2': True,
- 'tf.function': False
- }
- is_module = {
- 'tf': True,
+ 'tf.VERSION': True,
'tf.AClass': False,
+ 'tf.AClass.method': True,
'tf.AClass2': False,
'tf.function': False
}
py_module_names = ['tf', 'tfdbg']
- resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_class,
- is_module, py_module_names)
+ resolver = parser.ReferenceResolver(duplicate_of, doc_index, is_fragment,
+ py_module_names)
outdir = googletest.GetTempDir()
@@ -692,6 +697,23 @@ class ParserTest(googletest.TestCase):
# There are no __slots__, so all fields are visible in __dict__.
self.assertEqual(resolver.__dict__, resolver2.__dict__)
+ def testIsFreeFunction(self):
+
+ result = parser.is_free_function(test_function, 'test_module.test_function',
+ {'test_module': test_module})
+ self.assertTrue(result)
+
+ result = parser.is_free_function(test_function, 'TestClass.test_function',
+ {'TestClass': TestClass})
+ self.assertFalse(result)
+
+ result = parser.is_free_function(TestClass, 'TestClass', {})
+ self.assertFalse(result)
+
+ result = parser.is_free_function(test_module, 'test_module', {})
+ self.assertFalse(result)
+
+
RELU_DOC = """Computes rectified linear: `max(features, 0)`
Args:
diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py
index 448f246e0e..1a3e79621f 100644
--- a/tensorflow/tools/docs/pretty_docs.py
+++ b/tensorflow/tools/docs/pretty_docs.py
@@ -255,8 +255,9 @@ def _build_module_page(page_info):
# at least for basic types.
parts.append('## Other Members\n\n')
+ h3 = '<h3 id="{short_name}"><code>{short_name}</code></h3>\n\n'
for item in page_info.other_members:
- parts.append('`{short_name}`\n\n'.format(**item._asdict()))
+ parts.append(h3.format(**item._asdict()))
return ''.join(parts)
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 61419f25ae..3102239a19 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -167,17 +167,21 @@ class InstallHeaders(Command):
# directories for -I
install_dir = re.sub('/google/protobuf_archive/src', '', install_dir)
- # Copy eigen code into tensorflow/include.
+ # Copy external code headers into tensorflow/include.
# A symlink would do, but the wheel file that gets created ignores
# symlink within the directory hierarchy.
# NOTE(keveman): Figure out how to customize bdist_wheel package so
# we can do the symlink.
- if 'tensorflow/include/external/eigen_archive/' in install_dir:
- extra_dir = install_dir.replace(
- 'tensorflow/include/external/eigen_archive', '')
- if not os.path.exists(extra_dir):
- self.mkpath(extra_dir)
- self.copy_file(header, extra_dir)
+ external_header_locations = [
+ 'tensorflow/include/external/eigen_archive/',
+ 'tensorflow/include/external/com_google_absl/',
+ ]
+ for location in external_header_locations:
+ if location in install_dir:
+ extra_dir = install_dir.replace(location, '')
+ if not os.path.exists(extra_dir):
+ self.mkpath(extra_dir)
+ self.copy_file(header, extra_dir)
if not os.path.exists(install_dir):
self.mkpath(install_dir)
@@ -227,6 +231,8 @@ headers = (list(find_files('*.h', 'tensorflow/core')) +
list(find_files('*.h', 'tensorflow/stream_executor')) +
list(find_files('*.h', 'google/protobuf_archive/src')) +
list(find_files('*', 'third_party/eigen3')) +
+ list(find_files('*.h',
+ 'tensorflow/include/external/com_google_absl')) +
list(find_files('*', 'tensorflow/include/external/eigen_archive')))
setup(
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 2bf867c7e1..0ff695d9f8 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -106,11 +106,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_absl",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/c075ad321696fa5072e097f0a51e4fe76a6fe13e.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/c075ad321696fa5072e097f0a51e4fe76a6fe13e.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz",
],
- sha256 = "cb4e11259742954f88802be6f33c1007c16502d90d68e8898b5e5084264ca8a9",
- strip_prefix = "abseil-cpp-c075ad321696fa5072e097f0a51e4fe76a6fe13e",
+ sha256 = "f4f34f90083d5259f9a1a4067749d842599748d8ca03c1d9fe723124a7045c63",
+ strip_prefix = "abseil-cpp-fb462224c058487763f263b7995d70efd0242c17",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)
@@ -491,11 +491,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
],
- sha256 = "c7252290a113f694cccbb4b325c67b56f3aa6f5b3044524302c0e79db2da7e2a",
- strip_prefix = "llvm-dc6d9ec3646865125d057b6f515b4543df79920a",
+ sha256 = "2bda8dd724ab432c162fb6eace259ccf8a97f13cb627336611bff68da2f33ec2",
+ strip_prefix = "llvm-738b5f5028ef39cbb023967f80fa2e5dd568556b",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
diff --git a/third_party/gpus/cuda/remote.BUILD.tpl b/third_party/gpus/cuda/remote.BUILD.tpl
index f774def5e6..100c7bb7c4 100644
--- a/third_party/gpus/cuda/remote.BUILD.tpl
+++ b/third_party/gpus/cuda/remote.BUILD.tpl
@@ -75,6 +75,11 @@ alias(
)
alias(
+ name = "cudnn_header",
+ actual = "%{remote_cuda_repo}/cuda:cudnn_header",
+)
+
+alias(
name = "cufft",
actual = "%{remote_cuda_repo}/cuda:cufft",
)
diff --git a/third_party/llvm/llvm.autogenerated.BUILD b/third_party/llvm/llvm.autogenerated.BUILD
index 0ac27e26a4..776935739a 100644
--- a/third_party/llvm/llvm.autogenerated.BUILD
+++ b/third_party/llvm/llvm.autogenerated.BUILD
@@ -109,16 +109,23 @@ template_rule(
)
# A common library that all LLVM targets depend on.
+# TODO(b/113996071): We need to glob all potentially #included files and stage
+# them here because LLVM's build files are not strict headers clean, and remote
+# build execution requires all inputs to be depended upon.
cc_library(
name = "config",
- hdrs = [
+ hdrs = glob([
+ "**/*.h",
+ "**/*.def",
+ "**/*.inc.cpp",
+ ]) + [
"include/llvm/Config/AsmParsers.def",
"include/llvm/Config/AsmPrinters.def",
"include/llvm/Config/Disassemblers.def",
"include/llvm/Config/Targets.def",
- "include/llvm/Config/abi-breaking.h",
"include/llvm/Config/config.h",
"include/llvm/Config/llvm-config.h",
+ "include/llvm/Config/abi-breaking.h",
],
defines = llvm_defines,
includes = ["include"],
diff --git a/third_party/nccl/BUILD b/third_party/nccl/BUILD
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/nccl/BUILD
diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl
index 5d1ebf0686..ce9447096e 100644
--- a/third_party/nccl/nccl_configure.bzl
+++ b/third_party/nccl/nccl_configure.bzl
@@ -16,6 +16,7 @@ load(
_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
_TF_NCCL_VERSION = "TF_NCCL_VERSION"
+_TF_NCCL_CONFIG_REPO = "TF_NCCL_CONFIG_REPO"
_DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR"
_DEFINE_NCCL_MINOR = "#define NCCL_MINOR"
@@ -48,25 +49,8 @@ alias(
"""
# Local build results in dynamic link and the license should not be included.
-_NCCL_LOCAL_BUILD_TEMPLATE = """
-filegroup(
- name = "LICENSE",
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "nccl",
- srcs = ["nccl/lib/libnccl.so.%s"],
- hdrs = ["nccl/include/nccl.h"],
- include_prefix = "third_party/nccl",
- strip_include_prefix = "nccl/include",
- deps = [
- "@local_config_cuda//cuda:cuda_headers",
- ],
- visibility = ["//visibility:public"],
-)
-"""
-
+_NCCL_REMOTE_BUILD_TEMPLATE = Label("//third_party/nccl:remote.BUILD.tpl")
+_NCCL_LOCAL_BUILD_TEMPLATE = Label("//third_party/nccl:system.BUILD.tpl")
def _find_nccl_header(repository_ctx, nccl_install_path):
"""Finds the NCCL header on the system.
@@ -137,6 +121,13 @@ def _nccl_configure_impl(repository_ctx):
repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT)
return
+ if _TF_NCCL_CONFIG_REPO in repository_ctx.os.environ:
+ # Forward to the pre-configured remote repository.
+ repository_ctx.template("BUILD", _NCCL_REMOTE_BUILD_TEMPLATE, {
+ "%{target}": repository_ctx.os.environ[_TF_NCCL_CONFIG_REPO],
+ })
+ return
+
nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
if matches_version("1", nccl_version):
# Alias to GitHub target from @nccl_archive.
@@ -148,8 +139,10 @@ def _nccl_configure_impl(repository_ctx):
# Create target for locally installed NCCL.
nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip()
_check_nccl_version(repository_ctx, nccl_install_path, nccl_version)
- repository_ctx.symlink(nccl_install_path, "nccl")
- repository_ctx.file("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE % nccl_version)
+ repository_ctx.template("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE, {
+ "%{version}": nccl_version,
+ "%{install_path}": nccl_install_path,
+ })
nccl_configure = repository_rule(
diff --git a/third_party/nccl/remote.BUILD.tpl b/third_party/nccl/remote.BUILD.tpl
new file mode 100644
index 0000000000..d66fc5563d
--- /dev/null
+++ b/third_party/nccl/remote.BUILD.tpl
@@ -0,0 +1,6 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+alias(name="LICENSE", actual = "%{target}:LICENSE")
+alias(name = "nccl", actual = "%{target}:nccl")
diff --git a/third_party/nccl/system.BUILD.tpl b/third_party/nccl/system.BUILD.tpl
new file mode 100644
index 0000000000..7ca835dedf
--- /dev/null
+++ b/third_party/nccl/system.BUILD.tpl
@@ -0,0 +1,26 @@
+filegroup(
+ name = "LICENSE",
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nccl",
+ srcs = ["libnccl.so.%{version}"],
+ hdrs = ["nccl.h"],
+ include_prefix = "third_party/nccl",
+ deps = [
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "nccl-files",
+ outs = [
+ "libnccl.so.%{version}",
+ "nccl.h",
+ ],
+ cmd = """cp "%{install_path}/include/nccl.h" "$(@D)/nccl.h" &&
+ cp "%{install_path}/lib/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
+)
+