aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2018-06-21 12:03:01 -0700
committerGravatar Mingxing Tan <tanmingxing@google.com>2018-06-21 12:03:01 -0700
commitba86a8ed1e2b1617f40f25ad0107e8448e9e0848 (patch)
treeb9fed8c18eab093ec13279e2195d4137c7f4ada1 /tensorflow
parent9d2d40079c273e8de8644136b452715c0146b907 (diff)
parent7b4080564c268a54a5c0b877b28e67faaadff268 (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD18
-rw-r--r--tensorflow/api_template.__init__.py3
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py29
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc99
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc119
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc68
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h70
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc44
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h31
-rw-r--r--tensorflow/compiler/xla/shape_tree.h22
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc21
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc18
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb11
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops.py4
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt1
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake2
-rw-r--r--tensorflow/contrib/data/__init__.py13
-rw-r--r--tensorflow/contrib/data/kernels/threadpool_dataset_op.cc27
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD21
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py (renamed from tensorflow/contrib/data/python/ops/iterator_ops_test.py)0
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py59
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD20
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py14
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py6
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py14
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py6
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py6
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py11
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py12
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py6
-rw-r--r--tensorflow/contrib/distribute/python/BUILD1
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py8
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py6
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py6
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py42
-rw-r--r--tensorflow/contrib/distribute/python/monitor_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py10
-rw-r--r--tensorflow/contrib/distribute/python/values.py69
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py22
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py15
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/BUILD59
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/config.py72
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/ops.py71
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/ops_test.py59
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan.py232
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan_test.py101
-rw-r--r--tensorflow/contrib/eager/python/tfe.py1
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py20
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc93
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc147
-rw-r--r--tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc64
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h282
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h390
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h290
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h354
-rw-r--r--tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc62
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h48
-rw-r--r--tensorflow/contrib/lite/kernels/log_softmax_test.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc118
-rw-r--r--tensorflow/contrib/lite/kernels/mul_test.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc57
-rw-r--r--tensorflow/contrib/lite/kernels/softmax_test.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc12
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc2
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc36
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc16
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc60
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc25
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc24
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc36
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc4
-rw-r--r--tensorflow/contrib/lite/toco/model.h111
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc4
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc46
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc42
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc4
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc78
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h2
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util_test.cc6
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD11
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc52
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.h22
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc57
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_params.h101
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc54
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h11
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc64
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/command_line_flags.h27
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc43
-rwxr-xr-xtensorflow/contrib/makefile/build_all_android.sh8
-rwxr-xr-xtensorflow/contrib/makefile/build_all_ios.sh8
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py76
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py38
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Acos.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Add.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_AsString.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Asin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atan.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cos.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cross.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Diag.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Equal.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Exp.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FFT.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Floor.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Greater.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Less.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Log.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Qr.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Rint.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Sin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt3
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt3
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Substr.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tan.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tile.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt10
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc24
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc93
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h10
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc119
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD3
-rw-r--r--tensorflow/core/kernels/conv_2d.h2
-rw-r--r--tensorflow/core/kernels/cwise_op_equal_to_1.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_greater.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_greater_equal.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_less.cc7
-rw-r--r--tensorflow/core/kernels/cwise_op_less_equal.cc7
-rw-r--r--tensorflow/core/kernels/cwise_op_not_equal_to_1.cc4
-rw-r--r--tensorflow/python/BUILD12
-rw-r--r--tensorflow/python/__init__.py1
-rw-r--r--tensorflow/python/client/tf_session.i6
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py37
-rw-r--r--tensorflow/python/estimator/BUILD410
-rw-r--r--tensorflow/python/estimator/__init__.py25
-rw-r--r--tensorflow/python/estimator/api/BUILD1
-rw-r--r--tensorflow/python/estimator/keras.py2
-rw-r--r--tensorflow/python/framework/test_util.py6
-rw-r--r--tensorflow/python/framework/test_util_test.py8
-rwxr-xr-xtensorflow/python/keras/BUILD2
-rw-r--r--tensorflow/python/keras/__init__.py1
-rw-r--r--tensorflow/python/keras/backend.py188
-rw-r--r--tensorflow/python/keras/backend_test.py130
-rw-r--r--tensorflow/python/keras/estimator/__init__.py46
-rw-r--r--tensorflow/python/keras/layers/local.py20
-rw-r--r--tensorflow/python/keras/layers/wrappers.py16
-rw-r--r--tensorflow/python/keras/layers/wrappers_test.py56
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py5
-rw-r--r--tensorflow/python/kernel_tests/distributions/BUILD4
-rw-r--r--tensorflow/python/kernel_tests/distributions/bernoulli_test.py11
-rw-r--r--tensorflow/python/kernel_tests/distributions/categorical_test.py10
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py18
-rw-r--r--tensorflow/python/kernel_tests/distributions/exponential_test.py10
-rw-r--r--tensorflow/python/kernel_tests/distributions/gamma_test.py85
-rw-r--r--tensorflow/python/kernel_tests/distributions/laplace_test.py13
-rw-r--r--tensorflow/python/kernel_tests/distributions/multinomial_test.py16
-rw-r--r--tensorflow/python/kernel_tests/distributions/normal_test.py13
-rw-r--r--tensorflow/python/kernel_tests/distributions/uniform_test.py13
-rw-r--r--tensorflow/python/kernel_tests/random/BUILD4
-rw-r--r--tensorflow/python/ops/array_ops.py9
-rw-r--r--tensorflow/python/ops/control_flow_ops.py6
-rw-r--r--tensorflow/python/training/optimizer.py12
-rw-r--r--tensorflow/tools/api/generator/BUILD30
-rw-r--r--tensorflow/tools/api/generator/api_gen.bzl64
-rw-r--r--tensorflow/tools/api/generator/doc_srcs.py2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.debugging.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt7
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.io.pbtxt39
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/tensorflow.manip.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/tensorflow.math.pbtxt216
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/tensorflow.quantization.pbtxt35
-rw-r--r--tensorflow/tools/api/golden/tensorflow.strings.pbtxt32
-rw-r--r--tensorflow/workspace.bzl8
297 files changed, 5647 insertions, 2100 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 4e212e96dc..a15d033013 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -404,6 +404,7 @@ config_setting(
package_group(
name = "internal",
packages = [
+ "-//third_party/tensorflow/python/estimator",
"//learning/meta_rank/...",
"//tensorflow/...",
"//tensorflow_fold/llgtm/...",
@@ -578,11 +579,20 @@ gen_api_init_files(
py_library(
name = "tensorflow_py",
- srcs = [
- ":tensorflow_python_api_gen",
- "//tensorflow/python/estimator/api:estimator_python_api_gen",
+ srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tensorflow_py_no_contrib",
+ "//tensorflow/contrib:contrib_py",
+ "//tensorflow/python/estimator:estimator_py",
],
+)
+
+py_library(
+ name = "tensorflow_py_no_contrib",
+ srcs = [":tensorflow_python_api_gen"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = ["//tensorflow/python"],
+ deps = ["//tensorflow/python:no_contrib"],
)
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 9662d7b478..779f65d5b1 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -20,7 +20,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
-# API IMPORTS PLACEHOLDER
try:
import os # pylint: disable=g-import-not-at-top
@@ -37,6 +36,8 @@ try:
except (ImportError, AttributeError):
print('tf.estimator package not installed.')
+# API IMPORTS PLACEHOLDER
+
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 370085c1e2..8ae579abda 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -81,7 +81,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
- # Requires Sort HLO, which is not implemented on CPU or GPU.
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
if self.device in ["XLA_CPU", "XLA_GPU"]:
return
@@ -99,7 +99,32 @@ class XlaSortOpTest(xla_test.XLATestCase):
{p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
self.assertAllEqual(
np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
- self.assertEqual(set([0, 2, 3, 6]), set(results[1]))
+ self.assertEqual(list([3, 0, 1, 2]), list(results[1]))
+
+ def testTopKInfinities(self):
+ """Tests that positive and negative infinity sort correctly."""
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ # Only bfloat16 is implemented.
+ bfloat16 = dtypes.bfloat16.as_numpy_dtype
+ if bfloat16 not in self.numeric_types:
+ return
+
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtypes.bfloat16)
+ with self.test_scope():
+ topk = nn_ops.top_k(p, k=6)
+ results = sess.run(topk, {
+ p: np.array(
+ [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16)
+ })
+ self.assertAllEqual(
+ np.array(
+ [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
+ dtype=bfloat16), results[0])
+ self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index 703e13e089..cbe3c8aaff 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -61,42 +61,89 @@ class TopKOp : public XlaOpKernel {
if (input_shape.dim_size(0) < k) {
k = input_shape.dim_size(0);
}
- const xla::XlaOp input = context->Input(0);
- xla::XlaOp iota;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota));
+ const xla::XlaOp input_bf16 = context->Input(0);
+ xla::XlaOp iota_s32;
+ OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota_s32));
// TODO(b/73891930): add a key-value sort to HLO, rather than using
// bit-packing tricks here.
- // TODO(b/73891930): this implementation will convert Infs to NaNs. A
- // key-value sort would avoid this; for now, it is no worse than, say, the
- // CPU backend in fast-math mode.
+
+ xla::XlaOp zero = b->ConstantR0<int32>(0);
+
+ // max can either be 0x7FFFFFFF or 0x8000000. Neither choice is totally
+ // ideal. The implications of the choice are:
+ //
+ // 0x7FFFFFFF
+ // 1. +0.0 > -0.0
+ // 2. The elements of the inputs and outputs are bitwise identical.
+ // 3. The sort is unstable since a later +0.0 will appear before an earlier
+ // -0.0.
+ //
+ // 0x8000000
+ // 1. +0.0 == -0.0
+ // 2. All -0.0 in the input are replaced with +0.0 in the output.
+ // 3. The sort is stable.
+ xla::XlaOp max = b->ConstantR0<int32>(0x80000000);
+ xla::XlaOp index_mask = b->ConstantR0<int32>(0x0000FFFF);
+ xla::XlaOp value_mask = b->ConstantR0<int32>(0xFFFF0000);
+
+ // Convert to from bf16 to f32. The lower 16-bits are zero due to the
+ // definition of bf16.
+ xla::XlaOp input_f32 = b->ConvertElementType(input_bf16, xla::F32);
+
+ // Negate the input to reverse sort it. The lower 16-bits are zero, because
+ // negating a float is just inverting the high-bit.
+ xla::XlaOp negative_input_f32 = b->Neg(input_f32);
+
+ // Convert to a sign magnitude integer. The lower 16-bits are zero, since
+ // bitcast convert doesn't change any bits.
+ xla::XlaOp negative_input_sm32 =
+ b->BitcastConvertType(negative_input_f32, xla::S32);
+
+ // Convert from sign magnitude integer to two's complement integer. The
+ // lower 16-bits are zero on both sides of the select. On the false side,
+ // the value is unchanged, and on the true side, the lower 16-bits of max
+ // are all zero, so the lower 16-bits of the result of the subtraction will
+ // also be zero.
+ xla::XlaOp negative_input_s32 =
+ b->Select(b->Lt(negative_input_sm32, zero),
+ b->Sub(max, negative_input_sm32), negative_input_sm32);
+
+ // In order for the Or with iota_s32 to to work properly, the lower 16-bits
+ // of negative_input_32 must be zero.
// Pack elements as:
// * upper 16 bits are the value
// * lower 16 bits are the index.
- xla::XlaOp packed = b->BitcastConvertType(
- b->Or(b->BitcastConvertType(b->ConvertElementType(input, xla::F32),
- xla::S32),
- iota),
- xla::F32);
+ xla::XlaOp packed_s32 = b->Or(negative_input_s32, iota_s32);
// TODO(phawkins): use a more efficient algorithm that does not require a
// full sort.
- xla::XlaOp sorted = b->Slice(b->Rev(b->Sort(packed), {0}),
- /*start_indices=*/{0},
- /*limit_indices=*/{k},
- /*strides=*/{1});
-
- // Unpack the value/index
- xla::XlaOp x = b->BitcastConvertType(sorted, xla::S32);
- xla::XlaOp indices = b->And(x, b->ConstantR0<int32>(0x0000FFFF));
- xla::XlaOp values = b->ConvertElementType(
- b->BitcastConvertType(b->And(x, b->ConstantR0<int32>(0xFFFF0000)),
- xla::F32),
- xla::BF16);
-
- context->SetOutput(0, values);
- context->SetOutput(1, indices);
+ xla::XlaOp sorted_s32 = b->Slice(b->Sort(packed_s32),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1});
+
+ // Unpack the value/index.
+ xla::XlaOp indices_s32 = b->And(sorted_s32, index_mask);
+ xla::XlaOp negative_values_s32 = b->And(sorted_s32, value_mask);
+
+ // Convert from two's complement integer to sign magnitude integer.
+ xla::XlaOp negative_values_sm32 =
+ b->Select(b->Lt(negative_values_s32, zero),
+ b->Sub(max, negative_values_s32), negative_values_s32);
+
+ xla::XlaOp negative_values_f32 =
+ b->BitcastConvertType(negative_values_sm32, xla::F32);
+
+ // Negate the values to get back the original inputs.
+ xla::XlaOp values_f32 = b->Neg(negative_values_f32);
+
+ // Convert from f32 to bf16.
+ xla::XlaOp values_bf16 = b->ConvertElementType(values_f32, xla::BF16);
+
+ context->SetOutput(0, values_bf16);
+ context->SetOutput(1, indices_s32);
}
private:
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index e303999c63..d420863b85 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -137,7 +137,7 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
}
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
- const ShapeIndex& shape_index,
+ ShapeIndexView shape_index,
llvm::Value* ir_value) {
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
@@ -158,7 +158,7 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
llvm::Value* ir_value,
- const ShapeIndex& shape_index) {
+ ShapeIndexView shape_index) {
VLOG(2) << "Binding " << hlo.ToString();
const Shape& hlo_shape = hlo.shape();
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
index 3d34311b43..a86e6e78c6 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
@@ -51,7 +51,7 @@ class HloToIrBindings {
// Rebinds the given HLO to the LLVM IR value that represent its address.
void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value,
- const ShapeIndex& shape_index = {});
+ ShapeIndexView shape_index = {});
// Unbinds all IR values that's defined in an LLVM function, e.g., function
// arguments and stack variables. Global variables will be kept in bindings_.
@@ -71,7 +71,7 @@ class HloToIrBindings {
// A helper method that returns the base pointer of the IrArray containing the
// output of "inst".at the given ShapeIndex.
llvm::Value* GetBasePointer(const HloInstruction& hlo,
- const ShapeIndex& shape_index = {}) const {
+ ShapeIndexView shape_index = {}) const {
auto it = base_ptrs_.find(&hlo);
CHECK(it != base_ptrs_.end()) << hlo.ToString();
return it->second.element(shape_index);
@@ -97,7 +97,7 @@ class HloToIrBindings {
// Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape.
llvm::Value* GetTypedIrValue(const HloInstruction& hlo,
- const ShapeIndex& shape_index,
+ ShapeIndexView shape_index,
llvm::Value* ir_value);
const BufferAssignment* buffer_assignment_;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index a94119b0e9..f6f0a45124 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2710,8 +2710,9 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
// repeating the literal 4 or 2 times, so long as the destination buffer is
// an even multiple of 32 bits long.
+ const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index);
if ((num_bytes == 1 || num_bytes == 2) &&
- ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) {
+ ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
uint16 pattern16;
if (num_bytes == 1) {
uint8 b = literal_bytes.front();
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index d541776f00..9a4a1541ca 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -69,6 +70,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
// In that case, the operand of the reduce needs to have the same shape
// as the other tuple operands, but also we need to compare the output
// shapes of the reduces.
+ // TODO(tjoerg): Allow differences in fp precision.
auto* element_instr_1 = get_element_instr(instr1);
auto* element_instr_2 = get_element_instr(instr2);
if (element_instr_1->opcode() == HloOpcode::kReduce &&
@@ -147,5 +149,122 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop;
}
+bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
+ bool changed = false;
+ RecomputeReachability();
+
+ tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
+ // Keep a list of the instructions to fuse after making all the fusion
+ // decisions. We first aggressively add instructions to potential_fusion_list,
+ // then filter out instructions that will be no longer fusable because of
+ // reachability change. This avoids recalculating reachability on a large set
+ // of instructions.
+ std::vector<std::pair<HloInstruction*, HloInstruction*>>
+ potential_fusion_list;
+ std::vector<std::pair<HloInstruction*, HloInstruction*>> fusion_list;
+ std::vector<HloInstruction*> instrs_to_update_reachability;
+
+ // For each reduce or reduce multi-output fusion, try to fuse it with loop
+ // fusions operands.
+ for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) {
+ if (consumer->user_count() == 0) {
+ continue;
+ }
+ if (!IsReduction(consumer)) {
+ continue;
+ }
+ // TODO(b/110517657): Lowering multi-output reduce fusions with bfloat16
+ // output element types is not supported on GPU. However, bfloat16 is used
+ // in shared tests.
+ if (consumer->shape().element_type() == PrimitiveType::BF16) {
+ continue;
+ }
+
+ auto consumer_operands = consumer->operands();
+ for (size_t i = 0; i < consumer_operands.size(); ++i) {
+ HloInstruction* producer = consumer_operands[i];
+ if (!producer->IsFusable()) {
+ continue;
+ }
+ const bool is_loop_fusion =
+ producer->opcode() == HloOpcode::kFusion &&
+ producer->fusion_kind() == HloInstruction::FusionKind::kLoop;
+ if (!is_loop_fusion) {
+ continue;
+ }
+ if (!ShapesCompatibleForFusion(producer, consumer)) {
+ continue;
+ }
+ // If we have already decided to fuse this producer, skip it.
+ if (ContainsKey(to_fuse, producer)) {
+ continue;
+ }
+ // Do not fuse a producer if the other operands of the fusion are
+ // reachable from the producer, this would create a cycle.
+ if (std::any_of(consumer_operands.begin(), consumer_operands.end(),
+ [&](HloInstruction* operand) {
+ return producer != operand &&
+ reachability()->IsReachable(producer, operand);
+ })) {
+ continue;
+ }
+ to_fuse.insert(producer);
+ potential_fusion_list.emplace_back(producer, consumer);
+ instrs_to_update_reachability.push_back(producer);
+ instrs_to_update_reachability.push_back(consumer);
+ break;
+ }
+ }
+
+ // Filter out pairs that will be no longer fusable because of reachability
+ // change.
+ for (auto& fusion_pair : potential_fusion_list) {
+ HloInstruction* producer = fusion_pair.first;
+ HloInstruction* consumer = fusion_pair.second;
+ bool fusable = true;
+ for (size_t i = 0; i < consumer->operand_count(); ++i) {
+ if (producer != consumer->operand(i) &&
+ reachability()->IsReachable(producer, consumer->operand(i))) {
+ fusable = false;
+ break;
+ }
+ }
+ if (fusable) {
+ UpdateReachability(producer, consumer, instrs_to_update_reachability);
+ fusion_list.push_back(fusion_pair);
+ }
+ }
+
+ for (auto fusions_to_create : fusion_list) {
+ HloInstruction* producer = fusions_to_create.first;
+ HloInstruction* consumer = fusions_to_create.second;
+ if (consumer->opcode() != HloOpcode::kFusion) {
+ // Fusing with a reduce (fusion) always results in an input fusion.
+ HloInstruction* input_fusion =
+ computation()->AddInstruction(HloInstruction::CreateFusion(
+ consumer->shape(), HloInstruction::FusionKind::kInput, consumer));
+ VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
+ << consumer->name() << " into " << input_fusion->name();
+ TF_CHECK_OK(computation()->ReplaceInstruction(consumer, input_fusion));
+ if (producer->opcode() == HloOpcode::kFusion) {
+ input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
+ } else {
+ input_fusion->FuseInstructionIntoMultiOutput(producer);
+ }
+ } else {
+ VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
+ << consumer->name();
+
+ if (producer->opcode() == HloOpcode::kFusion) {
+ consumer->MergeFusionInstructionIntoMultiOutput(producer);
+ } else {
+ consumer->FuseInstructionIntoMultiOutput(producer);
+ }
+ }
+ changed = true;
+ }
+ return changed;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
index 16db0e0f02..67ca5d49ee 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -45,6 +45,9 @@ class GpuMultiOutputFusion : public MultiOutputFusion {
// Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override;
+
+ // Fuse loop fusions into reduce fusions.
+ bool DoProducerConsumerMultiOutputFusion() override;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
index 5e7ceb7976..bca2779464 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -255,5 +255,73 @@ TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) {
op::Tuple(op::Multiply(), op::Divide()));
}
+TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_add {
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
+ reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation
+ ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Add()));
+}
+
+TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_select {
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
+ greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
+ }
+
+ fused_reduce {
+ p0.2 = f32[2,2,2]{2,1,0} parameter(0)
+ c1 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation
+ mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+ fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
+ gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
+ gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
+ ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
+ op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index bc7340aa03..7e97eacf35 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1378,6 +1378,44 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ template <typename NativeT,
+ typename std::enable_if<
+ !is_complex_t<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleSort(HloInstruction* sort) {
+ TF_RET_CHECK(ShapeUtil::Rank(sort->shape()) == 1)
+ << "Sort is only supported for R1 shapes";
+
+ auto arg = sort->operand(0);
+ const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
+ VLOG(3) << "HandleSort arg_literal: " << arg_literal.ToString();
+ const auto& arg_data = arg_literal.data<ReturnT>();
+
+ std::vector<ReturnT> return_data(arg_data.begin(), arg_data.end());
+ std::sort(return_data.begin(), return_data.end(),
+ [](const ReturnT& a, const ReturnT& b) {
+ return SafeLess<ReturnT>(a, b);
+ });
+ auto result_literal = MakeUnique<Literal>(sort->shape());
+ result_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ReturnT>(return_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ parent_->evaluated_[sort] = std::move(result_literal);
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleSort(HloInstruction* sort) {
+ return InvalidArgument("Unsupported type for Sort");
+ }
+
+ Status HandleSort(HloInstruction* sort) override {
+ return HandleSort<ReturnT>(sort);
+ }
+
Status HandleReduce(HloInstruction* reduce) override {
auto arg = reduce->operand(0);
auto init_value = reduce->operand(1);
@@ -2118,6 +2156,38 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return rhs_unsigned >= lhs_size_unsigned;
}
+ // It's UB to use std::sort with std::less<float>, because of NaNs. Define
+ // "safe" less functions which are actually strict weak orders.
+ template <typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ static bool SafeLess(const NativeT& a, const NativeT& b) {
+ return a < b;
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_floating_point<NativeT>::value ||
+ std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+ static bool SafeLess(const NativeT& a, const NativeT& b) {
+ if (std::isnan(b)) {
+ return !std::isnan(a);
+ } else {
+ return a < b;
+ }
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, Eigen::half>::value>::type* = nullptr>
+ static bool SafeLess(const NativeT& a, const NativeT& b) {
+ if (Eigen::half_impl::isnan(b)) {
+ return !Eigen::half_impl::isnan(a);
+ } else {
+ return a < b;
+ }
+ }
+
HloEvaluator* parent_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 9fb15df7c2..268b4727bc 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -100,6 +100,29 @@ bool HloSharding::UsesDevice(int64 device) const {
std::find(devices.begin(), devices.end(), device) != devices.end();
}
+std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
+ int64 element_count = 1;
+ std::map<int64, int64> device_map;
+ if (IsTuple()) {
+ for (auto& tuple_element_sharding : tuple_elements()) {
+ auto unique_device = tuple_element_sharding.UniqueDevice();
+ if (unique_device.ok()) {
+ device_map[unique_device.ValueOrDie()] += 1;
+ }
+ }
+ element_count = tuple_elements().size();
+ } else {
+ auto unique_device = UniqueDevice();
+ if (unique_device.ok()) {
+ device_map[unique_device.ValueOrDie()] += 1;
+ }
+ }
+ if (count != nullptr) {
+ *count = element_count;
+ }
+ return device_map;
+}
+
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_);
@@ -439,6 +462,27 @@ tensorflow::gtl::optional<HloSharding> HloSharding::ExtractSingleSharding()
return tuple_elements_.front();
}
+size_t HloSharding::Hash() const {
+ if (!tuple_) {
+ size_t h = 0;
+ for (const auto& element : tuple_elements_) {
+ h = tensorflow::Hash64Combine(h, element.Hash());
+ }
+ return h;
+ }
+ if (replicated_) {
+ return 0;
+ }
+ size_t h = 0;
+ for (uint32 v : tile_assignment_) {
+ h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
+ }
+ for (uint32 v : tile_shape_.dimensions()) {
+ h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
+ }
+ return h;
+}
+
std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
out << sharding.ToString();
return out;
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index 1e843481c3..34324d2058 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -19,7 +19,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
+#include <map>
#include <string>
+#include <vector>
#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/literal_util.h"
@@ -118,6 +120,14 @@ class HloSharding {
// Returns true if the sharding defines an operation on the given device.
bool UsesDevice(int64 device) const;
+ // Retrieves an histogram of the devices used by the sharding. The returned
+ // map has the device number as key, and the occurrence count as value.
+ // If a sharding does not have a device, it will not be incuded in the
+ // histogram. The count argument, if not nullptr, will receive the total
+ // number of elements this sharding is made of (one for array, N leaves for
+ // tuples).
+ std::map<int64, int64> UsedDevices(int64* count) const;
+
// Returns the tile that should be executed on the given device.
// REQUIRES: !IsTuple()
std::vector<int64> TileIndexForDevice(int64 device) const;
@@ -179,26 +189,7 @@ class HloSharding {
}
bool operator!=(const HloSharding& other) const { return !(*this == other); }
- size_t Hash() const {
- if (!tuple_) {
- size_t h = 0;
- for (const auto& element : tuple_elements_) {
- h = tensorflow::Hash64Combine(h, element.Hash());
- }
- return h;
- }
- if (replicated_) {
- return 0;
- }
- size_t h = 0;
- for (uint32 v : tile_assignment_) {
- h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
- }
- for (uint32 v : tile_shape_.dimensions()) {
- h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
- }
- return h;
- }
+ size_t Hash() const;
struct Hasher {
size_t operator()(const HloSharding& sharding) const {
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 18e54d23c2..4aacc87b78 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -105,8 +105,8 @@ class ShapeTree {
// Returns the data element associated with the array in the shape at the
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
- const T& element(const ShapeIndex& index) const;
- T* mutable_element(const ShapeIndex& index);
+ const T& element(ShapeIndexView index) const;
+ T* mutable_element(ShapeIndexView index);
// Return the shape represented with this ShapeTree.
const Shape& shape() const { return *shape_; }
@@ -125,7 +125,7 @@ class ShapeTree {
// Returns true if the node at the given index is a leaf node (an array
// shape).
- bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->is_leaf; }
+ bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
ShapeTree(const ShapeTree&) = default;
ShapeTree& operator=(const ShapeTree&) = default;
@@ -211,12 +211,12 @@ class ShapeTree {
// Returns an iterator pointing to the given ShapeIndex.
// REQUIRES: index must exist in the ShapeTree.
- iterator find(const ShapeIndex& index) {
+ iterator find(ShapeIndexView index) {
Node* element = Lookup(index);
return iterator(&nodes_, typename std::vector<Node>::iterator(element),
/*iterate_leaves_only=*/false);
}
- const_iterator find(const ShapeIndex& index) const {
+ const_iterator find(ShapeIndexView index) const {
Node* element = Lookup(index);
return iterator(&nodes_,
typename std::vector<Node>::const_iterator(element),
@@ -285,8 +285,8 @@ class ShapeTree {
static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
// Return the tree node at the given index.
- Node* Lookup(const ShapeIndex& index);
- const Node* Lookup(const ShapeIndex& index) const;
+ Node* Lookup(ShapeIndexView index);
+ const Node* Lookup(ShapeIndexView index) const;
// The nodes in this shape tree.
std::vector<Node> nodes_;
@@ -463,17 +463,17 @@ ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
}
template <typename T>
-const T& ShapeTree<T>::element(const ShapeIndex& index) const {
+const T& ShapeTree<T>::element(ShapeIndexView index) const {
return Lookup(index)->data.second;
}
template <typename T>
-T* ShapeTree<T>::mutable_element(const ShapeIndex& index) {
+T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
return &Lookup(index)->data.second;
}
template <typename T>
-internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) {
+internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
Node* node = &nodes_[0];
for (const int64 i : index) {
CHECK_GE(i, 0);
@@ -485,7 +485,7 @@ internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) {
template <typename T>
const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
- const ShapeIndex& index) const {
+ ShapeIndexView index) const {
return const_cast<ShapeTree*>(this)->Lookup(index);
}
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 722d882471..3a885b4389 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -461,5 +461,26 @@ XLA_TEST_F(ConvertTest, ConvertS64U64) {
ComputeAndCompareR1<uint64>(&builder, unsigned_x, {});
}
+XLA_TEST_F(ConvertTest, ConvertBF16F32) {
+ XlaBuilder builder(TestName());
+
+ std::vector<bfloat16> all_bfloats(1 << 16);
+ for (int i = 0; i < all_bfloats.size(); ++i) {
+ all_bfloats[i].value = i;
+ }
+
+ std::vector<uint32> expected(all_bfloats.size());
+ for (int i = 0; i < expected.size(); ++i) {
+ expected[i] = (1U << 16) * i;
+ }
+
+ // Exhaustively test all bf16 to f32 conversions.
+ xla::XlaOp all_bfloats_bf16 = builder.ConstantR1<bfloat16>(all_bfloats);
+ xla::XlaOp all_bfloats_f32 =
+ builder.ConvertElementType(all_bfloats_bf16, F32);
+ xla::XlaOp all_bfloats_u32 = builder.BitcastConvertType(all_bfloats_f32, U32);
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index dd7c541733..000535a982 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -270,14 +270,22 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
- if (needs_index != nullptr &&
- !ShapeUtil::Equal(needs_index->shape(), use->shape())) {
- return Unimplemented(
- "Conflicting operand generation slice index constraints\n");
+ if (needs_index != nullptr) {
+ auto needs_index_shape = needs_index->shape();
+ auto use_shape = use->shape();
+ if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
+ needs_index_shape = needs_index->operand(0)->shape();
+ }
+ if (use->opcode() == HloOpcode::kDynamicSlice) {
+ use_shape = use->operand(0)->shape();
+ }
+ if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
+ return Unimplemented(
+ "Conflicting operand generation slice index constraints\n");
+ }
}
needs_index = use;
break;
-
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
needs_constant = use;
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 7d44a054a8..fffab5a795 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -114,6 +114,7 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
+ "//tensorflow/python/estimator:estimator_py",
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([
"//tensorflow/contrib/tensorrt:init_py",
]) + select({
diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
index d62390494b..0702273fac 100644
--- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
+++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
@@ -570,7 +570,7 @@
" autograph.utils.set_element_type(numbers, tf.int32)\n",
" for i in range(n):\n",
" numbers.append(i)\n",
- " return numbers.stack() # Stack the list so that it can be used as a Tensor\n",
+ " return autograph.stack(numbers) # Stack the list so that it can be used as a Tensor\n",
"\n",
"\n",
"tf_f = autograph.to_graph(f)\n",
@@ -648,7 +648,7 @@
" if not is_prime:\n",
" continue\n",
" primes.append(i)\n",
- " all_primes = primes.stack()\n",
+ " all_primes = autograph.stack(primes)\n",
"\n",
" print('The prime numbers less than', n, 'are:')\n",
" print(all_primes)\n",
@@ -953,8 +953,9 @@
" train_accuracies.append(step_train_accuracy)\n",
" test_accuracies.append(step_test_accuracy)\n",
" i += 1\n",
- " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n",
- " test_accuracies.stack())"
+ " return (autograph.stack(train_losses), autograph.stack(test_losses),\n",
+ " autograph.stack(train_accuracies),\n",
+ " autograph.stack(test_accuracies))"
],
"execution_count": 0,
"outputs": []
@@ -1236,7 +1237,7 @@
" cell_output, (state, output) = cell.call(ch, (state, output))\n",
" hidden_outputs.append(cell_output)\n",
" i += 1\n",
- " hidden_outputs = hidden_outputs.stack()\n",
+ " hidden_outputs = autograph.stack(hidden_outputs)\n",
" if training:\n",
" hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n",
" return hidden_outputs\n",
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py
index 012a51f711..47b80bdf4a 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops.py
@@ -119,10 +119,6 @@ def batch_function(num_batch_threads,
raise ValueError("All arguments to functions decorated with "
"`batch_function` are supposed to be Tensors; "
"found %s" % repr(a))
- for inp in computation.captured_inputs:
- print("inp: %s" % inp)
- for op in inp.consumers():
- print("op: %s" % op)
return gen_batch_ops.batch_function(
num_batch_threads=num_batch_threads,
max_batch_size=max_batch_size,
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 8a45858ae4..d530572e91 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -35,6 +35,7 @@ tensorflow/python/keras
tensorflow/python/keras/applications
tensorflow/python/keras/datasets
tensorflow/python/keras/engine
+tensorflow/python/keras/estimator
tensorflow/python/keras/layers
tensorflow/python/keras/preprocessing
tensorflow/python/keras/utils
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 38573f86ef..eb9482dc25 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -229,8 +229,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/debug/cli/profile_analyzer_cli_test.py"
# Windows does not have the curses library and uses readline.
"${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py"
- # Bug in shape inference (b/110283809)
- "${tensorflow_source_dir}/tensorflow/python/kernel_tests/random/random_ops_test.py"
# TFDBG grpc:// mode is not yet available on Windows.
"${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py"
"${tensorflow_source_dir}/tensorflow/python/debug/lib/grpc_large_data_test.py"
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 99699cd6d6..2a4cf877f0 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -25,7 +25,10 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@Counter
@@CheckpointInputPipelineHook
@@CsvDataset
+@@RandomDataset
+@@Reducer
@@SqlDataset
+@@TFRecordWriter
@@assert_element_shape
@@batch_and_drop_remainder
@@ -33,12 +36,15 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@choose_from_datasets
@@dense_to_sparse_batch
@@enumerate_dataset
+
+@@get_single_element
@@group_by_reducer
@@group_by_window
@@ignore_errors
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
+
@@map_and_batch
@@padded_batch_and_drop_remainder
@@parallel_interleave
@@ -51,8 +57,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@sliding_window_batch
@@sloppy_interleave
@@unbatch
-
-@@get_single_element
+@@unique
"""
from __future__ import absolute_import
@@ -74,6 +79,7 @@ from tensorflow.contrib.data.python.ops.get_single_element import get_single_ele
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
from tensorflow.contrib.data.python.ops.grouping import group_by_reducer
from tensorflow.contrib.data.python.ops.grouping import group_by_window
+from tensorflow.contrib.data.python.ops.grouping import Reducer
from tensorflow.contrib.data.python.ops.interleave_ops import choose_from_datasets
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets
@@ -81,6 +87,7 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
+from tensorflow.contrib.data.python.ops.random_ops import RandomDataset
from tensorflow.contrib.data.python.ops.readers import CsvDataset
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
@@ -90,6 +97,8 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
+from tensorflow.contrib.data.python.ops.unique import unique
+from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
index 3dfc3741c2..141706f393 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace {
@@ -24,19 +25,32 @@ namespace {
class ThreadPoolResource : public ResourceBase {
public:
ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
- const string& name, int num_threads, bool low_latency_hint)
- : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) {
- }
+ const string& name, int num_threads, bool low_latency_hint,
+ int max_intra_op_parallelism)
+ : thread_pool_(env, thread_options, name, num_threads, low_latency_hint),
+ max_intra_op_parallelism_(max_intra_op_parallelism) {}
// Schedules fn() for execution in the pool of threads.
void Schedule(std::function<void()> fn) {
- thread_pool_.Schedule(std::move(fn));
+ if (max_intra_op_parallelism_ < 0) {
+ thread_pool_.Schedule(std::move(fn));
+ } else {
+ thread_pool_.Schedule(std::bind(
+ [this](std::function<void()> bound_fn) {
+ // TODO(mrry): Consider moving this thread-local configuration to
+ // the threads themselves.
+ ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_);
+ bound_fn();
+ },
+ std::move(fn)));
+ }
}
string DebugString() override { return "ThreadPoolResource"; }
private:
thread::ThreadPool thread_pool_;
+ const int max_intra_op_parallelism_;
};
// Creates a handle to a ThreadPool resource. Note that we don't use
@@ -48,6 +62,8 @@ class ThreadPoolHandleOp : public OpKernel {
explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
+ &max_intra_op_parallelism_));
OP_REQUIRES(
ctx, num_threads_ > 0,
errors::InvalidArgument("`num_threads` must be greater than zero."));
@@ -78,7 +94,7 @@ class ThreadPoolHandleOp : public OpKernel {
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new ThreadPoolResource(
ctx->env(), {}, display_name_,
- num_threads_,
+ num_threads_, max_intra_op_parallelism_,
false /* low_latency_hint */);
return Status::OK();
}));
@@ -95,6 +111,7 @@ class ThreadPoolHandleOp : public OpKernel {
bool initialized_ GUARDED_BY(mu_) = false;
string display_name_;
int num_threads_;
+ int max_intra_op_parallelism_;
};
class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index f271d269ab..f48e96509a 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -158,6 +158,7 @@ REGISTER_OP("ThreadPoolHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Attr("num_threads: int")
+ .Attr("max_intra_op_parallelism: int = 1")
.Attr("display_name: string")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
@@ -166,6 +167,8 @@ Creates a custom thread pool with the given number of threads.
handle: A resource that can be consumed by one or more ThreadPoolDataset ops.
num_threads: The number of threads in the thread pool.
+max_intra_op_parallelism: The maximum degree of parallelism to use within
+ operations that execute on this threadpool.
display_name: A human-readable name for the threads that may be visible in
some visualizations.
)doc");
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index ed1542d03f..d81654e039 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -158,6 +158,26 @@ py_test(
)
py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:iterator_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+py_test(
name = "map_dataset_op_test",
size = "medium",
srcs = ["map_dataset_op_test.py"],
@@ -425,6 +445,7 @@ py_test(
"//tensorflow/python:script_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index fe618cdce6..9b1857de1a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -33,8 +33,8 @@ class DirectedInterleaveDatasetTest(test.TestCase):
input_datasets = [
dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
]
- dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset,
- input_datasets)
+ dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset,
+ input_datasets)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 30a993b1f7..30a993b1f7 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 9167cb3379..0486e2bce2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import threading
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import threadpool
@@ -30,9 +31,11 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test.TestCase):
+class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
- def testNumThreads(self):
+ @parameterized.parameters((1, None), (2, None), (4, None), (8, None),
+ (16, None), (4, -1), (4, 0), (4, 1), (4, 4))
+ def testNumThreads(self, num_threads, max_intra_op_parallelism):
def get_thread_id(_):
# Python creates a dummy thread object to represent the current
@@ -42,35 +45,35 @@ class OverrideThreadpoolDatasetTest(test.TestCase):
# identifier that maps one-to-one with the underlying OS thread.
return np.array(threading.current_thread().ident).astype(np.int64)
- for num_threads in [1, 2, 4, 8, 16]:
+ dataset = (
+ dataset_ops.Dataset.range(1000).map(
+ lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
+ num_parallel_calls=32).apply(unique.unique()))
- dataset = (
- dataset_ops.Dataset.range(1000).map(
- lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
- num_parallel_calls=32).apply(unique.unique()))
+ dataset = threadpool.override_threadpool(
+ dataset,
+ threadpool.PrivateThreadPool(
+ num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name="private_thread_pool_%d" % num_threads))
- dataset = threadpool.override_threadpool(
- dataset,
- threadpool.PrivateThreadPool(
- num_threads, display_name="private_thread_pool_%d" % num_threads))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- thread_ids = []
- try:
- while True:
- thread_ids.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
- self.assertEqual(len(thread_ids), len(set(thread_ids)))
- self.assertGreater(len(thread_ids), 0)
- # NOTE(mrry): We don't control the thread pool scheduling, and
- # so cannot guarantee that all of the threads in the pool will
- # perform work.
- self.assertLessEqual(len(thread_ids), num_threads)
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ thread_ids = []
+ try:
+ while True:
+ thread_ids.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ self.assertEqual(len(thread_ids), len(set(thread_ids)))
+ self.assertGreater(len(thread_ids), 0)
+ # NOTE(mrry): We don't control the thread pool scheduling, and
+ # so cannot guarantee that all of the threads in the pool will
+ # perform work.
+ self.assertLessEqual(len(thread_ids), num_threads)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 33b7a75046..0240814562 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -49,26 +49,6 @@ py_library(
],
)
-py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:model_fn",
- ],
-)
-
py_library(
name = "random_ops",
srcs = [
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 052618e08c..5708d47c20 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -77,17 +77,17 @@ def dense_to_sparse_batch(batch_size, row_shape):
"""
def _apply_fn(dataset):
- return DenseToSparseBatchDataset(dataset, batch_size, row_shape)
+ return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
return _apply_fn
-class UnbatchDataset(dataset_ops.Dataset):
+class _UnbatchDataset(dataset_ops.Dataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
- super(UnbatchDataset, self).__init__()
+ super(_UnbatchDataset, self).__init__()
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
@@ -144,7 +144,7 @@ def unbatch():
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
if not sparse.any_sparse(dataset.output_classes):
- return UnbatchDataset(dataset)
+ return _UnbatchDataset(dataset)
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
# are normalized to the rank-1 dense representation, so that the
@@ -170,7 +170,7 @@ def unbatch():
dataset.output_shapes,
dataset.output_classes,
allow_unsafe_cast=True)
- return UnbatchDataset(restructured_dataset)
+ return _UnbatchDataset(restructured_dataset)
return _apply_fn
@@ -298,12 +298,12 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class DenseToSparseBatchDataset(dataset_ops.Dataset):
+class _DenseToSparseBatchDataset(dataset_ops.Dataset):
"""A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
def __init__(self, input_dataset, batch_size, row_shape):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
- super(DenseToSparseBatchDataset, self).__init__()
+ super(_DenseToSparseBatchDataset, self).__init__()
if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." %
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 5f5513849c..d46d96c461 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -46,17 +46,17 @@ def ignore_errors():
"""
def _apply_fn(dataset):
- return IgnoreErrorsDataset(dataset)
+ return _IgnoreErrorsDataset(dataset)
return _apply_fn
-class IgnoreErrorsDataset(dataset_ops.Dataset):
+class _IgnoreErrorsDataset(dataset_ops.Dataset):
"""A `Dataset` that silently ignores errors when computing its input."""
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
- super(IgnoreErrorsDataset, self).__init__()
+ super(_IgnoreErrorsDataset, self).__init__()
self._input_dataset = input_dataset
def _as_variant_tensor(self):
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 4068a2ffa5..348884e9fa 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -55,7 +55,7 @@ def group_by_reducer(key_func, reducer):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return GroupByReducerDataset(dataset, key_func, reducer)
+ return _GroupByReducerDataset(dataset, key_func, reducer)
return _apply_fn
@@ -113,8 +113,8 @@ def group_by_window(key_func,
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return GroupByWindowDataset(dataset, key_func, reduce_func,
- window_size_func)
+ return _GroupByWindowDataset(dataset, key_func, reduce_func,
+ window_size_func)
return _apply_fn
@@ -254,12 +254,12 @@ class _VariantDataset(dataset_ops.Dataset):
return self._output_types
-class GroupByReducerDataset(dataset_ops.Dataset):
+class _GroupByReducerDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a reduction."""
def __init__(self, input_dataset, key_func, reducer):
"""See `group_by_reducer()` for details."""
- super(GroupByReducerDataset, self).__init__()
+ super(_GroupByReducerDataset, self).__init__()
self._input_dataset = input_dataset
@@ -388,12 +388,12 @@ class GroupByReducerDataset(dataset_ops.Dataset):
**dataset_ops.flat_structure(self))
-class GroupByWindowDataset(dataset_ops.Dataset):
+class _GroupByWindowDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
- super(GroupByWindowDataset, self).__init__()
+ super(_GroupByWindowDataset, self).__init__()
self._input_dataset = input_dataset
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 70153ac575..bcc959594a 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -153,7 +153,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
return _apply_fn
-class DirectedInterleaveDataset(dataset_ops.Dataset):
+class _DirectedInterleaveDataset(dataset_ops.Dataset):
"""A substitute for `Dataset.interleave()` on a fixed list of datasets."""
def __init__(self, selector_input, data_inputs):
@@ -236,7 +236,7 @@ def sample_from_datasets(datasets, weights=None, seed=None):
selector_input = dataset_ops.Dataset.zip(
(logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
- return DirectedInterleaveDataset(selector_input, datasets)
+ return _DirectedInterleaveDataset(selector_input, datasets)
def choose_from_datasets(datasets, choice_dataset):
@@ -280,4 +280,4 @@ def choose_from_datasets(datasets, choice_dataset):
and choice_dataset.output_classes == ops.Tensor):
raise TypeError("`choice_dataset` must be a dataset of scalar "
"`tf.int64` tensors.")
- return DirectedInterleaveDataset(choice_dataset, datasets)
+ return _DirectedInterleaveDataset(choice_dataset, datasets)
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 2ca3805d66..cf89657226 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -39,17 +39,17 @@ def optimize(optimizations=None):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return OptimizeDataset(dataset, optimizations)
+ return _OptimizeDataset(dataset, optimizations)
return _apply_fn
-class OptimizeDataset(dataset_ops.Dataset):
+class _OptimizeDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
def __init__(self, input_dataset, optimizations):
"""See `optimize()` for details."""
- super(OptimizeDataset, self).__init__()
+ super(_OptimizeDataset, self).__init__()
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 3c82a03df1..97931f75bd 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -23,6 +23,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
@@ -110,7 +112,8 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
-# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`.
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def set_stats_aggregator(stats_aggregator):
"""Set the given stats_aggregator for aggregating the input dataset stats.
@@ -128,6 +131,8 @@ def set_stats_aggregator(stats_aggregator):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def bytes_produced_stats(tag):
"""Records the number of bytes produced by each element of the input dataset.
@@ -150,6 +155,8 @@ def bytes_produced_stats(tag):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def latency_stats(tag):
"""Records the latency of producing each element of the input dataset.
@@ -171,6 +178,8 @@ def latency_stats(tag):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def feature_stats(tag):
"""Records the features stats from `Example` records of the input dataset.
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index bb49604d4d..9af1e784ff 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -37,22 +37,28 @@ def _generate_shared_name(prefix):
return "{}{}".format(prefix, uid)
+# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
class PrivateThreadPool(object):
"""A stateful resource that represents a private thread pool."""
- def __init__(self, num_threads, display_name=None):
+ def __init__(self, num_threads, display_name=None,
+ max_intra_op_parallelism=1):
"""Creates a `PrivateThreadPool` with the given number of threads."""
if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool")
self._resource = gen_dataset_ops.thread_pool_handle(
num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
shared_name=shared_name)
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device=context.context().device_name)
else:
self._resource = gen_dataset_ops.thread_pool_handle(
- num_threads=num_threads, display_name=display_name)
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name)
class _ThreadPoolDataset(dataset_ops.Dataset):
@@ -82,6 +88,8 @@ class _ThreadPoolDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
+# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def override_threadpool(dataset, thread_pool):
"""Returns a new dataset that uses the given thread pool for its operations.
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index 4ce6ddede8..e0ce0a4ef1 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -42,17 +42,17 @@ def unique():
"""
def _apply_fn(dataset):
- return UniqueDataset(dataset)
+ return _UniqueDataset(dataset)
return _apply_fn
-class UniqueDataset(dataset_ops.Dataset):
+class _UniqueDataset(dataset_ops.Dataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
- super(UniqueDataset, self).__init__()
+ super(_UniqueDataset, self).__init__()
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 9dfb8552f1..eba0dd0ea3 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -587,7 +587,6 @@ cuda_py_test(
],
tags = [
"multi_and_single_gpu",
- "noguitar",
"notsan",
],
)
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index f8ae8b9712..1009c3c012 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -536,7 +536,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
destinations = per_device_values[0].devices
grouped = _group_value_by_device(per_device_values)
- device_grad_packs, self._tensor_packer = _pack_tensors(
+ device_grad_packs, tensor_packer = _pack_tensors(
grouped, self._num_packs, self._agg_small_grads_max_bytes,
self._agg_small_grads_max_group)
@@ -554,7 +554,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
destinations, device_grad_packs))
- reduced = _unpack_tensors(reduced, self._tensor_packer)
+ reduced = _unpack_tensors(reduced, tensor_packer)
return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
method_string)
@@ -665,13 +665,13 @@ class MultiWorkerAllReduce(AllReduceCrossTowerOps):
(this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size(
spec_tuple.limit, remaining_grads)
if this_grads:
- device_grad_packs, self._tensor_packer = _pack_tensors(
+ device_grad_packs, tensor_packer = _pack_tensors(
this_grads, self._num_packs, self._agg_small_grads_max_bytes,
self._agg_small_grads_max_group)
range_agg_grads = cross_tower_utils.sum_gradients_all_reduce(
self._worker_devices, device_grad_packs, len(self._worker_devices),
spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
- range_agg_grads = _unpack_tensors(range_agg_grads, self._tensor_packer)
+ range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)
if not aggregated_grads:
aggregated_grads = range_agg_grads
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index 75754e3fe3..aeeb9553e6 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -89,7 +89,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
run_step()
weights.append(self.evaluate(layer.kernel))
- biases.append(self.evaluate(distribution.fetch(layer.bias)))
+ biases.append(self.evaluate(layer.bias))
if is_tpu:
with self.test_session() as sess:
@@ -254,7 +254,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
for _ in range(10):
run_step()
- moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean))
+ moving_means = self.evaluate(batchnorm.moving_mean)
# We make sure that the moving_mean is updated as if the sample mean is
# calculated over all towers.
@@ -345,7 +345,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
v = all_vars[0]
self.assertTrue(all([v is vi for vi in all_vars[1:]]))
- weight = numpy.squeeze(self.evaluate(distribution.fetch(v)))
+ weight = numpy.squeeze(self.evaluate(v))
# Our model is:
# predict = x * w
# loss = (predict - y)^2
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index dc270ac540..d8668b398f 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -31,7 +31,6 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util
@@ -286,8 +285,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def map(self, map_over, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
index = {}
- i = 0
- for m in map_over:
+ for i, m in enumerate(map_over):
d = self._devices[i % len(self._devices)]
with ops.device(d):
l = index.get(d, [])
@@ -349,7 +347,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
if isinstance(tower_local_var, values.TowerLocalVariable):
- return math_ops.add_n(self.unwrap(tower_local_var))
+ return tower_local_var._get_cross_tower() # pylint: disable=protected-access
assert isinstance(tower_local_var, values.Mirrored)
return array_ops.identity(tower_local_var.get())
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 7b41cfe064..cb150692de 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -337,6 +337,8 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
all_v_sum = {}
all_v_mean = {}
+ components_sum = {}
+ components_mean = {}
def model_fn(device_id):
tower_context = distribute_lib.get_tower_context()
@@ -350,21 +352,33 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
v_mean.assign(6.0 * device_id)]
all_v_sum[device_id] = v_sum
all_v_mean[device_id] = v_mean
- return updates, v_sum, v_mean
+ c_sum = v_sum.get()
+ c_mean = v_mean.get()
+ components_sum[device_id] = c_sum
+ components_mean[device_id] = c_mean
+ self.assertIsNot(v_sum, c_sum)
+ self.assertIsNot(v_mean, c_mean)
+ return updates, v_sum, v_mean, c_sum, c_mean
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
# Create "sum" and "mean" versions of TowerLocalVariables.
- ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower(
- model_fn, dist.worker_device_index, run_concurrently=False)
+ ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
+ dist.call_for_each_tower(
+ model_fn, dist.worker_device_index, run_concurrently=False))
# Should see the same wrapping instance in all towers.
self.assertIs(all_v_sum[0], ret_v_sum)
self.assertIs(all_v_mean[0], ret_v_mean)
- for i in range(1, dist.num_towers):
- self.assertIs(all_v_sum[0], all_v_sum[1])
- self.assertIs(all_v_mean[0], all_v_mean[1])
+ self.assertIs(all_v_sum[0], all_v_sum[1])
+ self.assertIs(all_v_mean[0], all_v_mean[1])
+
+ # Regroup should recover the same wrapper.
+ self.assertIs(ret_v_sum, regrouped_sum)
+ self.assertIs(ret_v_mean, regrouped_mean)
+ self.assertIsNot(components_sum[0], components_sum[1])
+ self.assertIsNot(components_mean[0], components_mean[1])
# Apply updates
self.evaluate(variables.global_variables_initializer())
@@ -385,14 +399,13 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
# Without get(device), should return the value you get by
# applying the reduction across all towers (whether you use
- # fetch(), get(), or nothing).
- self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
- self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean)))
+ # read_var(), get(), or nothing).
+ self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum)))
+ self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean)))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
- if not context.executing_eagerly():
- self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
- self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
+ self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
+ self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
# NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
# testing this in eager mode.
@@ -557,14 +570,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
# the individual values before running the update ops.
self.assertEquals(1.0, self.evaluate(
ret_v_sum.get(dist._devices[0]).read_value()))
- self.assertEquals(2.0, self.evaluate(dist.read_var(ret_v_sum)))
+ self.assertEquals(2.0, self.evaluate(ret_v_sum))
+
# Apply updates.
self.evaluate(update_ops)
# Assert that the aggregated value of the tower local vars is the sum of
# the individual values after running the update ops.
self.assertEquals(5.0, self.evaluate(
ret_v_sum.get(dist._devices[0]).read_value()))
- self.assertEquals(10.0, self.evaluate(dist.read_var(ret_v_sum)))
+ self.assertEquals(10.0, self.evaluate(ret_v_sum))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py
index 4fdb9bf69b..2892ce4394 100644
--- a/tensorflow/contrib/distribute/python/monitor_test.py
+++ b/tensorflow/contrib/distribute/python/monitor_test.py
@@ -52,11 +52,11 @@ class MonitorTest(test.TestCase, parameterized.TestCase):
self.assertEqual(1, len(layer.trainable_variables))
mirrored_weight_variable = layer.trainable_variables[0]
- start_error = self.evaluate(distribution.fetch(mirrored_weight_variable))
+ start_error = self.evaluate(mirrored_weight_variable)
start_error = abs(numpy.array(start_error) - 1)
monitor.run_steps(9)
- end_error = self.evaluate(distribution.fetch(mirrored_weight_variable))
+ end_error = self.evaluate(mirrored_weight_variable)
end_error = abs(numpy.array(end_error) - 1)
self.assertGreaterEqual(start_error, end_error)
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index abd3a65ac4..a2d736e422 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -59,8 +59,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
for _ in range(10):
run_step()
- weights.append(self.evaluate(distribution.fetch(layer.kernel)))
- biases.append(self.evaluate(distribution.fetch(layer.bias)))
+ weights.append(self.evaluate(layer.kernel))
+ biases.append(self.evaluate(layer.bias))
error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 75c5ec9659..2ee94d8f70 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,8 +50,8 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
for _ in range(10):
run_step()
- weights.append(self.evaluate(distribution.fetch(layer.kernel)))
- biases.append(self.evaluate(distribution.fetch(layer.bias)))
+ weights.append(self.evaluate(layer.kernel))
+ biases.append(self.evaluate(layer.bias))
error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 2b4ad9f146..d2fe8b3b1e 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -106,13 +106,13 @@ class DistributionTestBase(test.TestCase):
before_list = []
after_list = []
for g, v in g_v:
- fetched = d.fetch(v)
+ fetched = d.read_var(v)
before_list.append(fetched)
# control_dependencies irrelevant but harmless in eager execution
with ops.control_dependencies([fetched]):
g = d.reduce("sum", g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
- after_list.append(d.fetch(v))
+ after_list.append(d.read_var(v))
return before_list, after_list
for i in range(10):
@@ -159,12 +159,12 @@ class DistributionTestBase(test.TestCase):
before_list = []
after_list = []
for g, v in g_v:
- fetched = d.fetch(v)
+ fetched = d.read_var(v)
before_list.append(fetched)
with ops.control_dependencies([fetched]):
g = d.reduce("sum", g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
- after_list.append(d.fetch(v))
+ after_list.append(d.read_var(v))
return before_list, after_list
before_out, after_out = step()
@@ -184,7 +184,7 @@ class DistributionTestBase(test.TestCase):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
- observed = d.fetch(d.reduce("sum", map_out))
+ observed = d.reduce("sum", map_out)
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index aca544b7e7..9a48928a95 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,6 @@ import weakref
import six
-from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
@@ -43,7 +42,7 @@ from tensorflow.python.util import nest
# pylint: disable=line-too-long
-# TODO(josh11b): Should device values be strings or DeviceSpec objects
+# TODO(josh11b): Should device values be strings or DeviceSpec objects?
# Not sure DeviceSpec objects are usable as a dict key.
class DistributedValues(object):
"""Holds a map from device to values. Either PerDevice or Mirrored."""
@@ -163,9 +162,16 @@ class PerDevice(DistributedValues):
pass
-class Mirrored(DistributedValues):
+# Note that unlike PerDevice, Mirrored values inherit from
+# DistributedDelegate and so can be used directly in cross-tower mode.
+class Mirrored(DistributedDelegate):
"""Holds a map from device to values which are kept in sync."""
- pass
+
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return self._index[device]
+ return list(self._index.values())[0]
def _assign_on_device(device, variable, tensor):
@@ -186,6 +192,10 @@ class DistributedVariable(DistributedDelegate):
# Child class must set self._primary_var before calling
# super(...).__init__(index).
self._common_name = self._primary_var.name.split(":")[0]
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
super(DistributedVariable, self).__init__(index)
@property
@@ -281,10 +291,6 @@ class MirroredVariable(DistributedVariable, Mirrored,
"""Holds a map from device to variables whose values are kept in sync."""
def __init__(self, index, primary_var):
- # Use a weakref to make it easy to map from the contained values
- # to the container without introducing a reference cycle.
- for v in six.itervalues(index):
- v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
super(MirroredVariable, self).__init__(index)
@@ -353,7 +359,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
- return distribute_lib.get_distribution_strategy().fetch(
+ return distribute_lib.get_distribution_strategy().read_var(
tower_local_variable)
spec = saver.BaseSaverBuilder.SaveSpec(
tensor=tensor,
@@ -492,40 +498,40 @@ def regroup(per_device, wrap_class=PerDevice):
same_id = False
break
# Consider three cases where same_id is true:
- # * If v0 is a MirroredVariable (and same_id means it is the same
- # across all devices), we want to return it. We check
- # MirroredVariable specifically since it can look like it
- # has a _mirrored_container member since its members do.
- # * If v0 is a member of a mirrored variable, in which case
- # hasattr(v0, "_mirrored_container") is true, we want to
- # return the MirroredVariable that contains it using the
- # _mirrored_container logic below. This case can trigger
+ # * If v0 is a DistributedVariable (a MirroredVariable or
+ # TowerLocalVariable, and same_id means it is the same across all
+ # devices), we want to return it. We check DistributedVariable
+ # specifically since it can look like it has a
+ # _distributed_container member since its members do.
+ # * If v0 is a member of a distributed variable, in which case
+ # hasattr(v0, "_distributed_container") is true, we want to
+ # return the DistributedVariable that contains it using the
+ # _distributed_container logic below. This case can trigger
# same_id when there is only one device.
# * In any other situation, same_id means we return v0.
- if same_id and (isinstance(v0, MirroredVariable) or
- not hasattr(v0, "_mirrored_container")):
+ if same_id and (isinstance(v0, DistributedVariable) or
+ not hasattr(v0, "_distributed_container")):
return v0
# Detect the case where each device has a parallel component of the
- # same MirroredVariable. In this case we want to return the
- # containing MirroredVariable, after a bunch of sanity checking.
- # In particular, each component should have the same container,
- # and the devices of the variables should match the keys of the
- # per-device dictionary.
- # TODO(josh11b): Do we need similar logic for TowerLocalVariables?
- if hasattr(v0, "_mirrored_container"):
+ # same MirroredVariable (or TowerLocalVariable). In this case we
+ # want to return the containing MirroredVariable, after a bunch of
+ # sanity checking. In particular, each component should have the
+ # same container, and the devices of the variables should match the
+ # keys of the per-device dictionary.
+ if hasattr(v0, "_distributed_container"):
# pylint: disable=protected-access
assert not isinstance(v0, MirroredVariable), (
"ids = %s, items = %s" % ([id(v[1]) for v in items], items))
assert _devices_match(v0.device, items[0][0]), (
"v0.device = %s, items = %s" % (v0.device, items))
- mirrored_container = v0._mirrored_container()
- assert mirrored_container is not None
+ distributed_container = v0._distributed_container()
+ assert distributed_container is not None
for d, v in items[1:]:
assert _devices_match(v.device, d), (
"v.device = %s, d = %s, items = %s" % (v.device, d, items))
- assert mirrored_container is v._mirrored_container()
- return mirrored_container
+ assert distributed_container is v._distributed_container()
+ return distributed_container
# pylint: enable=protected-access
return wrap_class(per_device)
@@ -607,8 +613,7 @@ class PerDeviceDataset(object):
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
- self._dataset = dataset.apply(
- batching.batch_and_drop_remainder(len(devices)))
+ self._dataset = dataset.batch(len(devices), drop_remainder=True)
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
index e281e81bdf..d1ce273499 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
@@ -61,6 +61,28 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
atol=0.,
rtol=1e-7)
+ def testNoBatchStaticJacobian(self):
+ x = np.eye(2)
+ bijector = bijectors.CholeskyOuterProduct()
+
+ # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4.
+ self.assertAllClose(
+ np.log(4),
+ self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=2)))
+
+ def testNoBatchDynamicJacobian(self):
+ x = np.eye(2)
+ bijector = bijectors.CholeskyOuterProduct()
+ x_pl = array_ops.placeholder(dtypes.float32)
+
+ with self.test_session():
+ log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2)
+
+ # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4.
+ self.assertAllClose(
+ np.log(4),
+ log_det_jacobian.eval({x_pl: x}))
+
def testNoBatchStatic(self):
x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y)
y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
index 8267ee7df8..3e1e4fc829 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
@@ -182,7 +182,20 @@ class CholeskyOuterProduct(bijector.Bijector):
axis=-1)
fldj = p_float * np.log(2.) + sum_weighted_log_diag
- return fldj
+ # We finally need to undo adding an extra column in non-scalar cases
+ # where there is a single matrix as input.
+ if x.get_shape().ndims is not None:
+ if x.get_shape().ndims == 2:
+ fldj = array_ops.squeeze(fldj, axis=-1)
+ return fldj
+
+ shape = array_ops.shape(fldj)
+ maybe_squeeze_shape = array_ops.concat([
+ shape[:-1],
+ distribution_util.pick_vector(
+ math_ops.equal(array_ops.rank(x), 2),
+ np.array([], dtype=np.int32), shape[-1:])], 0)
+ return array_ops.reshape(fldj, maybe_squeeze_shape)
def _make_columnar(self, x):
"""Ensures non-scalar input has at least one column.
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 6f02c90368..12155a459c 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -15,6 +15,8 @@ py_library(
"//tensorflow/contrib/eager/python/examples/revnet:config",
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
"//tensorflow/contrib/eager/python/examples/rnn_ptb",
+ "//tensorflow/contrib/eager/python/examples/sagan",
+ "//tensorflow/contrib/eager/python/examples/sagan:config",
"//tensorflow/contrib/eager/python/examples/spinn:data",
],
)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD
new file mode 100644
index 0000000000..b470a41d81
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/BUILD
@@ -0,0 +1,59 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+# Model
+py_library(
+ name = "config",
+ srcs = ["config.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "ops",
+ srcs = ["ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "sagan",
+ srcs = ["sagan.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+# Tests
+cuda_py_test(
+ name = "ops_test",
+ size = "small",
+ srcs = ["ops_test.py"],
+ additional_deps = [
+ ":ops",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "sagan_test",
+ size = "large",
+ srcs = ["sagan_test.py"],
+ additional_deps = [
+ ":config",
+ ":sagan",
+ "//tensorflow:tensorflow_py",
+ ],
+ tags = [
+ "optonly",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py
new file mode 100644
index 0000000000..1967bbd867
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/config.py
@@ -0,0 +1,72 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Self-attention generative adversarial with eager execution.
+
+Configuration in format of tf.contrib.training.HParams.
+Supports default 128x128 ImageNet.
+
+Reference [Self-Attention Generative Adversarial
+Networks](https://arxiv.org/pdf/1805.08318.pdf)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+tfe = tf.contrib.eager
+
+
+def get_hparams_imagenet():
+ """Configurations to train SAGAN on 128x128 ImageNet dataset."""
+ config = tf.contrib.training.HParams()
+ if tf.test.is_gpu_available():
+ config.add_hparam("image_shape", (3, 128, 128))
+ config.add_hparam("data_format", "channels_first")
+ config.add_hparam("g_init_shape", (512, 4, 4))
+ else:
+ config.add_hparam("image_shape", (128, 128, 3))
+ config.add_hparam("data_format", "channels_first")
+ config.add_hparam("g_init_shape", (4, 4, 512))
+
+ config.add_hparam("latent_dim", 128)
+ config.add_hparam("update_g_once_every", 1)
+ config.add_hparam("batch_size", 64)
+ config.add_hparam("d_init_filters", 32)
+ config.add_hparam("num_upsamples", 5)
+ # (512, 4, 4) -> (3, 128, 128)
+ return config
+
+
+def get_hparams_mock():
+ """Configurations of smaller networks for testing."""
+ config = tf.contrib.training.HParams()
+ if tf.test.is_gpu_available():
+ config.add_hparam("image_shape", (3, 16, 16))
+ config.add_hparam("data_format", "channels_first")
+ config.add_hparam("g_init_shape", (32, 2, 2))
+ else:
+ config.add_hparam("image_shape", (16, 16, 3))
+ config.add_hparam("data_format", "channels_last")
+ config.add_hparam("g_init_shape", (2, 2, 32))
+
+ config.add_hparam("latent_dim", 16)
+ config.add_hparam("update_g_once_every", 1)
+ config.add_hparam("batch_size", 2)
+ config.add_hparam("d_init_filters", 4)
+ config.add_hparam("num_upsamples", 3)
+ # (32, 2, 2) -> (3, 16, 16)
+ return config
diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py
new file mode 100644
index 0000000000..9a03cab1d1
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/ops.py
@@ -0,0 +1,71 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Self-attention generative adversarial with eager execution.
+
+Auxiliary operations.
+
+Reference [Self-Attention Generative Adversarial
+Networks](https://arxiv.org/pdf/1805.08318.pdf)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def flatten_hw(x, data_format="channels_first"):
+ """Flatten the input tensor across height and width dimensions."""
+ if data_format == "channels_last":
+ x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first`
+
+ old_shape = tf.shape(x)
+ new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]]
+
+ return tf.reshape(x, new_shape)
+
+
+def broaden_hw(x, h, w, c, data_format="channels_first"):
+ """Broaden dimension so that output has height and width."""
+ if data_format == "channels_first":
+ shape = [-1, c, h, w]
+ else:
+ shape = [-1, h, w, c]
+
+ return tf.reshape(x, shape)
+
+
+class BroadenHW(tf.keras.layers.Layer):
+ """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`."""
+
+ def __init__(self, h, w, c, data_format="channels_first"):
+ super(BroadenHW, self).__init__()
+ self.h = h
+ self.w = w
+ self.c = c
+ self.data_format = data_format
+
+ def call(self, x):
+ return broaden_hw(
+ x, h=self.h, w=self.w, c=self.c, data_format=self.data_format)
+
+ def compute_output_shape(self, input_shape):
+ input_shape = tf.TensorShape(input_shape).as_list()
+ if self.data_format == "channels_first":
+ output_shape = (input_shape[0], self.c, self.h, self.w)
+ else:
+ output_shape = (input_shape[0], self.h, self.w, self.c)
+
+ return tf.TensorShape(output_shape)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py
new file mode 100644
index 0000000000..3454985904
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py
@@ -0,0 +1,59 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for auxiliary operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.sagan import ops
+
+
+class OpsTest(tf.test.TestCase):
+
+ def test_flatten_hw(self):
+ """Test `flatten_hw` function with mock object."""
+
+ batch_size = 1
+ # Default NCHW format
+ if tf.test.is_gpu_available():
+ x = tf.random_normal(shape=(batch_size, 3, 4, 4))
+ y = ops.flatten_hw(x, data_format="channels_first")
+ self.assertEqual(y.shape, (batch_size, 4 * 4, 3))
+
+ # NHWC format
+ x = tf.random_normal(shape=(batch_size, 4, 4, 3))
+ y = ops.flatten_hw(x, data_format="channels_last")
+ self.assertEqual(y.shape, (batch_size, 4 * 4, 3))
+
+ def test_broaden_hw(self):
+ """Test `broaden_hw` function with mock object."""
+
+ batch_size = 1
+ # NHWC format
+ x = tf.random_normal(shape=[batch_size, 4 * 4 * 16])
+ y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last")
+ self.assertEqual(y.shape, (batch_size, 4, 4, 16))
+
+ # Default NCHW format
+ if tf.test.is_gpu_available():
+ y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first")
+ self.assertEqual(y.shape, (batch_size, 16, 4, 4))
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
new file mode 100644
index 0000000000..561be36c91
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
@@ -0,0 +1,232 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Self-attention generative adversarial with eager execution.
+
+Code for main model.
+
+Reference [Self-Attention Generative Adversarial
+Networks](https://arxiv.org/pdf/1805.08318.pdf)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.sagan import ops
+tfe = tf.contrib.eager
+
+
+class SelfAttentionModule(tf.keras.Model):
+ """Self-attention module composed of convolutional layers."""
+
+ def __init__(self,
+ attention_features,
+ original_features,
+ data_format="channels_first"):
+ """Initialize the module.
+
+ Args:
+ attention_features: Number of filters for the attention computation.
+ original_features: Number of filters of the original Tensor.
+ data_format: Either 'channels_first' or 'channels_last'
+ """
+ super(SelfAttentionModule, self).__init__()
+ self.data_format = data_format
+ # Matrix multiplication implemented as 2D Convolution
+ self.f = tf.keras.layers.Conv2D(
+ filters=attention_features,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format)
+ self.g = tf.keras.layers.Conv2D(
+ filters=attention_features,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format)
+ self.h = tf.keras.layers.Conv2D(
+ filters=original_features,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format)
+ self.scale = tfe.Variable(0., trainable=True)
+
+ def call(self, x):
+ f = self.f(x)
+ g = self.g(x)
+ h = self.h(x)
+
+ f_flatten = ops.flatten_hw(f, data_format=self.data_format)
+ g_flatten = ops.flatten_hw(g, data_format=self.data_format)
+ h_flatten = ops.flatten_hw(h, data_format=self.data_format)
+
+ s = tf.matmul(g_flatten, f_flatten, transpose_b=True)
+ b = tf.nn.softmax(s, axis=-1)
+ o = tf.matmul(b, h_flatten)
+ y = self.scale * tf.reshape(o, tf.shape(x)) + x
+
+ return y
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+
+class SAGAN(tf.contrib.checkpoint.Checkpointable):
+ """Self-attention generative adversarial network."""
+
+ def __init__(self, config):
+ """Initialize the model.
+
+ Args:
+ config: tf.contrib.training.HParams object; specifies hyperparameters
+ """
+ super(SAGAN, self).__init__()
+ self.config = config
+ self.generator = self._construct_generator()
+ self.discriminator = self._construct_discriminator()
+
+ def _construct_generator(self):
+ """Construct generator."""
+ # TODO(lxuechen): Add spectral normalization for WGAN
+ axis = 1 if self.config.data_format == "channels_first" else 3
+
+ generator = tf.keras.Sequential()
+ generator.add(
+ tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,)))
+ generator.add(
+ tf.keras.layers.Dense(
+ units=np.prod(self.config.g_init_shape), activation=tf.nn.relu))
+
+ if self.config.data_format == "channels_first":
+ c, h, w = self.config.g_init_shape
+ else:
+ h, w, c = self.config.g_init_shape
+
+ # Reshape to NHWC/NCHW
+ generator.add(
+ ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format))
+
+ filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)]
+ filters_list[-1] = 3 # Standard RGB images
+
+ for filters in filters_list[:len(filters_list) // 2]:
+ generator.add(
+ tf.keras.layers.Conv2DTranspose(
+ filters=filters,
+ kernel_size=4,
+ strides=(2, 2),
+ use_bias=False,
+ padding="SAME",
+ data_format=self.config.data_format))
+ generator.add(tf.keras.layers.BatchNormalization(axis=axis))
+ generator.add(tf.keras.layers.Activation("relu"))
+
+ # pylint: disable=undefined-loop-variable
+ generator.add(
+ SelfAttentionModule(
+ original_features=filters,
+ attention_features=filters // 8,
+ data_format=self.config.data_format))
+ # pylint: enable=undefined-loop-variable
+
+ for filters in filters_list[len(filters_list) // 2:]:
+ generator.add(
+ tf.keras.layers.Conv2DTranspose(
+ filters=filters,
+ kernel_size=4,
+ strides=(2, 2),
+ use_bias=False,
+ padding="SAME",
+ data_format=self.config.data_format))
+ if filters == 3:
+ # Assume Image rescaled to [-1, 1]
+ generator.add(tf.keras.layers.Activation("tanh"))
+ else:
+ generator.add(tf.keras.layers.BatchNormalization(axis=axis))
+ generator.add(tf.keras.layers.Activation("relu"))
+
+ return generator
+
+ def _construct_discriminator(self):
+ """Construct discriminator."""
+ # TODO(lxuechen): Add spectral normalization for WGAN
+ discriminator = tf.keras.Sequential()
+ discriminator.add(
+ tf.keras.layers.InputLayer(input_shape=self.config.image_shape))
+
+ filters_list = [
+ self.config.d_init_filters * 2**p
+ for p in range(self.config.num_upsamples)
+ ]
+
+ for filters in filters_list[:(len(filters_list) + 1) // 2]:
+ discriminator.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=4,
+ strides=(2, 2),
+ padding="SAME",
+ data_format=self.config.data_format))
+ discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1))
+
+ # pylint: disable=undefined-loop-variable
+ discriminator.add(
+ SelfAttentionModule(
+ original_features=filters,
+ attention_features=filters // 8,
+ data_format=self.config.data_format))
+ # pylint: enable=undefined-loop-variable
+
+ for filters in filters_list[(len(filters_list) + 1) // 2:]:
+ discriminator.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=4,
+ strides=(2, 2),
+ padding="SAME",
+ data_format=self.config.data_format))
+ discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1))
+
+ discriminator.add(tf.keras.layers.Flatten())
+ discriminator.add(tf.keras.layers.Dense(units=1))
+
+ return discriminator
+
+ def compute_loss_and_grads(self, real_images, noise, training=True):
+ """Compute loss and gradients for both generator and discriminator."""
+ # TODO(lxuechen): Add gradient penalty for discriminator
+ with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
+ real_logits = self.discriminator(real_images, training=training)
+
+ fake_images = self.generator.call(noise, training=training)
+ fake_logits = self.discriminator.call(fake_images)
+
+ g_loss = self.compute_g_loss(fake_logits)
+ d_loss = self.compute_d_loss(fake_logits, real_logits)
+
+ g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables)
+ d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
+
+ return g_loss, d_loss, g_grads, d_grads
+
+ def compute_g_loss(self, fake_logits):
+ return -tf.reduce_mean(fake_logits) # Hinge loss
+
+ def compute_d_loss(self, fake_logits, real_logits):
+ # Hinge loss
+ real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits))
+ fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits))
+ return real_loss + fake_loss
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py
new file mode 100644
index 0000000000..1834594510
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for self-attention generative adversarial network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.sagan import config as config_
+from tensorflow.contrib.eager.python.examples.sagan import sagan
+tfe = tf.contrib.eager
+
+
+class SAGANTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(SAGANTest, self).setUp()
+ config = config_.get_hparams_mock()
+ self.noise_shape = (config.batch_size, config.latent_dim)
+ self.logits_shape = (config.batch_size, 1)
+ self.images_shape = (config.batch_size,) + config.image_shape
+
+ self.model = sagan.SAGAN(config=config)
+ self.noise = tf.random_normal(shape=self.noise_shape)
+ self.real_images = tf.random_normal(shape=self.images_shape)
+ self.config = config
+
+ def tearDown(self):
+ del self.model
+ del self.noise
+ del self.real_images
+ super(SAGANTest, self).tearDown()
+
+ def test_generator_call(self):
+ """Test `generator.__call__` function."""
+ fake_images = self.model.generator(self.noise, training=False)
+ self.assertEqual(fake_images.shape, self.images_shape)
+
+ def test_generator_call_defun(self):
+ """Test `generator.__call__` function with defun."""
+ call_ = tfe.defun(self.model.generator.__call__)
+ fake_images = call_(self.noise, training=False)
+ self.assertEqual(fake_images.shape, self.images_shape)
+
+ def test_discriminator_call(self):
+ """Test `discriminator.__call__` function."""
+ real_logits = self.model.discriminator(self.real_images)
+ self.assertEqual(real_logits.shape, self.logits_shape)
+
+ def test_discriminator_call_defun(self):
+ """Test `discriminator.__call__` function with defun."""
+ call_ = tfe.defun(self.model.discriminator.__call__)
+ real_logits = call_(self.real_images)
+ self.assertEqual(real_logits.shape, self.logits_shape)
+
+ def test_compute_loss_and_grads(self):
+ """Test `compute_loss_and_grads` function."""
+ g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads(
+ self.real_images, self.noise, training=False)
+ self.assertEqual(g_loss.shape, ())
+ self.assertEqual(d_loss.shape, ())
+ self.assertTrue(isinstance(g_grads, list))
+ self.assertTrue(isinstance(d_grads, list))
+ g_vars = self.model.generator.trainable_variables
+ d_vars = self.model.discriminator.trainable_variables
+
+ self.assertEqual(len(g_grads), len(g_vars))
+ self.assertEqual(len(d_grads), len(d_vars))
+
+ def test_compute_loss_and_grads_defun(self):
+ """Test `compute_loss_and_grads` function with defun."""
+ compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads)
+ g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads(
+ self.real_images, self.noise, training=False)
+ self.assertEqual(g_loss.shape, ())
+ self.assertEqual(d_loss.shape, ())
+ self.assertTrue(isinstance(g_grads, list))
+ self.assertTrue(isinstance(d_grads, list))
+ g_vars = self.model.generator.trainable_variables
+ d_vars = self.model.discriminator.trainable_variables
+
+ self.assertEqual(len(g_grads), len(g_vars))
+ self.assertEqual(len(d_grads), len(d_vars))
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index fee9db46fa..113aa7967c 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -68,6 +68,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@async_clear_error
@@run_test_in_graph_and_eager_modes
+@@run_all_tests_in_graph_and_eager_modes
@@DEVICE_PLACEMENT_EXPLICIT
@@DEVICE_PLACEMENT_WARN
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index a955e21b72..4d62ac65ff 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -21,8 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -35,13 +33,6 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
-def NoMemoryOptimizationConfig():
- config = config_pb2.ConfigProto()
- config.graph_options.rewrite_options.memory_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
- return config
-
-
def GetShrunkInceptionShapes(shrink=10):
"""Iterator for smaller versions of convolution shapes in 2015 Inception.
@@ -202,8 +193,7 @@ class FusedConv2DBiasActivationTest(test.TestCase):
# This is to guarantee that there is always negative values after
# bias add so that we can test whether relu works correctly.
x3 = bias
- # TODO(b/79323979): re-enable memory optimization after this bug is fixed.
- with self.test_session(use_gpu=True, config=NoMemoryOptimizationConfig()):
+ with self.test_session(use_gpu=True):
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype)
t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype)
fused_t2 = t2
@@ -251,9 +241,7 @@ class FusedConv2DBiasActivationTest(test.TestCase):
x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32)
def _SetupVal(data_format, use_gpu):
- # TODO(b/79323979): re-enable memory optimization after this bug is fixed.
- with self.test_session(
- use_gpu=use_gpu, config=NoMemoryOptimizationConfig()):
+ with self.test_session(use_gpu=use_gpu):
t1 = constant_op.constant(x1, shape=tensor_in_sizes)
t2 = constant_op.constant(x2, shape=filter_in_sizes)
t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]])
@@ -877,9 +865,7 @@ class FusedConvInt8Tests(test.TestCase):
conv_input_scale, conv_input, kernel, padding_type, strides,
side_input_scale, side_input, biases)
- # TODO(b/79323979): re-enable memory optimization after this bug is fixed.
- with self.test_session(
- use_gpu=True, config=NoMemoryOptimizationConfig()) as sess:
+ with self.test_session(use_gpu=True) as sess:
actual_y, expected_y = sess.run([actual, expected])
tf_logging.info("actual_y = ", actual_y)
tf_logging.info("expected_y = ", expected_y)
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index add36b46c0..99f81c4a8a 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -84,6 +84,38 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
&data->input_left_shift);
data->input_range_radius =
CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
+ } else if (input->type == kTfLiteInt16) {
+ static constexpr int kInputIntegerBits = 3;
+ static constexpr int kOutputFractionalBits = 15;
+
+ // These operators are implemented in fixed-point arithmetic,
+ // which intrinsically wants symmetric ranges (zero_point==0)
+ // and power-of-two scales (power-of-two is abbreviated below as POT).
+ // While more general support would be possible by means of rescaling,
+ // that would add some overhead and some loss of accuracy and wouldn't
+ // be used at the moment as current quantized LSTM applications are
+ // happy with symmetric, power-of-two-scales quantization. So we just
+ // implement that narrow case only for now.
+
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+
+ int input_scale_log2_rounded;
+ TF_LITE_ENSURE(context,
+ CheckedLog2(input->params.scale, &input_scale_log2_rounded));
+
+ int output_scale_log2_rounded;
+ TF_LITE_ENSURE(
+ context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
+ TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
+ -kOutputFractionalBits);
+
+ data->input_left_shift =
+ (15 - kInputIntegerBits) + input_scale_log2_rounded;
+ // Support for shifts is limited until we have a parameterized version of
+ // SaturatingRoundingMultiplyByPOT().
+ TF_LITE_ENSURE(context, data->input_left_shift >= 0);
+ TF_LITE_ENSURE(context, data->input_left_shift <= 1);
}
return context->ResizeTensor(context, output,
@@ -114,6 +146,30 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
&data->input_left_shift);
data->input_range_radius =
CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
+ } else if (input->type == kTfLiteInt16) {
+ static constexpr int kInputIntegerBits = 3;
+ static constexpr int kOutputFractionalBits = 15;
+
+ // See comments in TanhPrepare about requiring zero_point==0
+ // and a power-of-two ("POT") scale.
+
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+
+ int input_scale_log2_rounded;
+ TF_LITE_ENSURE(context,
+ CheckedLog2(input->params.scale, &input_scale_log2_rounded));
+
+ int output_scale_log2_rounded;
+ TF_LITE_ENSURE(
+ context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
+ TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
+ -kOutputFractionalBits);
+
+ data->input_left_shift =
+ (15 - kInputIntegerBits) + input_scale_log2_rounded;
+ // The int16 logistic implementation does not support shifting of the input.
+ TF_LITE_ENSURE_EQ(context, data->input_left_shift, 0);
}
return context->ResizeTensor(context, output,
@@ -250,12 +306,19 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
for (; in < in_end; in++, out++) *out = std::tanh(*in);
return kTfLiteOk;
} break;
+ case kTfLiteInt16: {
+ optimized_ops::Tanh(GetTensorData<int16_t>(input), GetTensorShape(input),
+ data->input_left_shift,
+ GetTensorData<int16_t>(output),
+ GetTensorShape(output));
+ return kTfLiteOk;
+ } break;
case kTfLiteUInt8: {
- optimized_ops::Tanh(GetTensorData<uint8_t>(input), GetTensorDims(input),
+ optimized_ops::Tanh(GetTensorData<uint8_t>(input), GetTensorShape(input),
input->params.zero_point, data->input_range_radius,
data->input_multiplier, data->input_left_shift,
GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ GetTensorShape(output));
return kTfLiteOk;
} break;
default:
@@ -280,12 +343,18 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in));
break;
}
+ case kTfLiteInt16: {
+ optimized_ops::Logistic(
+ GetTensorData<int16>(input), GetTensorShape(input),
+ GetTensorData<int16_t>(output), GetTensorShape(output));
+ break;
+ }
case kTfLiteUInt8: {
optimized_ops::Logistic(
- GetTensorData<uint8_t>(input), GetTensorDims(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(input),
input->params.zero_point, data->input_range_radius,
data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output), GetTensorDims(output));
+ GetTensorData<uint8_t>(output), GetTensorShape(output));
break;
}
default:
@@ -341,26 +410,26 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
optimized_ops::Softmax(GetTensorData<uint8_t>(input),
- GetTensorDims({batch_size, 1, 1, input_size}),
+ GetTensorShape({batch_size, 1, 1, input_size}),
data->input_multiplier, data->input_left_shift,
data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorDims({batch_size, 1, 1, input_size}));
+ GetTensorShape({batch_size, 1, 1, input_size}));
}
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
- optimized_ops::Softmax(GetTensorData<float>(input), GetTensorDims(input),
+ optimized_ops::Softmax(GetTensorData<float>(input), GetTensorShape(input),
params->beta, GetTensorData<float>(output),
- GetTensorDims(output));
+ GetTensorShape(output));
}
void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
- optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorDims(input),
+ optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorShape(input),
data->input_multiplier, data->input_left_shift,
data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ GetTensorShape(output));
}
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
@@ -415,8 +484,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
switch (input->type) {
case kTfLiteFloat32:
optimized_ops::LogSoftmax(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(output), GetTensorDims(output));
+ GetTensorData<float>(input), GetTensorShape(input),
+ GetTensorData<float>(output), GetTensorShape(output));
return kTfLiteOk;
default:
context->ReportError(context, "Only float32 supported currently., got %d",
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index 50a84edd47..587e1303da 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -75,23 +75,42 @@ class FloatActivationsOpModel : public BaseActivationsOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
-// TODO(ahentz): I don't quite understand the tradeoffs in the quantized
-// implementation of sigmoid and software, but a tolerance of twice the output
-// scale seems reasonable. We might want to change this if we have a better
-// theoretical bound.
+// Our fixed-point math function implementations have roughly 12 bits of
+// accuracy, when specialized to 16-bit fixed-point arithmetic.
+// That is purely an implementation compromise, it would have been possible
+// to get closer to 16 bits of accuracy but that would be more expensive,
+// and not needed for our purposes as ultimately the output is either
+// immediately down-quantized to 8 bits, or will typically be at the output
+// of the surrounding LSTM cell.
+// So we can require roughly 2^-12 accuracy when the output is 16-bit, and
+// we can more or less expect the full 2^-8 accuracy when the output is 8-bit.
+//
+// However, the representable output interval is often [-1, 1] (it has to be
+// for tanh, and even for logistic, when we implement it in fixed-point, we
+// typically have to do so on such a symmetric interval, e.g. ARM NEON only
+// has signed fixed-point arithmetic (SQRDMULH)). As the width of [-1, 1]
+// is 2, our representable values are often diluted by a factor of 2, whence
+// the factor of 2 below.
const float kQuantizedTolerance = 2 * (1. / 256);
+const float kQuantizedToleranceInt16 = 2 * (1. / 4096);
class QuantizedActivationsOpModel : public BaseActivationsOpModel {
public:
using BaseActivationsOpModel::BaseActivationsOpModel;
+ template <typename T>
void SetInput(std::initializer_list<float> data) {
- QuantizeAndPopulate<uint8_t>(input_, data);
+ QuantizeAndPopulate<T>(input_, data);
}
- std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ template <typename T>
+
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ template <typename T>
std::vector<float> GetDequantizedOutput() {
- return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
- GetScale(output_), GetZeroPoint(output_));
+ return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
+ GetZeroPoint(output_));
}
};
@@ -152,24 +171,47 @@ TEST(FloatActivationsOpTest, Tanh) {
}
TEST(QuantizedActivationsOpTest, Tanh) {
+ const float kMin = -1;
+ const float kMax = 127.f / 128.f;
QuantizedActivationsOpModel m(
BuiltinOperator_TANH,
- /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -8, 8},
- /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, -1, 1});
- m.SetInput({
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
-4, -2, 8, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
0.0, -0.999987, 0.964027, 0.999329, //
- -0.996078, -0.96402, 0.99999, 0.76159, //
+ -0.999329, -0.96402, 0.99999, 0.76159, //
},
- 4 * (1. / 256))));
- EXPECT_THAT(m.GetOutput(),
- ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 226}));
+ kQuantizedTolerance)));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 225}));
+}
+
+TEST(QuantizedActivationsOpTest, TanhInt16) {
+ const float kMin = -1;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_TANH,
+ /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<int16_t>({
+ 0, -6, 2, 4, //
+ -4, -2, 8, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.0, -0.999987, 0.964027, 0.999329, //
+ -0.999329, -0.96402, 0.99999, 0.76159, //
+ },
+ kQuantizedToleranceInt16)));
}
TEST(FloatActivationsOpTest, Sigmoid) {
@@ -190,22 +232,43 @@ TEST(QuantizedActivationsOpTest, Sigmoid) {
QuantizedActivationsOpModel m(
BuiltinOperator_LOGISTIC,
/*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
0.5, 0.002473, 0.880797, 0.982014, //
0.952574, 0.119203, 0.999955, 0.731059, //
},
kQuantizedTolerance)));
- EXPECT_THAT(m.GetOutput(),
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188}));
}
+TEST(QuantizedActivationsOpTest, SigmoidInt16) {
+ const float kMin = -1;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_LOGISTIC,
+ /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<int16_t>({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.5, 0.002473, 0.880797, 0.982014, //
+ 0.952574, 0.119203, 0.999955, 0.731059, //
+ },
+ kQuantizedToleranceInt16)));
+}
+
TEST(FloatActivationsOpTest, Softmax4D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}});
@@ -241,12 +304,12 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
QuantizedActivationsOpModel m(
0.1,
/*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, // depth = 0
3, -2, 10, 1, // depth = 1
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
.23463, .12877, .28658, .35003, //
@@ -258,21 +321,22 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
QuantizedActivationsOpModel m2(
0.1,
/*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10});
- m2.SetInput({
+ m2.SetInput<uint8_t>({
0, -6, //
2, 4, //
3, -2, //
10, 1, //
});
m2.Invoke();
- EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
- {
- 0.645656, 0.354344, //
- 0.450166, 0.549834, //
- 0.622459, 0.377541, //
- 0.710949, 0.28905, //
- },
- kQuantizedTolerance)));
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
}
TEST(FloatActivationsOpTest, Softmax2D) {
@@ -309,12 +373,12 @@ TEST(FloatActivationsOpTest, Softmax2D) {
TEST(QuantizedActivationsOpTest, Softmax2D) {
QuantizedActivationsOpModel m(0.1,
/*input=*/{TensorType_UINT8, {2, 4}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
.23463, .12877, .28658, .35003, //
@@ -325,21 +389,22 @@ TEST(QuantizedActivationsOpTest, Softmax2D) {
// Same input, but a different shape.
QuantizedActivationsOpModel m2(0.1,
/*input=*/{TensorType_UINT8, {4, 2}, -10, 10});
- m2.SetInput({
+ m2.SetInput<uint8_t>({
0, -6, //
2, 4, //
3, -2, //
10, 1, //
});
m2.Invoke();
- EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
- {
- 0.645656, 0.354344, //
- 0.450166, 0.549834, //
- 0.622459, 0.377541, //
- 0.710949, 0.28905, //
- },
- kQuantizedTolerance)));
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
}
// This contains the same test values as the Softmax test, but reference answer
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
index e786f785ab..d2f1103e14 100644
--- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -32,19 +32,21 @@ namespace tflite {
namespace {
void RunLogSoftmaxFloatReference(const uint8* input_data,
- const Dims<4>& dims_common, int32 input_offset,
- const double input_scale, int stride,
- float beta, uint8* reference_output_data) {
- const int ref_buffer_size = RequiredBufferSizeForDims(dims_common);
+ const RuntimeShape& shape_common,
+ int32 input_offset, const double input_scale,
+ int stride, float beta,
+ uint8* reference_output_data) {
+ const int ref_buffer_size = shape_common.FlatSize();
std::vector<float> reference_dequant_data(ref_buffer_size);
std::vector<float> reference_output_float_data(ref_buffer_size);
// Reference data generated via Dequant of input into float, and then applying
// float LogSoftmax.
- reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale,
- reference_dequant_data.data(), dims_common);
- optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common,
- reference_output_float_data.data(), dims_common);
+ reference_ops::Dequantize(
+ input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
+ reference_dequant_data.data(), ToRuntimeDims(shape_common));
+ optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common,
+ reference_output_float_data.data(), shape_common);
// Work with quantized scaling for LogSoftmax, under which 255 represents 0,
// and -16 gets nudged up to 0.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -55,9 +57,9 @@ void RunLogSoftmaxFloatReference(const uint8* input_data,
}
void CheckOutputData(const uint8* test_output, const uint8* reference_output,
- const Dims<4>& dims_common, const string& check_label,
- bool be_exacting) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ const RuntimeShape& shape_common,
+ const string& check_label, bool be_exacting) {
+ const int buffer_size = shape_common.FlatSize();
// While calculating some metrics in floating point, we work with quantized
// scaling.
std::vector<int> diff(buffer_size);
@@ -99,15 +101,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output,
// Runs the LogSoftmax and compares against the float reference implementation
// and the quantized reference implementation.
-void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
- int32 input_offset, const double input_scale,
- int stride, float beta) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+void RunOneLogSoftmaxTest(const uint8* input_data,
+ const RuntimeShape& shape_common, int32 input_offset,
+ const double input_scale, int stride, float beta) {
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> optimized_logsoftmax_output(buffer_size);
std::vector<uint8> reference_float_logsoftmax_output(buffer_size);
std::vector<uint8> reference_quant_logsoftmax_output(buffer_size);
- RunLogSoftmaxFloatReference(input_data, dims_common, input_offset,
+ RunLogSoftmaxFloatReference(input_data, shape_common, input_offset,
input_scale, stride, beta,
reference_float_logsoftmax_output.data());
@@ -126,23 +128,23 @@ void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier,
+ optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier,
input_beta_left_shift, reverse_scaling_divisor,
reverse_scaling_right_shift, diff_min,
- optimized_logsoftmax_output.data(), dims_common);
+ optimized_logsoftmax_output.data(), shape_common);
reference_ops::LogSoftmax(
- input_data, dims_common, input_beta_multiplier, input_beta_left_shift,
+ input_data, shape_common, input_beta_multiplier, input_beta_left_shift,
reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
- reference_quant_logsoftmax_output.data(), dims_common);
+ reference_quant_logsoftmax_output.data(), shape_common);
CheckOutputData(optimized_logsoftmax_output.data(),
- reference_float_logsoftmax_output.data(), dims_common,
+ reference_float_logsoftmax_output.data(), shape_common,
"Optimized vs float reference", false);
CheckOutputData(optimized_logsoftmax_output.data(),
- reference_quant_logsoftmax_output.data(), dims_common,
+ reference_quant_logsoftmax_output.data(), shape_common,
"Optimized vs quant reference", true);
CheckOutputData(reference_quant_logsoftmax_output.data(),
- reference_float_logsoftmax_output.data(), dims_common,
+ reference_float_logsoftmax_output.data(), shape_common,
"Quant reference vs float reference", false);
}
@@ -165,13 +167,13 @@ bool TryOneUniformLogSoftmax() {
const int32 input_offset = UniformRandomInt(-256, 0);
static constexpr float beta = 1.0f;
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandom(&input_data);
- RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset,
+ RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
input_scale, stride, beta);
return true;
}
@@ -203,14 +205,14 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) {
const int middle_min = UniformRandomInt(0, 255);
const int sides_max = UniformRandomInt(0, middle_min);
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
sides_max);
- RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset,
+ RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
input_scale, stride, beta);
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index c0dda4acf1..7816752132 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -26,6 +26,10 @@ limitations under the License.
namespace tflite {
namespace optimized_ops {
+// Unoptimized reference ops:
+using reference_ops::Relu1;
+using reference_ops::Relu6;
+
inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
return RuntimeShape(
{dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
@@ -34,15 +38,285 @@ inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- return L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
int32 input_zero_point, uint8* output_data,
const Dims<4>& output_dims) {
- return L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
- output_data, DimsToShape(output_dims));
+ L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
+ output_data, DimsToShape(output_dims));
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, filter_width, filter_height,
+ output_activation_min, output_activation_max, output_data,
+ DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, filter_width, filter_height,
+ output_activation_min, output_activation_max, output_data,
+ DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, filter_width, filter_height,
+ output_activation_min, output_activation_max, output_data,
+ DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), beta, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
+ input_beta_left_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
+ input_left_shift, reverse_scaling_divisor,
+ reverse_scaling_right_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+ int input_left_shift, int16* output_data,
+ const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
+ DimsToShape(output_dims));
}
} // namespace optimized_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 107e95ea6e..868269477e 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -85,6 +85,12 @@ using VectorMap = typename std::conditional<
Eigen::Dynamic, 1>>,
Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
+template <typename Scalar>
+VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
+ const int size = shape.FlatSize();
+ return VectorMap<Scalar>(data, size, 1);
+}
+
template <typename Scalar, int N>
VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
const int size = FlatSize(dims);
@@ -101,6 +107,23 @@ using MatrixMap = typename std::conditional<
Eigen::Dynamic, Eigen::Dynamic>>,
Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
+ const RuntimeShape& shape) {
+ const int dims_count = shape.DimensionsCount();
+ const int rows = shape.Dims(dims_count - 1);
+ const int cols = FlatSizeSkipDim(shape, dims_count - 1);
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
+ const RuntimeShape& shape) {
+ const int cols = shape.Dims(0);
+ const int rows = FlatSizeSkipDim(shape, 0);
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
template <typename Scalar, int N>
MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
const Dims<N>& dims) {
@@ -2343,12 +2366,12 @@ void GlobalBatchNormalization(const float* input_data,
}
}
-inline void Relu(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Relu(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
- const auto input = MapAsVector(input_data, input_dims);
- auto output = MapAsVector(output_data, output_dims);
+ const auto input = MapAsVector(input_data, input_shape);
+ auto output = MapAsVector(output_data, output_shape);
output = input.cwiseMax(0.0f);
}
@@ -3739,23 +3762,25 @@ inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
-inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight,
- float output_activation_min,
+inline void AveragePool(const float* input_data,
+ const RuntimeShape& input_shape, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float output_activation_min,
float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("AveragePool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
// TODO(benoitjacob) make this a proper reference impl without Eigen!
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// TODO(benoitjacob) get rid of the dynamic memory allocation here!
Eigen::VectorXf out_count(out_mat.cols());
out_count.setZero();
@@ -3793,9 +3818,9 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
for (int y = 0; y < output_height; ++y) {
for (int x = 0; x < output_width; ++x) {
for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- output_data[Offset(output_dims, c, x, y, b)],
+ output_data[Offset(output_shape, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -3803,44 +3828,23 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
+inline void AveragePool(const uint8* input_data,
+ const RuntimeShape& input_shape, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height,
int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -3860,11 +3864,12 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
uint16 acc[kAccBufferMaxSize];
memset(acc, 0, depth * sizeof(acc[0]));
const uint8* input_ptr =
- input_data + input_dims.strides[1] * in_x_origin +
- input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ input_data +
+ depth * (in_x_origin +
+ input_width * (in_y_origin + input_height * batch));
for (int fy = filter_y_start; fy < filter_y_end; fy++) {
- const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
- filter_x_start * input_dims.strides[1];
+ const uint8* input_row_ptr =
+ input_ptr + depth * (fy * input_width + filter_x_start);
for (int fx = filter_x_start; fx < filter_x_end; fx++) {
int channel = 0;
#ifdef USE_NEON
@@ -3895,7 +3900,7 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
uint8* output_ptr =
- output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ output_data + Offset(output_shape, batch, out_y, out_x, 0);
int channel = 0;
#ifdef USE_NEON
#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
@@ -3936,54 +3941,23 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+inline void MaxPool(const float* input_data, const RuntimeShape& input_shape,
int stride_width, int stride_height, int pad_width,
int pad_height, int kwidth, int kheight,
float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("MaxPool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Prefill the output to minimum representable float value
out_mat.setConstant(std::numeric_limits<float>::lowest());
for (int b = 0; b < batches; ++b) {
@@ -4016,9 +3990,9 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
for (int y = 0; y < output_height; ++y) {
for (int x = 0; x < output_width; ++x) {
for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- output_data[Offset(output_dims, c, x, y, b)],
+ output_data[Offset(output_shape, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -4026,41 +4000,21 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int kwidth, int kheight, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
int stride_width, int stride_height, int pad_width,
int pad_height, int filter_width, int filter_height,
int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -4078,11 +4032,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
uint8 acc[kAccBufferMaxSize];
memset(acc, 0, depth * sizeof(acc[0]));
const uint8* input_ptr =
- input_data + input_dims.strides[1] * in_x_origin +
- input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ input_data +
+ depth * (in_x_origin +
+ input_width * (in_y_origin + input_height * batch));
for (int fy = filter_y_start; fy < filter_y_end; fy++) {
- const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
- filter_x_start * input_dims.strides[1];
+ const uint8* input_row_ptr =
+ input_ptr + depth * (fy * input_width + filter_x_start);
for (int fx = filter_x_start; fx < filter_x_end; fx++) {
int channel = 0;
#ifdef USE_NEON
@@ -4108,7 +4063,7 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
uint8* output_ptr =
- output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ output_data + Offset(output_shape, batch, out_y, out_x, 0);
int channel = 0;
#ifdef USE_NEON
for (; channel <= depth - 16; channel += 16) {
@@ -4135,53 +4090,23 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
int stride_width, int stride_height, int pad_width,
int pad_height, int filter_width, int filter_height,
float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("L2Pool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
// Actually carry out L2 Pool. Code is written in forward mode: we go through
// the input values once, and write to all the pooled regions that it maps to.
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Eigen::VectorXf in_square(in_mat.rows());
Eigen::VectorXf out_count(out_mat.cols());
out_count.setZero();
@@ -4223,28 +4148,6 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
(out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
inline void LocalResponseNormalization(const float* input_data,
const Dims<4>& input_dims, int range,
float bias, float alpha, float beta,
@@ -4290,14 +4193,14 @@ inline void LocalResponseNormalization(const float* input_data,
}
}
-inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
float beta, float* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Softmax");
- MatchingFlatSize(input_dims, output_dims);
+ MatchingFlatSize(input_shape, output_shape);
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Compute the exponential first, removing the max coefficient for numerical
// stability.
out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
@@ -4309,10 +4212,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
out_mat.array().rowwise() *= scale;
}
-inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_beta_multiplier, int32 input_beta_left_shift,
int diff_min, uint8* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4326,8 +4229,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int b = 0; b < outer_size; ++b) {
const uint8* input_data_ptr = input_data + b * depth;
@@ -4517,11 +4423,14 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
// TODO(myenik): This is the same as the reference implementation, not actually
// optimized yet.
-inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
const float* block_input_data = input_data + i * depth;
@@ -4662,11 +4571,11 @@ log_x_for_x_greater_than_or_equal_to_1(
}
// Currently just a copy of the reference code.
-inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_multiplier, int32 input_left_shift,
int32 reverse_scaling_divisor,
int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
@@ -4681,8 +4590,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
const uint8* block_input_data = input_data + i * depth;
@@ -4746,21 +4658,21 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Logistic");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() =
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
- const int size = MatchingFlatSize(input_dims, output_dims);
+ const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
#ifdef USE_NEON
@@ -4892,10 +4804,10 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
}
@@ -4952,21 +4864,21 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Tanh");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().tanh();
}
-inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
// Note that this is almost the exact same code as in Logistic().
gemmlowp::ScopedProfilingLabel label("Tanh");
- const int size = MatchingFlatSize(input_dims, output_dims);
+ const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
int32_t output_zero_point = 128;
@@ -5107,16 +5019,16 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
int input_left_shift, int16* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Tanh/Int16");
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
TFLITE_DCHECK_LE(input_left_shift, 1);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
const int16* input_data_ptr = input_data;
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index 6f5f6a3e6f..878b2441b4 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -34,15 +34,297 @@ inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
- return L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
- DimsToShape(output_dims));
+ L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
}
inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
int32 input_zero_point, uint8* output_data,
const Dims<4>& output_dims) {
- return L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
- output_data, DimsToShape(output_dims));
+ L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
+ output_data, DimsToShape(output_dims));
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Relu1(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu1(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Relu6(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu6(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, filter_width, filter_height,
+ output_activation_min, output_activation_max, output_data,
+ DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, filter_width, filter_height,
+ output_activation_min, output_activation_max, output_data,
+ DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height,
+ pad_width, pad_height, filter_width, filter_height,
+ output_activation_min, output_activation_max, output_data,
+ DimsToShape(output_dims));
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), beta, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
+ input_beta_left_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
+ input_left_shift, reverse_scaling_divisor,
+ reverse_scaling_right_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+ int input_left_shift, int16* output_data,
+ const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
+ DimsToShape(output_dims));
}
} // namespace reference_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 483bd37ef9..89ec0eb266 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -914,9 +914,9 @@ void GlobalBatchNormalization(const float* input_data,
}
}
-inline void Relu(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float lower = 0;
@@ -925,9 +925,10 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Relu1(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu1(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float upper = 1;
@@ -937,9 +938,10 @@ inline void Relu1(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Relu6(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu6(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float upper = 6;
@@ -2257,18 +2259,21 @@ inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
-inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
+inline void AveragePool(const float* input_data,
+ const RuntimeShape& input_shape, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height,
float output_activation_min,
float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ const RuntimeShape& output_shape) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -2292,12 +2297,12 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
total +=
- input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
filter_count++;
}
}
const float average = total / filter_count;
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
ActivationFunctionWithMinMax(average, output_activation_min,
output_activation_max);
}
@@ -2306,42 +2311,22 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
+inline void AveragePool(const uint8* input_data,
+ const RuntimeShape& input_shape, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height,
int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -2364,14 +2349,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
++filter_x) {
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
- acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ acc +=
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
filter_count++;
}
}
acc = (acc + filter_count / 2) / filter_count;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
static_cast<uint8>(acc);
}
}
@@ -2379,50 +2365,19 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+inline void L2Pool(const float* input_data, const RuntimeShape& input_shape,
int stride_width, int stride_height, int pad_width,
int pad_height, int filter_width, int filter_height,
float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ float* output_data, const RuntimeShape& output_shape) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -2446,13 +2401,13 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
const float val =
- input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
sum_squares += val * val;
filter_count++;
}
}
const float l2pool_result = std::sqrt(sum_squares / filter_count);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
ActivationFunctionWithMinMax(l2pool_result, output_activation_min,
output_activation_max);
}
@@ -2461,40 +2416,19 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+inline void MaxPool(const float* input_data, const RuntimeShape& input_shape,
int stride_width, int stride_height, int pad_width,
int pad_height, int filter_width, int filter_height,
float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ float* output_data, const RuntimeShape& output_shape) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -2518,10 +2452,10 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
const int in_y = in_y_origin + filter_y;
max = std::max(
max,
- input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)]);
}
}
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
ActivationFunctionWithMinMax(max, output_activation_min,
output_activation_max);
}
@@ -2530,42 +2464,22 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+inline void MaxPool(const uint8* input_data, const RuntimeShape& input_shape,
int stride_width, int stride_height, int pad_width,
int pad_height, int filter_width, int filter_height,
int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
TFLITE_DCHECK_GE(output_activation_min, 0);
TFLITE_DCHECK_LE(output_activation_max, 255);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -2589,12 +2503,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
const int in_y = in_y_origin + filter_y;
max = std::max(
max,
- input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)]);
}
}
max = std::max<uint8>(max, output_activation_min);
max = std::min<uint8>(max, output_activation_max);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
static_cast<uint8>(max);
}
}
@@ -2602,38 +2516,6 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
inline void LocalResponseNormalization(const float* input_data,
const Dims<4>& input_dims, int range,
float bias, float alpha, float beta,
@@ -2657,11 +2539,14 @@ inline void LocalResponseNormalization(const float* input_data,
}
}
-inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
float beta, float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const RuntimeShape& output_shape) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
// Find max element value which we'll use to ensure numerical stability
@@ -2686,10 +2571,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_beta_multiplier, int32 input_beta_left_shift,
int diff_min, uint8* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -2702,8 +2587,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
uint8 max_in_row = 0;
@@ -2764,10 +2652,13 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
// Find max element value which we'll use to ensure numerical stability
@@ -2907,11 +2798,11 @@ log_x_for_x_greater_than_or_equal_to_1(
input_val);
}
-inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_multiplier, int32 input_left_shift,
int32 reverse_scaling_divisor,
int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -2925,8 +2816,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
uint8 max_in_row = 0;
@@ -2990,9 +2884,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
float val = input_data[i];
@@ -3001,11 +2895,11 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ uint8* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
const uint8 input_val_u8 = input_data[i];
@@ -3039,9 +2933,9 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
- int16* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -3057,9 +2951,9 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
float val = input_data[i];
@@ -3068,12 +2962,12 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
const int32 output_zero_point = 128;
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
const uint8 input_val_u8 = input_data[i];
@@ -3108,15 +3002,15 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
int input_left_shift, int16* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
TFLITE_DCHECK_LE(input_left_shift, 1);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
// F0 uses 0 integer bits, range [-1, 1].
// This is the return type of math functions such as tanh, logistic,
diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
index d781a7b642..a7dad3c14e 100644
--- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
@@ -32,19 +32,21 @@ namespace tflite {
namespace {
void RunSoftmaxFloatReference(const uint8* input_data,
- const Dims<4>& dims_common, int32 input_offset,
- const double input_scale, int stride, float beta,
+ const RuntimeShape& shape_common,
+ int32 input_offset, const double input_scale,
+ int stride, float beta,
uint8* reference_output_data) {
- const int ref_buffer_size = RequiredBufferSizeForDims(dims_common);
+ const int ref_buffer_size = shape_common.FlatSize();
std::vector<float> reference_dequant_data(ref_buffer_size);
std::vector<float> reference_output_float_data(ref_buffer_size);
// Reference data generated via Dequant of input into float, and then applying
// float Softmax.
- reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale,
- reference_dequant_data.data(), dims_common);
- optimized_ops::Softmax(reference_dequant_data.data(), dims_common, beta,
- reference_output_float_data.data(), dims_common);
+ reference_ops::Dequantize(
+ input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
+ reference_dequant_data.data(), ToRuntimeDims(shape_common));
+ optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta,
+ reference_output_float_data.data(), shape_common);
// Work with quantized scaling for Softmax, under which 256 represents 1, but
// we limit this to 255.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -55,9 +57,9 @@ void RunSoftmaxFloatReference(const uint8* input_data,
}
void CheckOutputData(const uint8* test_output, const uint8* reference_output,
- const Dims<4>& dims_common, const string& check_label,
- bool be_exacting) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ const RuntimeShape& shape_common,
+ const string& check_label, bool be_exacting) {
+ const int buffer_size = shape_common.FlatSize();
// While calculating some metrics in floating point, we work with quantized
// scaling.
std::vector<int> diff(buffer_size);
@@ -91,15 +93,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output,
// Runs the Softmax and compares against the float reference implementation and
// the quantized reference implementation.
-void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
- int32 input_offset, const double input_scale, int stride,
- float beta) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+void RunOneSoftmaxTest(const uint8* input_data,
+ const RuntimeShape& shape_common, int32 input_offset,
+ const double input_scale, int stride, float beta) {
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> optimized_softmax_output(buffer_size);
std::vector<uint8> reference_float_softmax_output(buffer_size);
std::vector<uint8> reference_quant_softmax_output(buffer_size);
- RunSoftmaxFloatReference(input_data, dims_common, input_offset, input_scale,
+ RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale,
stride, beta, reference_float_softmax_output.data());
int32 input_beta_multiplier;
@@ -113,21 +115,21 @@ void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::Softmax(input_data, dims_common, input_beta_multiplier,
+ optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier,
input_beta_left_shift, diff_min,
- optimized_softmax_output.data(), dims_common);
- reference_ops::Softmax(input_data, dims_common, input_beta_multiplier,
+ optimized_softmax_output.data(), shape_common);
+ reference_ops::Softmax(input_data, shape_common, input_beta_multiplier,
input_beta_left_shift, diff_min,
- reference_quant_softmax_output.data(), dims_common);
+ reference_quant_softmax_output.data(), shape_common);
CheckOutputData(optimized_softmax_output.data(),
- reference_float_softmax_output.data(), dims_common,
+ reference_float_softmax_output.data(), shape_common,
"Optimized vs float reference", false);
CheckOutputData(optimized_softmax_output.data(),
- reference_quant_softmax_output.data(), dims_common,
+ reference_quant_softmax_output.data(), shape_common,
"Optimized vs quant reference", true);
CheckOutputData(reference_quant_softmax_output.data(),
- reference_float_softmax_output.data(), dims_common,
+ reference_float_softmax_output.data(), shape_common,
"Quant reference vs float reference", false);
}
@@ -150,13 +152,13 @@ bool TryOneUniformSoftmax() {
const int32 input_offset = UniformRandomInt(-256, 0);
const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10);
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandom(&input_data);
- RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale,
+ RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale,
stride, beta);
return true;
}
@@ -188,14 +190,14 @@ bool TryOneSkyscraperSoftmax(bool small_depth) {
const int middle_min = UniformRandomInt(0, 255);
const int sides_max = UniformRandomInt(0, middle_min);
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
sides_max);
- RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale,
+ RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale,
stride, beta);
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 64f4881a46..707d2d261a 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -296,6 +296,50 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
// Flat size calculation, checking that dimensions match with one or more other
// arrays.
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return shape.FlatSize();
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1);
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1, check_shape_2);
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2,
+ const RuntimeShape& check_shape_3) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
+}
+
+// Flat size calculation, checking that dimensions match with one or more other
+// arrays.
template <int N>
inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
for (int i = 0; i < N; ++i) {
@@ -320,7 +364,7 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
for (int i = 0; i < N; ++i) {
TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
}
- return FlatSize(dims, check_dims_1, check_dims_2);
+ return MatchingFlatSize(dims, check_dims_1, check_dims_2);
}
template <int N>
@@ -331,7 +375,7 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
for (int i = 0; i < N; ++i) {
TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
}
- return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
+ return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
}
// Data is required to be contiguous, and so many operators can use either the
diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
index 62820a2f51..9a8d35e82c 100644
--- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
@@ -90,10 +90,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::LogSoftmax(input_buffer, input_dims,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ tflite::reference_ops::LogSoftmax(input_buffer, input_shape,
+ output_buffer.get(), input_shape);
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index b69a221447..9e01b73c49 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -39,6 +39,14 @@ constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
+
+ // Parameters used in the quantized paths where the output is 8bit
+ int32 output_activation_min;
+ int32 output_activation_max;
+
+ // Parameters used in all quantized paths
+ int32_t output_multiplier;
+ int output_shift;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -52,6 +60,7 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
@@ -62,7 +71,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
- output->type = input2->type;
data->requires_broadcast = !HaveSameShapes(input1, input2);
@@ -74,6 +82,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size = TfLiteIntArrayCopy(input1->dims);
}
+ if (output->type == kTfLiteUInt8) {
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ double real_multiplier =
+ input1->params.scale * input2->params.scale / output->params.scale;
+ QuantizeMultiplierSmallerThanOneExp(
+ real_multiplier, &data->output_multiplier, &data->output_shift);
+ data->output_shift *= -1;
+ }
+
return context->ResizeTensor(context, output, output_size);
}
@@ -107,42 +129,60 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
-void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- auto input1_offset = -input1->params.zero_point;
- auto input2_offset = -input2->params.zero_point;
- auto output_offset = output->params.zero_point;
-
- int32_t output_multiplier;
- int output_shift;
-
- double real_multiplier =
- input1->params.scale * input2->params.scale / output->params.scale;
- QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
- &output_shift);
- output_shift *= -1;
-
- int32 output_activation_min, output_activation_max;
- CalculateActivationRangeUint8(params->activation, output,
- &output_activation_min, &output_activation_max);
-
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, output_offset, \
- output_multiplier, output_shift, output_activation_min, \
- output_activation_max, GetTensorData<uint8_t>(output), \
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteMulParams* params, const OpData* data,
+ const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+ if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
+ output->type == kTfLiteUInt8) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ -input1->params.zero_point, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), -input2->params.zero_point, \
+ output->params.zero_point, data->output_multiplier, \
+ data->output_shift, data->output_activation_min, \
+ data->output_activation_max, GetTensorData<uint8_t>(output), \
GetTensorDims(output));
- // The quantized version of Mul doesn't support activations, so we
- // always use BroadcastMul.
- if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops, BroadcastMul);
+ // The quantized version of Mul doesn't support activations, so we
+ // always use BroadcastMul.
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, BroadcastMul);
+ } else {
+ TF_LITE_MUL(optimized_ops, BroadcastMul);
+ }
+#undef TF_LITE_MUL
+ } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
+ output->type == kTfLiteInt16) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
+ GetTensorData<int16_t>(input2), GetTensorDims(input2), \
+ GetTensorData<int16_t>(output), GetTensorDims(output));
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, Mul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
+#undef TF_LITE_MUL
+ } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
+ output->type == kTfLiteUInt8) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
+ GetTensorData<int16_t>(input2), GetTensorDims(input2), \
+ output->params.zero_point, data->output_activation_min, \
+ data->output_activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output));
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, Mul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
+#undef TF_LITE_MUL
} else {
- TF_LITE_MUL(optimized_ops, BroadcastMul);
+ context->ReportError(
+ context, "Unsupported combination of input and output types in Mul.");
+ return kTfLiteError;
}
-#undef TF_LITE_MUL
+ return kTfLiteOk;
}
template <KernelType kernel_type>
@@ -156,12 +196,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteFloat32) {
EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
- } else if (output->type == kTfLiteUInt8) {
- EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
- output);
+ } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_OK(
+ context, EvalQuantized<kernel_type>(context, node, params, data, input1,
+ input2, output));
} else {
context->ReportError(
- context, "Mul only supports FLOAT32 and quantized UINT8 now, got %d.",
+ context,
+ "Mul only supports FLOAT32 and quantized UINT8 and INT16 now, got %d.",
output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc
index f1a30f8263..43d56e50d2 100644
--- a/tensorflow/contrib/lite/kernels/mul_test.cc
+++ b/tensorflow/contrib/lite/kernels/mul_test.cc
@@ -58,6 +58,9 @@ class FloatMulOpModel : public BaseMulOpModel {
const float kQuantizedStep = 2.0 / 255.0;
const float kQuantizedTolerance =
2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep;
+const float kQuantizedStepInt16 = 2.0 / 32767.0;
+const float kQuantizedToleranceInt16 =
+ 2.0 * kQuantizedStepInt16 + kQuantizedStepInt16 * kQuantizedStepInt16;
class QuantizedMulOpModel : public BaseMulOpModel {
public:
@@ -67,6 +70,11 @@ class QuantizedMulOpModel : public BaseMulOpModel {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_), GetZeroPoint(output_));
}
+
+ std::vector<float> GetDequantizedOutputInt16() {
+ return Dequantize<int16_t>(ExtractVector<int16_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
};
TEST(FloatMulOpTest, NoActivation) {
@@ -138,6 +146,38 @@ TEST(QuantizedMulOpTest, NoActivation) {
kQuantizedTolerance)));
}
+TEST(QuantizedMulOpTest, NoActivationInt16) {
+ const float kMin = -1.f;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {}, kMin, kMax},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<int16_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
+ m.QuantizeAndPopulate<int16_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutputInt16(),
+ ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
+ kQuantizedToleranceInt16)));
+}
+
+TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) {
+ const float kMinInt16 = -1.f;
+ const float kMaxInt16 = 32767.f / 32768.f;
+ const float kMinUint8 = -1.f;
+ const float kMaxUint8 = 127.f / 128.f;
+ QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
+ {TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
+ {TensorType_UINT8, {}, kMinUint8, kMaxUint8},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<int16_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
+ m.QuantizeAndPopulate<int16_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
+ kQuantizedTolerance)));
+}
+
// for quantized Mul, the error shouldn't exceed 2*step
float GetTolerance(int min, int max) {
float kQuantizedStep = (max - min) / 255.0;
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 311e9b8399..41771e60bc 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -126,12 +126,13 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node,
float activation_min, activation_max;
CalculateActivationRangeFloat(params->activation, &activation_min,
&activation_max);
-#define TF_LITE_AVERAGE_POOL(type) \
- type::AveragePool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_AVERAGE_POOL(type) \
+ type::AveragePool(GetTensorData<float>(input), GetTensorShape(input), \
+ params->stride_width, params->stride_height, \
+ data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, \
+ activation_min, activation_max, \
+ GetTensorData<float>(output), GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_AVERAGE_POOL(reference_ops);
} else {
@@ -148,13 +149,13 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
&activation_max);
-#define TF_LITE_AVERAGE_POOL(type) \
- type::AveragePool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, \
- activation_min, activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output))
+#define TF_LITE_AVERAGE_POOL(type) \
+ type::AveragePool(GetTensorData<uint8_t>(input), GetTensorShape(input), \
+ params->stride_width, params->stride_height, \
+ data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, \
+ activation_min, activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_AVERAGE_POOL(reference_ops);
} else {
@@ -170,12 +171,13 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
float activation_min, activation_max;
CalculateActivationRangeFloat(params->activation, &activation_min,
&activation_max);
-#define TF_LITE_MAX_POOL(type) \
- type::MaxPool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_MAX_POOL(type) \
+ type::MaxPool(GetTensorData<float>(input), GetTensorShape(input), \
+ params->stride_width, params->stride_height, \
+ data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<float>(output), \
+ GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_MAX_POOL(reference_ops);
} else {
@@ -193,12 +195,12 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRangeUint8(params->activation, output, &activation_min,
&activation_max);
#define TF_LITE_MAX_POOL(type) \
- type::MaxPool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
+ type::MaxPool(GetTensorData<uint8_t>(input), GetTensorShape(input), \
params->stride_width, params->stride_height, \
data->padding.width, data->padding.height, \
params->filter_width, params->filter_height, activation_min, \
activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output))
+ GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_MAX_POOL(reference_ops);
} else {
@@ -214,12 +216,13 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node,
float activation_min, activation_max;
CalculateActivationRangeFloat(params->activation, &activation_min,
&activation_max);
-#define TF_LITE_L2_POOL(type) \
- type::L2Pool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_L2_POOL(type) \
+ type::L2Pool(GetTensorData<float>(input), GetTensorShape(input), \
+ params->stride_width, params->stride_height, \
+ data->padding.width, data->padding.height, \
+ params->filter_width, params->filter_height, activation_min, \
+ activation_max, GetTensorData<float>(output), \
+ GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_L2_POOL(reference_ops);
} else {
diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc
index 6c5338ff0f..727822f6be 100644
--- a/tensorflow/contrib/lite/kernels/softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/softmax_test.cc
@@ -92,10 +92,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
+ output_buffer.get(), input_shape);
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
@@ -120,10 +119,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
+ output_buffer.get(), input_shape);
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index 43387df9ce..b144486041 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -76,8 +76,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
auto input_type = op_context.input->type;
- TF_LITE_ENSURE(context,
- input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+ TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
+ input_type == kTfLiteUInt8 ||
+ input_type == kTfLiteInt16);
for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type;
}
@@ -137,9 +138,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_SPLIT(uint8_t);
break;
}
+ case kTfLiteInt16: {
+ TF_LITE_SPLIT(int16_t);
+ break;
+ }
default:
context->ReportError(
- context, "Only float32 and uint8 are currently supported, got %d.",
+ context,
+ "Only float32, uint8 and int16 are currently supported, got %d.",
op_context.input->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index 878bda36ef..6877fb237c 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -227,7 +227,7 @@ NodeProperties GetPropertiesForArray(const Model& model,
NodeProperties GetPropertiesForOperator(const Operator& op) {
NodeProperties node_properties;
- if (op.type == OperatorType::kTensorFlowUnsupported) {
+ if (op.type == OperatorType::kUnsupported) {
node_properties.label =
static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
} else {
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index afc6d5df20..6b78f1c05e 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -735,8 +735,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
GraphDef* tensorflow_graph) {
string softmax_input;
Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
- if (providing_op != nullptr &&
- providing_op->type == OperatorType::kTensorFlowReshape) {
+ if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
softmax_input = src_op.inputs[0];
} else {
// Insert a reshape operator that reduces the dimensions down to the 2 that
@@ -776,8 +775,7 @@ void ConvertLogSoftmaxOperator(const Model& model,
GraphDef* tensorflow_graph) {
string softmax_input;
Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
- if (providing_op != nullptr &&
- providing_op->type == OperatorType::kTensorFlowReshape) {
+ if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
softmax_input = src_op.inputs[0];
} else {
// Insert a reshape operator that reduces the dimensions down to the 2 that
@@ -1855,24 +1853,24 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertConcatenationOperator(
model, static_cast<const ConcatenationOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowReshape) {
+ } else if (src_op.type == OperatorType::kReshape) {
ConvertTensorFlowReshapeOperator(
model, static_cast<const TensorFlowReshapeOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kL2Pool) {
ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowSquare) {
+ } else if (src_op.type == OperatorType::kSquare) {
ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowSqrt) {
+ } else if (src_op.type == OperatorType::kSqrt) {
ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowRsqrt) {
+ } else if (src_op.type == OperatorType::kRsqrt) {
ConvertRsqrtOperator(model,
static_cast<const TensorFlowRsqrtOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowSplit) {
+ } else if (src_op.type == OperatorType::kSplit) {
ConvertSplitOperator(model,
static_cast<const TensorFlowSplitOperator&>(src_op),
tensorflow_graph);
@@ -1916,11 +1914,11 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kSub) {
ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowMinimum) {
+ } else if (src_op.type == OperatorType::kMinimum) {
ConvertTensorFlowMinimumOperator(
model, static_cast<const TensorFlowMinimumOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowMaximum) {
+ } else if (src_op.type == OperatorType::kMaximum) {
ConvertTensorFlowMaximumOperator(
model, static_cast<const TensorFlowMaximumOperator&>(src_op),
tensorflow_graph);
@@ -1939,7 +1937,7 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kTranspose) {
ConvertTransposeOperator(
model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowShape) {
+ } else if (src_op.type == OperatorType::kShape) {
ConvertTensorFlowShapeOperator(
model, static_cast<const TensorFlowShapeOperator&>(src_op),
tensorflow_graph);
@@ -1970,22 +1968,22 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertRandomUniformOperator(
model, static_cast<const RandomUniformOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowEqual) {
+ } else if (src_op.type == OperatorType::kEqual) {
ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowNotEqual) {
+ } else if (src_op.type == OperatorType::kNotEqual) {
ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowGreater) {
+ } else if (src_op.type == OperatorType::kGreater) {
ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) {
+ } else if (src_op.type == OperatorType::kGreaterEqual) {
ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowLess) {
+ } else if (src_op.type == OperatorType::kLess) {
ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowLessEqual) {
+ } else if (src_op.type == OperatorType::kLessEqual) {
ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
} else if (src_op.type == OperatorType::kSelect) {
ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowTile) {
+ } else if (src_op.type == OperatorType::kTile) {
ConvertTileOperator(model,
static_cast<const TensorFlowTileOperator&>(src_op),
tensorflow_graph);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
index 5ab399206b..b689be0792 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
@@ -23,7 +23,7 @@ namespace toco {
bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
auto tile_it = model->operators.begin() + op_index;
- if (tile_it->get()->type != OperatorType::kTensorFlowTile) {
+ if (tile_it->get()->type != OperatorType::kTile) {
return false;
}
auto* tile_op = static_cast<TransposeOperator*>(tile_it->get());
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
index 498c864bde..2c7ffe4884 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -111,7 +111,7 @@ bool DequantizeArray(const string& array_name,
auto* op_outputting_array = GetOpWithOutput(*model, array_name);
if (op_outputting_array) {
- if (op_outputting_array->type == OperatorType::kTensorFlowReshape) {
+ if (op_outputting_array->type == OperatorType::kReshape) {
return true;
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index bda6dce22b..82a4308ecb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -353,7 +353,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
changed = HardcodeMinMaxForConcatenation(model, op);
break;
- case OperatorType::kTensorFlowSplit:
+ case OperatorType::kSplit:
changed = HardcodeMinMaxForSplit(model, op);
break;
@@ -366,7 +366,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
case OperatorType::kSlice:
case OperatorType::kStridedSlice:
case OperatorType::kSqueeze:
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
case OperatorType::kPad:
case OperatorType::kGather:
case OperatorType::kTranspose:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
index 419a0776a6..b78efd7fc3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -44,10 +44,9 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
const auto* div_or_mul_op = div_it->get();
OperatorType expected_op_type_producing_div_or_mul_input;
if (div_or_mul_op->type == OperatorType::kDiv) {
- expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt;
+ expected_op_type_producing_div_or_mul_input = OperatorType::kSqrt;
} else if (div_or_mul_op->type == OperatorType::kMul) {
- expected_op_type_producing_div_or_mul_input =
- OperatorType::kTensorFlowRsqrt;
+ expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt;
} else {
return false;
}
@@ -75,8 +74,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
Operator* add_op = nullptr;
Operator* op_producing_add_input = nullptr;
if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd ||
- op_producing_sqrt_or_rsqrt_input->type ==
- OperatorType::kTensorFlowMaximum) {
+ op_producing_sqrt_or_rsqrt_input->type == OperatorType::kMaximum) {
add_op = op_producing_sqrt_or_rsqrt_input;
bool add_can_be_removed = false;
CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2);
@@ -113,7 +111,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
Operator* sum_op =
add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input;
- if (sum_op->type != OperatorType::kTensorFlowSum) {
+ if (sum_op->type != OperatorType::kSum) {
AddMessageF(
"Giving up trying to identify L2Normalization subgraph: "
"expected Sum op, got %s",
@@ -122,7 +120,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
}
Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
- if (square_op->type != OperatorType::kTensorFlowSquare) {
+ if (square_op->type != OperatorType::kSquare) {
AddMessageF(
"Giving up trying to identify L2Normalization subgraph: "
"expected Square op, got %s",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
index f69400b82f..705e73779b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -41,7 +41,7 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
const auto sqrt_it = model->operators.begin() + op_index;
const auto* sqrt_op = sqrt_it->get();
- if (sqrt_op->type != OperatorType::kTensorFlowSqrt) {
+ if (sqrt_op->type != OperatorType::kSqrt) {
return false;
}
@@ -72,7 +72,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
square_op = GetOpWithOutput(*model, avpool_op->inputs[0]);
CHECK_EQ(square_op->inputs.size(), 1);
- if (square_op->type != OperatorType::kTensorFlowSquare) {
+ if (square_op->type != OperatorType::kSquare) {
AddMessageF(
"Giving up trying to identify L2Pool subgraph: "
"expected Square op, got %s",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
index e9842524c8..910e38a6ba 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -266,26 +266,26 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
// State remember "information" activation function
Operator* fc_output_split;
- if (!MatchOperatorInputs(*state_info_tanh, *model,
- OperatorType::kTensorFlowSplit, &fc_output_split)) {
+ if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit,
+ &fc_output_split)) {
return false;
}
// State remember gate activation function
Operator* tmp;
- if (!MatchOperatorInputs(*state_remember_sig, *model,
- OperatorType::kTensorFlowSplit, &tmp) ||
+ if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit,
+ &tmp) ||
(tmp != fc_output_split)) {
return false;
}
// State forget gate activation function
- if (!MatchOperatorInputs(*state_forget_sig, *model,
- OperatorType::kTensorFlowSplit, &tmp) ||
+ if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit,
+ &tmp) ||
(tmp != fc_output_split)) {
return false;
}
// Fully connected output activation function
- if (!MatchOperatorInputs(*fc_output_sig, *model,
- OperatorType::kTensorFlowSplit, &tmp) ||
+ if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit,
+ &tmp) ||
(tmp != fc_output_split)) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
index bddb563206..94820a0166 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -60,24 +60,22 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
// Follow sequences of min+max and max+min. First get the leading op.
const auto op_it = model->operators.begin() + op_index;
const auto* op_0 = op_it->get();
- if (op_0->type != OperatorType::kTensorFlowMinimum &&
- op_0->type != OperatorType::kTensorFlowMaximum) {
+ if (op_0->type != OperatorType::kMinimum &&
+ op_0->type != OperatorType::kMaximum) {
return false;
}
// Get the paired op and ensure it's the counter to the first.
const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]);
if (!op_1 ||
- (op_1->type != OperatorType::kTensorFlowMinimum &&
- op_1->type != OperatorType::kTensorFlowMaximum) ||
+ (op_1->type != OperatorType::kMinimum &&
+ op_1->type != OperatorType::kMaximum) ||
op_0->type == op_1->type) {
return false;
}
- const auto* min_op =
- op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1;
- const auto* max_op =
- op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1;
+ const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1;
+ const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1;
if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
index 5065004093..95bc7f7d4b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
@@ -106,7 +106,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
std::size_t op_index) {
auto it = model->operators.begin() + op_index;
auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
- it->get(), OperatorType::kTensorFlowReshape);
+ it->get(), OperatorType::kReshape);
if (reshape_op == nullptr) {
return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 92d283ca2c..27a1049eaf 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -56,22 +56,22 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// These operators unconditionally produce float outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
break;
- case OperatorType::kTensorFlowLess:
- case OperatorType::kTensorFlowLessEqual:
- case OperatorType::kTensorFlowGreater:
- case OperatorType::kTensorFlowGreaterEqual:
- case OperatorType::kTensorFlowEqual:
- case OperatorType::kTensorFlowNotEqual:
+ case OperatorType::kLess:
+ case OperatorType::kLessEqual:
+ case OperatorType::kGreater:
+ case OperatorType::kGreaterEqual:
+ case OperatorType::kEqual:
+ case OperatorType::kNotEqual:
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
break;
case OperatorType::kRank:
- case OperatorType::kTensorFlowShape:
+ case OperatorType::kShape:
// These operators only produce int32 outputs.
SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
break;
- case OperatorType::kTensorFlowSplit:
- case OperatorType::kTensorFlowConcat:
+ case OperatorType::kSplit:
+ case OperatorType::kConcat:
case OperatorType::kFill: {
// These operators produce an output with the same type as their 2nd input
CHECK_GE(op->inputs.size(), 2);
@@ -135,7 +135,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32;
break;
}
- case OperatorType::kTensorFlowUnsupported: {
+ case OperatorType::kUnsupported: {
auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
// Some output tensors from the op could be eliminated by optimization.
// This can make unsupported_op->output_data_types have more elements than
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 77c0886811..e25125b429 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -90,8 +90,8 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array,
bool DoesOpBlockBackwardPropagation(const Operator& op) {
switch (op.type) {
case OperatorType::kConcatenation:
- case OperatorType::kTensorFlowConcat:
- case OperatorType::kTensorFlowConcatV2:
+ case OperatorType::kConcat:
+ case OperatorType::kConcatV2:
// Concat shouldn't block propagation, but we do expect that all inputs
// have the same range.
return false;
@@ -100,10 +100,10 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) {
// FakeQuant so make sure we move across them.
case OperatorType::kGather:
// Gathers need their parameters changed to the appropriate data type.
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
case OperatorType::kTranspose:
case OperatorType::kSelect:
- case OperatorType::kTensorFlowTile:
+ case OperatorType::kTile:
// Reshapes and transposes don't change values.
return false;
default:
@@ -121,11 +121,11 @@ bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) {
// Ignore gather indices.
return input_index != 0;
break;
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
case OperatorType::kTranspose:
// Ignore reshape/transpose shapes/dimensions.
return input_index != 0;
- case OperatorType::kTensorFlowTile:
+ case OperatorType::kTile:
// Ignore tile multiples.
return input_index != 0;
default:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index beda187f13..c61da203c6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -572,11 +572,11 @@ void ProcessAddNOperator(Model* model, Operator* op) {
bool KeepDims(const Operator& op) {
switch (op.type) {
- case OperatorType::kTensorFlowMin:
+ case OperatorType::kMin: // Reduction Min
return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
- case OperatorType::kTensorFlowMax:
+ case OperatorType::kMax: // Reduction Max
return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
- case OperatorType::kTensorFlowSum:
+ case OperatorType::kSum:
return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
case OperatorType::kMean:
return static_cast<const MeanOperator&>(op).keep_dims;
@@ -1577,14 +1577,14 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kLogistic:
case OperatorType::kTanh:
case OperatorType::kLocalResponseNormalization:
- case OperatorType::kTensorFlowIdentity:
+ case OperatorType::kIdentity:
case OperatorType::kFakeQuant:
case OperatorType::kNeg:
- case OperatorType::kTensorFlowRsqrt:
- case OperatorType::kTensorFlowSqrt:
- case OperatorType::kTensorFlowSquare:
- case OperatorType::kTensorFlowAll:
- case OperatorType::kTensorFlowAssert:
+ case OperatorType::kRsqrt:
+ case OperatorType::kSqrt:
+ case OperatorType::kSquare:
+ case OperatorType::kAll:
+ case OperatorType::kAssert:
case OperatorType::kCast:
case OperatorType::kFloor:
case OperatorType::kExp:
@@ -1603,14 +1603,14 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kDiv:
case OperatorType::kFloorDiv:
case OperatorType::kFloorMod:
- case OperatorType::kTensorFlowLess:
- case OperatorType::kTensorFlowLessEqual:
- case OperatorType::kTensorFlowGreater:
- case OperatorType::kTensorFlowMaximum:
- case OperatorType::kTensorFlowMinimum:
- case OperatorType::kTensorFlowGreaterEqual:
- case OperatorType::kTensorFlowEqual:
- case OperatorType::kTensorFlowNotEqual:
+ case OperatorType::kLess:
+ case OperatorType::kLessEqual:
+ case OperatorType::kGreater:
+ case OperatorType::kMaximum: // Element-wise Maximum
+ case OperatorType::kMinimum: // Element-wise Minimum
+ case OperatorType::kGreaterEqual:
+ case OperatorType::kEqual:
+ case OperatorType::kNotEqual:
ProcessSimpleBinaryOperator(model, op);
break;
case OperatorType::kAddN:
@@ -1643,7 +1643,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessFullyConnectedOperator(model,
static_cast<FullyConnectedOperator*>(op));
break;
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
ProcessTensorFlowReshapeOperator(
model, static_cast<TensorFlowReshapeOperator*>(op));
break;
@@ -1656,9 +1656,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kL2Pool:
ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
break;
- case OperatorType::kTensorFlowMin:
- case OperatorType::kTensorFlowMax:
- case OperatorType::kTensorFlowSum:
+ case OperatorType::kMin: // Reduction Min
+ case OperatorType::kMax: // Reduction Max
+ case OperatorType::kSum:
case OperatorType::kMean:
ProcessTensorFlowReductionOperator(model, op);
break;
@@ -1669,26 +1669,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
break;
- case OperatorType::kTensorFlowSwitch:
+ case OperatorType::kSwitch:
// We can't know the sizes of the outputs until we have resolved the
// predicate, and once we have resolved the predicate, the whole
// Switch node will get resolved away.
// See ResolveTensorFlowSwitch.
break;
- case OperatorType::kTensorFlowMerge:
+ case OperatorType::kMerge:
// No need to bother resolving TensorFlow Merge ops: other graph
// transformations will remove them anyway.
// See ResolveTensorFlowMerge.
break;
- case OperatorType::kTensorFlowSplit:
+ case OperatorType::kSplit:
ProcessTensorFlowSplitOperator(model,
static_cast<TensorFlowSplitOperator*>(op));
break;
case OperatorType::kSqueeze:
ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
break;
- case OperatorType::kTensorFlowConcat:
- case OperatorType::kTensorFlowConcatV2:
+ case OperatorType::kConcat:
+ case OperatorType::kConcatV2:
// Unimplemented, hopefully another graph transformation will
// drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
// will resolve this node to a DepthConcatenation, or else we have
@@ -1704,7 +1704,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kRank:
ProcessRankOperator(model, static_cast<RankOperator*>(op));
break;
- case OperatorType::kTensorFlowShape:
+ case OperatorType::kShape:
ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
break;
case OperatorType::kStack:
@@ -1725,7 +1725,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
break;
case OperatorType::kBatchMatMul:
- case OperatorType::kTensorFlowMatMul:
+ case OperatorType::kMatMul:
// MatMul operators are converted to FullyConnected, after which their
// shapes are propagated.
break;
@@ -1750,7 +1750,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kArgMax:
ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
break;
- case OperatorType::kTensorFlowUnsupported:
+ case OperatorType::kUnsupported:
break;
case OperatorType::kSvdf:
ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
@@ -1772,7 +1772,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessSparseToDenseOperator(model,
static_cast<SparseToDenseOperator*>(op));
break;
- case OperatorType::kTensorFlowTile:
+ case OperatorType::kTile:
ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
break;
default:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index eca2c701f8..1c61b8cb36 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -33,7 +33,7 @@ namespace {
bool SupportsQuantization(const Operator& op) {
auto type = op.type;
- if (type == OperatorType::kTensorFlowUnsupported) {
+ if (type == OperatorType::kUnsupported) {
auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
return unsupported->quantized;
}
@@ -42,15 +42,13 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kConcatenation ||
type == OperatorType::kL2Normalization || type == OperatorType::kAdd ||
type == OperatorType::kAveragePool || type == OperatorType::kMaxPool ||
- type == OperatorType::kTensorFlowMinimum ||
- type == OperatorType::kTensorFlowMaximum ||
+ type == OperatorType::kMinimum || type == OperatorType::kMaximum ||
type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
type == OperatorType::kResizeBilinear ||
- type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub ||
+ type == OperatorType::kSplit || type == OperatorType::kSub ||
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
- type == OperatorType::kPadV2 ||
- type == OperatorType::kTensorFlowReshape ||
+ type == OperatorType::kPadV2 || type == OperatorType::kReshape ||
type == OperatorType::kTanh || type == OperatorType::kMul ||
type == OperatorType::kSpaceToBatchND ||
type == OperatorType::kSpaceToDepth ||
@@ -58,11 +56,10 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kDepthToSpace ||
type == OperatorType::kLstmCell || type == OperatorType::kGather ||
type == OperatorType::kTranspose || type == OperatorType::kMean ||
- type == OperatorType::kTensorFlowGreater ||
- type == OperatorType::kTensorFlowGreaterEqual ||
- type == OperatorType::kTensorFlowLess ||
- type == OperatorType::kTensorFlowLessEqual ||
- type == OperatorType::kSelect || type == OperatorType::kArgMax;
+ type == OperatorType::kGreater ||
+ type == OperatorType::kGreaterEqual || type == OperatorType::kLess ||
+ type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
+ type == OperatorType::kArgMax;
}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
@@ -330,12 +327,12 @@ bool ChooseQuantizationForOperatorOutput(
}
if ((op.type == OperatorType::kDepthToSpace) ||
(op.type == OperatorType::kSpaceToDepth) ||
- (op.type == OperatorType::kTensorFlowReshape) ||
- (op.type == OperatorType::kTensorFlowSplit) ||
+ (op.type == OperatorType::kReshape) ||
+ (op.type == OperatorType::kSplit) ||
(op.type == OperatorType::kConcatenation &&
model->flags.change_concat_input_ranges())) {
int data_input_index = 0;
- if (op.type == OperatorType::kTensorFlowSplit) {
+ if (op.type == OperatorType::kSplit) {
data_input_index = 1;
}
// Copying and rearrangement ops should preserve the quantization parameters
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
index 35a0c46532..73ad326299 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
@@ -26,7 +26,7 @@ namespace toco {
bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
const auto assert_it = model->operators.begin() + op_index;
const auto* assert_op = assert_it->get();
- if (assert_op->type != OperatorType::kTensorFlowAssert) {
+ if (assert_op->type != OperatorType::kAssert) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
index 404269bbfd..7ec7752f25 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
@@ -28,7 +28,7 @@ namespace toco {
bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) {
const auto passthru_it = model->operators.begin() + op_index;
const auto* passthru_op = passthru_it->get();
- if (passthru_op->type != OperatorType::kTensorFlowIdentity) {
+ if (passthru_op->type != OperatorType::kIdentity) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index a950fe6442..9f5d8b9450 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -97,7 +97,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
"Cannot remove %s, neither its main input nor its output may be "
"discarded",
LogName(*passthru_op));
- if (passthru_op->type != OperatorType::kTensorFlowReshape &&
+ if (passthru_op->type != OperatorType::kReshape &&
model->GetArray(main_input_name).has_shape()) {
// We can't remove either array but we can remove the op. Converting it to
// a reshape gives us some hope of later on fixing that (either in the
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
index eaee1c662b..142c876b15 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
@@ -47,11 +47,11 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
double clamp_min;
double clamp_max;
switch (op_type) {
- case OperatorType::kTensorFlowMinimum:
+ case OperatorType::kMinimum: // Element-wise Minimum
clamp_min = -std::numeric_limits<double>::infinity();
clamp_max = clamp_value;
break;
- case OperatorType::kTensorFlowMaximum:
+ case OperatorType::kMaximum: // Element-wise Maximum
clamp_min = clamp_value;
clamp_max = std::numeric_limits<double>::infinity();
break;
@@ -72,8 +72,8 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) {
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
- if ((op->type != OperatorType::kTensorFlowMinimum &&
- op->type != OperatorType::kTensorFlowMaximum) ||
+ if ((op->type != OperatorType::kMinimum &&
+ op->type != OperatorType::kMaximum) ||
op->inputs.size() != 2) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
index e28d8cf01e..404f27e067 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
@@ -30,7 +30,7 @@ namespace {
bool IsReshapeTrivial(const Model& model, const Operator& op,
RemoveTrivialReshape* transformation) {
- CHECK(op.type == OperatorType::kTensorFlowReshape);
+ CHECK(op.type == OperatorType::kReshape);
// One way in which a reshape can be trivial is if its
// output shape is == its input shape
@@ -58,7 +58,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
// is only consumed by another reshape.
if (CountOpsWithInput(model, op.outputs[0]) == 1) {
const auto* next_op = GetOpWithInput(model, op.outputs[0]);
- if (next_op->type == OperatorType::kTensorFlowReshape) {
+ if (next_op->type == OperatorType::kReshape) {
transformation->AddMessageF(
"%s is trivial because its output is only consumed by another "
"Reshape op %s",
@@ -75,7 +75,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
const auto reshape_it = model->operators.begin() + op_index;
auto* reshape_op = reshape_it->get();
- if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ if (reshape_op->type != OperatorType::kReshape) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
index 1956ab2d20..dde91234a8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -48,7 +48,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
for (const auto& rnn_state : model->flags.rnn_states()) {
if (output == rnn_state.state_array()) {
CHECK(op->type == OperatorType::kFill ||
- op->type == OperatorType::kTensorFlowIdentity);
+ op->type == OperatorType::kIdentity);
found_output_as_rnn_state_array = true;
break;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
index 9f5b7920cb..550de83018 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
@@ -37,8 +37,8 @@ bool IsElementwiseOperator(OperatorType optype) {
case OperatorType::kRelu1:
case OperatorType::kRelu6:
case OperatorType::kTanh:
- case OperatorType::kTensorFlowSqrt:
- case OperatorType::kTensorFlowSquare:
+ case OperatorType::kSqrt:
+ case OperatorType::kSquare:
return true;
default:
return false;
@@ -51,7 +51,7 @@ bool IsMoveOperator(OperatorType optype) {
case OperatorType::kExpandDims:
case OperatorType::kSpaceToDepth:
case OperatorType::kSqueeze:
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
case OperatorType::kTranspose:
return true;
default:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
index 9e7fe1b1cc..c907a597cb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
@@ -123,8 +123,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
}
TensorFlowReshapeOperator* reshape_op =
- ConvertOperator<TensorFlowReshapeOperator*>(
- reshape_it->get(), OperatorType::kTensorFlowReshape);
+ ConvertOperator<TensorFlowReshapeOperator*>(reshape_it->get(),
+ OperatorType::kReshape);
if (reshape_op == nullptr) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index 6e78653fad..f7e5aa6609 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -145,17 +145,17 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
outval = floor(val0 / val1);
} else if (binary_op->type == OperatorType::kFloorMod) {
outval = val0 - (floor(val0 / val1) * val1);
- } else if (binary_op->type == OperatorType::kTensorFlowMinimum) {
+ } else if (binary_op->type == OperatorType::kMinimum) {
outval = std::min(val0, val1);
- } else if (binary_op->type == OperatorType::kTensorFlowMaximum) {
+ } else if (binary_op->type == OperatorType::kMaximum) {
outval = std::max(val0, val1);
- } else if (binary_op->type == OperatorType::kTensorFlowLess) {
+ } else if (binary_op->type == OperatorType::kLess) {
outval = val0 < val1;
- } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) {
+ } else if (binary_op->type == OperatorType::kLessEqual) {
outval = val0 <= val1;
- } else if (binary_op->type == OperatorType::kTensorFlowGreater) {
+ } else if (binary_op->type == OperatorType::kGreater) {
outval = val0 > val1;
- } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) {
+ } else if (binary_op->type == OperatorType::kGreaterEqual) {
outval = val0 >= val1;
} else {
LOG(FATAL) << "should not get here";
@@ -198,12 +198,12 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kDiv &&
binary_op->type != OperatorType::kFloorDiv &&
binary_op->type != OperatorType::kFloorMod &&
- binary_op->type != OperatorType::kTensorFlowMinimum &&
- binary_op->type != OperatorType::kTensorFlowMaximum &&
- binary_op->type != OperatorType::kTensorFlowLess &&
- binary_op->type != OperatorType::kTensorFlowLessEqual &&
- binary_op->type != OperatorType::kTensorFlowGreater &&
- binary_op->type != OperatorType::kTensorFlowGreaterEqual) {
+ binary_op->type != OperatorType::kMinimum &&
+ binary_op->type != OperatorType::kMaximum &&
+ binary_op->type != OperatorType::kLess &&
+ binary_op->type != OperatorType::kLessEqual &&
+ binary_op->type != OperatorType::kGreater &&
+ binary_op->type != OperatorType::kGreaterEqual) {
return false;
}
CHECK_EQ(binary_op->inputs.size(), 2);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
index 7e7ad383e7..41562ab393 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
@@ -25,7 +25,7 @@ namespace toco {
bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
- if (base_op->type != OperatorType::kTensorFlowReshape) {
+ if (base_op->type != OperatorType::kReshape) {
return false;
}
const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
index 9ea01acd05..8a0e3e8995 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
@@ -22,8 +22,7 @@ namespace toco {
bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
- if (!(op->type == OperatorType::kTensorFlowShape ||
- op->type == OperatorType::kRank)) {
+ if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) {
return false;
}
@@ -48,7 +47,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
// Compute the output
CHECK(!output_array.buffer);
auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>();
- if (op->type == OperatorType::kTensorFlowShape) {
+ if (op->type == OperatorType::kShape) {
// Copy the input shape into the output buffer.
output_buffer.data = input_array.shape().dims();
} else if (op->type == OperatorType::kRank) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index f6c8f79d8d..f89ef85fdb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -53,13 +53,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
case OperatorType::kCast:
case OperatorType::kLog:
case OperatorType::kNeg:
- case OperatorType::kTensorFlowRsqrt:
- case OperatorType::kTensorFlowSqrt:
- case OperatorType::kTensorFlowSquare:
- case OperatorType::kTensorFlowSum:
- case OperatorType::kTensorFlowMin:
- case OperatorType::kTensorFlowMax:
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kRsqrt:
+ case OperatorType::kSqrt:
+ case OperatorType::kSquare:
+ case OperatorType::kSum:
+ case OperatorType::kMin: // Reduction Min
+ case OperatorType::kMax: // Reduction Max
+ case OperatorType::kReshape:
case OperatorType::kRelu6:
case OperatorType::kRelu1:
case OperatorType::kRelu:
@@ -103,7 +103,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
// The min-max is only copied for ops that copy data without arithmetic.
// In future trivial transpose, etc, can be handled here.
- if (unary_op->type == OperatorType::kTensorFlowReshape) {
+ if (unary_op->type == OperatorType::kReshape) {
CopyMinMaxFromFirstInput(*unary_op, model);
}
@@ -164,10 +164,10 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
output_float_data[i] = outval;
}
- } else if (unary_op->type == OperatorType::kTensorFlowReshape) {
+ } else if (unary_op->type == OperatorType::kReshape) {
CHECK(input_buffer_size == output_buffer_size);
output_float_data = *input_float_data;
- } else if (unary_op->type == OperatorType::kTensorFlowSum) {
+ } else if (unary_op->type == OperatorType::kSum) {
CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs";
if (!IsConstantParameterArray(*model, unary_op->inputs[1])) {
AddMessageF("Axis input is non-constant");
@@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
output_float_data[i] = sum;
}
- } else if (unary_op->type == OperatorType::kTensorFlowMin) {
+ } else if (unary_op->type == OperatorType::kMin) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.
for (int i = 0; i < output_dims_count; i++) {
@@ -207,7 +207,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
min = std::min(min, (*input_float_data)[i]);
}
output_float_data[0] = min;
- } else if (unary_op->type == OperatorType::kTensorFlowMax) {
+ } else if (unary_op->type == OperatorType::kMax) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.
for (int i = 0; i < output_dims_count; i++) {
@@ -220,9 +220,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
output_float_data[0] = max;
} else if (unary_op->type == OperatorType::kNeg ||
unary_op->type == OperatorType::kLog ||
- unary_op->type == OperatorType::kTensorFlowRsqrt ||
- unary_op->type == OperatorType::kTensorFlowSqrt ||
- unary_op->type == OperatorType::kTensorFlowSquare) {
+ unary_op->type == OperatorType::kRsqrt ||
+ unary_op->type == OperatorType::kSqrt ||
+ unary_op->type == OperatorType::kSquare) {
// Element-wise ops. Should have perfectly matching sizes here.
for (int i = 0; i < output_dims_count; i++) {
CHECK_EQ(output_shape.dims(i), input_shape.dims(i));
@@ -235,11 +235,11 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
outval = -val;
} else if (unary_op->type == OperatorType::kLog) {
outval = std::log(val);
- } else if (unary_op->type == OperatorType::kTensorFlowRsqrt) {
+ } else if (unary_op->type == OperatorType::kRsqrt) {
outval = 1.0f / std::sqrt(val);
- } else if (unary_op->type == OperatorType::kTensorFlowSqrt) {
+ } else if (unary_op->type == OperatorType::kSqrt) {
outval = std::sqrt(val);
- } else if (unary_op->type == OperatorType::kTensorFlowSquare) {
+ } else if (unary_op->type == OperatorType::kSquare) {
outval = val * val;
} else {
LOG(FATAL) << "should not get here.";
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
index 2e063e3554..b615c9a545 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
@@ -28,7 +28,7 @@ namespace toco {
bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
const auto reshape_it = model->operators.begin() + op_index;
auto* reshape_op = reshape_it->get();
- if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ if (reshape_op->type != OperatorType::kReshape) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
index dd3e73635a..e8bb85704e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
@@ -36,7 +36,7 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
// If the output is consumed by a reshape op, it's a trivial squeeze.
if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) {
const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]);
- if (next_op->type == OperatorType::kTensorFlowReshape) {
+ if (next_op->type == OperatorType::kReshape) {
AddMessageF(
"%s is trivial because its output is only consumed by a "
"Reshape op",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
index 5c0c1e3478..fa5ee89933 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
@@ -28,8 +28,8 @@ namespace toco {
bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
auto concat_it = model->operators.begin() + op_index;
const auto* tf_concat_op = concat_it->get();
- if (tf_concat_op->type != OperatorType::kTensorFlowConcat &&
- tf_concat_op->type != OperatorType::kTensorFlowConcatV2) {
+ if (tf_concat_op->type != OperatorType::kConcat &&
+ tf_concat_op->type != OperatorType::kConcatV2) {
return false;
}
@@ -38,7 +38,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
// of inputs: in Concat,the axis is the first input, while in
// ConcatV2, it is the last input.
std::size_t axis_pos = 0;
- if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) {
+ if (tf_concat_op->type == OperatorType::kConcatV2) {
axis_pos = tf_concat_op->inputs.size() - 1;
}
const string axis_name = tf_concat_op->inputs[axis_pos];
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index 2a236d3f98..d496f5ae5e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -26,7 +26,7 @@ namespace toco {
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
auto matmul_it = model->operators.begin() + op_index;
- if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) {
+ if (matmul_it->get()->type != OperatorType::kMatMul) {
return false;
}
const auto* matmul_op =
@@ -97,7 +97,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
// MatMul op as a FullyConnected. However, TensorFlow skips the Reshape ops if
// the input doesn't need reshaping, so we can't just match (Reshape, MatMul)
// pairs.
- if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) {
+ if (previous_op && previous_op->type == OperatorType::kReshape) {
AddMessageF("Combining %s and %s into %s", LogName(*previous_op),
LogName(*matmul_op), LogName(*fc_op));
const auto& previous_op_output = previous_op->outputs[0];
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
index 38e0005890..4edffe3d48 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
@@ -27,7 +27,7 @@ namespace toco {
bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
const auto merge_it = model->operators.begin() + op_index;
const auto* merge_op = merge_it->get();
- if (merge_op->type != OperatorType::kTensorFlowMerge) {
+ if (merge_op->type != OperatorType::kMerge) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
index a418073441..da8e7a2d1c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
@@ -27,7 +27,7 @@ namespace toco {
bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
const auto switch_it = model->operators.begin() + op_index;
const auto* switch_op = switch_it->get();
- if (switch_op->type != OperatorType::kTensorFlowSwitch) {
+ if (switch_op->type != OperatorType::kSwitch) {
return false;
}
@@ -92,7 +92,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
if (*input_it == switch_op->outputs[nonselected_output_index]) {
// Let us guard our assumption that only Merge nodes consume the outputs
// of Switch nodes:
- CHECK(other_op->type == OperatorType::kTensorFlowMerge);
+ CHECK(other_op->type == OperatorType::kMerge);
input_it = other_op->inputs.erase(input_it);
} else {
++input_it;
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 2585cff56e..ef170b3884 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -96,38 +96,38 @@ enum class OperatorType : uint8 {
// Special operators used for importing TensorFlow nodes.
// The general intent is to have some graph transformation either
// drop them or rewrite them as general-purpose operators.
- kTensorFlowAll,
- kTensorFlowAssert,
- kTensorFlowConcat,
- kTensorFlowConcatV2,
- kTensorFlowGreater,
- kTensorFlowGreaterEqual,
- kTensorFlowIdentity,
- kTensorFlowLess,
- kTensorFlowLessEqual,
- kTensorFlowMax,
- kTensorFlowMaximum,
- kTensorFlowMin,
- kTensorFlowMinimum,
- kTensorFlowMatMul,
- kTensorFlowMerge,
+ kAll,
+ kAssert,
+ kConcat,
+ kConcatV2,
+ kGreater,
+ kGreaterEqual,
+ kIdentity,
+ kLess,
+ kLessEqual,
+ kMax, // Reduction Max
+ kMaximum, // Element-wise Maximum
+ kMin, // Reduction Min
+ kMinimum, // Element-wise Minimum
+ kMatMul,
+ kMerge,
kNeg,
- kTensorFlowReshape,
- kTensorFlowRsqrt,
- kTensorFlowShape,
- kTensorFlowSplit,
- kTensorFlowSqrt,
- kTensorFlowSquare,
- kTensorFlowSum,
- kTensorFlowSwitch,
- kTensorFlowTile,
+ kReshape,
+ kRsqrt,
+ kShape,
+ kSplit,
+ kSqrt,
+ kSquare,
+ kSum,
+ kSwitch,
+ kTile,
kTranspose,
kTopK_V2,
kDynamicPartition,
kDynamicStitch,
// An unsupported TF operation. It's only needed to be able to represent TF
// graph internally and is expected to be dropped by graph transformations.
- kTensorFlowUnsupported,
+ kUnsupported,
// Finally, TensorFlow uses different conventions for axes ordering,
// see AxesOrder, and this cannot always be resolved at the time of importing
// nodes, as TensorFlow parameters may be constant-expression subgraphs
@@ -136,8 +136,8 @@ enum class OperatorType : uint8 {
kReorderAxes,
kSelect,
kSparseToDense,
- kTensorFlowEqual,
- kTensorFlowNotEqual,
+ kEqual,
+ kNotEqual,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -801,7 +801,7 @@ struct DivOperator : Operator {
//
// TensorFlow equivalent: Identity
struct TensorFlowIdentityOperator : Operator {
- TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {}
+ TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {}
};
// Batch matrix multiplication operator. This comes from the (deprecated)
@@ -827,7 +827,7 @@ struct BatchMatMulOperator : Operator {
//
// TensorFlow equivalent: MatMul
struct TensorFlowMatMulOperator : Operator {
- TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {}
+ TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {}
};
// Padding operator. Pads a tensor with zeros.
@@ -961,7 +961,7 @@ struct StridedSliceOperator : Operator {
// TensorFlow equivalent: Reshape --- except that we only support a special case
// here, where the output shape is a matrix (2D) shape.
struct TensorFlowReshapeOperator : Operator {
- TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {}
+ TensorFlowReshapeOperator() : Operator(OperatorType::kReshape) {}
std::vector<int> shape;
};
@@ -1131,7 +1131,7 @@ struct SelectOperator : Operator {
//
// TensorFlow equivalent: Rsqrt
struct TensorFlowRsqrtOperator : Operator {
- TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {}
+ TensorFlowRsqrtOperator() : Operator(OperatorType::kRsqrt) {}
};
// Stacks a list of rank-R tensors into one rank-(R+1) tensor.
@@ -1159,7 +1159,7 @@ struct StackOperator : Operator {
//
// TensorFlow equivalent: Shape.
struct TensorFlowShapeOperator : Operator {
- TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {}
+ TensorFlowShapeOperator() : Operator(OperatorType::kShape) {}
ArrayDataType output_data_type = ArrayDataType::kInt32;
};
@@ -1170,7 +1170,7 @@ struct TensorFlowShapeOperator : Operator {
//
// TensorFlow equivalent: Sqrt
struct TensorFlowSqrtOperator : Operator {
- TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {}
+ TensorFlowSqrtOperator() : Operator(OperatorType::kSqrt) {}
};
// Element-wise square (x*x) operator.
@@ -1180,7 +1180,7 @@ struct TensorFlowSqrtOperator : Operator {
//
// TensorFlow equivalent: Square
struct TensorFlowSquareOperator : Operator {
- TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {}
+ TensorFlowSquareOperator() : Operator(OperatorType::kSquare) {}
};
// Transposes a tensor.
@@ -1215,7 +1215,7 @@ struct SubOperator : Operator {
//
// TensorFlow equivalent: Sum
struct TensorFlowSumOperator : Operator {
- TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {}
+ TensorFlowSumOperator() : Operator(OperatorType::kSum) {}
bool keep_dims = false;
};
@@ -1225,7 +1225,7 @@ struct TensorFlowSumOperator : Operator {
// inputs[0]: required: the input array
// inputs[1]: required: int array with length of rank(input[0])
struct TensorFlowTileOperator : Operator {
- TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {}
+ TensorFlowTileOperator() : Operator(OperatorType::kTile) {}
};
// TensorFlow Slice equivalent. Refer to TensorFlow documentation for details.
@@ -1240,7 +1240,7 @@ struct SliceOperator : Operator {
// Not fully supported, just a placeholder to handle TensorFlow graphs and
// support graph transformations to other operator types by matching sub-graphs.
struct TensorFlowSplitOperator : Operator {
- TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {}
+ TensorFlowSplitOperator() : Operator(OperatorType::kSplit) {}
int num_split = 0;
};
@@ -1251,7 +1251,7 @@ struct TensorFlowSplitOperator : Operator {
// dimension then we can change this op into a DepthConcatenation op.
// Otherwise, we hope for some other graph transformation to drop this node.
struct TensorFlowConcatOperator : Operator {
- TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {}
+ TensorFlowConcatOperator() : Operator(OperatorType::kConcat) {}
};
// TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for
@@ -1262,7 +1262,7 @@ struct TensorFlowConcatOperator : Operator {
// dimension then we can change this op into a DepthConcatenation op.
// Otherwise, we hope for some other graph transformation to drop this node.
struct TensorFlowConcatV2Operator : Operator {
- TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {}
+ TensorFlowConcatV2Operator() : Operator(OperatorType::kConcatV2) {}
};
// TensorFlow Merge equivalent. Refer to TensorFlow documentation for details.
@@ -1278,7 +1278,7 @@ struct TensorFlowConcatV2Operator : Operator {
// control flow that can be resolved at tooling time (independently of input
// activations).
struct TensorFlowMergeOperator : Operator {
- TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {}
+ TensorFlowMergeOperator() : Operator(OperatorType::kMerge) {}
};
// TensorFlow Switch equivalent. Refer to TensorFlow documentation for details.
@@ -1301,7 +1301,7 @@ struct TensorFlowMergeOperator : Operator {
// control flow that can be resolved at tooling time (independently of input
// activations).
struct TensorFlowSwitchOperator : Operator {
- TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {}
+ TensorFlowSwitchOperator() : Operator(OperatorType::kSwitch) {}
};
// TensorFlow All equivalent. Refer to TensorFlow documentation for details.
@@ -1310,7 +1310,7 @@ struct TensorFlowSwitchOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowAllOperator : Operator {
- TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {}
+ TensorFlowAllOperator() : Operator(OperatorType::kAll) {}
};
// TensorFlow Assert equivalent. Refer to TensorFlow documentation for details.
@@ -1318,7 +1318,7 @@ struct TensorFlowAllOperator : Operator {
// support graph transformations to other operator types by matching sub-graphs.
// Typically, we just drop Assert nodes.
struct TensorFlowAssertOperator : Operator {
- TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {}
+ TensorFlowAssertOperator() : Operator(OperatorType::kAssert) {}
};
// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
@@ -1327,7 +1327,7 @@ struct TensorFlowAssertOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowLessOperator : Operator {
- TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {}
+ TensorFlowLessOperator() : Operator(OperatorType::kLess) {}
};
// TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for
@@ -1337,8 +1337,7 @@ struct TensorFlowLessOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowLessEqualOperator : Operator {
- TensorFlowLessEqualOperator()
- : Operator(OperatorType::kTensorFlowLessEqual) {}
+ TensorFlowLessEqualOperator() : Operator(OperatorType::kLessEqual) {}
};
// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
@@ -1347,7 +1346,7 @@ struct TensorFlowLessEqualOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowGreaterOperator : Operator {
- TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {}
+ TensorFlowGreaterOperator() : Operator(OperatorType::kGreater) {}
};
// TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for
@@ -1357,8 +1356,7 @@ struct TensorFlowGreaterOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowGreaterEqualOperator : Operator {
- TensorFlowGreaterEqualOperator()
- : Operator(OperatorType::kTensorFlowGreaterEqual) {}
+ TensorFlowGreaterEqualOperator() : Operator(OperatorType::kGreaterEqual) {}
};
// TensorFlow Equal equivalent. Refer to TensorFlow documentation for
@@ -1368,13 +1366,13 @@ struct TensorFlowGreaterEqualOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowEqualOperator : Operator {
- TensorFlowEqualOperator() : Operator(OperatorType::kTensorFlowEqual) {}
+ TensorFlowEqualOperator() : Operator(OperatorType::kEqual) {}
};
// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for
// details.
struct TensorFlowNotEqualOperator : Operator {
- TensorFlowNotEqualOperator() : Operator(OperatorType::kTensorFlowNotEqual) {}
+ TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {}
};
// Global max reduction: computes the max of all of entries in the input array.
@@ -1386,7 +1384,7 @@ struct TensorFlowNotEqualOperator : Operator {
// TensorFlow equivalent: Max --- except that we only support the special case
// of global reduction across all dimensions.
struct TensorFlowMaxOperator : Operator {
- TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {}
+ TensorFlowMaxOperator() : Operator(OperatorType::kMax) {}
bool keep_dims = false;
};
@@ -1399,7 +1397,7 @@ struct TensorFlowMaxOperator : Operator {
// TensorFlow equivalent: Min --- except that we only support the special case
// of global reduction across all dimensions.
struct TensorFlowMinOperator : Operator {
- TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {}
+ TensorFlowMinOperator() : Operator(OperatorType::kMin) {}
bool keep_dims = false;
};
@@ -1412,7 +1410,7 @@ struct TensorFlowMinOperator : Operator {
//
// TensorFlow equivalent: Maximum
struct TensorFlowMaximumOperator : Operator {
- TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {}
+ TensorFlowMaximumOperator() : Operator(OperatorType::kMaximum) {}
};
// Element-wise minimum operator. Currently it only supports scalar as
@@ -1424,14 +1422,13 @@ struct TensorFlowMaximumOperator : Operator {
//
// TensorFlow equivalent: Minimum
struct TensorFlowMinimumOperator : Operator {
- TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {}
+ TensorFlowMinimumOperator() : Operator(OperatorType::kMinimum) {}
};
// General TF operation, unsupported by tf.mini. Expected to be dropped by
// graph transformations.
struct TensorFlowUnsupportedOperator : Operator {
- TensorFlowUnsupportedOperator()
- : Operator(OperatorType::kTensorFlowUnsupported) {}
+ TensorFlowUnsupportedOperator() : Operator(OperatorType::kUnsupported) {}
// The original TF operation type. Used for diagnostic purposes.
string tensorflow_op;
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 7ba2603a95..1972246807 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -49,7 +49,7 @@ details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
string custom_code;
- if (op.type == OperatorType::kTensorFlowUnsupported) {
+ if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
custom_code = unsupported_op.tensorflow_op;
@@ -211,7 +211,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
ordered_opcodes[op_index] =
CreateOperatorCode(*builder, builtin_ops[name], 0, op_version);
} else {
- // This could be a kTensorFlowUnsupported, in which case we should be
+ // This could be a kUnsupported, in which case we should be
// able to retrieve the original Tensorflow name from the OperatorKey, or
// this could be a proper TOCO operator that is completely unknown to TF
// Lite.
@@ -268,7 +268,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
: tflite_op_it->second.get();
// This is a custom op unless we can find it in ops_by_type, and even then
- // it could be a custom op (such as kTensorFlowUnsupported).
+ // it could be a custom op (such as kUnsupported).
auto options = Options::Custom(0);
std::vector<bool> mutating_input_variables;
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 098d2163e6..58ea5c725c 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -45,7 +45,7 @@ namespace details {
using TensorsMap = std::unordered_map<string, int>;
// A key to identify an operator.
-// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to
+// Only when `type` is `kUnsupported`, `custom_code` is filled to
// identify which operation is used.
struct OperatorKey {
OperatorKey(OperatorType type, const std::string& custom_code, int version)
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 409e7d72a5..d1fdbcb8e9 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -73,8 +73,8 @@ TEST_F(ExportTest, LoadOperatorsMap) {
EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
- EXPECT_EQ(3, operators[details::OperatorKey(
- OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]);
+ EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported,
+ "MyCrazyOp", 1)]);
}
TEST_F(ExportTest, Export) {
diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc
index cb44a5e6d7..d1867bd4fa 100644
--- a/tensorflow/contrib/lite/toco/tflite/import.cc
+++ b/tensorflow/contrib/lite/toco/tflite/import.cc
@@ -124,7 +124,7 @@ void ImportOperators(
new_op = ops_by_name.at(effective_opname)
->Deserialize(input_op->builtin_options(),
input_op->custom_options());
- if (new_op->type == OperatorType::kTensorFlowUnsupported) {
+ if (new_op->type == OperatorType::kUnsupported) {
auto* unsupported_op =
static_cast<TensorFlowUnsupportedOperator*>(new_op.get());
unsupported_op->tensorflow_op = opname;
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index fd6c849889..290a925c1e 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1114,8 +1114,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
ops.emplace_back(
new PadV2(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
- ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
- OperatorType::kTensorFlowReshape));
+ ops.emplace_back(
+ new Reshape(::tflite::BuiltinOperator_RESHAPE, OperatorType::kReshape));
ops.emplace_back(
new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax));
ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH,
@@ -1126,14 +1126,13 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kTranspose));
ops.emplace_back(
new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
- ops.emplace_back(
- new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kTensorFlowSum));
+ ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
OperatorType::kResizeBilinear));
ops.emplace_back(
new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
- ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT,
- OperatorType::kTensorFlowSplit));
+ ops.emplace_back(
+ new Split(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
OperatorType::kStridedSlice));
ops.emplace_back(
@@ -1145,28 +1144,27 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(
new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax));
ops.emplace_back(
- new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTensorFlowTile));
+ new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS,
OperatorType::kExpandDims));
ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV,
OperatorType::kTransposeConv));
ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE,
OperatorType::kSparseToDense));
- ops.emplace_back(new Shape(::tflite::BuiltinOperator_SHAPE,
- OperatorType::kTensorFlowShape));
+ ops.emplace_back(
+ new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
// Custom Operators.
ops.emplace_back(
new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
- ops.emplace_back(new TensorFlowUnsupported(
- "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported));
+ ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED",
+ OperatorType::kUnsupported));
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
ops.emplace_back(
new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN));
-
// Simple Operators.
ops.emplace_back(new SimpleOperator<DequantizeOperator>(
"DEQUANTIZE", OperatorType::kDequantize));
@@ -1188,21 +1186,21 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(new SimpleOperator<LogSoftmaxOperator>(
"LOG_SOFTMAX", OperatorType::kLogSoftmax));
ops.emplace_back(new SimpleOperator<TensorFlowMaximumOperator>(
- "MAXIMUM", OperatorType::kTensorFlowMaximum));
+ "MAXIMUM", OperatorType::kMaximum)); // Element-wise Maximum
ops.emplace_back(new SimpleOperator<TensorFlowMinimumOperator>(
- "MINIMUM", OperatorType::kTensorFlowMinimum));
+ "MINIMUM", OperatorType::kMinimum)); // Element-wise Minimum
ops.emplace_back(new SimpleOperator<TensorFlowGreaterOperator>(
- "GREATER", OperatorType::kTensorFlowGreater));
+ "GREATER", OperatorType::kGreater));
ops.emplace_back(new SimpleOperator<TensorFlowGreaterEqualOperator>(
- "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual));
- ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>(
- "LESS", OperatorType::kTensorFlowLess));
+ "GREATER_EQUAL", OperatorType::kGreaterEqual));
+ ops.emplace_back(
+ new SimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess));
ops.emplace_back(new SimpleOperator<TensorFlowLessEqualOperator>(
- "LESS_EQUAL", OperatorType::kTensorFlowLessEqual));
+ "LESS_EQUAL", OperatorType::kLessEqual));
ops.emplace_back(new SimpleOperator<TensorFlowEqualOperator>(
- "EQUAL", OperatorType::kTensorFlowEqual));
+ "EQUAL", OperatorType::kEqual));
ops.emplace_back(new SimpleOperator<TensorFlowNotEqualOperator>(
- "NOT_EQUAL", OperatorType::kTensorFlowNotEqual));
+ "NOT_EQUAL", OperatorType::kNotEqual));
ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
ops.emplace_back(
new SimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect));
@@ -1211,10 +1209,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
// Element-wise operator
ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog));
- ops.emplace_back(new SimpleOperator<TensorFlowSqrtOperator>(
- "SQRT", OperatorType::kTensorFlowSqrt));
+ ops.emplace_back(
+ new SimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt));
ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
- "RSQRT", OperatorType::kTensorFlowRsqrt));
+ "RSQRT", OperatorType::kRsqrt));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index bd881d079e..79c8e5d738 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -112,24 +112,20 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
OperatorType::kLogSoftmax);
CheckSimpleOperator<TensorFlowMaximumOperator>(
- "MAXIMUM", OperatorType::kTensorFlowMaximum);
+ "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum
CheckSimpleOperator<TensorFlowMinimumOperator>(
- "MINIMUM", OperatorType::kTensorFlowMinimum);
- CheckSimpleOperator<TensorFlowLessOperator>("LESS",
- OperatorType::kTensorFlowLess);
+ "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum
+ CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess);
CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
- CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL",
- OperatorType::kTensorFlowEqual);
- CheckSimpleOperator<TensorFlowNotEqualOperator>(
- "NOT_EQUAL", OperatorType::kTensorFlowNotEqual);
+ CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", OperatorType::kEqual);
+ CheckSimpleOperator<TensorFlowNotEqualOperator>("NOT_EQUAL",
+ OperatorType::kNotEqual);
CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
- CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT",
- OperatorType::kTensorFlowSqrt);
- CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT",
- OperatorType::kTensorFlowRsqrt);
+ CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
+ CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
}
TEST_F(OperatorTest, BuiltinAdd) {
@@ -258,7 +254,7 @@ TEST_F(OperatorTest, BuiltinReshape) {
TensorFlowReshapeOperator op;
op.shape = {1, 2, 4, 5, 8};
auto output_toco_op = SerializeAndDeserialize(
- GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op);
+ GetOperator("RESHAPE", OperatorType::kReshape), op);
EXPECT_EQ(op.shape, output_toco_op->shape);
}
@@ -281,8 +277,8 @@ TEST_F(OperatorTest, BuiltinSpaceToDepth) {
TEST_F(OperatorTest, CustomSplit) {
TensorFlowSplitOperator op;
op.num_split = 123;
- auto output_toco_op = SerializeAndDeserialize(
- GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op);
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op);
EXPECT_EQ(op.num_split, output_toco_op->num_split);
}
@@ -434,8 +430,8 @@ TEST_F(OperatorTest, BuiltinTransposeConv) {
TEST_F(OperatorTest, BuiltinShape) {
TensorFlowShapeOperator op;
op.output_data_type = ArrayDataType::kInt64;
- auto output_toco_op = SerializeAndDeserialize(
- GetOperator("SHAPE", OperatorType::kTensorFlowShape), op);
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op);
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
}
@@ -467,10 +463,8 @@ TEST_F(OperatorTest, TensorFlowUnsupported) {
}
node_def.SerializeToString(&op.tensorflow_node_def);
- auto output_toco_op =
- SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
- OperatorType::kTensorFlowUnsupported),
- op);
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
::tensorflow::NodeDef output_node_def;
output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
@@ -493,10 +487,8 @@ TEST_F(OperatorTest, TensorFlowUnsupported) {
TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
- auto output_toco_op =
- SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
- OperatorType::kTensorFlowUnsupported),
- op);
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
::tensorflow::NodeDef output_node_def;
output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 3173d524b7..2534d1ef2a 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -34,11 +34,11 @@ limitations under the License.
namespace toco {
namespace {
-// CHECK-fails if the model contains a kTensorFlowUnsupported operation.
+// CHECK-fails if the model contains a kUnsupported operation.
void CheckUnsupportedOperations(const Model& model) {
std::set<string> unsupported_ops;
for (auto& op : model.operators) {
- if (op->type == OperatorType::kTensorFlowUnsupported) {
+ if (op->type == OperatorType::kUnsupported) {
unsupported_ops.insert(
static_cast<const TensorFlowUnsupportedOperator*>(op.get())
->tensorflow_op);
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 92bab5246c..a52c812ef4 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -338,23 +338,23 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Div)
HANDLE_OPERATORTYPENAME_CASE(Tanh)
HANDLE_OPERATORTYPENAME_CASE(Sin)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert)
+ HANDLE_OPERATORTYPENAME_CASE(All)
+ HANDLE_OPERATORTYPENAME_CASE(Assert)
HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
HANDLE_OPERATORTYPENAME_CASE(Fill)
HANDLE_OPERATORTYPENAME_CASE(FloorMod)
HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum)
+ HANDLE_OPERATORTYPENAME_CASE(Greater)
+ HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
+ HANDLE_OPERATORTYPENAME_CASE(Identity)
+ HANDLE_OPERATORTYPENAME_CASE(Less)
+ HANDLE_OPERATORTYPENAME_CASE(LessEqual)
+ HANDLE_OPERATORTYPENAME_CASE(MatMul)
+ HANDLE_OPERATORTYPENAME_CASE(Max) // Reduction Max
+ HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
+ HANDLE_OPERATORTYPENAME_CASE(Merge)
+ HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min
+ HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
HANDLE_OPERATORTYPENAME_CASE(Neg)
HANDLE_OPERATORTYPENAME_CASE(Pad)
HANDLE_OPERATORTYPENAME_CASE(PadV2)
@@ -362,22 +362,22 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Stack)
HANDLE_OPERATORTYPENAME_CASE(Range)
HANDLE_OPERATORTYPENAME_CASE(Rank)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape)
+ HANDLE_OPERATORTYPENAME_CASE(Reshape)
HANDLE_OPERATORTYPENAME_CASE(Squeeze)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape)
+ HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
+ HANDLE_OPERATORTYPENAME_CASE(Shape)
HANDLE_OPERATORTYPENAME_CASE(Slice)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch)
+ HANDLE_OPERATORTYPENAME_CASE(Split)
+ HANDLE_OPERATORTYPENAME_CASE(Sqrt)
+ HANDLE_OPERATORTYPENAME_CASE(Square)
+ HANDLE_OPERATORTYPENAME_CASE(Switch)
HANDLE_OPERATORTYPENAME_CASE(Sub)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile)
+ HANDLE_OPERATORTYPENAME_CASE(Sum)
+ HANDLE_OPERATORTYPENAME_CASE(Tile)
HANDLE_OPERATORTYPENAME_CASE(Transpose)
HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2)
+ HANDLE_OPERATORTYPENAME_CASE(Concat)
+ HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
HANDLE_OPERATORTYPENAME_CASE(Cast)
HANDLE_OPERATORTYPENAME_CASE(Floor)
HANDLE_OPERATORTYPENAME_CASE(Gather)
@@ -388,14 +388,14 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Svdf)
HANDLE_OPERATORTYPENAME_CASE(ArgMax)
HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported)
+ HANDLE_OPERATORTYPENAME_CASE(Unsupported)
HANDLE_OPERATORTYPENAME_CASE(Exp)
HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
HANDLE_OPERATORTYPENAME_CASE(Select)
HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowEqual)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowNotEqual)
+ HANDLE_OPERATORTYPENAME_CASE(Equal)
+ HANDLE_OPERATORTYPENAME_CASE(NotEqual)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -403,7 +403,7 @@ const char* OperatorTypeName(OperatorType type) {
}
string HelpfulOperatorTypeName(const Operator& op) {
- if (op.type == OperatorType::kTensorFlowUnsupported) {
+ if (op.type == OperatorType::kUnsupported) {
return toco::port::StringF(
"(Unsupported TensorFlow op: %s)",
static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
@@ -413,16 +413,20 @@ string HelpfulOperatorTypeName(const Operator& op) {
bool OperatorSupportsFusedActivation(OperatorType type) {
switch (type) {
- case OperatorType::kConcatenation:
- case OperatorType::kFakeQuant:
- case OperatorType::kGather:
- case OperatorType::kSlice:
- case OperatorType::kSqueeze:
- case OperatorType::kTensorFlowReshape:
- case OperatorType::kTensorFlowSplit:
- return false;
- default:
+ case OperatorType::kAdd:
+ case OperatorType::kAveragePool:
+ case OperatorType::kBatchNormalization:
+ case OperatorType::kConv:
+ case OperatorType::kDepthwiseConv:
+ case OperatorType::kDiv:
+ case OperatorType::kFullyConnected:
+ case OperatorType::kL2Pool:
+ case OperatorType::kMaxPool:
+ case OperatorType::kMul:
+ case OperatorType::kSub:
return true;
+ default:
+ return false;
}
}
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 7681ce9d39..791ced8d01 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -101,6 +101,8 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
const char* OperatorTypeName(OperatorType type);
string HelpfulOperatorTypeName(const Operator& op);
+// Whether the operator can be fused with an activation function. Note that this
+// will return false by default for new operators; fusing support is opt-in.
bool OperatorSupportsFusedActivation(OperatorType type);
void DumpGraphvizVideoFrame(const Model& model);
diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc
index a683867374..8609e5bedd 100644
--- a/tensorflow/contrib/lite/toco/tooling_util_test.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc
@@ -175,4 +175,10 @@ TEST(NumElementsTest, UnsignedInt64) {
EXPECT_EQ(status.error_message(), kLargeTensorMessage);
}
+TEST(FusedActivationTest, DefaultsToUnfused) {
+ EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd));
+ EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone));
+ EXPECT_FALSE(OperatorSupportsFusedActivation(static_cast<OperatorType>(255)));
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index 8857062c00..183a545295 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -67,6 +67,16 @@ cc_library(
)
cc_library(
+ name = "benchmark_params",
+ srcs = [
+ "benchmark_params.cc",
+ "logging.h",
+ ],
+ hdrs = ["benchmark_params.h"],
+ copts = common_copts,
+)
+
+cc_library(
name = "benchmark_model_lib",
srcs = [
"benchmark_model.cc",
@@ -75,6 +85,7 @@ cc_library(
hdrs = ["benchmark_model.h"],
copts = common_copts,
deps = [
+ ":benchmark_params",
":command_line_flags",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
index a8a9a6112c..08648bcfe2 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
@@ -48,6 +48,19 @@ namespace tflite {
namespace benchmark {
using tensorflow::Stat;
+BenchmarkParams BenchmarkModel::DefaultParams() {
+ BenchmarkParams params;
+ params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(50));
+ params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
+ params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
+ params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
+ params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
+ params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
+ return params;
+}
+
+BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {}
+
void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
auto inference_us = results.inference_time_us();
auto init_us = results.startup_latency_us();
@@ -60,24 +73,29 @@ void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
std::vector<Flag> BenchmarkModel::GetFlags() {
return {
- Flag("num_runs", &params_.num_runs, "number of runs"),
- Flag("run_delay", &params_.run_delay, "delay between runs in seconds"),
- Flag("num_threads", &params_.num_threads, "number of threads"),
- Flag("benchmark_name", &params_.benchmark_name, "benchmark name"),
- Flag("output_prefix", &params_.output_prefix, "benchmark output prefix"),
- Flag("warmup_runs", &params_.warmup_runs,
- "how many runs to initialize model"),
+ CreateFlag<int32_t>("num_runs", &params_, "number of runs"),
+ CreateFlag<float>("run_delay", &params_, "delay between runs in seconds"),
+ CreateFlag<int32_t>("num_threads", &params_, "number of threads"),
+ CreateFlag<std::string>("benchmark_name", &params_, "benchmark name"),
+ CreateFlag<std::string>("output_prefix", &params_,
+ "benchmark output prefix"),
+ CreateFlag<int32_t>("warmup_runs", &params_,
+ "how many runs to initialize model"),
};
}
void BenchmarkModel::LogFlags() {
- TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]";
- TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay
+ TFLITE_LOG(INFO) << "Num runs: [" << params_.Get<int32_t>("num_runs") << "]";
+ TFLITE_LOG(INFO) << "Inter-run delay (seconds): ["
+ << params_.Get<float>("run_delay") << "]";
+ TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
+ << "]";
+ TFLITE_LOG(INFO) << "Benchmark name: ["
+ << params_.Get<std::string>("benchmark_name") << "]";
+ TFLITE_LOG(INFO) << "Output prefix: ["
+ << params_.Get<std::string>("output_prefix") << "]";
+ TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get<int32_t>("warmup_runs")
<< "]";
- TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]";
- TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]";
- TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]";
- TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]";
}
Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
@@ -91,7 +109,7 @@ Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
listeners_.OnSingleRunEnd();
run_stats.UpdateStat(end_us - start_us);
- SleepForSeconds(params_.run_delay);
+ SleepForSeconds(params_.Get<float>("run_delay"));
}
std::stringstream stream;
@@ -117,8 +135,10 @@ void BenchmarkModel::Run(int argc, char **argv) {
<< "ms";
uint64_t input_bytes = ComputeInputBytes();
- Stat<int64_t> warmup_time_us = Run(params_.warmup_runs, WARMUP);
- Stat<int64_t> inference_time_us = Run(params_.num_runs, REGULAR);
+ Stat<int64_t> warmup_time_us =
+ Run(params_.Get<int32_t>("warmup_runs"), WARMUP);
+ Stat<int64_t> inference_time_us =
+ Run(params_.Get<int32_t>("num_runs"), REGULAR);
listeners_.OnBenchmarkEnd(
{startup_latency_us, input_bytes, warmup_time_us, inference_time_us});
}
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
index d48f693693..942e21f67a 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h"
#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h"
#include "tensorflow/core/util/stats_calculator.h"
@@ -63,17 +64,6 @@ class BenchmarkResults {
tensorflow::Stat<int64_t> inference_time_us_;
};
-struct BenchmarkParams {
- BenchmarkParams()
- : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {}
- int num_runs;
- int warmup_runs;
- float run_delay;
- int num_threads;
- std::string benchmark_name;
- std::string output_prefix;
-};
-
class BenchmarkListener {
public:
virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
@@ -130,12 +120,22 @@ class BenchmarkLoggingListener : public BenchmarkListener {
void OnBenchmarkEnd(const BenchmarkResults& results) override;
};
+template <typename T>
+Flag CreateFlag(const char* name, BenchmarkParams* params,
+ const std::string& usage) {
+ return Flag(name, [params, name](const T& val) { params->Set<T>(name, val); },
+ params->Get<T>(name), usage);
+}
+
// Benchmarks a model.
//
// Subclasses need to implement initialization and running of the model.
// The results can be collected by adding BenchmarkListener(s).
class BenchmarkModel {
public:
+ static BenchmarkParams DefaultParams();
+ BenchmarkModel();
+ BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {}
virtual ~BenchmarkModel() {}
bool ParseFlags(int argc, char** argv);
virtual void Init() = 0;
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc
new file mode 100644
index 0000000000..1dcf580a9d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
+
+namespace tflite {
+namespace benchmark {
+
+void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a,
+ BenchmarkParam::ParamType b) {
+ TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter.";
+}
+
+template <>
+BenchmarkParam::ParamType BenchmarkParam::GetValueType<int32_t>() {
+ return BenchmarkParam::ParamType::TYPE_INT32;
+}
+
+template <>
+BenchmarkParam::ParamType BenchmarkParam::GetValueType<bool>() {
+ return BenchmarkParam::ParamType::TYPE_BOOL;
+}
+
+template <>
+BenchmarkParam::ParamType BenchmarkParam::GetValueType<float>() {
+ return BenchmarkParam::ParamType::TYPE_FLOAT;
+}
+
+template <>
+BenchmarkParam::ParamType BenchmarkParam::GetValueType<std::string>() {
+ return BenchmarkParam::ParamType::TYPE_STRING;
+}
+
+void BenchmarkParams::AssertParamExists(const std::string& name) const {
+ TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found.";
+}
+
+} // namespace benchmark
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h
new file mode 100644
index 0000000000..33448dd162
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
+
+namespace tflite {
+namespace benchmark {
+
+template <typename T>
+class TypedBenchmarkParam;
+
+class BenchmarkParam {
+ protected:
+ enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
+
+ public:
+ template <typename T>
+ static std::unique_ptr<BenchmarkParam> Create(const T& default_value) {
+ return std::unique_ptr<BenchmarkParam>(
+ new TypedBenchmarkParam<T>(default_value));
+ }
+
+ template <typename T>
+ TypedBenchmarkParam<T>* AsTyped() {
+ AssertHasSameType(GetValueType<T>(), type_);
+ return static_cast<TypedBenchmarkParam<T>*>(this);
+ }
+ virtual ~BenchmarkParam() {}
+ BenchmarkParam(ParamType type) : type_(type) {}
+
+ private:
+ static void AssertHasSameType(ParamType a, ParamType b);
+ template <typename T>
+ static ParamType GetValueType();
+
+ const ParamType type_;
+};
+
+template <typename T>
+class TypedBenchmarkParam : public BenchmarkParam {
+ public:
+ TypedBenchmarkParam(const T& value)
+ : BenchmarkParam(GetValueType<T>()), value_(value) {}
+ void Set(const T& value) { value_ = value; }
+
+ T Get() { return value_; }
+
+ private:
+ T value_;
+};
+
+class BenchmarkParams {
+ public:
+ void AddParam(const std::string& name,
+ std::unique_ptr<BenchmarkParam> value) {
+ params_[name] = std::move(value);
+ }
+
+ bool HasParam(const std::string& name) const {
+ return params_.find(name) != params_.end();
+ }
+
+ template <typename T>
+ void Set(const std::string& name, const T& value) {
+ AssertParamExists(name);
+ params_.at(name)->AsTyped<T>()->Set(value);
+ }
+
+ template <typename T>
+ T Get(const std::string& name) const {
+ AssertParamExists(name);
+ return params_.at(name)->AsTyped<T>()->Get();
+ }
+
+ private:
+ void AssertParamExists(const std::string& name) const;
+ std::unordered_map<std::string, std::unique_ptr<BenchmarkParam>> params_;
+};
+
+} // namespace benchmark
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 5f803cec19..73affc26b0 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -162,15 +162,37 @@ bool PopulateInputLayerInfo(
return true;
}
+BenchmarkParams GetDefaultParams() {
+ BenchmarkParams default_params = BenchmarkModel::DefaultParams();
+ default_params.AddParam("graph", BenchmarkParam::Create<std::string>(""));
+ default_params.AddParam("input_layer",
+ BenchmarkParam::Create<std::string>(""));
+ default_params.AddParam("input_layer_shape",
+ BenchmarkParam::Create<std::string>(""));
+ default_params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false));
+ return default_params;
+}
+
} // namespace
+BenchmarkTfLiteModel::BenchmarkTfLiteModel()
+ : BenchmarkModel(GetDefaultParams()) {
+ AddListener(&profiling_listener_);
+}
+
+BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params)
+ : BenchmarkModel(std::move(params)) {
+ AddListener(&profiling_listener_);
+}
+
std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
std::vector<Flag> flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags();
std::vector<Flag> specific_flags = {
- Flag("graph", &graph, "graph file name"),
- Flag("input_layer", &input_layer_string, "input layer names"),
- Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"),
- Flag("use_nnapi", &use_nnapi, "use nnapi api")};
+ CreateFlag<std::string>("graph", &params_, "graph file name"),
+ CreateFlag<std::string>("input_layer", &params_, "input layer names"),
+ CreateFlag<std::string>("input_layer_shape", &params_,
+ "input layer shape"),
+ CreateFlag<bool>("use_nnapi", &params_, "use nnapi api")};
flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
return flags;
@@ -178,19 +200,22 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
void BenchmarkTfLiteModel::LogFlags() {
BenchmarkModel::LogFlags();
- TFLITE_LOG(INFO) << "Graph: [" << graph << "]";
- TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]";
- TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]";
- TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]";
+ TFLITE_LOG(INFO) << "Graph: [" << params_.Get<std::string>("graph") << "]";
+ TFLITE_LOG(INFO) << "Input layers: ["
+ << params_.Get<std::string>("input_layer") << "]";
+ TFLITE_LOG(INFO) << "Input shapes: ["
+ << params_.Get<std::string>("input_layer_shape") << "]";
+ TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
}
bool BenchmarkTfLiteModel::ValidateFlags() {
- if (graph.empty()) {
+ if (params_.Get<std::string>("graph").empty()) {
TFLITE_LOG(ERROR)
<< "Please specify the name of your TF Lite input file with --graph";
return false;
}
- return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string,
+ return PopulateInputLayerInfo(params_.Get<std::string>("input_layer"),
+ params_.Get<std::string>("input_layer_shape"),
&inputs);
}
@@ -205,6 +230,7 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
}
void BenchmarkTfLiteModel::Init() {
+ std::string graph = params_.Get<std::string>("graph");
model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
if (!model) {
TFLITE_LOG(FATAL) << "Failed to mmap model " << graph;
@@ -226,10 +252,14 @@ void BenchmarkTfLiteModel::Init() {
}
profiling_listener_.SetInterpreter(interpreter.get());
- if (params_.num_threads != -1) {
- interpreter->SetNumThreads(params_.num_threads);
+ const int32_t num_threads = params_.Get<int32_t>("num_threads");
+
+ if (num_threads != -1) {
+ interpreter->SetNumThreads(num_threads);
}
+ bool use_nnapi = params_.Get<bool>("use_nnapi");
+
interpreter->UseNNAPI(use_nnapi);
auto interpreter_inputs = interpreter->inputs();
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index ffb93da964..50cc3f24b3 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -50,9 +50,8 @@ class ProfilingListener : public BenchmarkListener {
// Benchmarks a TFLite model by running tflite interpreter.
class BenchmarkTfLiteModel : public BenchmarkModel {
public:
- BenchmarkTfLiteModel() : use_nnapi(false) {
- AddListener(&profiling_listener_);
- }
+ BenchmarkTfLiteModel();
+ BenchmarkTfLiteModel(BenchmarkParams params);
std::vector<Flag> GetFlags() override;
void LogFlags() override;
@@ -70,13 +69,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
private:
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
- std::string graph;
- std::string input_layer_string;
- std::string input_layer_type_string;
- std::string input_layer_shape_string;
- std::string input_layer_values_string;
std::vector<InputLayerInfo> inputs;
- bool use_nnapi;
ProfilingListener profiling_listener_;
};
diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc
index 8195fc44be..ff818b9dcb 100644
--- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <cstring>
#include <sstream>
#include <string>
+#include <utility>
#include <vector>
namespace tflite {
@@ -44,76 +45,79 @@ bool ParseFlag(const std::string& arg, const std::string& flag,
}
template <typename T>
-bool ParseFlag(const std::string& flag_value, T* value) {
+bool ParseFlag(const std::string& flag_value,
+ const std::function<void(const T&)>& hook) {
std::istringstream stream(flag_value);
T read_value;
stream >> read_value;
if (!stream.eof() && !stream.good()) {
return false;
}
- *value = read_value;
+ hook(read_value);
return true;
}
-bool ParseBoolFlag(const std::string& flag_value, bool* value) {
+bool ParseBoolFlag(const std::string& flag_value,
+ const std::function<void(const bool&)>& hook) {
if (flag_value != "true" && flag_value != "false") {
return false;
}
- *value = (flag_value == "true");
+ hook(flag_value == "true");
return true;
}
-
-bool ParseStringFlag(const std::string& flag_value, std::string* value) {
- *value = flag_value;
- return true;
-}
-
} // namespace
-Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text)
+Flag::Flag(const char* name, const std::function<void(const int32_t&)>& hook,
+ int32_t default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_INT32),
- value_hook_([dst](const std::string& flag_value) {
- return ParseFlag<int32_t>(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ return ParseFlag<int32_t>(flag_value, hook);
}),
- default_for_display_(ToString(*dst)),
+ default_for_display_(ToString(default_value)),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text)
+Flag::Flag(const char* name, const std::function<void(const int64_t&)>& hook,
+ int64_t default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_INT64),
- value_hook_([dst](const std::string& flag_value) {
- return ParseFlag<int64_t>(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ return ParseFlag<int64_t>(flag_value, hook);
}),
- default_for_display_(ToString(*dst)),
+ default_for_display_(ToString(default_value)),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, float* dst, const std::string& usage_text)
+Flag::Flag(const char* name, const std::function<void(const float&)>& hook,
+ float default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_FLOAT),
- value_hook_([dst](const std::string& flag_value) {
- return ParseFlag<float>(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ return ParseFlag<float>(flag_value, hook);
}),
- default_for_display_(ToString(*dst)),
+ default_for_display_(ToString(default_value)),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, bool* dst, const std::string& usage_text)
+Flag::Flag(const char* name, const std::function<void(const bool&)>& hook,
+ bool default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_BOOL),
- value_hook_([dst](const std::string& flag_value) {
- return ParseBoolFlag(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ return ParseBoolFlag(flag_value, hook);
}),
- default_for_display_((*dst) ? "true" : "false"),
+ default_for_display_(default_value ? "true" : "false"),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, std::string* dst, const std::string& usage_text)
+Flag::Flag(const char* name,
+ const std::function<void(const std::string&)>& hook,
+ const std::string& default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_STRING),
- value_hook_([dst](const std::string& flag_value) {
- return ParseStringFlag(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ hook(flag_value);
+ return true;
}),
- default_for_display_(*dst),
+ default_for_display_(default_value),
usage_text_(usage_text) {}
bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h
index 36f9e64767..2e514ae3ea 100644
--- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h
+++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h
@@ -33,10 +33,11 @@ namespace tflite {
// int some_int = 10;
// bool some_switch = false;
// std::string some_name = "something";
+//
// std::vector<tensorFlow::Flag> flag_list = {
-// Flag("some_int", &some_int, "an integer that affects X"),
-// Flag("some_switch", &some_switch, "a bool that affects Y"),
-// Flag("some_name", &some_name, "a std::string that affects Z")
+// Flag::CreateFlag("some_int", &some_int, "an integer that affects X"),
+// Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"),
+// Flag::CreateFlag("some_name", &some_name, "a string that affects Z")
// };
// // Get usage message before ParseFlags() to capture default values.
// std::string usage = Flag::Usage(argv[0], flag_list);
@@ -63,11 +64,21 @@ namespace tflite {
// text, and a pointer to the corresponding variable.
class Flag {
public:
- Flag(const char* name, int32_t* dst, const std::string& usage_text);
- Flag(const char* name, int64_t* dst, const std::string& usage_text);
- Flag(const char* name, bool* dst, const std::string& usage_text);
- Flag(const char* name, std::string* dst, const std::string& usage_text);
- Flag(const char* name, float* dst, const std::string& usage_text);
+ template <typename T>
+ static Flag CreateFlag(const char* name, T* val, const char* usage) {
+ return Flag(name, [val](const T& v) { *val = v; }, *val, usage);
+ }
+
+ Flag(const char* name, const std::function<void(const int32_t&)>& hook,
+ int32_t default_value, const std::string& usage_text);
+ Flag(const char* name, const std::function<void(const int64_t&)>& hook,
+ int64_t default_value, const std::string& usage_text);
+ Flag(const char* name, const std::function<void(const float&)>& hook,
+ float default_value, const std::string& usage_text);
+ Flag(const char* name, const std::function<void(const bool&)>& hook,
+ bool default_value, const std::string& usage_text);
+ Flag(const char* name, const std::function<void(const std::string&)>& hook,
+ const std::string& default_value, const std::string& usage_text);
private:
friend class Flags;
diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc
index 620d61b027..03da805109 100644
--- a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc
@@ -34,15 +34,15 @@ TEST(CommandLineFlagsTest, BasicUsage) {
"--some_name=somethingelse",
"--some_float=42.0"};
int argc = 6;
- bool parsed_ok =
- Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {
- Flag("some_int32", &some_int32, "some int32"),
- Flag("some_int64", &some_int64, "some int64"),
- Flag("some_switch", &some_switch, "some switch"),
- Flag("some_name", &some_name, "some name"),
- Flag("some_float", &some_float, "some float"),
- });
+ bool parsed_ok = Flags::Parse(
+ &argc, reinterpret_cast<const char**>(argv_strings),
+ {
+ Flag::CreateFlag("some_int32", &some_int32, "some int32"),
+ Flag::CreateFlag("some_int64", &some_int64, "some int64"),
+ Flag::CreateFlag("some_switch", &some_switch, "some switch"),
+ Flag::CreateFlag("some_name", &some_name, "some name"),
+ Flag::CreateFlag("some_float", &some_float, "some float"),
+ });
EXPECT_EQ(true, parsed_ok);
EXPECT_EQ(20, some_int32);
@@ -57,9 +57,9 @@ TEST(CommandLineFlagsTest, EmptyStringFlag) {
int argc = 2;
std::string some_string = "invalid";
const char* argv_strings[] = {"program_name", "--some_string="};
- bool parsed_ok =
- Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {Flag("some_string", &some_string, "some string")});
+ bool parsed_ok = Flags::Parse(
+ &argc, reinterpret_cast<const char**>(argv_strings),
+ {Flag::CreateFlag("some_string", &some_string, "some string")});
EXPECT_EQ(true, parsed_ok);
EXPECT_EQ(some_string, "");
@@ -72,7 +72,7 @@ TEST(CommandLineFlagsTest, BadIntValue) {
const char* argv_strings[] = {"program_name", "--some_int=notanumber"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {Flag("some_int", &some_int, "some int")});
+ {Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(10, some_int);
@@ -83,9 +83,9 @@ TEST(CommandLineFlagsTest, BadBoolValue) {
bool some_switch = false;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_switch=notabool"};
- bool parsed_ok =
- Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {Flag("some_switch", &some_switch, "some switch")});
+ bool parsed_ok = Flags::Parse(
+ &argc, reinterpret_cast<const char**>(argv_strings),
+ {Flag::CreateFlag("some_switch", &some_switch, "some switch")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(false, some_switch);
@@ -98,7 +98,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) {
const char* argv_strings[] = {"program_name", "--some_float=notanumber"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {Flag("some_float", &some_float, "some float")});
+ {Flag::CreateFlag("some_float", &some_float, "some float")});
EXPECT_EQ(false, parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
@@ -136,10 +136,11 @@ TEST(CommandLineFlagsTest, UsageString) {
// match against, and we don't want a flakey test.
const std::string tool_name = "some_tool_name";
std::string usage = Flags::Usage(
- tool_name + " <flags>", {Flag("some_int", &some_int, "some int"),
- Flag("some_int64", &some_int64, "some int64"),
- Flag("some_switch", &some_switch, "some switch"),
- Flag("some_name", &some_name, "some name")});
+ tool_name + " <flags>",
+ {Flag::CreateFlag("some_int", &some_int, "some int"),
+ Flag::CreateFlag("some_int64", &some_int64, "some int64"),
+ Flag::CreateFlag("some_switch", &some_switch, "some switch"),
+ Flag::CreateFlag("some_name", &some_name, "some name")});
// Match the usage message, being sloppy about whitespace.
const char* expected_usage =
" usage: some_tool_name <flags>\n"
diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh
index fc88f59e09..fb9e77ae1b 100755
--- a/tensorflow/contrib/makefile/build_all_android.sh
+++ b/tensorflow/contrib/makefile/build_all_android.sh
@@ -30,6 +30,14 @@ arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64 tegra)"
exit 1
}
+echo "********************************************************************"
+echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference."
+echo "You are currently using an older version. Please switch over to TensorFlow Lite."
+echo ""
+echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite"
+echo "********************************************************************"
+echo ""
+
if [[ -z "${NDK_ROOT}" ]]; then
echo "NDK_ROOT should be set as an environment variable" 1>&2
exit 1
diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh
index 0a458a27b3..1d4677ef4b 100755
--- a/tensorflow/contrib/makefile/build_all_ios.sh
+++ b/tensorflow/contrib/makefile/build_all_ios.sh
@@ -31,6 +31,14 @@ usage() {
exit 1
}
+echo "********************************************************************"
+echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference."
+echo "You are currently using an older version. Please switch over to TensorFlow Lite."
+echo ""
+echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite"
+echo "********************************************************************"
+echo ""
+
DEFAULT_ARCH="i386 x86_64 armv7 armv7s arm64"
while getopts "a:g:T" opt_name; do
case "$opt_name" in
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index a6be2084aa..b14202ff9e 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1064,7 +1064,7 @@ def streaming_auc(predictions,
name=name)
-def _compute_dynamic_auc(labels, predictions, curve='ROC'):
+def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None):
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
Computes the area under the ROC or PR curve using each prediction as a
@@ -1077,13 +1077,22 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
predictions: A 1-D `Tensor` of predictions whose values are `float64`.
curve: The name of the curve to be computed, 'ROC' for the Receiving
Operating Characteristic or 'PR' for the Precision-Recall curve.
+ weights: A 1-D `Tensor` of weights whose values are `float64`.
Returns:
A scalar `Tensor` containing the area-under-curve value for the input.
"""
- # Count the total number of positive and negative labels in the input.
+ # Compute the total weight and the total positive weight.
size = array_ops.size(predictions)
- total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32)
+ if weights is None:
+ weights = array_ops.ones_like(labels, dtype=dtypes.float64)
+ labels, predictions, weights = metrics_impl._remove_squeezable_dimensions(
+ labels, predictions, weights)
+ total_weight = math_ops.reduce_sum(weights)
+ total_positive = math_ops.reduce_sum(
+ array_ops.where(
+ math_ops.greater(labels, 0), weights,
+ array_ops.zeros_like(labels, dtype=dtypes.float64)))
def continue_computing_dynamic_auc():
"""Continues dynamic auc computation, entered if labels are not all equal.
@@ -1091,9 +1100,11 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
Returns:
A scalar `Tensor` containing the area-under-curve value.
"""
- # Sort the predictions descending, and the corresponding labels as well.
+ # Sort the predictions descending, keeping the same order for the
+ # corresponding labels and weights.
ordered_predictions, indices = nn.top_k(predictions, k=size)
ordered_labels = array_ops.gather(labels, indices)
+ ordered_weights = array_ops.gather(weights, indices)
# Get the counts of the unique ordered predictions.
_, _, counts = array_ops.unique_with_counts(ordered_predictions)
@@ -1103,23 +1114,39 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
# Count the positives to the left of the split indices.
- positives = math_ops.cast(
- array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]),
- dtypes.int32)
- true_positives = array_ops.gather(positives, splits)
+ true_positives = array_ops.gather(
+ array_ops.pad(
+ math_ops.cumsum(
+ array_ops.where(
+ math_ops.greater(ordered_labels, 0), ordered_weights,
+ array_ops.zeros_like(ordered_labels,
+ dtype=dtypes.float64))),
+ paddings=[[1, 0]]), splits)
if curve == 'ROC':
- # Count the negatives to the left of every split point and the total
- # number of negatives for computing the FPR.
- false_positives = math_ops.subtract(splits, true_positives)
- total_negative = size - total_positive
+ # Compute the weight of the negatives to the left of every split point and
+ # the total weight of the negatives number of negatives for computing the
+ # FPR.
+ false_positives = array_ops.gather(
+ array_ops.pad(
+ math_ops.cumsum(
+ array_ops.where(
+ math_ops.less(ordered_labels, 1), ordered_weights,
+ array_ops.zeros_like(
+ ordered_labels, dtype=dtypes.float64))),
+ paddings=[[1, 0]]), splits)
+ total_negative = total_weight - total_positive
x_axis_values = math_ops.truediv(false_positives, total_negative)
y_axis_values = math_ops.truediv(true_positives, total_positive)
elif curve == 'PR':
x_axis_values = math_ops.truediv(true_positives, total_positive)
# For conformance, set precision to 1 when the number of positive
# classifications is 0.
+ positives = array_ops.gather(
+ array_ops.pad(math_ops.cumsum(ordered_weights), paddings=[[1, 0]]),
+ splits)
y_axis_values = array_ops.where(
- math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits),
+ math_ops.greater(splits, 0),
+ math_ops.truediv(true_positives, positives),
array_ops.ones_like(true_positives, dtype=dtypes.float64))
# Calculate trapezoid areas.
@@ -1133,7 +1160,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
return control_flow_ops.cond(
math_ops.logical_or(
math_ops.equal(total_positive, 0), math_ops.equal(
- total_positive, size)),
+ total_positive, total_weight)),
true_fn=lambda: array_ops.constant(0, dtypes.float64),
false_fn=continue_computing_dynamic_auc)
@@ -1143,7 +1170,8 @@ def streaming_dynamic_auc(labels,
curve='ROC',
metrics_collections=(),
updates_collections=(),
- name=None):
+ name=None,
+ weights=None):
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
USAGE NOTE: this approach requires storing all of the predictions and labels
@@ -1168,6 +1196,8 @@ def streaming_dynamic_auc(labels,
should be added to.
name: An optional name for the variable_scope that contains the metric
variables.
+ weights: A 'Tensor' of non-negative weights whose values are castable to
+ `float64`. Will be flattened into a 1-D `Tensor`.
Returns:
auc: A scalar `Tensor` containing the current area-under-curve value.
@@ -1195,14 +1225,24 @@ def streaming_dynamic_auc(labels,
check_ops.assert_less_equal(
labels,
array_ops.ones_like(labels, dtypes.int64),
- message='labels must be 0 or 1, at least one is >1')
+ message='labels must be 0 or 1, at least one is >1'),
]):
preds_accum, update_preds = streaming_concat(
predictions, name='concat_preds')
labels_accum, update_labels = streaming_concat(
labels, name='concat_labels')
- update_op = control_flow_ops.group(update_labels, update_preds)
- auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
+ if weights is not None:
+ weights = array_ops.reshape(
+ math_ops.cast(weights, dtypes.float64), [-1])
+ weights_accum, update_weights = streaming_concat(
+ weights, name='concat_weights')
+ update_op = control_flow_ops.group(update_labels, update_preds,
+ update_weights)
+ else:
+ weights_accum = None
+ update_op = control_flow_ops.group(update_labels, update_preds)
+ auc = _compute_dynamic_auc(
+ labels_accum, preds_accum, curve=curve, weights=weights_accum)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
if metrics_collections:
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index e720097636..a09fc4abd4 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2127,6 +2127,44 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
+ def testWithWeights(self):
+ batch_size = 10
+ num_batches = 100
+ labels = np.array([])
+ predictions = np.array([])
+ weights = np.array([])
+ tf_labels = variables.Variable(
+ array_ops.ones(batch_size, dtypes_lib.int32),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.int32)
+ tf_predictions = variables.Variable(
+ array_ops.ones(batch_size),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.float32)
+ tf_weights = variables.Variable(
+ array_ops.ones(batch_size),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.float32)
+ auc, update_op = metrics.streaming_dynamic_auc(tf_labels,
+ tf_predictions,
+ weights=tf_weights)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in xrange(num_batches):
+ new_labels = np.random.randint(0, 2, size=batch_size)
+ noise = np.random.uniform(-0.2, 0.2, size=batch_size)
+ new_predictions = 0.4 + 0.2 * new_labels + noise
+ new_weights = np.random.uniform(0.0, 3.0, size=batch_size)
+ labels = np.concatenate([labels, new_labels])
+ predictions = np.concatenate([predictions, new_predictions])
+ weights = np.concatenate([weights, new_weights])
+ sess.run([tf_labels.assign(new_labels),
+ tf_predictions.assign(new_predictions),
+ tf_weights.assign(new_weights)])
+ sess.run(update_op)
+ expected_auc = _np_auc(predictions, labels, weights)
+ self.assertAlmostEqual(expected_auc, auc.eval())
+
class AucWithConfidenceIntervalsTest(test.TestCase):
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index a44f29fa37..c6f3bd6ee1 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -162,12 +162,12 @@ def _get_processor(v):
def _var_key_v2(var):
"""Key for representing a primary variable, for looking up slots."""
# pylint: disable=protected-access
- if hasattr(var, "_mirrored_container"):
- mirrored_container = var._mirrored_container()
- assert mirrored_container is not None
+ if hasattr(var, "_distributed_container"):
+ distributed_container = var._distributed_container()
+ assert distributed_container is not None
if context.executing_eagerly():
- return mirrored_container._unique_id
- return mirrored_container._shared_name
+ return distributed_container._unique_id
+ return distributed_container._shared_name
if context.executing_eagerly():
return var._unique_id
return var.op.name
diff --git a/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt b/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt
new file mode 100644
index 0000000000..ca1ee78526
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Acos"
+ endpoint {
+ name: "math.acos"
+ }
+ endpoint {
+ name: "acos"
+ deprecation_message: "tf.acos is deprecated, please use tf.math.acos instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt
new file mode 100644
index 0000000000..7503353e41
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Acosh"
+ endpoint {
+ name: "math.acosh"
+ }
+ endpoint {
+ name: "acosh"
+ deprecation_message: "tf.acosh is deprecated, please use tf.math.acosh instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Add.pbtxt b/tensorflow/core/api_def/python_api/api_def_Add.pbtxt
new file mode 100644
index 0000000000..cc5d68b15d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Add.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Add"
+ endpoint {
+ name: "math.add"
+ }
+ endpoint {
+ name: "add"
+ deprecation_message: "tf.add is deprecated, please use tf.math.add instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt b/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt
new file mode 100644
index 0000000000..9306eaf373
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "AsString"
+ endpoint {
+ name: "dtypes.as_string"
+ }
+ endpoint {
+ name: "as_string"
+ deprecation_message: "tf.as_string is deprecated, please use tf.dtypes.as_string instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt b/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt
new file mode 100644
index 0000000000..7622af7b45
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Asin"
+ endpoint {
+ name: "math.asin"
+ }
+ endpoint {
+ name: "asin"
+ deprecation_message: "tf.asin is deprecated, please use tf.math.asin instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt
new file mode 100644
index 0000000000..395275c21d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Asinh"
+ endpoint {
+ name: "math.asinh"
+ }
+ endpoint {
+ name: "asinh"
+ deprecation_message: "tf.asinh is deprecated, please use tf.math.asinh instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt
new file mode 100644
index 0000000000..dfcd632558
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Atan"
+ endpoint {
+ name: "math.atan"
+ }
+ endpoint {
+ name: "atan"
+ deprecation_message: "tf.atan is deprecated, please use tf.math.atan instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt
new file mode 100644
index 0000000000..fba79507aa
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Atan2"
+ endpoint {
+ name: "math.atan2"
+ }
+ endpoint {
+ name: "atan2"
+ deprecation_message: "tf.atan2 is deprecated, please use tf.math.atan2 instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt
new file mode 100644
index 0000000000..f7164c33e8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Atanh"
+ endpoint {
+ name: "math.atanh"
+ }
+ endpoint {
+ name: "atanh"
+ deprecation_message: "tf.atanh is deprecated, please use tf.math.atanh instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
new file mode 100644
index 0000000000..56e49a2221
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "BatchToSpaceND"
+ endpoint {
+ name: "manip.batch_to_space_nd"
+ }
+ endpoint {
+ name: "batch_to_space_nd"
+ deprecation_message: "tf.batch_to_space_nd is deprecated, please use tf.manip.batch_to_space_nd instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt b/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt
new file mode 100644
index 0000000000..7c37b534c7
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Betainc"
+ endpoint {
+ name: "math.betainc"
+ }
+ endpoint {
+ name: "betainc"
+ deprecation_message: "tf.betainc is deprecated, please use tf.math.betainc instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt b/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt
new file mode 100644
index 0000000000..0c72cf2edd
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Ceil"
+ endpoint {
+ name: "math.ceil"
+ }
+ endpoint {
+ name: "ceil"
+ deprecation_message: "tf.ceil is deprecated, please use tf.math.ceil instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt b/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt
new file mode 100644
index 0000000000..7ea52d30b6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "CheckNumerics"
+ endpoint {
+ name: "debugging.check_numerics"
+ }
+ endpoint {
+ name: "check_numerics"
+ deprecation_message: "tf.check_numerics is deprecated, please use tf.debugging.check_numerics instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
index 2676c92bfb..568fab4037 100644
--- a/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
@@ -1,9 +1,10 @@
op {
graph_op_name: "Cholesky"
endpoint {
- name: "cholesky"
+ name: "linalg.cholesky"
}
endpoint {
- name: "linalg.cholesky"
+ name: "cholesky"
+ deprecation_message: "tf.cholesky is deprecated, please use tf.linalg.cholesky instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt
new file mode 100644
index 0000000000..6550cd2d4e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Cos"
+ endpoint {
+ name: "math.cos"
+ }
+ endpoint {
+ name: "cos"
+ deprecation_message: "tf.cos is deprecated, please use tf.math.cos instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt
new file mode 100644
index 0000000000..ef82a45a80
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Cosh"
+ endpoint {
+ name: "math.cosh"
+ }
+ endpoint {
+ name: "cosh"
+ deprecation_message: "tf.cosh is deprecated, please use tf.math.cosh instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt
new file mode 100644
index 0000000000..33c1b8c617
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Cross"
+ endpoint {
+ name: "linalg.cross"
+ }
+ endpoint {
+ name: "cross"
+ deprecation_message: "tf.cross is deprecated, please use tf.linalg.cross instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt
new file mode 100644
index 0000000000..55c43ceba2
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DecodeBase64"
+ endpoint {
+ name: "io.decode_base64"
+ }
+ endpoint {
+ name: "decode_base64"
+ deprecation_message: "tf.decode_base64 is deprecated, please use tf.io.decode_base64 instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt
new file mode 100644
index 0000000000..5f6be24cc4
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DecodeCompressed"
+ endpoint {
+ name: "io.decode_compressed"
+ }
+ endpoint {
+ name: "decode_compressed"
+ deprecation_message: "tf.decode_compressed is deprecated, please use tf.io.decode_compressed instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt
new file mode 100644
index 0000000000..3759047f57
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DecodeJSONExample"
+ endpoint {
+ name: "io.decode_json_example"
+ }
+ endpoint {
+ name: "decode_json_example"
+ deprecation_message: "tf.decode_json_example is deprecated, please use tf.io.decode_json_example instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt
new file mode 100644
index 0000000000..a83f702dca
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DecodeRaw"
+ endpoint {
+ name: "io.decode_raw"
+ }
+ endpoint {
+ name: "decode_raw"
+ deprecation_message: "tf.decode_raw is deprecated, please use tf.io.decode_raw instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt b/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt
new file mode 100644
index 0000000000..c9b4f76fab
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Dequantize"
+ endpoint {
+ name: "quantization.dequantize"
+ }
+ endpoint {
+ name: "dequantize"
+ deprecation_message: "tf.dequantize is deprecated, please use tf.quantization.dequantize instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt b/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt
new file mode 100644
index 0000000000..2043facfa9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Diag"
+ endpoint {
+ name: "linalg.tensor_diag"
+ }
+ endpoint {
+ name: "diag"
+ deprecation_message: "tf.diag is deprecated, please use tf.linalg.tensor_diag instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt
new file mode 100644
index 0000000000..7fa30b2347
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DiagPart"
+ endpoint {
+ name: "linalg.tensor_diag_part"
+ }
+ endpoint {
+ name: "diag_part"
+ deprecation_message: "tf.diag_part is deprecated, please use tf.linalg.tensor_diag_part instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt
new file mode 100644
index 0000000000..03f57678a8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Digamma"
+ endpoint {
+ name: "math.digamma"
+ }
+ endpoint {
+ name: "digamma"
+ deprecation_message: "tf.digamma is deprecated, please use tf.math.digamma instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt b/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt
new file mode 100644
index 0000000000..47b4ab4da4
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "EncodeBase64"
+ endpoint {
+ name: "io.encode_base64"
+ }
+ endpoint {
+ name: "encode_base64"
+ deprecation_message: "tf.encode_base64 is deprecated, please use tf.io.encode_base64 instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt b/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt
new file mode 100644
index 0000000000..2630962f7d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Equal"
+ endpoint {
+ name: "math.equal"
+ }
+ endpoint {
+ name: "equal"
+ deprecation_message: "tf.equal is deprecated, please use tf.math.equal instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt b/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt
new file mode 100644
index 0000000000..6a511b3251
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Erfc"
+ endpoint {
+ name: "math.erfc"
+ }
+ endpoint {
+ name: "erfc"
+ deprecation_message: "tf.erfc is deprecated, please use tf.math.erfc instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt b/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt
new file mode 100644
index 0000000000..e1fd718ff0
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Exp"
+ endpoint {
+ name: "math.exp"
+ }
+ endpoint {
+ name: "exp"
+ deprecation_message: "tf.exp is deprecated, please use tf.math.exp instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt b/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt
new file mode 100644
index 0000000000..ca25706407
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Expm1"
+ endpoint {
+ name: "math.expm1"
+ }
+ endpoint {
+ name: "expm1"
+ deprecation_message: "tf.expm1 is deprecated, please use tf.math.expm1 instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
new file mode 100644
index 0000000000..d302e26ad2
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "ExtractImagePatches"
+ endpoint {
+ name: "image.extract_image_patches"
+ }
+ endpoint {
+ name: "extract_image_patches"
+ deprecation_message: "tf.extract_image_patches is deprecated, please use tf.image.extract_image_patches instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt b/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
index 3bcab99415..57a00a08e3 100644
--- a/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
@@ -1,9 +1,10 @@
op {
graph_op_name: "FFT"
endpoint {
- name: "fft"
+ name: "spectral.fft"
}
endpoint {
- name: "spectral.fft"
+ name: "fft"
+ deprecation_message: "tf.fft is deprecated, please use tf.spectral.fft instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt
new file mode 100644
index 0000000000..cd14b13675
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxArgs"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_args"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_args"
+ deprecation_message: "tf.fake_quant_with_min_max_args is deprecated, please use tf.quantization.fake_quant_with_min_max_args instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt
new file mode 100644
index 0000000000..d55cb69d1d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxArgsGradient"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_args_gradient"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_args_gradient"
+ deprecation_message: "tf.fake_quant_with_min_max_args_gradient is deprecated, please use tf.quantization.fake_quant_with_min_max_args_gradient instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt
new file mode 100644
index 0000000000..6ff4f2cdb2
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxVars"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_vars"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_vars"
+ deprecation_message: "tf.fake_quant_with_min_max_vars is deprecated, please use tf.quantization.fake_quant_with_min_max_vars instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt
new file mode 100644
index 0000000000..817a35cc6c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxVarsGradient"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_vars_gradient"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_vars_gradient"
+ deprecation_message: "tf.fake_quant_with_min_max_vars_gradient is deprecated, please use tf.quantization.fake_quant_with_min_max_vars_gradient instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt
new file mode 100644
index 0000000000..275c0d5225
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxVarsPerChannel"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_vars_per_channel"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_vars_per_channel"
+ deprecation_message: "tf.fake_quant_with_min_max_vars_per_channel is deprecated, please use tf.quantization.fake_quant_with_min_max_vars_per_channel instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
new file mode 100644
index 0000000000..897312897f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxVarsPerChannelGradient"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_vars_per_channel_gradient"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_vars_per_channel_gradient"
+ deprecation_message: "tf.fake_quant_with_min_max_vars_per_channel_gradient is deprecated, please use tf.quantization.fake_quant_with_min_max_vars_per_channel_gradient instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt b/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt
new file mode 100644
index 0000000000..788d95edc1
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Floor"
+ endpoint {
+ name: "math.floor"
+ }
+ endpoint {
+ name: "floor"
+ deprecation_message: "tf.floor is deprecated, please use tf.math.floor instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
new file mode 100644
index 0000000000..371dc740df
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "GatherNd"
+ endpoint {
+ name: "manip.gather_nd"
+ }
+ endpoint {
+ name: "gather_nd"
+ deprecation_message: "tf.gather_nd is deprecated, please use tf.manip.gather_nd instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt b/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt
new file mode 100644
index 0000000000..c8c56515b2
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Greater"
+ endpoint {
+ name: "math.greater"
+ }
+ endpoint {
+ name: "greater"
+ deprecation_message: "tf.greater is deprecated, please use tf.math.greater instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt
new file mode 100644
index 0000000000..ccb390fb3e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "GreaterEqual"
+ endpoint {
+ name: "math.greater_equal"
+ }
+ endpoint {
+ name: "greater_equal"
+ deprecation_message: "tf.greater_equal is deprecated, please use tf.math.greater_equal instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt b/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
index 6bbc4ed720..267ad8d0a0 100644
--- a/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
@@ -1,9 +1,10 @@
op {
graph_op_name: "IFFT"
endpoint {
- name: "ifft"
+ name: "spectral.ifft"
}
endpoint {
- name: "spectral.ifft"
+ name: "ifft"
+ deprecation_message: "tf.ifft is deprecated, please use tf.spectral.ifft instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt
new file mode 100644
index 0000000000..4e7e3a6e57
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Igamma"
+ endpoint {
+ name: "math.igamma"
+ }
+ endpoint {
+ name: "igamma"
+ deprecation_message: "tf.igamma is deprecated, please use tf.math.igamma instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt b/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt
new file mode 100644
index 0000000000..ea92a0916b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Igammac"
+ endpoint {
+ name: "math.igammac"
+ }
+ endpoint {
+ name: "igammac"
+ deprecation_message: "tf.igammac is deprecated, please use tf.math.igammac instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt b/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt
new file mode 100644
index 0000000000..bce642b96a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "InvertPermutation"
+ endpoint {
+ name: "math.invert_permutation"
+ }
+ endpoint {
+ name: "invert_permutation"
+ deprecation_message: "tf.invert_permutation is deprecated, please use tf.math.invert_permutation instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt
new file mode 100644
index 0000000000..a2c12f2ea0
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "IsFinite"
+ endpoint {
+ name: "debugging.is_finite"
+ }
+ endpoint {
+ name: "is_finite"
+ deprecation_message: "tf.is_finite is deprecated, please use tf.debugging.is_finite instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt
new file mode 100644
index 0000000000..7c29811fd7
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "IsInf"
+ endpoint {
+ name: "debugging.is_inf"
+ }
+ endpoint {
+ name: "is_inf"
+ deprecation_message: "tf.is_inf is deprecated, please use tf.debugging.is_inf instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt
new file mode 100644
index 0000000000..459cf3ccbd
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "IsNan"
+ endpoint {
+ name: "debugging.is_nan"
+ }
+ endpoint {
+ name: "is_nan"
+ deprecation_message: "tf.is_nan is deprecated, please use tf.debugging.is_nan instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Less.pbtxt b/tensorflow/core/api_def/python_api/api_def_Less.pbtxt
new file mode 100644
index 0000000000..15cbdc6d8e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Less.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Less"
+ endpoint {
+ name: "math.less"
+ }
+ endpoint {
+ name: "less"
+ deprecation_message: "tf.less is deprecated, please use tf.math.less instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt
new file mode 100644
index 0000000000..35aa18698f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "LessEqual"
+ endpoint {
+ name: "math.less_equal"
+ }
+ endpoint {
+ name: "less_equal"
+ deprecation_message: "tf.less_equal is deprecated, please use tf.math.less_equal instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt
new file mode 100644
index 0000000000..89886b09d3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Lgamma"
+ endpoint {
+ name: "math.lgamma"
+ }
+ endpoint {
+ name: "lgamma"
+ deprecation_message: "tf.lgamma is deprecated, please use tf.math.lgamma instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Log.pbtxt b/tensorflow/core/api_def/python_api/api_def_Log.pbtxt
new file mode 100644
index 0000000000..fb82aa7e43
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Log.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Log"
+ endpoint {
+ name: "math.log"
+ }
+ endpoint {
+ name: "log"
+ deprecation_message: "tf.log is deprecated, please use tf.math.log instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt b/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt
new file mode 100644
index 0000000000..6b451aa546
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Log1p"
+ endpoint {
+ name: "math.log1p"
+ }
+ endpoint {
+ name: "log1p"
+ deprecation_message: "tf.log1p is deprecated, please use tf.math.log1p instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt
new file mode 100644
index 0000000000..403a8c71ff
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "LogicalAnd"
+ endpoint {
+ name: "math.logical_and"
+ }
+ endpoint {
+ name: "logical_and"
+ deprecation_message: "tf.logical_and is deprecated, please use tf.math.logical_and instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt
new file mode 100644
index 0000000000..f228958c77
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "LogicalNot"
+ endpoint {
+ name: "math.logical_not"
+ }
+ endpoint {
+ name: "logical_not"
+ deprecation_message: "tf.logical_not is deprecated, please use tf.math.logical_not instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt
new file mode 100644
index 0000000000..ab89f236e7
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "LogicalOr"
+ endpoint {
+ name: "math.logical_or"
+ }
+ endpoint {
+ name: "logical_or"
+ deprecation_message: "tf.logical_or is deprecated, please use tf.math.logical_or instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt
new file mode 100644
index 0000000000..8930d66940
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "MatchingFiles"
+ endpoint {
+ name: "io.matching_files"
+ }
+ endpoint {
+ name: "matching_files"
+ deprecation_message: "tf.matching_files is deprecated, please use tf.io.matching_files instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
index 89b1c1f5a9..bad2f03f32 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_band_part"
+ deprecation_message: "tf.matrix_band_part is deprecated, please use tf.linalg.band_part instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
index 4d289f542f..d241d4d721 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_determinant"
+ deprecation_message: "tf.matrix_determinant is deprecated, please use tf.linalg.det instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
index fd9d34635e..208b37e297 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_diag"
+ deprecation_message: "tf.matrix_diag is deprecated, please use tf.linalg.diag instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
index fa5d1f10af..a8a50e8a89 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_diag_part"
+ deprecation_message: "tf.matrix_diag_part is deprecated, please use tf.linalg.diag_part instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
index c0ddd73704..944513fcd9 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_inverse"
+ deprecation_message: "tf.matrix_inverse is deprecated, please use tf.linalg.inv instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
index 01f4f0e89d..a6080dbc2d 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_set_diag"
+ deprecation_message: "tf.matrix_set_diag is deprecated, please use tf.linalg.set_diag instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
index cef763e4e9..caba80326b 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_solve"
+ deprecation_message: "tf.matrix_solve is deprecated, please use tf.linalg.solve instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
index a0d576aa31..a4dfa538ed 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_triangular_solve"
+ deprecation_message: "tf.matrix_triangular_solve is deprecated, please use tf.linalg.triangular_solve instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt b/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt
new file mode 100644
index 0000000000..90af9e145b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Maximum"
+ endpoint {
+ name: "math.maximum"
+ }
+ endpoint {
+ name: "maximum"
+ deprecation_message: "tf.maximum is deprecated, please use tf.math.maximum instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt b/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt
new file mode 100644
index 0000000000..33bcd6f667
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Minimum"
+ endpoint {
+ name: "math.minimum"
+ }
+ endpoint {
+ name: "minimum"
+ deprecation_message: "tf.minimum is deprecated, please use tf.math.minimum instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt
new file mode 100644
index 0000000000..385565daaf
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "NotEqual"
+ endpoint {
+ name: "math.not_equal"
+ }
+ endpoint {
+ name: "not_equal"
+ deprecation_message: "tf.not_equal is deprecated, please use tf.math.not_equal instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt
new file mode 100644
index 0000000000..29f02ab1ac
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "ParseTensor"
+ endpoint {
+ name: "io.parse_tensor"
+ }
+ endpoint {
+ name: "parse_tensor"
+ deprecation_message: "tf.parse_tensor is deprecated, please use tf.io.parse_tensor instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt
new file mode 100644
index 0000000000..567a448642
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Polygamma"
+ endpoint {
+ name: "math.polygamma"
+ }
+ endpoint {
+ name: "polygamma"
+ deprecation_message: "tf.polygamma is deprecated, please use tf.math.polygamma instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
index b19da0d817..a9371b5d9b 100644
--- a/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "qr"
+ deprecation_message: "tf.qr is deprecated, please use tf.linalg.qr instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt b/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt
new file mode 100644
index 0000000000..44508ef079
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "QuantizedConcat"
+ endpoint {
+ name: "quantization.quantized_concat"
+ }
+ endpoint {
+ name: "quantized_concat"
+ deprecation_message: "tf.quantized_concat is deprecated, please use tf.quantization.quantized_concat instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt
new file mode 100644
index 0000000000..7c38fae31c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "ReadFile"
+ endpoint {
+ name: "io.read_file"
+ }
+ endpoint {
+ name: "read_file"
+ deprecation_message: "tf.read_file is deprecated, please use tf.io.read_file instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt
new file mode 100644
index 0000000000..0f37e99f4f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Reciprocal"
+ endpoint {
+ name: "math.reciprocal"
+ }
+ endpoint {
+ name: "reciprocal"
+ deprecation_message: "tf.reciprocal is deprecated, please use tf.math.reciprocal instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
new file mode 100644
index 0000000000..6938e20e57
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "RegexReplace"
+ endpoint {
+ name: "strings.regex_replace"
+ }
+ endpoint {
+ name: "regex_replace"
+ deprecation_message: "tf.regex_replace is deprecated, please use tf.strings.regex_replace instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
new file mode 100644
index 0000000000..907d95a6f0
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Reshape"
+ endpoint {
+ name: "manip.reshape"
+ }
+ endpoint {
+ name: "reshape"
+ deprecation_message: "tf.reshape is deprecated, please use tf.manip.reshape instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
index 8307a3c2dd..bbe9e97d60 100644
--- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
@@ -1,6 +1,14 @@
op {
graph_op_name: "ReverseV2"
endpoint {
+ name: "manip.reverse"
+ }
+ endpoint {
+ name: "reverse"
+ deprecation_message: "tf.reverse is deprecated, please use tf.manip.reverse instead."
+ }
+ endpoint {
name: "reverse_v2"
+ deprecation_message: "tf.reverse_v2 is deprecated, please use tf.manip.reverse instead."
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt
new file mode 100644
index 0000000000..4330a80d04
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Rint"
+ endpoint {
+ name: "math.rint"
+ }
+ endpoint {
+ name: "rint"
+ deprecation_message: "tf.rint is deprecated, please use tf.math.rint instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt
new file mode 100644
index 0000000000..6a45f4aff5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Rsqrt"
+ endpoint {
+ name: "math.rsqrt"
+ }
+ endpoint {
+ name: "rsqrt"
+ deprecation_message: "tf.rsqrt is deprecated, please use tf.math.rsqrt instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
new file mode 100644
index 0000000000..cabf171cb0
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "ScatterNd"
+ endpoint {
+ name: "manip.scatter_nd"
+ }
+ endpoint {
+ name: "scatter_nd"
+ deprecation_message: "tf.scatter_nd is deprecated, please use tf.manip.scatter_nd instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt
new file mode 100644
index 0000000000..65e34a1fcf
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentMax"
+ endpoint {
+ name: "math.segment_max"
+ }
+ endpoint {
+ name: "segment_max"
+ deprecation_message: "tf.segment_max is deprecated, please use tf.math.segment_max instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt
new file mode 100644
index 0000000000..f1e19c5571
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentMean"
+ endpoint {
+ name: "math.segment_mean"
+ }
+ endpoint {
+ name: "segment_mean"
+ deprecation_message: "tf.segment_mean is deprecated, please use tf.math.segment_mean instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt
new file mode 100644
index 0000000000..fd9a3c380d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentMin"
+ endpoint {
+ name: "math.segment_min"
+ }
+ endpoint {
+ name: "segment_min"
+ deprecation_message: "tf.segment_min is deprecated, please use tf.math.segment_min instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt
new file mode 100644
index 0000000000..f2be8baafc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentProd"
+ endpoint {
+ name: "math.segment_prod"
+ }
+ endpoint {
+ name: "segment_prod"
+ deprecation_message: "tf.segment_prod is deprecated, please use tf.math.segment_prod instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt
new file mode 100644
index 0000000000..c7cc1d0c9f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentSum"
+ endpoint {
+ name: "math.segment_sum"
+ }
+ endpoint {
+ name: "segment_sum"
+ deprecation_message: "tf.segment_sum is deprecated, please use tf.math.segment_sum instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt b/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt
new file mode 100644
index 0000000000..0794334987
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Sin"
+ endpoint {
+ name: "math.sin"
+ }
+ endpoint {
+ name: "sin"
+ deprecation_message: "tf.sin is deprecated, please use tf.math.sin instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt
new file mode 100644
index 0000000000..c42f8678c6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Sinh"
+ endpoint {
+ name: "math.sinh"
+ }
+ endpoint {
+ name: "sinh"
+ deprecation_message: "tf.sinh is deprecated, please use tf.math.sinh instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt b/tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt
index 2de56c27be..c4da47241b 100644
--- a/tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt
@@ -1,6 +1,9 @@
op {
graph_op_name: "Softplus"
endpoint {
+ name: "math.softplus"
+ }
+ endpoint {
name: "nn.softplus"
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt b/tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt
index b47412d135..852d205024 100644
--- a/tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt
@@ -3,4 +3,7 @@ op {
endpoint {
name: "nn.softsign"
}
+ endpoint {
+ name: "math.softsign"
+ }
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
new file mode 100644
index 0000000000..63a7547e14
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SpaceToBatchND"
+ endpoint {
+ name: "manip.space_to_batch_nd"
+ }
+ endpoint {
+ name: "space_to_batch_nd"
+ deprecation_message: "tf.space_to_batch_nd is deprecated, please use tf.manip.space_to_batch_nd instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt b/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt
new file mode 100644
index 0000000000..01a33a3346
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SquaredDifference"
+ endpoint {
+ name: "math.squared_difference"
+ }
+ endpoint {
+ name: "squared_difference"
+ deprecation_message: "tf.squared_difference is deprecated, please use tf.math.squared_difference instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt
new file mode 100644
index 0000000000..53c1b8053d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringJoin"
+ endpoint {
+ name: "strings.join"
+ }
+ endpoint {
+ name: "string_join"
+ deprecation_message: "tf.string_join is deprecated, please use tf.strings.join instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt
new file mode 100644
index 0000000000..364806e1f5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringStrip"
+ endpoint {
+ name: "strings.strip"
+ }
+ endpoint {
+ name: "string_strip"
+ deprecation_message: "tf.string_strip is deprecated, please use tf.strings.strip instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt
new file mode 100644
index 0000000000..b0e93d2b22
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringToHashBucket"
+ endpoint {
+ name: "strings.to_hash_bucket"
+ }
+ endpoint {
+ name: "string_to_hash_bucket"
+ deprecation_message: "tf.string_to_hash_bucket is deprecated, please use tf.strings.to_hash_bucket instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt
new file mode 100644
index 0000000000..9576e1a9de
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringToHashBucketFast"
+ endpoint {
+ name: "strings.to_hash_bucket_fast"
+ }
+ endpoint {
+ name: "string_to_hash_bucket_fast"
+ deprecation_message: "tf.string_to_hash_bucket_fast is deprecated, please use tf.strings.to_hash_bucket_fast instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt
new file mode 100644
index 0000000000..e8c7c12608
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringToHashBucketStrong"
+ endpoint {
+ name: "strings.to_hash_bucket_strong"
+ }
+ endpoint {
+ name: "string_to_hash_bucket_strong"
+ deprecation_message: "tf.string_to_hash_bucket_strong is deprecated, please use tf.strings.to_hash_bucket_strong instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt
new file mode 100644
index 0000000000..9de1ca0b30
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringToNumber"
+ endpoint {
+ name: "strings.to_number"
+ }
+ endpoint {
+ name: "string_to_number"
+ deprecation_message: "tf.string_to_number is deprecated, please use tf.strings.to_number instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
new file mode 100644
index 0000000000..25d1bb3f51
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Substr"
+ endpoint {
+ name: "strings.substr"
+ }
+ endpoint {
+ name: "substr"
+ deprecation_message: "tf.substr is deprecated, please use tf.strings.substr instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt
new file mode 100644
index 0000000000..8bcf381dd4
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Tan"
+ endpoint {
+ name: "math.tan"
+ }
+ endpoint {
+ name: "tan"
+ deprecation_message: "tf.tan is deprecated, please use tf.math.tan instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
new file mode 100644
index 0000000000..0b9053a529
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Tile"
+ endpoint {
+ name: "manip.tile"
+ }
+ endpoint {
+ name: "tile"
+ deprecation_message: "tf.tile is deprecated, please use tf.manip.tile instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt
new file mode 100644
index 0000000000..1ea59d2e63
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "UnsortedSegmentMax"
+ endpoint {
+ name: "math.unsorted_segment_max"
+ }
+ endpoint {
+ name: "unsorted_segment_max"
+ deprecation_message: "tf.unsorted_segment_max is deprecated, please use tf.math.unsorted_segment_max instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt
new file mode 100644
index 0000000000..9857def6fe
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "UnsortedSegmentMin"
+ endpoint {
+ name: "math.unsorted_segment_min"
+ }
+ endpoint {
+ name: "unsorted_segment_min"
+ deprecation_message: "tf.unsorted_segment_min is deprecated, please use tf.math.unsorted_segment_min instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt
new file mode 100644
index 0000000000..d9e3f7be69
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "UnsortedSegmentProd"
+ endpoint {
+ name: "math.unsorted_segment_prod"
+ }
+ endpoint {
+ name: "unsorted_segment_prod"
+ deprecation_message: "tf.unsorted_segment_prod is deprecated, please use tf.math.unsorted_segment_prod instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt
new file mode 100644
index 0000000000..0cffd12404
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "UnsortedSegmentSum"
+ endpoint {
+ name: "math.unsorted_segment_sum"
+ }
+ endpoint {
+ name: "unsorted_segment_sum"
+ deprecation_message: "tf.unsorted_segment_sum is deprecated, please use tf.math.unsorted_segment_sum instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt b/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt
new file mode 100644
index 0000000000..f28a9151ca
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "WriteFile"
+ endpoint {
+ name: "io.write_file"
+ }
+ endpoint {
+ name: "write_file"
+ deprecation_message: "tf.write_file is deprecated, please use tf.io.write_file instead."
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt b/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt
new file mode 100644
index 0000000000..a84ffcdf14
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Zeta"
+ endpoint {
+ name: "math.zeta"
+ }
+ endpoint {
+ name: "zeta"
+ deprecation_message: "tf.zeta is deprecated, please use tf.math.zeta instead."
+ }
+}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
index d0684f1833..159435fd7d 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
@@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/protobuf/worker.pb.h"
+// (Omitted internal-only flag)
+
namespace tensorflow {
namespace grpc {
@@ -168,15 +170,20 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
(header.size() +
VarLengthEncodingSize(RecvTensorResponse::kTensorFieldNumber,
overall_tensor_proto_bytesize));
- // If "tensor_data_is_large == false", we copy the tensor data to the
- // end of the buffer we are preparing that holds the rest of the
+ // If "share_tensor_slice_memory == false", we copy the tensor data to
+ // the end of the buffer we are preparing that holds the rest of the
// RecvTensorResponse protocol buffer.
//
- // If "tensor_data_is_large == true", we arrange to share the backing
- // store of the data by creating a slice that also points to the
+ // If "share_tensor_slice_memory == true", we arrange to share the
+ // backing store of the data by creating a slice that also points to the
// backing store, with appropriate reference counts to keep the
// backing store alive as needed.
- bool tensor_data_is_large = (tdata.size() > kLargeTensorBytes);
+ //
+ // We enable this behavior if the tensor is large.
+ bool share_tensor_slice_memory = (tdata.size() > kLargeTensorBytes);
+
+ // (Omitted internal-only conditional)
+
size_t encoder_size = expected_size - tdata.size();
// Encode all but the actual "tdata", but including the tag and
@@ -201,10 +208,11 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
::grpc::Slice slices[2];
int num_slices = 0;
{
- size_t slice_len = e.size() + (tensor_data_is_large ? 0 : tdata.size());
+ size_t slice_len =
+ e.size() + (share_tensor_slice_memory ? 0 : tdata.size());
slices[0] = ::grpc::Slice(slice_len);
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
- if (!tensor_data_is_large) {
+ if (!share_tensor_slice_memory) {
// (E)
memcpy(const_cast<uint8_t*>(slices[0].begin()) + e.size(), tdata.data(),
tdata.size());
@@ -212,7 +220,7 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
num_slices += 1;
}
- if (tensor_data_is_large) {
+ if (share_tensor_slice_memory) {
// (E) Encode tensor data, but by sharing backing store
const TensorBuffer* buf = DMAHelper::buffer(&val);
buf->Ref();
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b994d26397..d34eecd009 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -78,6 +78,14 @@ string GetDataFormat(const OpInfo& op_features) {
return data_format;
}
+string GetFilterFormat(const OpInfo& op_features) {
+ string filter_format = "HWIO"; // Default format.
+ if (op_features.attr().find("filter_format") != op_features.attr().end()) {
+ filter_format = op_features.attr().at("filter_format").s();
+ }
+ return filter_format;
+}
+
Padding GetPadding(const OpInfo& op_features) {
if (op_features.attr().find("padding") != op_features.attr().end() &&
op_features.attr().at("padding").s() == "VALID") {
@@ -513,29 +521,44 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
y_index = 3;
channel_index = 1;
} else {
+ // Use NHWC.
x_index = 1;
y_index = 2;
channel_index = 3;
}
+ const string& filter_format = GetFilterFormat(op_features);
+ int filter_x_index, filter_y_index, in_channel_index, out_channel_index;
+ if (filter_format == "HWIO") {
+ filter_x_index = 0;
+ filter_y_index = 1;
+ in_channel_index = 2;
+ out_channel_index = 3;
+ } else {
+ // Use OIHW
+ filter_x_index = 2;
+ filter_y_index = 3;
+ in_channel_index = 1;
+ out_channel_index = 0;
+ }
int64 batch = image_shape.dim(0).size();
int64 ix = image_shape.dim(x_index).size();
int64 iy = image_shape.dim(y_index).size();
int64 iz = image_shape.dim(channel_index).size();
- int64 kx = filter_shape.dim(0).size();
- int64 ky = filter_shape.dim(1).size();
+ int64 kx = filter_shape.dim(filter_x_index).size();
+ int64 ky = filter_shape.dim(filter_y_index).size();
std::vector<int64> strides = GetStrides(op_features);
const auto padding = GetPadding(op_features);
int64 sx = strides[x_index];
int64 sy = strides[y_index];
int64 ox = GetOutputSize(ix, kx, sx, padding);
int64 oy = GetOutputSize(iy, ky, sy, padding);
- int64 oz = filter_shape.dim(3).size();
+ int64 oz = filter_shape.dim(out_channel_index).size();
// Only check equality when both sizes are known (in other words, when
// neither is set to a minimum dimension size of 1).
- if (iz != 1 && filter_shape.dim(2).size() != 1) {
- CHECK_EQ(iz, filter_shape.dim(2).size());
+ if (iz != 1 && filter_shape.dim(in_channel_index).size() != 1) {
+ CHECK_EQ(iz, filter_shape.dim(in_channel_index).size());
} else {
- iz = std::max<int64>(iz, filter_shape.dim(2).size());
+ iz = std::max<int64>(iz, filter_shape.dim(in_channel_index).size());
}
OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
@@ -1054,6 +1077,24 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
//
// For more information, see
// contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+
+ // TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and
+ // filter formats (OIHW_VECT_I).
+ string data_format = GetDataFormat(op_context.op_info);
+ if (data_format != "NCHW" && data_format != "NHWC") {
+ LOG(WARNING) << "unsupported data format: " << data_format;
+ Costs cost = Costs::ZeroCosts();
+ cost.inaccurate = true;
+ return cost;
+ }
+ string filter_format = GetFilterFormat(op_context.op_info);
+ if (filter_format != "HWIO" && filter_format != "OIHW") {
+ LOG(WARNING) << "unsupported filter format: " << filter_format;
+ Costs cost = Costs::ZeroCosts();
+ cost.inaccurate = true;
+ return cost;
+ }
+
auto& conv_input = op_context.op_info.inputs(0);
auto& filter = op_context.op_info.inputs(1);
auto& bias = op_context.op_info.inputs(2);
@@ -1069,28 +1110,12 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct the shape of our output tensor from our convolution dimensions
// and format, as it may not be available yet.
- //
// TODO(varomodt): should we centralize the Conv2D input/output shapes?
- bool unknown_conv_format = false;
OpInfo::TensorProperties output;
- switch (GetConvolutionFormat(op_context)) {
- case NCHW:
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
- break;
- case NHWC:
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
- break;
- default:
- // TODO(b/77722245): support cost estimation for NCHW_VECT_C.
- LOG(WARNING) << "unsupported data format: "
- << GetDataFormat(op_context.op_info)
- << " Defaulting to NHWC.";
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
- unknown_conv_format = true;
- break;
+ if (data_format == "NCHW") {
+ output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
+ } else if (data_format == "NHWC") {
+ output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
}
// Add the operations the fused op always computes.
@@ -1115,7 +1140,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct component operations and run the cost computation.
auto costs = PredictFusedOp(op_context_with_output, component_ops);
- costs.inaccurate |= found_unknown_shapes || unknown_conv_format;
+ costs.inaccurate |= found_unknown_shapes;
return costs;
}
@@ -1568,20 +1593,6 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
}
/* static */
-OpLevelCostEstimator::ConvolutionFormat
-OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) {
- auto data_format = GetDataFormat(op_context.op_info);
- if (data_format == "NCHW") {
- return NCHW;
- } else if (data_format == "NHWC") {
- return NHWC;
- } else if (data_format == "NCHW_VECT_C") {
- return NCHW_VECT_C;
- }
-
- return UNKNOWN_CONVOLUTION_FORMAT;
-}
-
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index d384f57279..a277dfdf65 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -84,13 +84,6 @@ class OpLevelCostEstimator {
int64 sy; // Stride y.
Padding padding; // SAME or VALID.
};
- enum ConvolutionFormat {
- UNKNOWN_CONVOLUTION_FORMAT,
- NHWC,
- NCHW,
- NCHW_VECT_C,
- NCHW_VECT_W,
- };
int64 CountConv2DOperations(const OpInfo& op_features,
bool* found_unknown_shapes) const;
int64 CountConv2DOperations(const OpInfo& op_features,
@@ -198,9 +191,6 @@ class OpLevelCostEstimator {
static OpInfo::TensorProperties DescribeTensor(
DataType type, const std::vector<int64>& dims);
- // Returns the Conv2D format for this operation.
- static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context);
-
// This method calculates the execution time depending on whether IO can
// overlap with computation. It assumes the memory and the compute times have
// already been calculated.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index b2c021b73a..77352f6652 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -155,19 +155,38 @@ OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
// Note that this assumes the NHWC data format.
OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
int iz2, int kx, int ky, int ox,
- int oy, int oz,
- bool has_side_input) {
+ int oy, int oz, bool has_side_input,
+ const string& data_format,
+ const string& filter_format) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("FusedConv2DBiasActivation");
- DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
- DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+ auto* attr_data_format = op_context.op_info.mutable_attr();
+ SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
+ auto* attr_filter_format = op_context.op_info.mutable_attr();
+ SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
+ if (data_format == "NHWC") {
+ DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
+ } else {
+ // Use the NCHW format.
+ DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
+ }
+ if (filter_format == "HWIO") {
+ DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+ } else {
+ // Use the OIHW format.
+ DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
+ }
DescribeTensor1D(oz, op_context.op_info.add_inputs());
// Add the side_input, if any.
auto side_input = op_context.op_info.add_inputs();
if (has_side_input) {
- DescribeTensor4D(batch, ox, oy, oz, side_input);
+ if (data_format == "NHWC") {
+ DescribeTensor4D(batch, ox, oy, oz, side_input);
+ } else {
+ DescribeTensor4D(batch, oz, ox, oy, side_input);
+ }
}
// Add the scaling tensors.
@@ -549,25 +568,79 @@ TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
SetComputeMemoryOverlap(false); // Set it back to default.
}
-TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationExecutionTime) {
+TEST_F(OpLevelCostEstimatorTest,
+ FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
- 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true));
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
+ "NCHW", "HWIO"));
+ EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "HWIO"));
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_FALSE(cost.inaccurate);
}
-TEST_F(OpLevelCostEstimatorTest,
- FusedConv2DBiasActivationNoSideInputExecutionTime) {
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
- 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false));
- EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
- EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
- EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "OIHW"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_FALSE(cost.inaccurate);
}
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NHWC", "HWIO"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NHWC", "OIHW"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+// TODO(yaozhang): Update once NCHW_VECT_C is supported.
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW_VECT_C", "OIHW"));
+ EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(0), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(0), cost.execution_time);
+ EXPECT_TRUE(cost.inaccurate);
+}
+
+// TODO(yaozhang): Update once OIHW_VECT_I is supported.
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "OIHW_VECT_I"));
+ EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(0), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(0), cost.execution_time);
+ EXPECT_TRUE(cost.inaccurate);
+}
+
TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
@@ -655,8 +728,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
TensorProto tensor_proto;
TensorShapeProto tensor_shape_proto;
- // Dimension larger than max value; should fail while converting to Tensor
- // class.
+ // Dimension larger than max value; should fail while converting to
+ // Tensor class.
tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
EXPECT_FALSE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
@@ -676,8 +749,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
// Check GetTensorShapeProtoFromTensorProto() resturns correct values.
{
std::vector<int64> shape_expected = {10, 20, 30, 40};
- GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/false,
- &tensor_proto);
+ GetTensorProto(DT_INT32, {4}, shape_expected,
+ /*tensor_content=*/false, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -685,8 +758,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {40, 20, 90, 40};
- GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/false,
- &tensor_proto);
+ GetTensorProto(DT_INT64, {4}, shape_expected,
+ /*tensor_content=*/false, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -694,8 +767,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {10, 20, 30, 40};
- GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/true,
- &tensor_proto);
+ GetTensorProto(DT_INT32, {4}, shape_expected,
+ /*tensor_content=*/true, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -703,8 +776,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {40, 20, 90, 40};
- GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/true,
- &tensor_proto);
+ GetTensorProto(DT_INT64, {4}, shape_expected,
+ /*tensor_content=*/true, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index e292ff200a..792eb74e31 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -138,6 +138,9 @@ cc_library(
tf_cc_test(
name = "serial_device_batch_scheduler_test",
srcs = ["serial_device_batch_scheduler_test.cc"],
+ tags = [
+ "notap", # b/110374108
+ ],
deps = [
":fake_clock_env",
":serial_device_batch_scheduler",
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index 6949e5b5fd..6b7544fd4c 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -159,7 +159,7 @@ struct TransformFilter {
Eigen::DSizes<IndexType, NDIMS> expanded_dims;
expanded_dims[0] = in.dimension(NDIMS - 1); // output filters
expanded_dims[1] = in.dimension(NDIMS - 2); // input filters
- for (int i = 0; i < NDIMS; ++i) { // spatial dimensions
+ for (int i = 0; i < NDIMS - 2; ++i) { // spatial dimensions
expanded_dims[i + 2] = in.dimension(i);
}
diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
index ea10ebe9a0..931f59014b 100644
--- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
- uint8, int8, int16);
+REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
+ uint8, int8, int16, bfloat16);
REGISTER_KERNEL_BUILDER(
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
ApproximateEqualOp<CPUDevice, float>);
diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc
index a4ea408836..b385e9e545 100644
--- a/tensorflow/core/kernels/cwise_op_greater.cc
+++ b/tensorflow/core/kernels/cwise_op_greater.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER8(BinaryOp, CPU, "Greater", functor::greater, float, Eigen::half,
- double, int32, int64, uint8, int8, int16);
+REGISTER9(BinaryOp, CPU, "Greater", functor::greater, float, Eigen::half,
+ double, int32, int64, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "Greater", functor::greater, float, Eigen::half,
double, int64, uint8, int8, int16);
diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc
index 3f34d6269e..8bfc018052 100644
--- a/tensorflow/core/kernels/cwise_op_greater_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER8(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float,
- Eigen::half, double, int32, int64, uint8, int8, int16);
+REGISTER9(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float,
+ Eigen::half, double, int32, int64, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float,
Eigen::half, double, int64, uint8, int8, int16);
diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc
index 575968126f..e369fdcf8a 100644
--- a/tensorflow/core/kernels/cwise_op_less.cc
+++ b/tensorflow/core/kernels/cwise_op_less.cc
@@ -16,8 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER9(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
- bfloat16, int32, int64, uint8, int8, int16);
+REGISTER5(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
+ bfloat16, int32);
+REGISTER5(BinaryOp, CPU, "Less", functor::less, int64, uint8, int8, int16,
+ bfloat16);
+
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "Less", functor::less, float, Eigen::half, double,
int64, uint8, int8, int16);
diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc
index 499200d054..3353e117cd 100644
--- a/tensorflow/core/kernels/cwise_op_less_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_less_equal.cc
@@ -16,8 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER9(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
- bfloat16, double, int32, int64, uint8, int8, int16);
+REGISTER5(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
+ bfloat16, double, int32);
+REGISTER5(BinaryOp, CPU, "LessEqual", functor::less_equal, int64, uint8, int8,
+ int16, bfloat16);
+
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "LessEqual", functor::less_equal, float, Eigen::half,
double, int64, uint8, int8, int16);
diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
index 935619711c..9f1e575805 100644
--- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
- double, uint8, int8, int16);
+REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
+ double, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
double, uint8);
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 3fc25772f6..c1b59e44a6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4,14 +4,16 @@
# Public targets:
# ":platform" - Low-level and platform-specific Python code.
-package(default_visibility = [
+visibility = [
"//engedu/ml/tf_from_scratch:__pkg__",
"//tensorflow:internal",
"//tensorflow/contrib/lite/toco/python:__pkg__",
"//tensorflow_models:__subpackages__",
# TODO(aselle): to pass open source test.
"//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__",
-])
+]
+
+package(default_visibility = visibility)
licenses(["notice"]) # Apache 2.0
@@ -55,12 +57,12 @@ py_library(
"//tensorflow/contrib/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed
"//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed
"//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed
- "//tensorflow/tools/api/generator:__pkg__",
"//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed
],
deps = [
":no_contrib",
"//tensorflow/contrib:contrib_py",
+ "//tensorflow/python/estimator:estimator_py",
],
)
@@ -126,7 +128,6 @@ py_library(
":weights_broadcast_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data",
- "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/keras",
"//tensorflow/python/ops/distributions",
@@ -358,6 +359,9 @@ cc_library(
name = "ndarray_tensor",
srcs = ["lib/core/ndarray_tensor.cc"],
hdrs = ["lib/core/ndarray_tensor.h"],
+ visibility = visibility + [
+ "//learning/deepmind/courier:__subpackages__",
+ ],
deps = [
":bfloat16_lib",
":ndarray_tensor_bridge",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index cf707fb2c7..a2ab63bb48 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -79,7 +79,6 @@ from tensorflow.python.ops import initializers_ns as initializers
# Bring in subpackages.
from tensorflow.python import data
from tensorflow.python import keras
-from tensorflow.python.estimator import estimator_lib as estimator
from tensorflow.python.feature_column import feature_column_lib as feature_column
from tensorflow.python.layers import layers
from tensorflow.python.ops import bitwise_ops as bitwise
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index def730371d..985cb90436 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -135,7 +135,7 @@ tensorflow::ImportNumpy();
// Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers
%typemap(out) int64_t {
- $result = PyInt_FromLong($1);
+ $result = PyLong_FromLongLong($1);
}
// We use TF_OperationGetControlInputs_wrapper instead of
@@ -610,7 +610,7 @@ def TF_Reset(target, containers=None, config=None):
}
for (size_t i = 0; i < $1.size(); ++i) {
- PyList_SET_ITEM($result, i, PyLong_FromLong($1[i]));
+ PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i]));
}
}
@@ -673,7 +673,7 @@ def TF_Reset(target, containers=None, config=None):
}
for (size_t i = 0; i < $1.size(); ++i) {
- PyList_SET_ITEM($result, i, PyInt_FromLong($1[i]));
+ PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i]));
}
}
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 9e7af878d3..c44a6e6c84 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -809,11 +809,12 @@ class Dataset(object):
def batch(self, batch_size, drop_remainder=False):
"""Combines consecutive elements of this dataset into batches.
- NOTE: If the number of elements (`N`) in this dataset is not an exact
- multiple of `batch_size`, the final batch contain smaller tensors with
- shape `N % batch_size` in the batch dimension. If your program depends on
- the batches having the same shape, consider using the
- @{tf.contrib.data.batch_and_drop_remainder} transformation instead.
+ The tensors in the resulting element will have an additional outer
+ dimension, which will be `batch_size` (or `N % batch_size` for the last
+ element if `batch_size` does not divide the number of input elements `N`
+ evenly and `drop_remainder` is `False`). If your program depends on the
+ batches having the same outer dimension, you should set the `drop_remainder`
+ argument to `True` to prevent the smaller batch from being produced.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
@@ -836,13 +837,19 @@ class Dataset(object):
"""Combines consecutive elements of this dataset into padded batches.
This transformation combines multiple consecutive elements of the input
- dataset into a single element. Like @{tf.data.Dataset.batch}, the tensors
- in the resulting element have an additional outer dimension, which will be
- `batch_size` for all but the last element, and `N % batch_size` for the
- last element (where `N` is the number of elements in this dataset). Unlike
- @{tf.data.Dataset.batch}, the elements may have different shapes for some
- of their components, and this transformation will pad each component to
- the respective shape in `padding_shapes`. The `padding_shapes` argument
+ dataset into a single element.
+
+ Like @{tf.data.Dataset.batch}, the tensors in the resulting element will
+ have an additional outer dimension, which will be `batch_size` (or
+ `N % batch_size` for the last element if `batch_size` does not divide the
+ number of input elements `N` evenly and `drop_remainder` is `False`). If
+ your program depends on the batches having the same outer dimension, you
+ should set the `drop_remainder` argument to `True` to prevent the smaller
+ batch from being produced.
+
+ Unlike @{tf.data.Dataset.batch}, the input elements to be batched may have
+ different shapes, and this transformation will pad each component to the
+ respective shape in `padding_shapes`. The `padding_shapes` argument
determines the resulting shape for each dimension of each component in an
output element:
@@ -852,12 +859,6 @@ class Dataset(object):
will be padded out to the maximum length of all elements in that
dimension.
- NOTE: If the number of elements (`N`) in this dataset is not an exact
- multiple of `batch_size`, the final batch contain smaller tensors with
- shape `N % batch_size` in the batch dimension. If your program depends on
- the batches having the same shape, consider using the
- @{tf.contrib.data.padded_batch_and_drop_remainder} transformation instead.
-
See also @{tf.contrib.data.dense_to_sparse_batch}, which combines elements
that may have different shapes into a @{tf.SparseTensor}.
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 326019ff2a..38e446da0c 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -10,7 +10,10 @@ load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "estimator_py",
- srcs = ["estimator_lib.py"],
+ srcs = [
+ "__init__.py",
+ "estimator_lib.py",
+ ],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow:__pkg__",
@@ -31,7 +34,7 @@ py_library(
":parsing_utils",
":run_config",
":training",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -41,10 +44,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":gc",
- "//tensorflow/python:errors",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:util",
],
@@ -58,10 +58,7 @@ py_test(
deps = [
":estimator",
":exporter",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -70,8 +67,7 @@ py_library(
srcs = ["gc.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -82,10 +78,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":gc",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -95,12 +88,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":export_output",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -113,12 +101,7 @@ py_test(
deps = [
":export_output",
":model_fn",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -130,11 +113,7 @@ py_library(
":estimator",
":exporter",
":run_config",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -153,13 +132,7 @@ py_test(
":inputs",
":run_config",
":training",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -168,7 +141,7 @@ py_library(
srcs = ["run_config.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/core:protos_all_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -180,8 +153,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":run_config",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -194,14 +166,7 @@ py_library(
":head",
":model_fn",
":optimizers",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -225,26 +190,7 @@ py_test(
":numpy_io",
":pandas_io",
":run_config",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -257,20 +203,7 @@ py_library(
":estimator",
":head",
":model_fn",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:boosted_trees_ops",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -284,19 +217,8 @@ py_test(
],
deps = [
":boosted_trees",
- "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:resources",
- "//tensorflow/python:training",
- "//tensorflow/python/estimator:numpy_io",
- "//tensorflow/python/feature_column",
+ ":inputs",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -309,14 +231,7 @@ py_library(
":head",
":model_fn",
":optimizers",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:summary",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -333,22 +248,7 @@ py_library(
":model_fn",
":numpy_io",
":prediction_keys",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -371,16 +271,7 @@ py_test(
":numpy_io",
":pandas_io",
":prediction_keys",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -396,19 +287,7 @@ py_library(
":linear",
":model_fn",
":optimizers",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -431,17 +310,7 @@ py_test(
":numpy_io",
":pandas_io",
":prediction_keys",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:nn",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -453,10 +322,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/data",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -467,10 +333,7 @@ py_test(
tags = ["notsan"], # b/67510291
deps = [
":util",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:training",
- "//tensorflow/python/data",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -487,21 +350,7 @@ py_library(
":model_fn",
":run_config",
":util",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:metrics",
- "//tensorflow/python:platform",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/data",
- "//tensorflow/python/saved_model:builder",
- "//tensorflow/python/saved_model:constants",
- "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -520,29 +369,7 @@ py_test(
":model_fn",
":numpy_io",
":run_config",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:lib",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:saver_test_utils",
- "//tensorflow/python:session",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
- "//tensorflow/python/data",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:loader",
- "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -555,9 +382,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -568,10 +393,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":parsing_utils",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -580,9 +402,7 @@ py_library(
srcs = ["export/export_output.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -594,13 +414,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":export_output",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -613,7 +427,7 @@ py_library(
deps = [
":export_export",
":export_output",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -625,13 +439,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":util",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -644,17 +452,8 @@ py_test(
deps = [
":export_export",
":export_output",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python/saved_model:signature_def_utils",
+ ":util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -667,24 +466,7 @@ py_library(
":metric_keys",
":model_fn",
":prediction_keys",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:nn",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:weights_broadcast_ops",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -703,23 +485,7 @@ py_test(
":model_fn",
":numpy_io",
":prediction_keys",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -732,7 +498,7 @@ py_library(
deps = [
":numpy_io",
":pandas_io",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -744,11 +510,7 @@ py_library(
":estimator",
":head",
":optimizers",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -766,25 +528,7 @@ py_library(
":numpy_io",
":pandas_io",
":run_config",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -802,7 +546,7 @@ py_test(
deps = [
":linear",
":linear_testing_utils",
- "//tensorflow/python:client_testlib",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -831,9 +575,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":numpy_io",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -842,7 +584,7 @@ py_library(
srcs = ["canned/optimizers.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -854,8 +596,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":optimizers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -873,9 +614,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":pandas_io",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -895,15 +634,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -917,7 +648,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":inputs_queues",
- "//tensorflow/python:client_testlib",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -928,10 +659,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":inputs_queues",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -944,32 +672,7 @@ py_library(
":export_export",
":model_fn",
":run_config",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:platform",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:summary",
- "//tensorflow/python:tensor_util",
- "//tensorflow/python:training",
- "//tensorflow/python:training_util",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/keras:backend",
- "//tensorflow/python/keras:engine",
- "//tensorflow/python/keras:layers",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -984,18 +687,9 @@ py_test(
],
deps = [
":keras",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:run_config",
- "//tensorflow/python/keras",
- "//tensorflow/python/keras:backend",
- "//tensorflow/python/keras:engine",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/estimator/__init__.py b/tensorflow/python/estimator/__init__.py
index e69de29bb2..8cf8df567f 100644
--- a/tensorflow/python/estimator/__init__.py
+++ b/tensorflow/python/estimator/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Import Estimator APIs.
+
+Note: This file is imported by the create_estimator_api genrule. It must
+transitively import all Estimator modules/packages for their @estimator_export
+annotations to generate the public Estimator python API.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.estimator.estimator_lib
diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD
index cddee9b8f3..aa5a29e6dd 100644
--- a/tensorflow/python/estimator/api/BUILD
+++ b/tensorflow/python/estimator/api/BUILD
@@ -14,4 +14,5 @@ gen_api_init_files(
api_name = "estimator",
output_files = ESTIMATOR_API_INIT_FILES,
package = "tensorflow.python.estimator",
+ package_dep = "//tensorflow/python/estimator:estimator_py",
)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 2f439f765e..312eb9a035 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -45,7 +45,6 @@ from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -446,7 +445,6 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects,
saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt'))
-@tf_export('keras.estimator.model_to_estimator')
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 3ed5c9e6a4..708ab1707e 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -67,6 +67,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.tf_export import tf_export
@@ -618,6 +619,11 @@ def run_in_graph_and_eager_modes(__unused__=None,
assert not __unused__, "Add () after run_in_graph_and_eager_modes."
def decorator(f):
+ if tf_inspect.isclass(f):
+ raise ValueError(
+ "`run_test_in_graph_and_eager_modes` only supports test methods. "
+ "Did you mean to use `run_all_tests_in_graph_and_eager_modes`?")
+
def decorated(self, **kwargs):
with context.graph_mode():
with self.test_session(use_gpu=use_gpu):
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 0178908bcc..2a7cf88d6e 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -595,6 +595,14 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
+ def testRunInGraphAndEagerModesOnTestCase(self):
+ msg = "`run_test_in_graph_and_eager_modes` only supports test methods.*"
+ with self.assertRaisesRegexp(ValueError, msg):
+ @test_util.run_in_graph_and_eager_modes()
+ class Foo(object):
+ pass
+ del Foo # Make pylint unused happy.
+
class GarbageCollectionTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 9012f4ee38..bc33dddc95 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -39,6 +39,7 @@ py_library(
"datasets/imdb.py",
"datasets/mnist.py",
"datasets/reuters.py",
+ "estimator/__init__.py",
"preprocessing/__init__.py",
"preprocessing/image.py",
"preprocessing/sequence.py",
@@ -866,6 +867,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:util",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py
index 3493069a5b..198c66d9e1 100644
--- a/tensorflow/python/keras/__init__.py
+++ b/tensorflow/python/keras/__init__.py
@@ -27,6 +27,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks
from tensorflow.python.keras import constraints
from tensorflow.python.keras import datasets
+from tensorflow.python.keras import estimator
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index c55a756bcc..fed779650e 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import itertools
import json
import os
import weakref
@@ -4245,58 +4246,115 @@ def pool3d(x,
return x
-def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
- """Apply 1D conv with un-shared weights.
-
- Arguments:
- inputs: 3D tensor with shape:
- (batch_size, steps, input_dim)
- if data_format is "channels_last" or
- (batch_size, input_dim, steps)
- if data_format is "channels_first".
- kernel: the unshared weight for convolution,
- with shape (output_length, feature_dim, filters)
- kernel_size: a tuple of a single integer,
- specifying the length of the 1D convolution window
- strides: a tuple of a single integer,
- specifying the stride length of the convolution
- data_format: the data format, channels_first or channels_last
-
- Returns:
- the tensor after 1d conv with un-shared weights, with shape (batch_size,
- output_length, filters)
+def local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format=None):
+ """Apply N-D convolution with un-shared weights.
+
+ Arguments:
+ inputs: (N+2)-D tensor with shape
+ (batch_size, channels_in, d_in1, ..., d_inN)
+ if data_format='channels_first', or
+ (batch_size, d_in1, ..., d_inN, channels_in)
+ if data_format='channels_last'.
+ kernel: the unshared weight for N-D convolution,
+ with shape (output_items, feature_dim, channels_out), where
+ feature_dim = np.prod(kernel_size) * channels_in,
+ output_items = np.prod(output_shape).
+ kernel_size: a tuple of N integers, specifying the
+ spatial dimensions of the N-D convolution window.
+ strides: a tuple of N integers, specifying the strides
+ of the convolution along the spatial dimensions.
+ output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
+ dimensionality of the output.
+ data_format: string, "channels_first" or "channels_last".
+
+ Returns:
+ An (N+2)-D tensor with shape:
+ (batch_size, channels_out) + output_shape
+ if data_format='channels_first', or:
+ (batch_size,) + output_shape + (channels_out,)
+ if data_format='channels_last'.
Raises:
- ValueError: if `data_format` is neither `channels_last` or
- `channels_first`.
+ ValueError: if `data_format` is neither
+ `channels_last` nor `channels_first`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
- stride = strides[0]
kernel_shape = int_shape(kernel)
- output_length = kernel_shape[0]
feature_dim = kernel_shape[1]
+ channels_out = kernel_shape[-1]
+ ndims = len(output_shape)
+ spatial_dimensions = list(range(ndims))
xs = []
- for i in range(output_length):
- slice_length = slice(i * stride, i * stride + kernel_size[0])
+ output_axes_ticks = [range(axis_max) for axis_max in output_shape]
+ for position in itertools.product(*output_axes_ticks):
+ slices = [slice(None)]
+
if data_format == 'channels_first':
- xs.append(reshape(inputs[:, :, slice_length], (1, -1, feature_dim)))
- else:
- xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim)))
+ slices.append(slice(None))
+
+ slices.extend([slice(position[d] * strides[d],
+ position[d] * strides[d] + kernel_size[d])
+ for d in spatial_dimensions])
+
+ if data_format == 'channels_last':
+ slices.append(slice(None))
+
+ xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
x_aggregate = concatenate(xs, axis=0)
- # Shape: `(output_length, batch_size, filters)`.
output = batch_dot(x_aggregate, kernel)
+ output = reshape(output, output_shape + (-1, channels_out))
if data_format == 'channels_first':
- output = permute_dimensions(output, (1, 2, 0))
+ permutation = [ndims, ndims + 1] + spatial_dimensions
else:
- output = permute_dimensions(output, (1, 0, 2))
- return output
+ permutation = [ndims] + spatial_dimensions + [ndims + 1]
+
+ return permute_dimensions(output, permutation)
+
+
+def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
+ """Apply 1D conv with un-shared weights.
+
+ Arguments:
+ inputs: 3D tensor with shape:
+ (batch_size, steps, input_dim)
+ if data_format is "channels_last" or
+ (batch_size, input_dim, steps)
+ if data_format is "channels_first".
+ kernel: the unshared weight for convolution,
+ with shape (output_length, feature_dim, filters).
+ kernel_size: a tuple of a single integer,
+ specifying the length of the 1D convolution window.
+ strides: a tuple of a single integer,
+ specifying the stride length of the convolution.
+ data_format: the data format, channels_first or channels_last.
+
+ Returns:
+ A 3d tensor with shape:
+ (batch_size, output_length, filters)
+ if data_format='channels_first'
+ or 3D tensor with shape:
+ (batch_size, filters, output_length)
+ if data_format='channels_last'.
+ """
+ output_shape = (kernel.shape[0],)
+ return local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format)
def local_conv2d(inputs,
@@ -4309,64 +4367,34 @@ def local_conv2d(inputs,
Arguments:
inputs: 4D tensor with shape:
- (batch_size, filters, new_rows, new_cols)
- if data_format='channels_first'
- or 4D tensor with shape:
- (batch_size, new_rows, new_cols, filters)
- if data_format='channels_last'.
+ (batch_size, filters, new_rows, new_cols)
+ if data_format='channels_first'
+ or 4D tensor with shape:
+ (batch_size, new_rows, new_cols, filters)
+ if data_format='channels_last'.
kernel: the unshared weight for convolution,
- with shape (output_items, feature_dim, filters)
+ with shape (output_items, feature_dim, filters).
kernel_size: a tuple of 2 integers, specifying the
- width and height of the 2D convolution window.
+ width and height of the 2D convolution window.
strides: a tuple of 2 integers, specifying the strides
- of the convolution along the width and height.
- output_shape: a tuple with (output_row, output_col)
- data_format: the data format, channels_first or channels_last
+ of the convolution along the width and height.
+ output_shape: a tuple with (output_row, output_col).
+ data_format: the data format, channels_first or channels_last.
Returns:
- A 4d tensor with shape:
+ A 4D tensor with shape:
(batch_size, filters, new_rows, new_cols)
if data_format='channels_first'
or 4D tensor with shape:
(batch_size, new_rows, new_cols, filters)
if data_format='channels_last'.
-
- Raises:
- ValueError: if `data_format` is neither
- `channels_last` or `channels_first`.
"""
- if data_format is None:
- data_format = image_data_format()
- if data_format not in {'channels_first', 'channels_last'}:
- raise ValueError('Unknown data_format: ' + str(data_format))
-
- stride_row, stride_col = strides
- output_row, output_col = output_shape
- kernel_shape = int_shape(kernel)
- feature_dim = kernel_shape[1]
- filters = kernel_shape[2]
-
- xs = []
- for i in range(output_row):
- for j in range(output_col):
- slice_row = slice(i * stride_row, i * stride_row + kernel_size[0])
- slice_col = slice(j * stride_col, j * stride_col + kernel_size[1])
- if data_format == 'channels_first':
- xs.append(
- reshape(inputs[:, :, slice_row, slice_col], (1, -1, feature_dim)))
- else:
- xs.append(
- reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim)))
-
- x_aggregate = concatenate(xs, axis=0)
- output = batch_dot(x_aggregate, kernel)
- output = reshape(output, (output_row, output_col, -1, filters))
-
- if data_format == 'channels_first':
- output = permute_dimensions(output, (2, 3, 0, 1))
- else:
- output = permute_dimensions(output, (2, 0, 1, 3))
- return output
+ return local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format)
@tf_export('keras.backend.bias_add')
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 98f36ad87f..2ba6c8ef15 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
import scipy.sparse
@@ -662,7 +663,7 @@ class BackendShapeOpsTest(test.TestCase):
np_kwargs={'data_format': 'channels_first'})
-class BackendNNOpsTest(test.TestCase):
+class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
def test_bias_add(self):
with self.test_session():
@@ -811,52 +812,117 @@ class BackendNNOpsTest(test.TestCase):
padding='same', data_format='channels_last')
self.assertEqual(y.get_shape().as_list(), [10, 5, 5])
- def test_local_conv1d_channels_dim(self):
- input_length = 5
- input_dim = 3
+ def test_local_conv_channels_dim(self):
+ filters = 3
batch_size = 2
- inputs = np.random.normal(0, 1, (batch_size, input_dim, input_length))
- inputs_cf = keras.backend.variable(inputs)
+ for input_shape in [(3, 5), (2, 3, 5), (2, 5, 3, 4)]:
+ channels_in = input_shape[0]
+ input_spatial_shape = input_shape[1:]
+ dim = len(input_spatial_shape)
- filters = 4
- for kernel_size in [(1,), (2,), (3,)]:
- for strides in [(1,), (2,), (3,)]:
- output_length = (input_length - kernel_size[0]
- + strides[0]) // strides[0]
+ inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
+ inputs_cf = keras.backend.variable(inputs)
- kernel_shape = (output_length, kernel_size[0] * input_dim, filters)
- kernel = np.random.normal(0, 1, (output_length,
- input_dim,
- kernel_size[0],
- filters))
- kernel_cf = np.reshape(kernel, kernel_shape)
- kernel_cf = keras.backend.variable(kernel_cf)
+ for kernel_size in [1, 2]:
+ for stride in [1, 2]:
+ kernel_sizes = (kernel_size,) * dim
+ strides = (stride,) * dim
- conv_cf = keras.backend.local_conv1d(inputs_cf,
+ output_shape = tuple([(i - kernel_size + stride) // stride
+ for i in input_spatial_shape])
+
+ kernel_shape = (np.prod(output_shape),
+ np.prod(kernel_sizes) * channels_in,
+ filters)
+
+ kernel = np.random.normal(
+ 0,
+ 1,
+ output_shape + (channels_in, np.prod(kernel_sizes), filters)
+ )
+
+ kernel_cf = np.reshape(kernel, kernel_shape)
+ kernel_cf = keras.backend.variable(kernel_cf)
+
+ conv_cf = keras.backend.local_conv(inputs_cf,
kernel_cf,
- kernel_size,
+ kernel_sizes,
strides,
+ output_shape,
'channels_first')
- inputs_cl = np.transpose(inputs, (0, 2, 1))
- inputs_cl = keras.backend.variable(inputs_cl)
+ inputs_cl = np.transpose(inputs, [0, 2] + list(range(3, dim + 2)) +
+ [1])
+ inputs_cl = keras.backend.variable(inputs_cl)
- kernel_cl = np.reshape(np.transpose(kernel, (0, 2, 1, 3)),
- kernel_shape)
- kernel_cl = keras.backend.variable(kernel_cl)
+ kernel_cl = np.reshape(
+ np.transpose(kernel, list(range(dim)) + [dim + 1, dim, dim + 2]),
+ kernel_shape
+ )
+ kernel_cl = keras.backend.variable(kernel_cl)
- conv_cl = keras.backend.local_conv1d(inputs_cl,
+ conv_cl = keras.backend.local_conv(inputs_cl,
kernel_cl,
- kernel_size,
+ kernel_sizes,
strides,
+ output_shape,
'channels_last')
- with self.test_session():
- conv_cf = keras.backend.eval(conv_cf)
- conv_cl = keras.backend.eval(conv_cl)
+ with self.test_session():
+ conv_cf = keras.backend.eval(conv_cf)
+ conv_cl = keras.backend.eval(conv_cl)
+
+ self.assertAllCloseAccordingToType(
+ conv_cf,
+ np.transpose(conv_cl,
+ [0, dim + 1] + list(range(1, dim + 1))),
+ atol=1e-5
+ )
+
+ @parameterized.named_parameters(
+ ('local_conv1d', (5, 6), (3,), (1,), (3,)),
+ ('local_conv2d', (4, 5, 6), (3, 3), (1, 1), (2, 3)))
+ def test_local_conv_1d_and_2d(self,
+ input_shape,
+ kernel_sizes,
+ strides,
+ output_shape):
+ filters = 3
+ batch_size = 2
+
+ inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
+ inputs = keras.backend.variable(inputs)
+
+ kernel = np.random.normal(0, 1, (np.prod(output_shape),
+ np.prod(kernel_sizes) * input_shape[-1],
+ filters))
+ kernel = keras.backend.variable(kernel)
+
+ local_conv = keras.backend.local_conv(inputs,
+ kernel,
+ kernel_sizes,
+ strides,
+ output_shape,
+ 'channels_last')
+ if len(output_shape) == 1:
+ local_conv_dim = keras.backend.local_conv1d(inputs,
+ kernel,
+ kernel_sizes,
+ strides,
+ 'channels_last')
+ else:
+ local_conv_dim = keras.backend.local_conv2d(inputs,
+ kernel,
+ kernel_sizes,
+ strides,
+ output_shape,
+ 'channels_last')
+
+ with self.test_session():
+ local_conv = keras.backend.eval(local_conv)
+ local_conv_dim = keras.backend.eval(local_conv_dim)
- self.assertAllCloseAccordingToType(conv_cf,
- np.transpose(conv_cl, (0, 2, 1)))
+ self.assertAllCloseAccordingToType(local_conv, local_conv_dim)
def test_conv2d(self):
val = np.random.random((10, 4, 10, 10))
diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py
new file mode 100644
index 0000000000..cb86a69990
--- /dev/null
+++ b/tensorflow/python/keras/estimator/__init__.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras estimator API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.util.tf_export import tf_export
+
+# Keras has undeclared dependency on tensorflow/estimator:estimator_py.
+# As long as you depend //third_party/py/tensorflow:tensorflow target
+# everything will work as normal.
+
+try:
+ import tensorflow.python.estimator.keras as keras_lib # pylint: disable=g-import-not-at-top
+ model_to_estimator = tf_export('keras.estimator.model_to_estimator')(
+ keras_lib.model_to_estimator)
+except Exception: # pylint: disable=broad-except
+
+ # pylint: disable=unused-argument
+ def stub_model_to_estimator(keras_model=None,
+ keras_model_path=None,
+ custom_objects=None,
+ model_dir=None,
+ config=None):
+ raise NotImplementedError(
+ 'tf.keras.estimator.model_to_estimator function not available in your '
+ 'installation.')
+ # pylint: enable=unused-argument
+
+ model_to_estimator = tf_export('keras.estimator.model_to_estimator')(
+ stub_model_to_estimator)
+
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index f222ea3083..0983e35e21 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -140,9 +140,9 @@ class LocallyConnected1D(Layer):
if input_dim is None:
raise ValueError('Axis 2 of input should be fully-defined. '
'Found shape:', input_shape)
- output_length = conv_utils.conv_output_length(
+ self.output_length = conv_utils.conv_output_length(
input_length, self.kernel_size[0], self.padding, self.strides[0])
- self.kernel_shape = (output_length, self.kernel_size[0] * input_dim,
+ self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
self.filters)
self.kernel = self.add_weight(
shape=self.kernel_shape,
@@ -152,7 +152,7 @@ class LocallyConnected1D(Layer):
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
- shape=(output_length, self.filters),
+ shape=(self.output_length, self.filters),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
@@ -182,12 +182,13 @@ class LocallyConnected1D(Layer):
return (input_shape[0], length, self.filters)
def call(self, inputs):
- output = K.local_conv1d(inputs, self.kernel, self.kernel_size,
- self.strides, self.data_format)
+ output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
+ (self.output_length,), self.data_format)
+
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
- if self.activation is not None:
- output = self.activation(output)
+
+ output = self.activation(output)
return output
def get_config(self):
@@ -400,9 +401,8 @@ class LocallyConnected2D(Layer):
return (input_shape[0], rows, cols, self.filters)
def call(self, inputs):
- output = K.local_conv2d(inputs, self.kernel, self.kernel_size, self.strides,
- (self.output_row, self.output_col),
- self.data_format)
+ output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
+ (self.output_row, self.output_col), self.data_format)
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index 7759561ef9..00d0fc67d1 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -45,7 +45,9 @@ class Wrapper(Layer):
"""
def __init__(self, layer, **kwargs):
+ assert isinstance(layer, Layer)
self.layer = layer
+ self._track_checkpointable(layer, name='layer')
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).
@@ -154,9 +156,16 @@ class TimeDistributed(Wrapper):
Arguments:
layer: a layer instance.
+
+ Raises:
+ ValueError: If not initialized with a `Layer` instance.
"""
def __init__(self, layer, **kwargs):
+ if not isinstance(layer, Layer):
+ raise ValueError(
+ 'Please initialize `TimeDistributed` layer with a '
+ '`Layer` instance. You passed: {input}'.format(input=layer))
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
@@ -249,7 +258,8 @@ class Bidirectional(Wrapper):
they will be returned as a list.
Raises:
- ValueError: In case of invalid `merge_mode` argument.
+ ValueError: If not initialized with a `Layer` instance or
+ In case of invalid `merge_mode` argument.
Examples:
@@ -265,6 +275,10 @@ class Bidirectional(Wrapper):
"""
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
+ if not isinstance(layer, Layer):
+ raise ValueError(
+ 'Please initialize `Bidirectional` layer with a '
+ '`Layer` instance. You passed: {input}'.format(input=layer))
if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]:
raise ValueError('Invalid merge mode. '
'Merge mode should be one of '
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index 5eab6aba8a..3b997732b5 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -23,8 +23,10 @@ import copy
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@@ -85,6 +87,10 @@ class TimeDistributedTest(test.TestCase):
# test config
model.get_config()
+ checkpointed_objects = set(checkpointable_util.list_objects(model))
+ for v in model.variables:
+ self.assertIn(v, checkpointed_objects)
+
def test_timedistributed_static_batch_size(self):
model = keras.models.Sequential()
model.add(
@@ -97,6 +103,13 @@ class TimeDistributedTest(test.TestCase):
epochs=1,
batch_size=10)
+ def test_timedistributed_invalid_init(self):
+ x = constant_op.constant(np.zeros((1, 1)).astype('float32'))
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Please initialize `TimeDistributed` layer with a `Layer` instance.'):
+ keras.layers.TimeDistributed(x)
+
def test_timedistributed_conv2d(self):
with self.test_session():
model = keras.models.Sequential()
@@ -220,6 +233,13 @@ class BidirectionalTest(test.TestCase):
model = keras.models.model_from_json(model.to_json())
model.summary()
+ def test_bidirectional_invalid_init(self):
+ x = constant_op.constant(np.zeros((1, 1)).astype('float32'))
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Please initialize `Bidirectional` layer with a `Layer` instance.'):
+ keras.layers.Bidirectional(x)
+
def test_bidirectional_weight_loading(self):
rnn = keras.layers.SimpleRNN
samples = 2
@@ -424,6 +444,42 @@ class BidirectionalTest(test.TestCase):
layer.trainable = True
assert len(layer.trainable_weights) == 6
+ def test_Bidirectional_updates(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3, 2))
+ x_reachable_update = x * x
+ layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
+ _ = layer(x)
+ assert not layer.updates
+ assert not layer.get_updates_for(None)
+ assert not layer.get_updates_for(x)
+ layer.forward_layer.add_update(x_reachable_update, inputs=x)
+ layer.forward_layer.add_update(1, inputs=None)
+ layer.backward_layer.add_update(x_reachable_update, inputs=x)
+ layer.backward_layer.add_update(1, inputs=None)
+ assert len(layer.updates) == 4
+ assert len(layer.get_updates_for(None)) == 2
+ assert len(layer.get_updates_for(x)) == 2
+
+ def test_Bidirectional_losses(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3, 2))
+ x_reachable_loss = x * x
+ layer = keras.layers.Bidirectional(
+ keras.layers.SimpleRNN(
+ 3, kernel_regularizer='l1', bias_regularizer='l1'))
+ _ = layer(x)
+ assert len(layer.losses) == 4
+ assert len(layer.get_losses_for(None)) == 4
+ assert not layer.get_losses_for(x)
+ layer.forward_layer.add_loss(x_reachable_loss, inputs=x)
+ layer.forward_layer.add_loss(1, inputs=None)
+ layer.backward_layer.add_loss(x_reachable_loss, inputs=x)
+ layer.backward_layer.add_loss(1, inputs=None)
+ assert len(layer.losses) == 8
+ assert len(layer.get_losses_for(None)) == 6
+ assert len(layer.get_losses_for(x)) == 2
+
def test_Bidirectional_with_constants(self):
with self.test_session():
# Test basic case.
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index ccd05a8820..b61232cded 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -96,7 +96,8 @@ class UnaryOpTest(test.TestCase):
np_ans = np_func(x)
with self.test_session(use_gpu=False):
inx = ops.convert_to_tensor(x)
- if x.dtype in (np.float32, np.float64):
+ if x.dtype in (np.float32, np.float64,
+ dtypes_lib.bfloat16.as_numpy_dtype):
y = 1.1 * tf_func(inx)
np_ans *= 1.1
else:
@@ -105,6 +106,8 @@ class UnaryOpTest(test.TestCase):
self.assertShapeEqual(np_ans, y)
if x.dtype == np.float16:
self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
+ elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
+ self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
else:
self.assertAllClose(np_ans, tf_cpu)
diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD
index 985922245e..14532965d8 100644
--- a/tensorflow/python/kernel_tests/distributions/BUILD
+++ b/tensorflow/python/kernel_tests/distributions/BUILD
@@ -135,6 +135,10 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
+ tags = [
+ "noguitar", # b/110489471
+ "notap", # b/110489471
+ ],
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 095d1cde15..ed5ea8b034 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -22,6 +22,7 @@ import importlib
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@@ -272,6 +273,16 @@ class BernoulliTest(test.TestCase):
dist = bernoulli.Bernoulli(np.log([.2, .4]))
self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
+ @test_util.run_in_graph_and_eager_modes()
+ def testNotReparameterized(self):
+ p = constant_op.constant([0.2, 0.6])
+ with backprop.GradientTape() as tape:
+ tape.watch(p)
+ dist = bernoulli.Bernoulli(probs=p)
+ samples = dist.sample(100)
+ grad_p = tape.gradient(samples, p)
+ self.assertIsNone(grad_p)
+
def testSampleActsLikeSampleN(self):
with self.test_session() as sess:
p = [0.2, 0.6]
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py
index 68b4ffdb58..d8939433ce 100644
--- a/tensorflow/python/kernel_tests/distributions/categorical_test.py
+++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
@@ -376,6 +377,15 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
self.assertAllClose(
[0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2)
+ def testNotReparameterized(self):
+ p = constant_op.constant([0.3, 0.3, 0.4])
+ with backprop.GradientTape() as tape:
+ tape.watch(p)
+ dist = categorical.Categorical(p)
+ samples = dist.sample(100)
+ grad_p = tape.gradient(samples, p)
+ self.assertIsNone(grad_p)
+
def testLogPMFBroadcasting(self):
with self.test_session():
# 1 x 2 x 2
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
index 7922fb0606..9344785b09 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
@@ -17,6 +17,9 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -475,6 +478,21 @@ class DirichletMultinomialTest(test.TestCase):
self.assertAllClose(
actual_covariance_, sample_covariance_, atol=0., rtol=0.15)
+ def testNotReparameterized(self):
+ total_count = constant_op.constant(5.0)
+ concentration = constant_op.constant([0.1, 0.1, 0.1])
+ with backprop.GradientTape() as tape:
+ tape.watch(total_count)
+ tape.watch(concentration)
+ dist = ds.DirichletMultinomial(
+ total_count=total_count,
+ concentration=concentration)
+ samples = dist.sample(100)
+ grad_total_count, grad_concentration = tape.gradient(
+ samples, [total_count, concentration])
+ self.assertIsNone(grad_total_count)
+ self.assertIsNone(grad_concentration)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index ebcd41b0e2..850da3e969 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -23,6 +23,7 @@ import importlib
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import nn_ops
@@ -163,6 +164,15 @@ class ExponentialTest(test.TestCase):
stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
0.01)
+ def testFullyReparameterized(self):
+ lam = constant_op.constant([0.1, 1.0])
+ with backprop.GradientTape() as tape:
+ tape.watch(lam)
+ exponential = exponential_lib.Exponential(rate=lam)
+ samples = exponential.sample(100)
+ grad_lam = tape.gradient(samples, lam)
+ self.assertIsNotNone(grad_lam)
+
def testExponentialWithSoftplusRate(self):
with self.test_session():
lam = [-2.2, -3.4]
diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py
index 5e4813ac07..154e859f3c 100644
--- a/tensorflow/python/kernel_tests/distributions/gamma_test.py
+++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py
@@ -21,9 +21,9 @@ import importlib
import numpy as np
-from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
@@ -45,6 +45,7 @@ special = try_import("scipy.special")
stats = try_import("scipy.stats")
+@test_util.run_all_in_graph_and_eager_modes
class GammaTest(test.TestCase):
def testGammaShape(self):
@@ -53,9 +54,9 @@ class GammaTest(test.TestCase):
beta = constant_op.constant(11.0)
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- self.assertEqual(gamma.batch_shape_tensor().eval(), (5,))
+ self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(gamma.event_shape_tensor().eval(), [])
+ self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
def testGammaLogPDF(self):
@@ -74,8 +75,8 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(log_pdf.eval(), expected_log_pdf)
- self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
+ self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+ self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensional(self):
with self.test_session():
@@ -87,10 +88,10 @@ class GammaTest(test.TestCase):
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
log_pdf = gamma.log_prob(x)
- log_pdf_values = log_pdf.eval()
+ log_pdf_values = self.evaluate(log_pdf)
self.assertEqual(log_pdf.get_shape(), (6, 2))
pdf = gamma.prob(x)
- pdf_values = pdf.eval()
+ pdf_values = self.evaluate(pdf)
self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
return
@@ -108,10 +109,10 @@ class GammaTest(test.TestCase):
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
log_pdf = gamma.log_prob(x)
- log_pdf_values = log_pdf.eval()
+ log_pdf_values = self.evaluate(log_pdf)
self.assertEqual(log_pdf.get_shape(), (6, 2))
pdf = gamma.prob(x)
- pdf_values = pdf.eval()
+ pdf_values = self.evaluate(pdf)
self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
@@ -135,7 +136,7 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(cdf.eval(), expected_cdf)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testGammaMean(self):
with self.test_session():
@@ -146,7 +147,7 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
- self.assertAllClose(gamma.mean().eval(), expected_means)
+ self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
with self.test_session():
@@ -155,7 +156,7 @@ class GammaTest(test.TestCase):
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_modes = (alpha_v - 1) / beta_v
self.assertEqual(gamma.mode().get_shape(), (3,))
- self.assertAllClose(gamma.mode().eval(), expected_modes)
+ self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
with self.test_session():
@@ -166,7 +167,7 @@ class GammaTest(test.TestCase):
rate=beta_v,
allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
- gamma.mode().eval()
+ self.evaluate(gamma.mode())
def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self):
with self.test_session():
@@ -179,7 +180,7 @@ class GammaTest(test.TestCase):
expected_modes = (alpha_v - 1) / beta_v
expected_modes[0] = np.nan
self.assertEqual(gamma.mode().get_shape(), (3,))
- self.assertAllClose(gamma.mode().eval(), expected_modes)
+ self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaVariance(self):
with self.test_session():
@@ -190,7 +191,7 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
- self.assertAllClose(gamma.variance().eval(), expected_variances)
+ self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
def testGammaStd(self):
with self.test_session():
@@ -201,7 +202,7 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
- self.assertAllClose(gamma.stddev().eval(), expected_stddev)
+ self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
def testGammaEntropy(self):
with self.test_session():
@@ -212,10 +213,10 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
- self.assertAllClose(gamma.entropy().eval(), expected_entropy)
+ self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
def testGammaSampleSmallAlpha(self):
- with session.Session():
+ with self.test_session():
alpha_v = 0.05
beta_v = 1.0
alpha = constant_op.constant(alpha_v)
@@ -223,7 +224,7 @@ class GammaTest(test.TestCase):
n = 100000
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
samples = gamma.sample(n, seed=137)
- sample_values = samples.eval()
+ sample_values = self.evaluate(samples)
self.assertEqual(samples.get_shape(), (n,))
self.assertEqual(sample_values.shape, (n,))
self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
@@ -240,7 +241,7 @@ class GammaTest(test.TestCase):
atol=.15)
def testGammaSample(self):
- with session.Session():
+ with self.test_session():
alpha_v = 4.0
beta_v = 3.0
alpha = constant_op.constant(alpha_v)
@@ -248,7 +249,7 @@ class GammaTest(test.TestCase):
n = 100000
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
samples = gamma.sample(n, seed=137)
- sample_values = samples.eval()
+ sample_values = self.evaluate(samples)
self.assertEqual(samples.get_shape(), (n,))
self.assertEqual(sample_values.shape, (n,))
self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
@@ -265,13 +266,13 @@ class GammaTest(test.TestCase):
atol=.15)
def testGammaSampleMultiDimensional(self):
- with session.Session():
+ with self.test_session():
alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
n = 10000
samples = gamma.sample(n, seed=137)
- sample_values = samples.eval()
+ sample_values = self.evaluate(samples)
self.assertEqual(samples.get_shape(), (n, 10, 100))
self.assertEqual(sample_values.shape, (n, 10, 100))
zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100
@@ -306,12 +307,12 @@ class GammaTest(test.TestCase):
return ks < 0.02
def testGammaPdfOfSampleMultiDims(self):
- with session.Session() as sess:
+ with self.test_session():
gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
num = 50000
samples = gamma.sample(num, seed=137)
pdfs = gamma.prob(samples)
- sample_vals, pdf_vals = sess.run([samples, pdfs])
+ sample_vals, pdf_vals = self.evaluate([samples, pdfs])
self.assertEqual(samples.get_shape(), (num, 2, 2))
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
@@ -345,18 +346,18 @@ class GammaTest(test.TestCase):
with self.test_session():
alpha_v = constant_op.constant(0.0, name="alpha")
beta_v = constant_op.constant(1.0, name="beta")
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- validate_args=True)
- with self.assertRaisesOpError("alpha"):
- gamma.mean().eval()
+ with self.assertRaisesOpError("x > 0"):
+ gamma = gamma_lib.Gamma(concentration=alpha_v,
+ rate=beta_v,
+ validate_args=True)
+ self.evaluate(gamma.mean())
alpha_v = constant_op.constant(1.0, name="alpha")
beta_v = constant_op.constant(0.0, name="beta")
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- validate_args=True)
- with self.assertRaisesOpError("beta"):
- gamma.mean().eval()
+ with self.assertRaisesOpError("x > 0"):
+ gamma = gamma_lib.Gamma(concentration=alpha_v,
+ rate=beta_v,
+ validate_args=True)
+ self.evaluate(gamma.mean())
def testGammaWithSoftplusConcentrationRate(self):
with self.test_session():
@@ -364,10 +365,10 @@ class GammaTest(test.TestCase):
beta_v = constant_op.constant([1.0, -3.6], name="beta")
gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
concentration=alpha_v, rate=beta_v)
- self.assertAllEqual(nn_ops.softplus(alpha_v).eval(),
- gamma.concentration.eval())
- self.assertAllEqual(nn_ops.softplus(beta_v).eval(),
- gamma.rate.eval())
+ self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)),
+ self.evaluate(gamma.concentration))
+ self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)),
+ self.evaluate(gamma.rate))
def testGammaGammaKL(self):
alpha0 = np.array([3.])
@@ -377,15 +378,15 @@ class GammaTest(test.TestCase):
beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
# Build graph.
- with self.test_session() as sess:
+ with self.test_session():
g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
x = g0.sample(int(1e4), seed=0)
kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
kl_actual = kullback_leibler.kl_divergence(g0, g1)
- # Execute graph.
- [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual])
+ # Execute graph.
+ [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
self.assertEqual(beta0.shape, kl_actual.get_shape())
diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py
index 918c7f63f2..24b243f647 100644
--- a/tensorflow/python/kernel_tests/distributions/laplace_test.py
+++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py
@@ -22,6 +22,7 @@ import importlib
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
@@ -255,6 +256,18 @@ class LaplaceTest(test.TestCase):
atol=0.)
self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
+ def testLaplaceFullyReparameterized(self):
+ loc = constant_op.constant(4.0)
+ scale = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(loc)
+ tape.watch(scale)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ samples = laplace.sample(100)
+ grad_loc, grad_scale = tape.gradient(samples, [loc, scale])
+ self.assertIsNotNone(grad_loc)
+ self.assertIsNotNone(grad_scale)
+
def testLaplaceSampleMultiDimensional(self):
with session.Session():
loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
index e24e8ade73..6d5d40123e 100644
--- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
@@ -18,6 +18,8 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import backprop
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -343,6 +345,20 @@ class MultinomialTest(test.TestCase):
self.assertAllClose(
actual_covariance_, sample_covariance_, atol=0., rtol=0.10)
+ def testNotReparameterized(self):
+ total_count = constant_op.constant(5.0)
+ p = constant_op.constant([0.2, 0.6])
+ with backprop.GradientTape() as tape:
+ tape.watch(total_count)
+ tape.watch(p)
+ dist = multinomial.Multinomial(
+ total_count=total_count,
+ probs=p)
+ samples = dist.sample(100)
+ grad_total_count, grad_p = tape.gradient(samples, [total_count, p])
+ self.assertIsNone(grad_total_count)
+ self.assertIsNone(grad_p)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index d793e03272..c7e00ff8d8 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -23,6 +23,7 @@ import math
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -453,6 +454,18 @@ class NormalTest(test.TestCase):
self.assertAllEqual(expected_samples_shape, samples.get_shape())
self.assertAllEqual(expected_samples_shape, sample_values.shape)
+ def testNormalFullyReparameterized(self):
+ mu = constant_op.constant(4.0)
+ sigma = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(mu)
+ tape.watch(sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ samples = normal.sample(100)
+ grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma])
+ self.assertIsNotNone(grad_mu)
+ self.assertIsNotNone(grad_sigma)
+
@test_util.run_in_graph_and_eager_modes()
def testNormalSampleMultiDimensional(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py
index e74051c901..978fff1cc1 100644
--- a/tensorflow/python/kernel_tests/distributions/uniform_test.py
+++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py
@@ -22,6 +22,7 @@ import importlib
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
@@ -299,6 +300,18 @@ class UniformTest(test.TestCase):
expected_pdf = [1.0, 0.1]
self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ def testFullyReparameterized(self):
+ a = constant_op.constant(0.1)
+ b = constant_op.constant(0.8)
+ with backprop.GradientTape() as tape:
+ tape.watch(a)
+ tape.watch(b)
+ uniform = uniform_lib.Uniform(a, b)
+ samples = uniform.sample(100)
+ grad_a, grad_b = tape.gradient(samples, [a, b])
+ self.assertIsNotNone(grad_a)
+ self.assertIsNotNone(grad_b)
+
# Eager doesn't pass due to a type mismatch in one of the ops.
def testUniformFloat64(self):
uniform = uniform_lib.Uniform(
diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD
index a9bd68971e..3b3a28fc9a 100644
--- a/tensorflow/python/kernel_tests/random/BUILD
+++ b/tensorflow/python/kernel_tests/random/BUILD
@@ -88,10 +88,6 @@ cuda_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:random_ops",
],
- tags = [
- "manual",
- "no_oss",
- ],
)
cuda_py_test(
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index fae63b1132..361667ec49 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import gen_math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_array_ops import *
+from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -2609,14 +2610,6 @@ def where(condition, x=None, y=None, name=None):
raise ValueError("x and y must both be non-None or both be None.")
-@tf_export("reverse")
-def reverse(tensor, axis, name=None):
- return gen_array_ops.reverse_v2(tensor, axis, name)
-
-
-reverse.__doc__ = gen_array_ops.reverse_v2.__doc__
-
-
# pylint: disable=redefined-builtin
@tf_export("reverse_sequence")
@deprecation.deprecated_args(
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 9413bfa2af..837c144467 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -3348,12 +3348,6 @@ def group(*inputs, **kwargs):
if not hasattr(inp, "device"):
raise TypeError("Expected tf.group() expected Tensor arguments not "
"'%s' with type '%s'" % (inp, type(inp)))
- if not hasattr(inp, "device"):
- if isinstance(inp, list):
- raise TypeError("To call tf.group() with a list, use "
- "tf.group(*[...]) not tf.group([...]).")
- raise TypeError("Expected tf.group() expected Tensor arguments not "
- "'%s' with type '%s'" % (inp, type(inp)))
dev = inp.device
if dev in ops_on_device:
ops_on_device[dev].append(inp)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index cae29eea93..fe9ffde11c 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -730,15 +730,15 @@ class Optimizer(
if not named_slots:
return None
- if hasattr(var, "_mirrored_container"):
+ if hasattr(var, "_distributed_container"):
# NOTE: If this isn't patched, then there is no `handle` in
# `_resource_apply_dense`.
- mirrored_container = var._mirrored_container()
- assert mirrored_container is not None
+ distributed_container = var._distributed_container()
+ assert distributed_container is not None
if context.executing_eagerly():
- key = mirrored_container._unique_id
+ key = distributed_container._unique_id
else:
- key = (mirrored_container.graph, mirrored_container._shared_name)
+ key = (distributed_container.graph, distributed_container._shared_name)
# pylint: enable=protected-access
mirrored_slot = named_slots.get(key, None)
if mirrored_slot is None: return None
@@ -839,7 +839,7 @@ class Optimizer(
def _get_non_slot_variable(self, name, graph=None):
non_slot = self._non_slot_dict.get((name, graph), None)
- if hasattr(non_slot, "_mirrored_container"):
+ if hasattr(non_slot, "_distributed_container"):
# This is a mirrored non-slot. In order to enable code like `_finish`
# to assign to a non-slot, return the current context replica.
return non_slot.get()
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index 6065c12cad..8c760e6f52 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -3,38 +3,37 @@
licenses(["notice"]) # Apache 2.0
-exports_files(["LICENSE"])
-
load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
-py_library(
- name = "doc_srcs",
- srcs = ["doc_srcs.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:util",
+exports_files(
+ [
+ "LICENSE",
+ "create_python_api.py",
],
)
-py_binary(
- name = "create_python_api",
- srcs = ["create_python_api.py"],
+py_library(
+ name = "doc_srcs",
+ srcs = ["doc_srcs.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":doc_srcs",
- "//tensorflow/python:no_contrib",
+ "//tensorflow/python:util",
],
)
py_test(
name = "create_python_api_test",
- srcs = ["create_python_api_test.py"],
+ srcs = [
+ "create_python_api.py",
+ "create_python_api_test.py",
+ ],
srcs_version = "PY2AND3",
deps = [
- ":create_python_api",
+ ":doc_srcs",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:no_contrib",
],
)
@@ -67,5 +66,6 @@ py_test(
":doc_srcs",
"//tensorflow/python:client_testlib",
"//tensorflow/python:no_contrib",
+ "//tensorflow/python/estimator:estimator_py",
],
)
diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/tools/api/generator/api_gen.bzl
index 41713a94ec..d746b5d3e4 100644
--- a/tensorflow/tools/api/generator/api_gen.bzl
+++ b/tensorflow/tools/api/generator/api_gen.bzl
@@ -8,13 +8,16 @@ TENSORFLOW_API_INIT_FILES = [
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
+ "debugging/__init__.py",
"distributions/__init__.py",
"distributions/bijectors/__init__.py",
+ "dtypes/__init__.py",
"errors/__init__.py",
"feature_column/__init__.py",
"gfile/__init__.py",
"graph_util/__init__.py",
"image/__init__.py",
+ "io/__init__.py",
"initializers/__init__.py",
"keras/__init__.py",
"keras/activations/__init__.py",
@@ -65,6 +68,7 @@ TENSORFLOW_API_INIT_FILES = [
"nn/rnn_cell/__init__.py",
"profiler/__init__.py",
"python_io/__init__.py",
+ "quantization/__init__.py",
"resource_loader/__init__.py",
"strings/__init__.py",
"saved_model/__init__.py",
@@ -114,22 +118,44 @@ ESTIMATOR_API_INIT_FILES = [
# template will be replaced with root imports collected by this genrule.
# srcs: genrule sources. If passing root_init_template, the template file
# must be included in sources.
-def gen_api_init_files(name,
- output_files=TENSORFLOW_API_INIT_FILES,
- root_init_template=None,
- srcs=[],
- api_name="tensorflow",
- package="tensorflow.python"):
- root_init_template_flag = ""
- if root_init_template:
- root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
- native.genrule(
- name = name,
- outs = output_files,
- cmd = (
- "$(location //tensorflow/tools/api/generator:create_python_api) " +
- root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"),
- srcs = srcs,
- tools = ["//tensorflow/tools/api/generator:create_python_api"],
- visibility = ["//tensorflow:__pkg__"],
- )
+# api_name: Name of the project that you want to generate API files for
+# (e.g. "tensorflow" or "estimator").
+# package: Python package containing the @tf_export decorators you want to
+# process
+# package_dep: Python library target containing your package.
+
+def gen_api_init_files(
+ name,
+ output_files = TENSORFLOW_API_INIT_FILES,
+ root_init_template = None,
+ srcs = [],
+ api_name = "tensorflow",
+ package = "tensorflow.python",
+ package_dep = "//tensorflow/python:no_contrib"):
+ root_init_template_flag = ""
+ if root_init_template:
+ root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
+
+ api_gen_binary_target = "create_" + package + "_api"
+ native.py_binary(
+ name = "create_" + package + "_api",
+ srcs = ["//tensorflow/tools/api/generator:create_python_api.py"],
+ main = "//tensorflow/tools/api/generator:create_python_api.py",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ package_dep,
+ "//tensorflow/tools/api/generator:doc_srcs",
+ ],
+ )
+
+ native.genrule(
+ name = name,
+ outs = output_files,
+ cmd = (
+ "$(location :" + api_gen_binary_target + ") " +
+ root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"),
+ srcs = srcs,
+ tools = [":" + api_gen_binary_target ],
+ visibility = ["//tensorflow:__pkg__"],
+ )
diff --git a/tensorflow/tools/api/generator/doc_srcs.py b/tensorflow/tools/api/generator/doc_srcs.py
index ccd5bea481..ad1988494d 100644
--- a/tensorflow/tools/api/generator/doc_srcs.py
+++ b/tensorflow/tools/api/generator/doc_srcs.py
@@ -43,7 +43,7 @@ _TENSORFLOW_DOC_SOURCES = {
'gfile': DocSource(docstring_module_name='platform.gfile'),
'graph_util': DocSource(docstring_module_name='framework.graph_util'),
'image': DocSource(docstring_module_name='ops.image_ops'),
- 'keras.estimator': DocSource(docstring_module_name='estimator.keras'),
+ 'keras.estimator': DocSource(docstring_module_name='keras.estimator'),
'linalg': DocSource(docstring_module_name='ops.linalg_ops'),
'logging': DocSource(docstring_module_name='ops.logging_ops'),
'losses': DocSource(docstring_module_name='ops.losses.losses'),
diff --git a/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt b/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt
new file mode 100644
index 0000000000..d9efe97821
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.debugging"
+tf_module {
+ member_method {
+ name: "check_numerics"
+ argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "is_finite"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "is_inf"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "is_nan"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt
new file mode 100644
index 0000000000..98e1feed00
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.dtypes"
+tf_module {
+ member_method {
+ name: "as_string"
+ argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index e268fa3f61..e89b4dbffd 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -85,6 +85,10 @@ tf_module {
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], "
}
member_method {
+ name: "extract_image_patches"
+ argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "extract_jpeg_shape"
argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/tensorflow.io.pbtxt
new file mode 100644
index 0000000000..3a36c168aa
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.io.pbtxt
@@ -0,0 +1,39 @@
+path: "tensorflow.io"
+tf_module {
+ member_method {
+ name: "decode_base64"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "decode_compressed"
+ argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
+ }
+ member_method {
+ name: "decode_json_example"
+ argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "decode_raw"
+ argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
+ name: "encode_base64"
+ argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "matching_files"
+ argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "parse_tensor"
+ argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "read_file"
+ argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "write_file"
+ argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
index 00b9238543..3b5845f99a 100644
--- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
@@ -69,6 +69,10 @@ tf_module {
argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "cross"
+ argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "det"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -141,6 +145,14 @@ tf_module {
argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
}
member_method {
+ name: "tensor_diag"
+ argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "tensor_diag_part"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "tensordot"
argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
index 0b84165285..9add462396 100644
--- a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
@@ -1,7 +1,35 @@
path: "tensorflow.manip"
tf_module {
member_method {
+ name: "batch_to_space_nd"
+ argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "gather_nd"
+ argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "reshape"
+ argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "reverse"
+ argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "roll"
argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
+ member_method {
+ name: "scatter_nd"
+ argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "space_to_batch_nd"
+ argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "tile"
+ argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/tensorflow.math.pbtxt
index 03fbf6266d..25573cb494 100644
--- a/tensorflow/tools/api/golden/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.math.pbtxt
@@ -1,6 +1,38 @@
path: "tensorflow.math"
tf_module {
member_method {
+ name: "acos"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "acosh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "asin"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "asinh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "atan"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "atan2"
+ argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "atanh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "bessel_i0"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i0\'], "
}
@@ -17,7 +49,191 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "betainc"
+ argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "ceil"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "cos"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "cosh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "digamma"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "equal"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "erfc"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "exp"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "expm1"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "floor"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "greater"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "greater_equal"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "igamma"
+ argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "igammac"
+ argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "invert_permutation"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "less"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "less_equal"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "lgamma"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "log"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "log1p"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "logical_and"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "logical_not"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "logical_or"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "maximum"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "minimum"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "not_equal"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "polygamma"
+ argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "polyval"
argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "reciprocal"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "rint"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "rsqrt"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_max"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_mean"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_min"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_prod"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_sum"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "sin"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "sinh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "softplus"
+ argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "softsign"
+ argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "squared_difference"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "tan"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "unsorted_segment_max"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "unsorted_segment_min"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "unsorted_segment_prod"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "unsorted_segment_sum"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "zeta"
+ argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 01b8058118..20d61aae9d 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -309,6 +309,10 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "debugging"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "distributions"
mtype: "<type \'module\'>"
}
@@ -317,6 +321,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "dtypes"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "errors"
mtype: "<type \'module\'>"
}
@@ -381,6 +389,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "io"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "keras"
mtype: "<type \'module\'>"
}
@@ -457,6 +469,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "quantization"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "quint16"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt
new file mode 100644
index 0000000000..6d865efed0
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt
@@ -0,0 +1,35 @@
+path: "tensorflow.quantization"
+tf_module {
+ member_method {
+ name: "dequantize"
+ argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_args"
+ argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_args_gradient"
+ argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_vars"
+ argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_vars_gradient"
+ argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_vars_per_channel"
+ argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_vars_per_channel_gradient"
+ argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "quantized_concat"
+ argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
index b641c39feb..9a831fed26 100644
--- a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
@@ -1,11 +1,43 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "join"
+ argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
+ }
+ member_method {
name: "regex_full_match"
argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "regex_replace"
+ argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
name: "split"
argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], "
}
+ member_method {
+ name: "strip"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "substr"
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "to_hash_bucket"
+ argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "to_hash_bucket_fast"
+ argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "to_hash_bucket_strong"
+ argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "to_number"
+ argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
+ }
}
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 1a50b2a487..696f9b08b3 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -452,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/a587557962e93552e1a8b9270b435b021891e9cd.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/a587557962e93552e1a8b9270b435b021891e9cd.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/19357eaea4f9599bcb228611719e0c5b8fc65298.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/19357eaea4f9599bcb228611719e0c5b8fc65298.tar.gz",
],
- sha256 = "5cf25652e8913e88ce2fb02f1186affd25cf5c1cb2146f9754881daaf3450ddb",
- strip_prefix = "llvm-a587557962e93552e1a8b9270b435b021891e9cd",
+ sha256 = "c07971d102ae5353c4a22c15e82e75f4347a16260c52060187baf4b113161216",
+ strip_prefix = "llvm-19357eaea4f9599bcb228611719e0c5b8fc65298",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)