aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2018-01-26 07:39:13 -0800
committerGravatar Jianwei Xie <xiejw@google.com>2018-01-26 07:39:13 -0800
commit8209078d766038179ae39662b8c230942712ce31 (patch)
treed46082d798260617d7baa6c46ff21a21d2e506fa
parent73cf824a24e46766a1674c7879d8c48bd0728083 (diff)
parentabdc62aee1eeba32be56d761a2f9988306356084 (diff)
solve push conflict
-rw-r--r--configure.py119
-rw-r--r--tensorflow/BUILD9
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc10
-rw-r--r--tensorflow/compiler/xla/map_util.h21
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py25
-rw-r--r--tensorflow/compiler/xla/service/BUILD13
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc37
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc72
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.h65
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc43
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc120
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h1
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc3
-rw-r--r--tensorflow/compiler/xla/service/service.cc33
-rw-r--r--tensorflow/compiler/xla/service/service.h6
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.cc66
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.h46
-rw-r--r--tensorflow/compiler/xla/tests/BUILD9
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc36
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc76
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc125
-rw-r--r--tensorflow/compiler/xla/util.cc15
-rw-r--r--tensorflow/compiler/xla/util.h10
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc28
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt3
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt2
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py75
-rw-r--r--tensorflow/contrib/estimator/BUILD3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py89
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py4
-rw-r--r--tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc7
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_test.py150
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/__init__.py23
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/synthetic.py66
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py262
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py223
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimators_test.py32
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py102
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export_test.py34
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/gc_test.py49
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h2
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h16
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc17
-rw-r--r--tensorflow/contrib/lite/model.cc20
-rw-r--r--tensorflow/contrib/lite/model_test.cc9
-rwxr-xr-x[-rw-r--r--]tensorflow/contrib/lite/schema/schema_generated.h0
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc4
-rw-r--r--tensorflow/contrib/lite/toco/args.h1
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc18
-rw-r--r--tensorflow/contrib/lite/toco/model_flags.proto37
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.h21
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc2
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc27
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h22
-rw-r--r--tensorflow/contrib/lite/tools/BUILD25
-rw-r--r--tensorflow/contrib/lite/tools/verifier.cc43
-rw-r--r--tensorflow/contrib/lite/tools/verifier.h31
-rw-r--r--tensorflow/contrib/lite/tools/verifier_test.cc136
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py77
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py933
-rw-r--r--tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py86
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py115
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py99
-rw-r--r--tensorflow/contrib/predictor/predictor_factories_test.py4
-rw-r--r--tensorflow/contrib/py2tf/converters/break_canonicalization.py26
-rw-r--r--tensorflow/contrib/py2tf/converters/builtin_functions.py7
-rw-r--r--tensorflow/contrib/py2tf/converters/call_trees.py19
-rw-r--r--tensorflow/contrib/py2tf/converters/continue_canonicalization.py22
-rw-r--r--tensorflow/contrib/py2tf/converters/control_flow.py96
-rw-r--r--tensorflow/contrib/py2tf/converters/for_canonicalization.py28
-rw-r--r--tensorflow/contrib/py2tf/converters/side_effect_guards.py24
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates.py55
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates_test.py36
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc11
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py36
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py34
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py48
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h4
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc2
-rw-r--r--tensorflow/contrib/tensorrt/BUILD45
-rw-r--r--tensorflow/contrib/tensorrt/tensorrt_test.cc159
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc11
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto5
-rw-r--r--tensorflow/contrib/tpu/profiler/tpu_profiler.proto13
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py2
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py30
-rw-r--r--tensorflow/core/BUILD7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt5
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc29
-rw-r--r--tensorflow/core/framework/op_kernel.cc7
-rw-r--r--tensorflow/core/graph/costmodel.cc53
-rw-r--r--tensorflow/core/graph/costmodel.h4
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h3
-rw-r--r--tensorflow/core/grappler/costs/measuring_cost_estimator.cc4
-rw-r--r--tensorflow/core/grappler/costs/op_performance_data.proto7
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.h4
-rw-r--r--tensorflow/core/grappler/grappler_item.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc2
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc13
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc21
-rw-r--r--tensorflow/core/kernels/summary_kernels.cc43
-rw-r--r--tensorflow/core/kernels/svd_op_gpu.cu.cc4
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc19
-rw-r--r--tensorflow/core/ops/image_ops.cc36
-rw-r--r--tensorflow/core/ops/training_ops.cc42
-rw-r--r--tensorflow/core/profiler/internal/tfprof_timeline.h1
-rw-r--r--tensorflow/core/profiler/internal/tfprof_utils.cc3
-rw-r--r--tensorflow/core/util/cuda_device_functions.h499
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h857
-rw-r--r--tensorflow/core/util/cuda_kernel_helper_test.cu.cc60
-rw-r--r--tensorflow/core/util/cuda_launch_config.h284
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_reader.py49
-rw-r--r--tensorflow/examples/label_image/label_image.py41
-rw-r--r--tensorflow/python/client/session.py185
-rw-r--r--tensorflow/python/client/session_test.py310
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc3
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/canned/head.py134
-rw-r--r--tensorflow/python/estimator/canned/head_test.py251
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py101
-rw-r--r--tensorflow/python/framework/test_util.py12
-rw-r--r--tensorflow/python/grappler/cost_analyzer_tool.py41
-rw-r--r--tensorflow/python/grappler/tf_optimizer.i5
-rwxr-xr-xtensorflow/python/keras/BUILD32
-rw-r--r--tensorflow/python/keras/_impl/keras/__init__.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/activations.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/__init__.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/densenet.py346
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/densenet_test.py101
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py156
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py19
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/inception_v3.py21
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/mobilenet.py70
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/nasnet.py783
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/nasnet_test.py76
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/resnet50.py51
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/vgg16.py60
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/vgg19.py73
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/xception.py90
-rw-r--r--tensorflow/python/keras/_impl/keras/backend.py289
-rw-r--r--tensorflow/python/keras/_impl/keras/backend_test.py12
-rw-r--r--tensorflow/python/keras/_impl/keras/callbacks.py132
-rw-r--r--tensorflow/python/keras/_impl/keras/constraints.py22
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/boston_housing.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/cifar.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/cifar10.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/cifar100.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py7
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/imdb.py63
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/mnist.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/datasets/reuters.py58
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology.py61
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py588
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_test.py87
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/advanced_activations.py68
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional.py139
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py44
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_test.py66
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/embeddings.py26
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/local.py110
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/merge.py110
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/noise.py49
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py650
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent_test.py99
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers.py58
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers_test.py125
-rw-r--r--tensorflow/python/keras/_impl/keras/losses.py20
-rw-r--r--tensorflow/python/keras/_impl/keras/metrics.py9
-rw-r--r--tensorflow/python/keras/_impl/keras/models.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/models_test.py29
-rw-r--r--tensorflow/python/keras/_impl/keras/optimizers.py125
-rw-r--r--tensorflow/python/keras/_impl/keras/optimizers_test.py1
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/image.py164
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/sequence.py25
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/text.py51
-rw-r--r--tensorflow/python/keras/_impl/keras/preprocessing/text_test.py16
-rw-r--r--tensorflow/python/keras/_impl/keras/regularizers.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/data_utils.py213
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/generic_utils.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/io_utils.py18
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/layer_utils.py35
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/np_utils.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/vis_utils.py35
-rw-r--r--tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py71
-rw-r--r--tensorflow/python/keras/applications/__init__.py7
-rw-r--r--tensorflow/python/keras/applications/densenet/__init__.py29
-rw-r--r--tensorflow/python/keras/applications/nasnet/__init__.py28
-rw-r--r--tensorflow/python/keras/layers/__init__.py3
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py97
-rw-r--r--tensorflow/python/kernel_tests/diag_op_test.py225
-rw-r--r--tensorflow/python/kernel_tests/map_stage_op_test.py105
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py72
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py23
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py20
-rw-r--r--tensorflow/python/kernel_tests/scalar_test.py4
-rw-r--r--tensorflow/python/kernel_tests/sparse_slice_op_test.py102
-rw-r--r--tensorflow/python/kernel_tests/stage_op_test.py34
-rw-r--r--tensorflow/python/layers/base.py139
-rw-r--r--tensorflow/python/layers/maxout.py34
-rw-r--r--tensorflow/python/layers/network.py7
-rw-r--r--tensorflow/python/ops/array_grad.py116
-rw-r--r--tensorflow/python/ops/control_flow_ops.py481
-rw-r--r--tensorflow/python/ops/data_flow_ops.py390
-rw-r--r--tensorflow/python/ops/gradients_impl.py91
-rw-r--r--tensorflow/python/ops/image_ops_impl.py5
-rw-r--r--tensorflow/python/ops/nn_batchnorm_test.py15
-rw-r--r--tensorflow/python/ops/nn_grad.py298
-rw-r--r--tensorflow/python/ops/nn_grad_test.py13
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py38
-rw-r--r--tensorflow/python/ops/rnn.py5
-rw-r--r--tensorflow/python/ops/special_math_ops.py116
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py74
-rw-r--r--tensorflow/python/ops/variable_scope.py4
-rw-r--r--tensorflow/python/ops/variables.py63
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py10
-rw-r--r--tensorflow/python/tools/inspect_checkpoint.py14
-rw-r--r--tensorflow/python/training/checkpoint_utils.py5
-rw-r--r--tensorflow/python/training/coordinator_test.py70
-rw-r--r--tensorflow/python/training/moving_averages.py4
-rw-r--r--tensorflow/python/training/optimizer.py2
-rw-r--r--tensorflow/tensorflow.bzl16
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt23
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt186
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt186
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt183
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt2
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh8
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade.py4
-rw-r--r--tensorflow/workspace.bzl2
-rw-r--r--third_party/gpus/cuda_configure.bzl104
-rw-r--r--third_party/tensorrt/BUILD0
-rw-r--r--third_party/tensorrt/BUILD.tpl67
-rw-r--r--third_party/tensorrt/build_defs.bzl.tpl7
-rw-r--r--third_party/tensorrt/tensorrt_configure.bzl224
297 files changed, 11470 insertions, 5925 deletions
diff --git a/configure.py b/configure.py
index cf16ef4837..083fed1710 100644
--- a/configure.py
+++ b/configure.py
@@ -43,6 +43,7 @@ _DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
+_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu'
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@@ -959,6 +960,119 @@ def set_tf_cudnn_version(environ_cp):
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
+def set_tf_tensorrt_install_path(environ_cp):
+ """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
+
+ Adapted from code contributed by Sami Kama (https://github.com/samikama).
+
+ Args:
+ environ_cp: copy of the os.environ.
+
+ Raises:
+ ValueError: if this method was called under non-Linux platform.
+ UserInputError: if user has provided invalid input multiple times.
+ """
+ if not is_linux():
+ raise ValueError('Currently TensorRT is only supported on Linux platform.')
+
+ # Ask user whether to add TensorRT support.
+ if str(int(get_var(
+ environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1':
+ return
+
+ for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
+ ask_tensorrt_path = (r'Please specify the location where TensorRT is '
+ 'installed. [Default is %s]:') % (
+ _DEFAULT_TENSORRT_PATH_LINUX)
+ trt_install_path = get_from_env_or_user_or_default(
+ environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path,
+ _DEFAULT_TENSORRT_PATH_LINUX)
+
+ # Result returned from "read" will be used unexpanded. That make "~"
+ # unusable. Going through one more level of expansion to handle that.
+ trt_install_path = os.path.realpath(
+ os.path.expanduser(trt_install_path))
+
+ def find_libs(search_path):
+ """Search for libnvinfer.so in "search_path"."""
+ fl = set()
+ if os.path.exists(search_path) and os.path.isdir(search_path):
+ fl.update([os.path.realpath(os.path.join(search_path, x))
+ for x in os.listdir(search_path) if 'libnvinfer.so' in x])
+ return fl
+
+ possible_files = find_libs(trt_install_path)
+ possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
+ possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
+
+ def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
+ """Check the compatibility between tensorrt and cudnn/cudart libraries."""
+ ldd_bin = which('ldd') or '/usr/bin/ldd'
+ ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
+ cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
+ cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
+ cudnn = None
+ cudart = None
+ for line in ldd_out:
+ if 'libcudnn.so' in line:
+ cudnn = cudnn_pattern.search(line)
+ elif 'libcudart.so' in line:
+ cudart = cuda_pattern.search(line)
+ if cudnn and len(cudnn.group(1)):
+ cudnn = convert_version_to_int(cudnn.group(1))
+ if cudart and len(cudart.group(1)):
+ cudart = convert_version_to_int(cudart.group(1))
+ return (cudnn == cudnn_ver) and (cudart == cuda_ver)
+
+ cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
+ cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
+ nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
+ highest_ver = [0, None, None]
+
+ for lib_file in possible_files:
+ if is_compatible(lib_file, cuda_ver, cudnn_ver):
+ ver_str = nvinfer_pattern.search(lib_file).group(1)
+ ver = convert_version_to_int(ver_str) if len(ver_str) else 0
+ if ver > highest_ver[0]:
+ highest_ver = [ver, ver_str, lib_file]
+ if highest_ver[1] is not None:
+ trt_install_path = os.path.dirname(highest_ver[2])
+ tf_tensorrt_version = highest_ver[1]
+ break
+
+ # Try another alternative from ldconfig.
+ ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
+ ldconfig_output = run_shell([ldconfig_bin, '-p'])
+ search_result = re.search(
+ '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output)
+ if search_result:
+ libnvinfer_path_from_ldconfig = search_result.group(2)
+ if os.path.exists(libnvinfer_path_from_ldconfig):
+ if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver):
+ trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
+ tf_tensorrt_version = search_result.group(1)
+ break
+
+ # Reset and Retry
+ print('Invalid path to TensorRT. None of the following files can be found:')
+ print(trt_install_path)
+ print(os.path.join(trt_install_path, 'lib'))
+ print(os.path.join(trt_install_path, 'lib64'))
+ if search_result:
+ print(libnvinfer_path_from_ldconfig)
+
+ else:
+ raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
+ 'times in a row. Assuming to be a scripting mistake.' %
+ _DEFAULT_PROMPT_ASK_ATTEMPTS)
+
+ # Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION
+ environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
+ write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
+ environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version
+ write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version)
+
+
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
@@ -1244,9 +1358,11 @@ def main():
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0'
+ environ_cp['TF_NEED_TENSORRT'] = '0'
if is_macos():
environ_cp['TF_NEED_JEMALLOC'] = '0'
+ environ_cp['TF_NEED_TENSORRT'] = '0'
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
'with_jemalloc', True)
@@ -1278,6 +1394,8 @@ def main():
'TF_CUDA_CONFIG_REPO' not in environ_cp):
set_tf_cuda_version(environ_cp)
set_tf_cudnn_version(environ_cp)
+ if is_linux():
+ set_tf_tensorrt_install_path(environ_cp)
set_tf_cuda_compute_capabilities(environ_cp)
set_tf_cuda_clang(environ_cp)
@@ -1332,6 +1450,7 @@ def main():
'more details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
+ config_info_line('tensorrt', 'Build with TensorRT support.')
if __name__ == '__main__':
main()
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 63849943e4..b26c525525 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -370,6 +370,14 @@ config_setting(
visibility = ["//visibility:public"],
)
+# TODO(laigd): consider removing this option and make TensorRT enabled
+# automatically when CUDA is enabled.
+config_setting(
+ name = "with_tensorrt_support",
+ values = {"define": "with_tensorrt_support=true"},
+ visibility = ["//visibility:public"],
+)
+
package_group(
name = "internal",
packages = [
@@ -566,6 +574,7 @@ filegroup(
"//tensorflow/contrib/tensor_forest/proto:all_files",
"//tensorflow/contrib/tensorboard:all_files",
"//tensorflow/contrib/tensorboard/db:all_files",
+ "//tensorflow/contrib/tensorrt:all_files",
"//tensorflow/contrib/testing:all_files",
"//tensorflow/contrib/text:all_files",
"//tensorflow/contrib/tfprof:all_files",
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 438f1443f1..c22fd37129 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -182,6 +182,7 @@ cc_library(
deps = [
":status",
":status_macros",
+ ":statusor",
":types",
":xla_data_proto",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index d6b4ebfc39..952109dde2 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -98,6 +98,7 @@ cc_library(
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:source_map_util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@llvm//:support",
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 523169fdd2..fbeedfcecd 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -21,10 +21,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
+#include "tensorflow/compiler/xla/service/source_map_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace se = ::perftools::gputools;
+using xla::source_map_util::InvalidParameterArgument;
+
namespace xla {
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
@@ -79,9 +82,10 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
for (int i = 0; i < arguments.size(); ++i) {
if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
arguments[i]->on_host_shape())) {
- return InvalidArgument(
- "argument does not match shape or layout of computation parameter "
- "%d: expected %s, got %s",
+ return InvalidParameterArgument(
+ executable_.get(), i,
+ "Argument does not match shape or layout of computation parameter "
+ "%d: want %s, got %s",
i,
ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
.c_str(),
diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h
index 50659c1240..0ad0b91330 100644
--- a/tensorflow/compiler/xla/map_util.h
+++ b/tensorflow/compiler/xla/map_util.h
@@ -16,6 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
+#include <functional>
+#include <sstream>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -44,6 +49,22 @@ typename Collection::value_type::second_type& FindOrDie(
return it->second;
}
+// Like FindOrDie but returns an error instead of dying if `key` is not in
+// `container`.
+template <class Collection>
+StatusOr<
+ std::reference_wrapper<const typename Collection::value_type::second_type>>
+MaybeFind(const Collection& collection,
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::const_iterator it = collection.find(key);
+ if (it == collection.end()) {
+ std::ostringstream os;
+ os << key;
+ return NotFound("key not found: %s", os.str().c_str());
+ }
+ return {it->second};
+}
+
// Inserts the key-value pair into the collection. Dies if key was already
// present.
template <class Collection>
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 9cfe1249f5..66ace613a0 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -36,15 +36,22 @@ from tensorflow.compiler.xla.python import pywrap_xla as c_api
# pylint: disable=invalid-name
-OpMetadata = collections.namedtuple(
- 'OpMetadata',
- [
- 'op_type',
- 'op_name',
- 'source_file',
- 'source_line',
- ],
-)
+_OP_METADATA_FIELDS = [
+ 'op_type',
+ 'op_name',
+ 'source_file',
+ 'source_line',
+]
+OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS)
+
+
+def OpMetadataToProto(pyobj):
+ proto = xla_data_pb2.OpMetadata()
+ for field in _OP_METADATA_FIELDS:
+ attr = getattr(pyobj, field)
+ if attr is not None:
+ setattr(proto, field, attr)
+ return proto
def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1):
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 9a0acda94f..469acc330c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -460,6 +460,7 @@ cc_library(
":hlo_proto_util",
":platform_util",
":session_proto",
+ ":source_map_util",
":transfer_manager",
":user_computation",
":versioned_computation_handle",
@@ -2348,6 +2349,18 @@ tf_cc_test(
],
)
+cc_library(
+ name = "source_map_util",
+ srcs = ["source_map_util.cc"],
+ hdrs = ["source_map_util.h"],
+ deps = [
+ ":executable",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 323620c131..d5594dc07c 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -1358,6 +1358,43 @@ void BufferAssigner::BuildColocatedBufferSets(
index, points_to_analysis, &colocated_set);
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
});
+
+ // Add true_operand and conditional.true_computation.parameter(0) as a
+ // colocated buffer set. Note that this has to be done for each subshape
+ // in the true_operand of the conditional.
+ ShapeUtil::ForEachSubshape(
+ conditional_hlo->operand(1)->shape(),
+ [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ std::vector<const LogicalBuffer*> true_set;
+ // Add conditional.true_operand.
+ AddBufferToColocatedSet(conditional_hlo->operand(1), index,
+ points_to_analysis, &true_set);
+ // Add conditional.true_computation.parameter_instruction(0).
+ AddBufferToColocatedSet(
+ conditional_hlo->true_computation()->parameter_instruction(0),
+ index, points_to_analysis, &true_set);
+ AddSetToColocatedBufferSets(true_set, colocated_buffer_sets);
+ });
+
+ // Add false_operand and conditional.false_computation.parameter(0) as a
+ // colocated buffer set. Note that this has to be done for each subshape
+ // in the false_operand of the conditional.
+ ShapeUtil::ForEachSubshape(
+ conditional_hlo->operand(2)->shape(),
+ [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ std::vector<const LogicalBuffer*> false_set;
+ // Add conditional.false_operand.
+ AddBufferToColocatedSet(conditional_hlo->operand(2), index,
+ points_to_analysis, &false_set);
+ // Add conditional.false_computation.parameter_instruction(0).
+ AddBufferToColocatedSet(
+ conditional_hlo->false_computation()->parameter_instruction(
+ 0),
+ index, points_to_analysis, &false_set);
+ AddSetToColocatedBufferSets(false_set, colocated_buffer_sets);
+ });
}
}
}
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index b9306a8bb0..dab73596e1 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -101,7 +101,7 @@ CompileOnlyService::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(*program_shape, instance.argument_layouts,
- &execution_options));
+ &execution_options, *user_computation));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
computation_tracker_.BuildHloModule(
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index df5e2e35f8..3c3328b9cd 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -228,6 +228,7 @@ cc_library(
cc_library(
name = "gpu_executable",
srcs = [
+ "conditional_thunk.cc",
"convolution_thunk.cc",
"copy_thunk.cc",
"cudnn_batchnorm_thunk.cc",
@@ -243,6 +244,7 @@ cc_library(
"while_thunk.cc",
],
hdrs = [
+ "conditional_thunk.h",
"convolution_thunk.h",
"copy_thunk.h",
"cudnn_batchnorm_thunk.h",
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
new file mode 100644
index 0000000000..790ca535b1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace gpu {
+
+ConditionalThunk::ConditionalThunk(
+ const BufferAllocation::Slice& predicate_buffer_index,
+ const BufferAllocation::Slice& true_operand_buffer_index,
+ const BufferAllocation::Slice& false_operand_buffer_index,
+ ThunkSequence true_thunk_sequence, ThunkSequence false_thunk_sequence,
+ const HloInstruction* hlo)
+ : Thunk(Kind::kConditional, hlo),
+ predicate_buffer_index_(predicate_buffer_index),
+ true_operand_buffer_index_(true_operand_buffer_index),
+ false_operand_buffer_index_(false_operand_buffer_index),
+ true_thunk_(std::move(true_thunk_sequence), hlo),
+ false_thunk_(std::move(false_thunk_sequence), hlo) {}
+
+Status ConditionalThunk::Initialize(const GpuExecutable& executable) {
+ TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable));
+ TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable));
+ return Status::OK();
+}
+
+Status ConditionalThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) {
+ // Copy the predicate value from device.
+ bool predicate;
+ perftools::gputools::DeviceMemoryBase predicate_address =
+ buffer_allocations.GetDeviceAddress(predicate_buffer_index_);
+ stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool));
+
+ Status block_status = stream->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ return InternalError("Failed to retrieve predicate value on stream %p: %s.",
+ stream, block_status.error_message().c_str());
+ }
+
+ // Execute the true or the false computation depending on the value of the
+ // predicate.
+ if (predicate) {
+ TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ } else {
+ TF_RETURN_IF_ERROR(
+ false_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ }
+
+ return Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
new file mode 100644
index 0000000000..7725c46a3b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// ConditionalThunk implements the conditional instruction on GPU by reading the
+// predicate of the conditional and executing the true or the false computation
+// depending on the value of the predicate.
+//
+// ConditionalThunk assumes that the buffers of the conditional result and the
+// result of the true and false computations share the same allocation. Also,
+// the buffers of the true operand of the conditional and that of the parameter
+// instruction of the true computation share the same allocation. Similarly, the
+// buffers of the false operand and that of the parameter instruction of the
+// false computation share the same allocation.
+class ConditionalThunk : public Thunk {
+ public:
+ ConditionalThunk(const BufferAllocation::Slice& predicate_buffer_index,
+ const BufferAllocation::Slice& true_operand_buffer_index,
+ const BufferAllocation::Slice& false_operand_buffer_index,
+ ThunkSequence true_thunk_sequence,
+ ThunkSequence false_thunk_sequence,
+ const HloInstruction* hlo);
+
+ ConditionalThunk(const ConditionalThunk&) = delete;
+ ConditionalThunk& operator=(const ConditionalThunk&) = delete;
+
+ Status Initialize(const GpuExecutable& executable) override;
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ BufferAllocation::Slice predicate_buffer_index_;
+ BufferAllocation::Slice true_operand_buffer_index_;
+ BufferAllocation::Slice false_operand_buffer_index_;
+ SequentialThunk true_thunk_;
+ SequentialThunk false_thunk_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index e67087d822..e3b493c663 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -36,7 +36,7 @@ namespace gpu {
StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
HloInstruction* hlo) {
- HloInstruction*& copy = inserted_copies_[hlo];
+ HloInstruction*& copy = hlo_to_copy_map_[hlo];
if (copy == nullptr) {
TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
}
@@ -86,27 +86,34 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
}
}
- // Init values of a while node cannot be constants. Insert copies for any
- // constants found at the operand of a while.
- tensorflow::gtl::FlatSet<HloInstruction*> copied_constants;
+ // Init values of while and conditional nodes cannot be constants. Insert
+ // copies for any constants found at the operands of these nodes.
+ tensorflow::gtl::FlatSet<HloInstruction*> inserted_copies;
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() != HloOpcode::kWhile) {
+ if (instruction->opcode() != HloOpcode::kWhile &&
+ instruction->opcode() != HloOpcode::kConditional) {
continue;
}
- for (auto& pair :
- dataflow->GetInstructionValueSet(instruction->operand(0))) {
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (value->defining_instruction()->opcode() ==
- HloOpcode::kConstant &&
- !ContainsKey(copied_constants, value->defining_instruction())) {
- HloInstruction* constant = value->defining_instruction();
- TF_ASSIGN_OR_RETURN(HloInstruction * copy,
- FindOrInsertCopy(constant));
- TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
- copied_constants.insert(constant);
- changed = true;
+ for (auto operand : instruction->operands()) {
+ // Skip the operands that have already been replaced with a copy in a
+ // previous iteration (which is possible when a constant is used as an
+ // operand in multiple places).
+ if (ContainsKey(inserted_copies, operand)) {
+ continue;
+ }
+ for (auto& pair : dataflow->GetInstructionValueSet(operand)) {
+ const HloValueSet& value_set = pair.second;
+ for (const HloValue* value : value_set.values()) {
+ if (value->defining_instruction()->IsConstant() &&
+ !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) {
+ HloInstruction* constant = value->defining_instruction();
+ TF_ASSIGN_OR_RETURN(HloInstruction * copy,
+ FindOrInsertCopy(constant));
+ TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
+ inserted_copies.insert(copy);
+ changed = true;
+ }
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 4d77f337e6..0c6f9b511f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -32,13 +32,13 @@ class GpuCopyInsertion : public HloPassInterface {
StatusOr<bool> Run(HloModule* module) override;
protected:
- // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
+ // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
// duplicate copies.
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
// A map containing all copies inserted to materialize operands of library
// calls. The key is the copied instruction and the value is the copy.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
+ tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 095c3df3bf..23b72c3f71 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -758,37 +758,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) {
"to a cudnn CustomCall using CudnnBatchNormRewriter.");
}
-Status IrEmitter::HandleConditional(HloInstruction* conditional) {
- auto pred = conditional->operand(0);
- auto true_arg = conditional->operand(1);
- auto false_arg = conditional->operand(2);
-
- llvm::Value* conditional_result = GetBasePointer(*conditional);
-
- llvm::LoadInst* pred_value = ir_builder_.CreateLoad(
- GetBasePointer(*pred),
- llvm_ir::AsStringRef(IrName(conditional, "load_predicate_value")));
- llvm::Value* pred_cond = ir_builder_.CreateICmpNE(
- pred_value,
- llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0),
- llvm_ir::AsStringRef(IrName(conditional, "boolean_predicate")));
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- pred_cond, IrName(conditional, "if_then_else"), &ir_builder_);
-
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *conditional->true_computation(), {GetBasePointer(*true_arg)},
- conditional_result));
-
- SetToFirstInsertPoint(if_data.false_block, &ir_builder_);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *conditional->false_computation(), {GetBasePointer(*false_arg)},
- conditional_result));
-
- SetToFirstInsertPoint(if_data.after_block, &ir_builder_);
- return Status::OK();
-}
-
llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest(
const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 39bafaa346..3aa178410f 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -96,7 +96,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
Status HandleRng(HloInstruction* random) override;
- Status HandleConditional(HloInstruction* conditional) override;
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
@@ -367,6 +366,11 @@ class IrEmitterUnnested : public IrEmitter {
std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
const int64 loop_limit);
+ // Returns a ConditionalThunk that executes the thunk sequence for
+ // 'true_computation' or 'false_computation' depending on the value of the
+ // predicate in the given conditional instruction.
+ std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
+
Status Postprocess(HloInstruction* hlo) override;
// Returns the last generated thunk.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index be35351e87..fc8783e753 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
@@ -272,8 +273,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
}
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
- thunk_sequence_->push_back(BuildKernelThunk(conditional));
- return IrEmitter::HandleConditional(conditional);
+ thunk_sequence_->emplace_back(BuildConditionalThunk(conditional));
+ return Status::OK();
}
Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
@@ -2102,6 +2103,24 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
namespace {
+// Checks that the buffers corresponding to the given two HLOs share the same
+// allocation.
+Status CheckHloBuffersShareAllocation(
+ const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
+ const BufferAssignment& buffer_assignment) {
+ const BufferAllocation::Slice slice_a =
+ buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
+ const BufferAllocation::Slice slice_b =
+ buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
+ if (slice_a != slice_b) {
+ return InternalError(
+ "instruction %s %s does not share allocation with instruction %s %s",
+ a->ToString().c_str(), slice_a.ToString().c_str(),
+ b->ToString().c_str(), slice_b.ToString().c_str());
+ }
+ return Status::OK();
+}
+
// Checks that all buffers used during while loop iteration share the same
// buffer allocation. This includes buffers for while result, while init
// operand, condition parameter, body parameter and body result.
@@ -2111,37 +2130,65 @@ Status CheckWhileBuffersShareAllocation(
const BufferAssignment& buffer_assignment) {
return ShapeUtil::ForEachSubshapeWithStatus(
xla_while->shape(),
- [&buffer_assignment, &xla_while](const Shape& /*subshape*/,
- const ShapeIndex& index) -> Status {
- auto check = [&buffer_assignment](const HloInstruction* a,
- const HloInstruction* b,
- const ShapeIndex& index) -> Status {
- const BufferAllocation::Slice slice_a =
- buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
- const BufferAllocation::Slice slice_b =
- buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
- if (slice_a != slice_b) {
- return InternalError(
- "instruction %s %s does not share allocation with "
- "instruction %s %s",
- a->ToString().c_str(), slice_a.ToString().c_str(),
- b->ToString().c_str(), slice_b.ToString().c_str());
- }
- return Status::OK();
- };
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
const HloInstruction* condition_parameter =
xla_while->while_condition()->parameter_instruction(0);
const HloComputation* body = xla_while->while_body();
const HloInstruction* body_parameter = body->parameter_instruction(0);
const HloInstruction* body_result = body->root_instruction();
- TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
- TF_RETURN_IF_ERROR(check(xla_while, condition_parameter, index));
- TF_RETURN_IF_ERROR(check(xla_while, body_parameter, index));
- TF_RETURN_IF_ERROR(check(xla_while, body_result, index));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, xla_while->operand(0), index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, condition_parameter, index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, body_parameter, index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ xla_while, body_result, index, buffer_assignment));
return Status::OK();
});
}
+// Checks that the buffers used in a conditional instruction are shared with the
+// operands and result as follows:
+// * The result buffer of the conditional should share the allocation with the
+// result buffers of the true and false computations.
+// * The buffer of operand 1 should share the allocation with the buffer of
+// the parameter 0 instruction of the true computation.
+// * The buffer of operand 2 should share the allocation with the buffer of
+// the parameter 0 instruction of the false computation.
+Status CheckConditionalBuffersShareAllocation(
+ const HloInstruction* conditional,
+ const BufferAssignment& buffer_assignment) {
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ conditional, conditional->true_computation()->root_instruction(),
+ index, buffer_assignment));
+ TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
+ conditional, conditional->false_computation()->root_instruction(),
+ index, buffer_assignment));
+ return Status::OK();
+ }));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->operand(1)->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ return CheckHloBuffersShareAllocation(
+ conditional->operand(1),
+ conditional->true_computation()->parameter_instruction(0), index,
+ buffer_assignment);
+ }));
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ conditional->operand(2)->shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
+ return CheckHloBuffersShareAllocation(
+ conditional->operand(2),
+ conditional->false_computation()->parameter_instruction(0), index,
+ buffer_assignment);
+ }));
+ return Status::OK();
+}
+
} // namespace
std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
@@ -2184,6 +2231,31 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
ir_emitter_body.ConsumeThunkSequence(), hlo);
}
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
+ const HloInstruction* hlo) {
+ // Check that the buffers used in conditional are shared with the operands and
+ // result appropriately.
+ TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
+ hlo, ir_emitter_context_->buffer_assignment()));
+
+ HloComputation* true_computation = hlo->true_computation();
+ IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation,
+ ir_emitter_context_);
+ TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true));
+
+ HloComputation* false_computation = hlo->false_computation();
+ IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation,
+ ir_emitter_context_);
+ TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false));
+
+ return MakeUnique<ConditionalThunk>(
+ GetAllocationSlice(*hlo->operand(0)),
+ GetAllocationSlice(*hlo->operand(1)),
+ GetAllocationSlice(*hlo->operand(2)),
+ std::move(*ir_emitter_true.ConsumeThunkSequence()),
+ std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo);
+}
+
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 625c3f8bea..2c3032d79b 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -41,6 +41,7 @@ class GpuExecutable;
class Thunk {
public:
enum class Kind {
+ kConditional,
kConvolution,
kCopy,
kCudnnBatchNormBackward,
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 2194d24257..f30530db08 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -128,7 +128,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, argument_layouts, &execution_options));
+ CreateModuleConfig(*program_shape, argument_layouts, &execution_options,
+ *user_computation));
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
execute_backend_->stream_executor(device_ordinal));
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 926ebbe314..849df1d8e6 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/source_map_util.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -56,6 +57,7 @@ namespace se = ::perftools::gputools;
using ::tensorflow::strings::Printf;
using ::tensorflow::strings::StrCat;
+using ::xla::source_map_util::InvalidParameterArgument;
namespace xla {
@@ -261,7 +263,8 @@ StatusOr<std::vector<const ShapedBuffer*>> Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
- const ExecutionOptions* execution_options) {
+ const ExecutionOptions* execution_options,
+ const UserComputation& user_computation) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
auto* computation_layout = config->mutable_entry_computation_layout();
@@ -275,8 +278,10 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
// ProgramShape.
if (!ShapeUtil::Compatible(*argument_shapes[i],
program_shape.parameters(i))) {
- return InvalidArgument(
- "computation expects parameter %d to have shape %s, given shape %s",
+ return InvalidParameterArgument(
+ *user_computation.ParameterMetadata(i).value(),
+ "Argument does not match shape of computation parameter %d: want %s, "
+ "got %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
@@ -318,12 +323,14 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutionOptions& execution_options) {
+ const ExecutionOptions& execution_options,
+ const UserComputation& user_computation) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
argument_shapes.push_back(&arg->on_host_shape());
}
- return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
+ return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
+ user_computation);
}
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
@@ -742,9 +749,10 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
// Create an HloModuleConfig object for the computation, given the shape of
// the program and the argument allocations.
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments,
- request.execution_options()));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(*program_shape, arguments,
+ request.execution_options(), *user_computation));
VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -852,7 +860,8 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments, arg->execution_options()));
+ CreateModuleConfig(*program_shape, arguments, arg->execution_options(),
+ *user_computation));
VLOG(3) << "Execute created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -916,7 +925,8 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, arguments, arg->execution_options()));
+ CreateModuleConfig(*program_shape, arguments, arg->execution_options(),
+ *user_computation));
VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -1236,7 +1246,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(program_shape, {}, execution_options));
+ CreateModuleConfig(program_shape, {}, execution_options,
+ *user_computation));
// Exclude dead parameter instructions for the purpose of computing constants.
TF_ASSIGN_OR_RETURN(
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 0a7d0b3a7d..ca77e8fe3a 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -251,7 +251,8 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutionOptions& execution_options);
+ const ExecutionOptions& execution_options,
+ const UserComputation& user_computation);
protected:
friend class LocalExecutable;
@@ -275,7 +276,8 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
- const ExecutionOptions* execution_options);
+ const ExecutionOptions* execution_options,
+ const UserComputation& user_computation);
// Builds an Executable for the given parameters.
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
new file mode 100644
index 0000000000..8cbaac7b37
--- /dev/null
+++ b/tensorflow/compiler/xla/service/source_map_util.cc
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/source_map_util.h"
+
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+namespace source_map_util {
+namespace {
+
+Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
+ const char* format, va_list args) {
+ string message;
+ tensorflow::strings::Appendv(&message, format, args);
+ if (!op_metadata.source_file().empty()) {
+ tensorflow::strings::Appendf(&message, " (%s:%d)",
+ op_metadata.source_file().c_str(),
+ op_metadata.source_line());
+ }
+ return InvalidArgument("%s", message.c_str());
+}
+
+} // namespace
+
+Status InvalidParameterArgument(const OpMetadata& op_metadata,
+ const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ Status result = InvalidParameterArgumentV(op_metadata, format, args);
+ va_end(args);
+ return result;
+}
+
+Status InvalidParameterArgument(Executable* executable, int parameter_number,
+ const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ if (executable != nullptr && executable->has_module()) {
+ const HloModule& module = executable->module();
+ const HloComputation& computation = *module.entry_computation();
+ HloInstruction* param = computation.parameter_instruction(parameter_number);
+ const OpMetadata& metadata = param->metadata();
+ Status result = InvalidParameterArgumentV(metadata, format, args);
+ va_end(args);
+ return result;
+ }
+ Status result = InvalidArgumentV(format, args);
+ va_end(args);
+ return result;
+}
+
+} // namespace source_map_util
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h
new file mode 100644
index 0000000000..a776d745f4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/source_map_util.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+namespace source_map_util {
+
+// Creates an INVALID_ARUGMENT status with the given format string.
+//
+// Also, attempts to extract the OpMetadata for parameter_number on executable
+// and append it to the status message for source mapping to user code.
+//
+// executable may be nullptr, but parameter_number should not be out of bounds
+// or a CHECK-failure may occur.
+Status InvalidParameterArgument(Executable* executable, int parameter_number,
+ const char* format, ...)
+ TF_PRINTF_ATTRIBUTE(3, 4);
+
+// As above, but takes the parameter metadata directly instead of extracting it
+// from the executable.
+Status InvalidParameterArgument(const OpMetadata& op_metadata,
+ const char* format, ...)
+ TF_PRINTF_ATTRIBUTE(2, 3);
+
+} // namespace source_map_util
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 3afd52b6b2..4410647f84 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -351,6 +351,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:platform_util",
@@ -848,7 +849,8 @@ xla_test(
name = "half_test",
srcs = ["half_test.cc"],
backends = [
- "cpu",
+ # TODO(b/72509305): Flaky (fails with SEGV) as of 2018-01-25
+ # "cpu",
"gpu",
],
deps = [
@@ -1034,7 +1036,10 @@ xla_test(
name = "select_and_scatter_test",
timeout = "long",
srcs = ["select_and_scatter_test.cc"],
- tags = ["enable_for_xla_interpreter"],
+ tags = [
+ "enable_for_xla_interpreter",
+ "optonly",
+ ],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index 659660d91e..f594cc10ac 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -104,7 +104,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
ASSERT_FALSE(status.ok());
ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
ASSERT_THAT(status.status().error_message(),
- ContainsRegex("expects parameter 0"));
+ ContainsRegex(
+ "Argument does not match shape of computation parameter 0"));
// Shape mismatch in parameter 1 (rank)
status = client_->Execute(computation, {f32_data.get(), f32_data.get()},
@@ -112,7 +113,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
ASSERT_FALSE(status.ok());
ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
ASSERT_THAT(status.status().error_message(),
- ContainsRegex("expects parameter 1"));
+ ContainsRegex(
+ "Argument does not match shape of computation parameter 1"));
// Shape mismatch in parameter 1 (element type)
status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()},
@@ -120,7 +122,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
ASSERT_FALSE(status.ok());
ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
ASSERT_THAT(status.status().error_message(),
- ContainsRegex("expects parameter 1"));
+ ContainsRegex(
+ "Argument does not match shape of computation parameter 1"));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 0016b6cc61..bc82167482 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -355,8 +355,7 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
}
// Test true and false computations that return a tuple of arrays.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) {
+XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
ComputationBuilder builder(client_, TestName());
auto pred = builder.ConstantR0<bool>(true);
auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}),
@@ -373,9 +372,7 @@ XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) {
// Test true and false computations that return a tuple of a predicate, a
// scalar, and an array.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest,
- DISABLED_ON_GPU(ReturnTupleofPredicateScalarArray)) {
+XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
ComputationBuilder true_builder(client_, TestName() + ".true");
{
true_builder.Parameter(0, empty_tuple_, "tuple");
@@ -413,8 +410,7 @@ XLA_TEST_F(ConditionalOpTest,
}
// Test true and false computations that return a nested tuple.
-// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend.
-XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnNestedTuple)) {
+XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
ComputationBuilder true_builder(client_, TestName() + ".true");
{
true_builder.Parameter(0, empty_tuple_, "tuple");
@@ -532,6 +528,32 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
+XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
+ ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional");
+ {
+ Shape r0bool = ShapeUtil::MakeShape(PRED, {});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
+ auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
+ auto pred_cond = inner_builder.GetTupleElement(param0, 0);
+ auto true_operand = inner_builder.GetTupleElement(param0, 1);
+ auto false_operand = inner_builder.GetTupleElement(param0, 2);
+ inner_builder.Conditional(pred_cond, true_operand,
+ CreateR0CeilComputation(), false_operand,
+ CreateR0FloorComputation());
+ }
+ auto inner_builder_result = inner_builder.Build();
+ EXPECT_IS_OK(inner_builder_result.status());
+
+ ComputationBuilder builder(client_, TestName());
+ auto pred2 = builder.ConstantR0<bool>(false);
+ auto operand1 = builder.ConstantR0<float>(1.1f);
+ auto operand2 = builder.ConstantR0<float>(12.2f);
+ auto tuple_operand = builder.Tuple({pred2, operand1, operand2});
+ builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
+
+ ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
+}
+
// Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 73b37e201a..7f3c72671d 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -1016,37 +1016,39 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
::testing::tuple<R2ReduceWindowTestData, bool>> {
protected:
R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
-};
-TEST_P(R2ReduceWindowTest, Add) {
- ComputationBuilder b(client_, TestName());
- const auto& param = ::testing::get<0>(GetParam());
- CHECK(param.reducer == kAdd);
-
- const float kInitValue = 0.0f;
- Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
- std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2DWithLayout(
- input, LayoutUtil::MakeLayout(param.layout));
+ void DoIt() {
+ ComputationBuilder b(client_, TestName());
+ const auto& param = ::testing::get<0>(GetParam());
+ CHECK(param.reducer == kAdd);
- ComputationDataHandle parameter;
- auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
- &b, &parameter);
- auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
- b.ReduceWindow(/*operand=*/parameter,
- /*init_value=*/init_value,
- /*computation=*/CreateScalarAddComputation(FloatType(), &b),
- /*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/param.padding);
+ const float kInitValue = 0.0f;
+ Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
+ std::unique_ptr<Literal> input_literal =
+ Literal::CreateR2FromArray2DWithLayout(
+ input, LayoutUtil::MakeLayout(param.layout));
- auto expected = ReferenceUtil::ReduceWindow2DAdd(
- /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
- /*stride=*/param.strides, /*padding=*/param.padding);
+ ComputationDataHandle parameter;
+ auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
+ &b, &parameter);
+ auto init_value =
+ CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
+ b.ReduceWindow(/*operand=*/parameter,
+ /*init_value=*/init_value,
+ /*computation=*/CreateScalarAddComputation(FloatType(), &b),
+ /*window_dimensions=*/param.window_bounds,
+ /*window_strides=*/param.strides, /*padding=*/param.padding);
+
+ auto expected = ReferenceUtil::ReduceWindow2DAdd(
+ /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
+ /*stride=*/param.strides, /*padding=*/param.padding);
+
+ ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
+ {input_arg.get()}, DefaultErrorSpec());
+ }
+};
- ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
- {input_arg.get()}, DefaultErrorSpec());
-}
+TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
INSTANTIATE_TEST_CASE_P(
R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
@@ -1054,6 +1056,26 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(use_bfloat16_params)),
R2ReduceWindowTestDataToString);
+class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {};
+
+// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
+XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test,
+ DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
+ DoIt();
+}
+
+const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = {
+ {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128},
+ /*strides=*/{1, 1}, /*layout=*/{1, 0},
+ /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+};
+
+INSTANTIATE_TEST_CASE_P(
+ R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test,
+ ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test),
+ ::testing::ValuesIn(use_bfloat16_params)),
+ R2ReduceWindowTestDataToString);
+
struct R1ReduceWindowTestData {
int64 base_bounds[1];
int64 window_bounds[1];
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 1d2f436194..9ad2a19853 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -19,12 +19,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -32,6 +34,7 @@ limitations under the License.
namespace xla {
namespace {
namespace se = ::perftools::gputools;
+namespace gtl = ::tensorflow::gtl;
class HloProfileTest : public ClientLibraryTestBase {};
@@ -43,39 +46,74 @@ struct ParsedProfileOutputLine {
string trops;
string bytes_per_sec;
string bytes_per_cycle;
- string name;
+ string opcode;
};
-StatusOr<ParsedProfileOutputLine> ParseProfileOutputLine(const string& line,
- bool expect_flops,
- bool expect_trops) {
+::testing::AssertionResult HasFlops(
+ const ParsedProfileOutputLine& parsed_line) {
+ if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) {
+ return ::testing::AssertionSuccess()
+ << "'flops' field present in " << parsed_line.opcode << ": '"
+ << parsed_line.flops << "'";
+ }
+
+ return ::testing::AssertionFailure()
+ << "'flops' field absent in " << parsed_line.opcode << ": '"
+ << parsed_line.flops << "'";
+}
+
+::testing::AssertionResult HasTrops(
+ const ParsedProfileOutputLine& parsed_line) {
+ if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) {
+ return ::testing::AssertionSuccess()
+ << "'trops' field present in " << parsed_line.opcode << ": '"
+ << parsed_line.trops << "'";
+ }
+
+ return ::testing::AssertionFailure()
+ << "'trops' field absent in " << parsed_line.opcode << ": '"
+ << parsed_line.trops << "'";
+}
+
+Status ParseOneProfileOutputLine(
+ const string& line, bool expect_hlo,
+ gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results) {
string separator = "[^:]*:: +";
string match_percentage = "\\d+\\.\\d\\d%";
string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)";
string match_usecs = "([0-9.]+) usec";
- string match_flops = expect_flops ? "([0-9.TGMk]+)FLOP/s" : "(<none>)";
- string match_trops = expect_trops ? "([0-9.TGMk]+)TROP/s" : "(<none>)";
+ string match_flops = "([^ ]+)";
+ string match_trops = "([^ ]+)";
string match_bytes_per_sec = "([0-9.TGMKi]+)B/s";
string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle";
+
+ // The underlined part is what we're trying to match with match_opcode:
+ //
+ // %dot33 = f32[256,256]{1,0} dot(...)
+ // ^^^
+
+ string match_opcode =
+ expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])";
string regexp_pattern = tensorflow::strings::StrCat(
" +", match_cycles, separator, match_usecs, separator, match_flops,
separator, match_trops, separator, match_bytes_per_sec, separator,
- match_bytes_per_cycle, separator, "(.*)");
+ match_bytes_per_cycle, separator, match_opcode);
- RE2 pattern(regexp_pattern);
ParsedProfileOutputLine parsed_line;
bool matched = RE2::FullMatch(
- line, pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
+ line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage,
&parsed_line.usec, &parsed_line.flops, &parsed_line.trops,
&parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle,
- &parsed_line.name);
+ &parsed_line.opcode);
if (!matched) {
return tensorflow::errors::InvalidArgument(
"Input did not match regexp. Input: ", line,
", Regexp: ", regexp_pattern);
}
- return parsed_line;
+ InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+
+ return Status::OK();
}
// Returns void so that we can ASSERT.
@@ -148,7 +186,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) {
ClientLibrary::GetOrCreateLocalClient(platform));
ComputationBuilder builder(client, TestName());
- auto result = builder.Tanh(builder.Dot(
+ auto result = builder.Tanh(builder.Add(
builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
@@ -161,31 +199,43 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) {
std::vector<string> profile_output_lines =
tensorflow::str_util::Split(profile_output, '\n');
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine total_profile,
- ParseProfileOutputLine(profile_output_lines[1], /*expect_flops=*/true,
- /*expect_trops=*/true));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine dot_profile,
- ParseProfileOutputLine(profile_output_lines[2], /*expect_flops=*/true,
- /*expect_trops=*/false));
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines));
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine tanh_profile,
- ParseProfileOutputLine(profile_output_lines[3], /*expect_flops=*/false,
- /*expect_trops=*/true));
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[2], /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ profile_output_lines[3], /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile,
+ MaybeFind(parsed_profile_lines, "[total]"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
+ MaybeFind(parsed_profile_lines, "add"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile,
+ MaybeFind(parsed_profile_lines, "tanh"));
EXPECT_GT(total_profile.cycles, 0);
EXPECT_EQ(total_profile.cycles_percentage, "100.00%");
+ EXPECT_TRUE(HasFlops(total_profile));
+ EXPECT_TRUE(HasTrops(total_profile));
+
EXPECT_GT(total_profile.cycles, dot_profile.cycles);
EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
+ EXPECT_TRUE(HasFlops(dot_profile));
+ EXPECT_FALSE(HasTrops(dot_profile));
+
EXPECT_GT(total_profile.cycles, tanh_profile.cycles);
EXPECT_NE(tanh_profile.cycles_percentage, "0.00%");
EXPECT_NE(tanh_profile.cycles_percentage, "100.00%");
+
+ EXPECT_FALSE(HasFlops(tanh_profile));
+ EXPECT_TRUE(HasTrops(tanh_profile));
}
// TODO(b/71364943): This test exposes a bug in the parallel CPU backend.
@@ -220,7 +270,7 @@ XLA_TEST_F(HloProfileTest,
auto matrix = builder.GetTupleElement(state, 1);
auto next_iteration = builder.Add(builder.GetTupleElement(state, 0),
builder.ConstantR0<int32>(1));
- builder.Tuple({next_iteration, builder.Dot(matrix, matrix)});
+ builder.Tuple({next_iteration, builder.Add(matrix, matrix)});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
@@ -249,20 +299,23 @@ XLA_TEST_F(HloProfileTest,
ASSERT_NE(while_body_profile_start, profile_output_lines.end());
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine total_while_body_profile,
- ParseProfileOutputLine(*std::next(while_body_profile_start, 1),
- /*expect_flops=*/false,
- /*expect_trops=*/false));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK_AND_ASSIGN(
- ParsedProfileOutputLine dot_profile,
- ParseProfileOutputLine(*std::next(while_body_profile_start, 2),
- /*expect_flops=*/false,
- /*expect_trops=*/false));
+ TF_ASSERT_OK(
+ ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1),
+ /*expect_hlo=*/false, &parsed_profile_lines));
+
+ TF_ASSERT_OK(
+ ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2),
+ /*expect_hlo=*/true, &parsed_profile_lines));
+
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
+ MaybeFind(parsed_profile_lines, "[total]"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
+ MaybeFind(parsed_profile_lines, "add"));
EXPECT_GT(total_while_body_profile.cycles, 0);
- EXPECT_EQ(total_while_body_profile.name, "[total]");
+ EXPECT_EQ(total_while_body_profile.opcode, "[total]");
EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%");
EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles);
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
index fe5d29a6b6..b020905035 100644
--- a/tensorflow/compiler/xla/util.cc
+++ b/tensorflow/compiler/xla/util.cc
@@ -30,9 +30,7 @@ limitations under the License.
#include "tensorflow/core/platform/stacktrace.h"
namespace xla {
-namespace {
-// Logs the provided status message with a backtrace.
Status WithLogBacktrace(const Status& status) {
CHECK(!status.ok());
VLOG(1) << status.ToString();
@@ -40,8 +38,6 @@ Status WithLogBacktrace(const Status& status) {
return status;
}
-} // namespace
-
ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled)
: enabled(enabled), label(label) {
if (enabled) {
@@ -74,13 +70,18 @@ Status AppendStatus(Status prior, tensorflow::StringPiece context) {
// Implementation note: we can't common these out (without using macros) because
// they all need to va_start/va_end their varargs in their frame.
-Status InvalidArgument(const char* format, ...) {
+Status InvalidArgumentV(const char* format, va_list args) {
string message;
+ tensorflow::strings::Appendv(&message, format, args);
+ return WithLogBacktrace(tensorflow::errors::InvalidArgument(message));
+}
+
+Status InvalidArgument(const char* format, ...) {
va_list args;
va_start(args, format);
- tensorflow::strings::Appendv(&message, format, args);
+ Status result = InvalidArgumentV(format, args);
va_end(args);
- return WithLogBacktrace(tensorflow::errors::InvalidArgument(message));
+ return result;
}
Status Unimplemented(const char* format, ...) {
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 1d7dd34449..4bc2d632cd 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -40,6 +40,13 @@ limitations under the License.
namespace xla {
+// Logs the provided status message with a backtrace.
+//
+// For use by Status-factories, logs a backtrace at the point where the status
+// is created, such that we can use --vmodule=util=1 to see all status
+// creation backtraces.
+Status WithLogBacktrace(const Status& status);
+
// Ranks greater than 8 are very rare, so use InlinedVector<int64, 8> to store
// the bounds and indices. And for the rare cases of ranks greater than 8,
// the InlinedVector will just behave like an std::vector<> and allocate the
@@ -207,6 +214,9 @@ Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+// Passed-varargs variant of the InvalidArgument factory above.
+Status InvalidArgumentV(const char* format, va_list args);
+
// Splits the lines of the original, replaces leading whitespace with the prefix
// given by "indentation", and returns the string joined by newlines again. As a
// side effect, any additional trailing whitespace is removed.
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
index deb324634b..1bfd27305d 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
namespace tensorflow {
-
namespace {
constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
@@ -40,33 +39,6 @@ Status ParseJson(StringPiece json, Json::Value* result) {
return Status::OK();
}
-string ColumnTypeToString(BigQueryTableAccessor::ColumnType enum_type) {
- switch (enum_type) {
- case BigQueryTableAccessor::ColumnType::kRecord:
- return "RECORD";
- case BigQueryTableAccessor::ColumnType::kString:
- return "STRING";
- case BigQueryTableAccessor::ColumnType::kBytes:
- return "BYTES";
- case BigQueryTableAccessor::ColumnType::kInteger:
- return "INTEGER";
- case BigQueryTableAccessor::ColumnType::kFloat:
- return "FLOAT";
- case BigQueryTableAccessor::ColumnType::kBoolean:
- return "BOOLEAN";
- case BigQueryTableAccessor::ColumnType::kTimestamp:
- return "TIMESTAMP";
- case BigQueryTableAccessor::ColumnType::kDate:
- return "DATE";
- case BigQueryTableAccessor::ColumnType::kTime:
- return "TIME";
- case BigQueryTableAccessor::ColumnType::kDatetime:
- return "DATETIME";
- case BigQueryTableAccessor::ColumnType::kNone:
- return "NONE";
- }
-}
-
Status ParseColumnType(const string& type,
BigQueryTableAccessor::ColumnType* enum_type) {
if (type == "RECORD") {
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 817e96f5da..12bfd3c62b 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -134,6 +134,9 @@ if(WIN32)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /D_ITERATOR_DEBUG_LEVEL=0")
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /D_ITERATOR_DEBUG_LEVEL=0")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /D_ITERATOR_DEBUG_LEVEL=0")
+
+ # Try to avoid flaky failures due to failed generation of generate.stamp files.
+ set(CMAKE_SUPPRESS_REGENERATION ON)
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 7db454bd83..9ce8b3cc9c 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -33,9 +33,11 @@ tensorflow/python/grappler
tensorflow/python/keras
tensorflow/python/keras/activations
tensorflow/python/keras/applications
+tensorflow/python/keras/applications/densenet
tensorflow/python/keras/applications/inception_resnet_v2
tensorflow/python/keras/applications/inception_v3
tensorflow/python/keras/applications/mobilenet
+tensorflow/python/keras/applications/nasnet
tensorflow/python/keras/applications/resnet50
tensorflow/python/keras/applications/vgg16
tensorflow/python/keras/applications/vgg19
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index bae66ffd42..b806799202 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -35,10 +35,10 @@ from tensorflow.python.ops.variables import Variable
from tensorflow.python.client.session import Session
from tensorflow.python.framework import ops
-__all__ = ["copy_op_to_graph", "copy_variable_to_graph", "get_copied_op"]
+__all__ = ['copy_op_to_graph', 'copy_variable_to_graph', 'get_copied_op']
-def copy_variable_to_graph(org_instance, to_graph, scope=""):
+def copy_variable_to_graph(org_instance, to_graph, scope=''):
"""Given a `Variable` instance from one `Graph`, initializes and returns
a copy of it from another `Graph`, under the specified scope
(default `""`).
@@ -56,12 +56,11 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
"""
if not isinstance(org_instance, Variable):
- raise TypeError(str(org_instance) + " is not a Variable")
+ raise TypeError(str(org_instance) + ' is not a Variable')
#The name of the new variable
- if scope != "":
- new_name = (scope + '/' +
- org_instance.name[:org_instance.name.index(':')])
+ if scope != '':
+ new_name = (scope + '/' + org_instance.name[:org_instance.name.index(':')])
else:
new_name = org_instance.name[:org_instance.name.index(':')]
@@ -73,15 +72,15 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
for name, collection in org_instance.graph._collections.items():
if org_instance in collection:
if (name == ops.GraphKeys.GLOBAL_VARIABLES or
- name == ops.GraphKeys.TRAINABLE_VARIABLES or
- scope == ''):
+ name == ops.GraphKeys.TRAINABLE_VARIABLES or scope == ''):
collections.append(name)
else:
collections.append(scope + '/' + name)
#See if its trainable.
- trainable = (org_instance in org_instance.graph.get_collection(
- ops.GraphKeys.TRAINABLE_VARIABLES))
+ trainable = (
+ org_instance in org_instance.graph.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES))
#Get the initial value
with org_instance.graph.as_default():
temp_session = Session()
@@ -89,17 +88,17 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""):
#Initialize the new variable
with to_graph.as_default():
- new_var = Variable(init_value,
- trainable,
- name=new_name,
- collections=collections,
- validate_shape=False)
+ new_var = Variable(
+ init_value,
+ trainable,
+ name=new_name,
+ collections=collections,
+ validate_shape=False)
return new_var
-def copy_op_to_graph(org_instance, to_graph, variables,
- scope=""):
+def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
"""Returns a copy of an operation from another Graph under a specified scope.
Given an `Operation` `org_instance` from one `Graph`,
@@ -139,14 +138,12 @@ def copy_op_to_graph(org_instance, to_graph, variables,
#If a variable by the new name already exists, return the
#correspondng tensor that will act as an input
if new_name in copied_variables:
- return to_graph.get_tensor_by_name(
- copied_variables[new_name].name)
+ return to_graph.get_tensor_by_name(copied_variables[new_name].name)
#If an instance of the same name exists, return appropriately
try:
- already_present = to_graph.as_graph_element(new_name,
- allow_tensor=True,
- allow_operation=True)
+ already_present = to_graph.as_graph_element(
+ new_name, allow_tensor=True, allow_operation=True)
return already_present
except:
pass
@@ -184,20 +181,21 @@ def copy_op_to_graph(org_instance, to_graph, variables,
#If it has an original_op parameter, copy it
if op._original_op is not None:
- new_original_op = copy_op_to_graph(op._original_op, to_graph,
- variables, scope)
+ new_original_op = copy_op_to_graph(op._original_op, to_graph, variables,
+ scope)
else:
new_original_op = None
#If it has control inputs, call this function recursively on each.
- new_control_inputs = [copy_op_to_graph(x, to_graph, variables,
- scope)
- for x in op.control_inputs]
+ new_control_inputs = [
+ copy_op_to_graph(x, to_graph, variables, scope)
+ for x in op.control_inputs
+ ]
#If it has inputs, call this function recursively on each.
- new_inputs = [copy_op_to_graph(x, to_graph, variables,
- scope)
- for x in op.inputs]
+ new_inputs = [
+ copy_op_to_graph(x, to_graph, variables, scope) for x in op.inputs
+ ]
#Make a new node_def based on that of the original.
#An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it
@@ -216,13 +214,8 @@ def copy_op_to_graph(org_instance, to_graph, variables,
op_def = deepcopy(op._op_def)
#Initialize a new Operation instance
- new_op = ops.Operation(new_node_def,
- to_graph,
- new_inputs,
- output_types,
- new_control_inputs,
- input_types,
- new_original_op,
+ new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types,
+ new_control_inputs, input_types, new_original_op,
op_def)
#Use Graph's hidden methods to add the op
to_graph._add_op(new_op) # pylint: disable=protected-access
@@ -233,10 +226,10 @@ def copy_op_to_graph(org_instance, to_graph, variables,
return new_op
else:
- raise TypeError("Could not copy instance: " + str(org_instance))
+ raise TypeError('Could not copy instance: ' + str(org_instance))
-def get_copied_op(org_instance, graph, scope=""):
+def get_copied_op(org_instance, graph, scope=''):
"""Given an `Operation` instance from some `Graph`, returns
its namesake from `graph`, under the specified scope
(default `""`).
@@ -259,5 +252,5 @@ def get_copied_op(org_instance, graph, scope=""):
else:
new_name = org_instance.name
- return graph.as_graph_element(new_name, allow_tensor=True,
- allow_operation=True)
+ return graph.as_graph_element(
+ new_name, allow_tensor=True, allow_operation=True)
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index cdbe05e4d2..6cdbed5b89 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -163,7 +163,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:check_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lookup_ops",
@@ -177,7 +177,6 @@ py_library(
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
- "//tensorflow/python/estimator:util",
"//tensorflow/python/ops/losses",
"//tensorflow/python/saved_model:signature_constants",
],
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index fd0994490a..238cf287b7 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
@@ -29,7 +28,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -45,6 +43,7 @@ def multi_class_head(n_classes,
weight_column=None,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""Creates a `_Head` for multi class classification.
@@ -65,6 +64,12 @@ def multi_class_head(n_classes,
labels have shape `[batch_size, 1]`, the loss is the weighted sum over
`batch_size`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with
+ shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
+ the input labels before passing them to `loss_fn`.
+
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`binary_classification_head`).
@@ -79,6 +84,7 @@ def multi_class_head(n_classes,
`label_vocabulary` is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -94,12 +100,17 @@ def multi_class_head(n_classes,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
def binary_classification_head(
- weight_column=None, thresholds=None, label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM, name=None):
+ weight_column=None,
+ thresholds=None,
+ label_vocabulary=None,
+ loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
+ name=None):
"""Creates a `_Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
@@ -119,6 +130,12 @@ def binary_classification_head(
labels have shape `[batch_size, 1]`, the loss is the weighted sum over
`batch_size`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with
+ shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
+ the input labels before passing them to `loss_fn`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -136,6 +153,7 @@ def binary_classification_head(
is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -151,12 +169,14 @@ def binary_classification_head(
thresholds=thresholds,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
def regression_head(weight_column=None,
label_dimension=1,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""Creates a `_Head` for regression using the `mean_squared_error` loss.
@@ -175,6 +195,10 @@ def regression_head(weight_column=None,
`[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
`[D0, D1, ... DN, label_dimension]`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, label_dimension]`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -185,6 +209,7 @@ def regression_head(weight_column=None,
`[batch_size, label_dimension]`).
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -198,6 +223,7 @@ def regression_head(weight_column=None,
weight_column=weight_column,
label_dimension=label_dimension,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
@@ -287,7 +313,7 @@ def multi_label_head(n_classes,
'Length of label_vocabulary must be n_classes ({}). '
'Given: {}'.format(n_classes, len(label_vocabulary)))
if loss_fn:
- _validate_loss_fn_args(loss_fn)
+ head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
@@ -371,9 +397,9 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
labels=processed_labels, logits=logits,
expected_labels_dimension=self.logits_dimension)
if self._loss_fn:
- unweighted_loss = _call_loss_fn(
+ unweighted_loss = head_lib._call_loss_fn( # pylint:disable=protected-access
loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
- features=features)
+ features=features, expected_loss_dim=1)
else:
unweighted_loss = losses.sigmoid_cross_entropy(
multi_class_labels=processed_labels, logits=logits,
@@ -555,52 +581,3 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
threshold=threshold,
name=recall_key))
return metric_ops
-
-
-def _validate_loss_fn_args(loss_fn):
- """Validates loss_fn arguments.
-
- Required arguments: labels, logits.
- Optional arguments: features.
-
- Args:
- loss_fn: The loss function.
- Raises:
- ValueError: If the signature is unexpected.
- """
- loss_fn_args = util.fn_args(loss_fn)
- for required_arg in ['labels', 'logits']:
- if required_arg not in loss_fn_args:
- raise ValueError(
- 'loss_fn must contain argument: {}. '
- 'Given arguments: {}'.format(required_arg, loss_fn_args))
- invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features']))
- if invalid_args:
- raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
-
-
-def _call_loss_fn(loss_fn, labels, logits, features):
- """Calls loss_fn and checks the returned shape.
-
- Args:
- loss_fn: The loss function.
- labels: Processed labels Tensor.
- logits: Logits Tensor of shape [batch_size, logits_dimension].
- features: Features dict.
- Returns:
- Loss Tensor with shape [batch_size, 1].
- """
- loss_fn_args = util.fn_args(loss_fn)
- kwargs = {}
- if 'features' in loss_fn_args:
- kwargs['features'] = features
- unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)
- batch_size = array_ops.shape(logits)[0]
- loss_shape = array_ops.shape(unweighted_loss)
- check_shape_op = control_flow_ops.Assert(
- math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])),
- data=[
- 'loss_fn must return Tensor of shape [batch_size, 1]. Given: ',
- loss_shape])
- with ops.control_dependencies([check_shape_op]):
- return array_ops.identity(unweighted_loss)
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 1adbd6f0fe..43cdfec968 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -381,8 +381,8 @@ class MultiLabelHead(test.TestCase):
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- r'loss_fn must return Tensor of shape \[batch_size, 1\]\. '
- r'Given: \] \[2\]'):
+ r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] '
+ r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2\]'):
actual_training_loss.eval()
def test_eval_labels_none(self):
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
index 6a56237f67..bafd1d5941 100644
--- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
@@ -25,13 +25,6 @@ limitations under the License.
namespace tensorflow {
-namespace {
-// Return the string containing the list of valid activation modes, that can be
-// used as an Attr() in REGISTER_OP.
-string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; }
-
-} // namespace
-
// --------------------------------------------------------------------------
// TODO(pauldonnelly): Add support for double inputs and scales to this Op,
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index 2eaea23177..fc8f153fe3 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -221,8 +221,8 @@ class FeatureColumnTest(test.TestCase):
weighted_sparse_col = fc.weighted_sparse_column(ids, "weights")
self.assertEqual(weighted_sparse_col.name, "ids_weighted_by_weights")
- b = fc.shared_embedding_columns([sparse_col, weighted_sparse_col],
- dimension=4, combiner="mean")
+ b = fc.shared_embedding_columns(
+ [sparse_col, weighted_sparse_col], dimension=4, combiner="mean")
self.assertEqual(len(b), 2)
self.assertEqual(b[0].shared_embedding_name,
"a1_ids_weighted_by_weights_shared_embedding")
@@ -230,8 +230,8 @@ class FeatureColumnTest(test.TestCase):
"a1_ids_weighted_by_weights_shared_embedding")
# Tries reversing order to check compatibility condition.
- b = fc.shared_embedding_columns([weighted_sparse_col, sparse_col],
- dimension=4, combiner="mean")
+ b = fc.shared_embedding_columns(
+ [weighted_sparse_col, sparse_col], dimension=4, combiner="mean")
self.assertEqual(len(b), 2)
self.assertEqual(b[0].shared_embedding_name,
"a1_ids_weighted_by_weights_shared_embedding")
@@ -240,18 +240,17 @@ class FeatureColumnTest(test.TestCase):
# Tries adding two weighted columns to check compatibility between them.
weighted_sparse_col_2 = fc.weighted_sparse_column(ids, "weights_2")
- b = fc.shared_embedding_columns([weighted_sparse_col,
- weighted_sparse_col_2],
- dimension=4, combiner="mean")
+ b = fc.shared_embedding_columns(
+ [weighted_sparse_col, weighted_sparse_col_2],
+ dimension=4,
+ combiner="mean")
self.assertEqual(len(b), 2)
self.assertEqual(
b[0].shared_embedding_name,
- "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding"
- )
+ "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding")
self.assertEqual(
b[1].shared_embedding_name,
- "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding"
- )
+ "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding")
def testSharedEmbeddingColumnDeterminism(self):
# Tests determinism in auto-generated shared_embedding_name.
@@ -286,10 +285,10 @@ class FeatureColumnTest(test.TestCase):
columns = fc.shared_embedding_columns(
[a1, a2], dimension=4, combiner="mean")
columns_copy = copy.deepcopy(columns)
- self.assertEqual(
- columns_copy[0].shared_embedding_name, "a1_a2_shared_embedding")
- self.assertEqual(
- columns_copy[1].shared_embedding_name, "a1_a2_shared_embedding")
+ self.assertEqual(columns_copy[0].shared_embedding_name,
+ "a1_a2_shared_embedding")
+ self.assertEqual(columns_copy[1].shared_embedding_name,
+ "a1_a2_shared_embedding")
def testOneHotColumn(self):
a = fc.sparse_column_with_keys("a", ["a", "b", "c", "d"])
@@ -336,11 +335,11 @@ class FeatureColumnTest(test.TestCase):
weighted_ids = fc.weighted_sparse_column(ids, "weights")
one_hot = fc.one_hot_column(weighted_ids)
features = {
- 'ids': constant_op.constant([['marlo', 'unknown', 'omar']]),
- 'weights': constant_op.constant([[2., 4., 6.]])
+ "ids": constant_op.constant([["marlo", "unknown", "omar"]]),
+ "weights": constant_op.constant([[2., 4., 6.]])
}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
- features, [one_hot])
+ features, [one_hot])
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
@@ -349,11 +348,9 @@ class FeatureColumnTest(test.TestCase):
def testMissingValueInOneHotColumnForSparseColumnWithKeys(self):
ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"])
one_hot = fc.one_hot_column(ids)
- features = {
- 'ids': constant_op.constant([['marlo', 'unknown', 'omar']])
- }
+ features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
- features, [one_hot])
+ features, [one_hot])
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
@@ -379,8 +376,7 @@ class FeatureColumnTest(test.TestCase):
self.assertEqual(d4.default_value, None)
self.assertEqual(d4.is_sparse, True)
# Default value is a list but dimension is None.
- with self.assertRaisesRegexp(ValueError,
- "Only scalar default value.*"):
+ with self.assertRaisesRegexp(ValueError, "Only scalar default value.*"):
fc._real_valued_var_len_column("g5", default_value=[2., 3.])
def testRealValuedVarLenColumnDtypes(self):
@@ -390,18 +386,19 @@ class FeatureColumnTest(test.TestCase):
"rvc": parsing_ops.VarLenFeature(dtype=dtypes.float32)
}, rvc.config)
- rvc = fc._real_valued_var_len_column("rvc", default_value=0,
- is_sparse=False)
- self.assertDictEqual(
- {
- "rvc": parsing_ops.FixedLenSequenceFeature(shape=[],
- dtype=dtypes.float32,
- allow_missing=True,
- default_value=0.0)
- }, rvc.config)
-
- rvc = fc._real_valued_var_len_column("rvc", dtype=dtypes.int32,
- default_value=0, is_sparse=True)
+ rvc = fc._real_valued_var_len_column(
+ "rvc", default_value=0, is_sparse=False)
+ self.assertDictEqual({
+ "rvc":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=0.0)
+ }, rvc.config)
+
+ rvc = fc._real_valued_var_len_column(
+ "rvc", dtype=dtypes.int32, default_value=0, is_sparse=True)
self.assertDictEqual(
{
"rvc": parsing_ops.VarLenFeature(dtype=dtypes.int32)
@@ -409,8 +406,8 @@ class FeatureColumnTest(test.TestCase):
with self.assertRaisesRegexp(TypeError,
"dtype must be convertible to float"):
- fc._real_valued_var_len_column("rvc", dtype=dtypes.string,
- default_value="", is_sparse=True)
+ fc._real_valued_var_len_column(
+ "rvc", dtype=dtypes.string, default_value="", is_sparse=True)
def testRealValuedColumn(self):
a = fc.real_valued_column("aaa")
@@ -504,13 +501,13 @@ class FeatureColumnTest(test.TestCase):
for output_rank in range(1, 3 + len(dimensions)):
with variable_scope.variable_scope("output_rank_{}".format(output_rank)):
real_valued_output = real_valued_column._to_dnn_input_layer(
- constant_op.constant(
- real_valued_input, dtype=dtypes.float32),
+ constant_op.constant(real_valued_input, dtype=dtypes.float32),
output_rank=output_rank)
with self.test_session() as sess:
real_valued_eval = sess.run(real_valued_output)
- expected_shape = (input_shape[:output_rank - 1] +
- [np.prod(input_shape[output_rank - 1:])])
+ expected_shape = (
+ input_shape[:output_rank - 1] +
+ [np.prod(input_shape[output_rank - 1:])])
self.assertEquals(expected_shape, list(real_valued_eval.shape))
def testRealValuedColumnDensification(self):
@@ -520,8 +517,7 @@ class FeatureColumnTest(test.TestCase):
"sparse_real_valued1", is_sparse=True)
sparse_tensor = sparse_tensor_lib.SparseTensor(
values=[2.0, 5.0], indices=[[0, 0], [2, 0]], dense_shape=[3, 1])
- with self.assertRaisesRegexp(
- ValueError, "Set is_sparse to False"):
+ with self.assertRaisesRegexp(ValueError, "Set is_sparse to False"):
real_valued_column._to_dnn_input_layer(sparse_tensor)
def testRealValuedColumnDeepCopy(self):
@@ -549,9 +545,8 @@ class FeatureColumnTest(test.TestCase):
def testBucketizedColumnRequiresRealValuedColumnDimension(self):
with self.assertRaisesRegexp(
TypeError, "source_column must be an instance of _RealValuedColumn.*"):
- fc.bucketized_column(fc._real_valued_var_len_column("bbb",
- is_sparse=True),
- [0])
+ fc.bucketized_column(
+ fc._real_valued_var_len_column("bbb", is_sparse=True), [0])
def testBucketizedColumnRequiresSortedBuckets(self):
with self.assertRaisesRegexp(ValueError,
@@ -654,20 +649,14 @@ class FeatureColumnTest(test.TestCase):
def testRealValuedColumnDtypes(self):
rvc = fc.real_valued_column("rvc")
- self.assertDictEqual(
- {
- "rvc": parsing_ops.FixedLenFeature(
- [1], dtype=dtypes.float32)
- },
- rvc.config)
+ self.assertDictEqual({
+ "rvc": parsing_ops.FixedLenFeature([1], dtype=dtypes.float32)
+ }, rvc.config)
rvc = fc.real_valued_column("rvc", dtype=dtypes.int32)
- self.assertDictEqual(
- {
- "rvc": parsing_ops.FixedLenFeature(
- [1], dtype=dtypes.int32)
- },
- rvc.config)
+ self.assertDictEqual({
+ "rvc": parsing_ops.FixedLenFeature([1], dtype=dtypes.int32)
+ }, rvc.config)
with self.assertRaisesRegexp(ValueError,
"dtype must be convertible to float"):
@@ -702,8 +691,9 @@ class FeatureColumnTest(test.TestCase):
batch_size = 4
dense_scalar_input = [1, 2, 3, 4]
sparse_column = fc.sparse_column_with_integerized_feature("values", 10)
- features = {"values":
- constant_op.constant(dense_scalar_input, dtype=dtypes.int64)}
+ features = {
+ "values": constant_op.constant(dense_scalar_input, dtype=dtypes.int64)
+ }
sparse_column.insert_transformed_feature(features)
sparse_output = features[sparse_column]
expected_shape = [batch_size, 1]
@@ -731,8 +721,7 @@ class FeatureColumnTest(test.TestCase):
def testSparseColumnKeysDeepCopy(self):
"""Tests deepcopy of sparse_column_with_keys."""
- column = fc.sparse_column_with_keys(
- "a", keys=["key0", "key1", "key2"])
+ column = fc.sparse_column_with_keys("a", keys=["key0", "key1", "key2"])
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
@@ -785,8 +774,9 @@ class FeatureColumnTest(test.TestCase):
a = fc.sparse_column_with_hash_bucket("cross_aaa", hash_bucket_size=100)
b = fc.sparse_column_with_hash_bucket("cross_bbb", hash_bucket_size=100)
cross_col = fc.crossed_column(set([a, b]), hash_bucket_size=10000)
- one_hot_col = fc.one_hot_column(fc.sparse_column_with_hash_bucket(
- "sparse_column_for_one_hot", hash_bucket_size=100))
+ one_hot_col = fc.one_hot_column(
+ fc.sparse_column_with_hash_bucket(
+ "sparse_column_for_one_hot", hash_bucket_size=100))
scattered_embedding_col = fc.scattered_embedding_column(
"scattered_embedding_column", size=100, dimension=10, hash_key=1)
feature_columns = set([
@@ -809,17 +799,13 @@ class FeatureColumnTest(test.TestCase):
"str_id_weights_column":
parsing_ops.VarLenFeature(dtypes.float32),
"real_valued_column1":
- parsing_ops.FixedLenFeature(
- [1], dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature([1], dtype=dtypes.float32),
"real_valued_column2":
- parsing_ops.FixedLenFeature(
- [5], dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature([5], dtype=dtypes.float32),
"real_valued_column_for_bucketization1":
- parsing_ops.FixedLenFeature(
- [1], dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature([1], dtype=dtypes.float32),
"real_valued_column_for_bucketization2":
- parsing_ops.FixedLenFeature(
- [4], dtype=dtypes.float32),
+ parsing_ops.FixedLenFeature([4], dtype=dtypes.float32),
"cross_aaa":
parsing_ops.VarLenFeature(dtypes.string),
"cross_bbb":
@@ -849,11 +835,14 @@ class FeatureColumnTest(test.TestCase):
real_valued_col0 = fc._real_valued_var_len_column(
"real_valued_column0", is_sparse=True)
real_valued_col1 = fc._real_valued_var_len_column(
- "real_valued_column1", dtype=dtypes.int64, default_value=0,
+ "real_valued_column1",
+ dtype=dtypes.int64,
+ default_value=0,
is_sparse=False)
feature_columns = set([real_valued_col0, real_valued_col1])
expected_config = {
- "real_valued_column0": parsing_ops.VarLenFeature(dtype=dtypes.float32),
+ "real_valued_column0":
+ parsing_ops.VarLenFeature(dtype=dtypes.float32),
"real_valued_column1":
parsing_ops.FixedLenSequenceFeature(
[], dtype=dtypes.int64, allow_missing=True, default_value=0),
@@ -874,7 +863,9 @@ class FeatureColumnTest(test.TestCase):
real_valued_col5 = fc._real_valued_var_len_column(
"real_valued_column5", default_value=2, is_sparse=True)
real_valued_col6 = fc._real_valued_var_len_column(
- "real_valued_column6", dtype=dtypes.int64, default_value=1,
+ "real_valued_column6",
+ dtype=dtypes.int64,
+ default_value=1,
is_sparse=False)
feature_columns = [
real_valued_col1, real_valued_col2, real_valued_col3, real_valued_col4,
@@ -902,8 +893,7 @@ class FeatureColumnTest(test.TestCase):
parsing_ops.VarLenFeature(dtype=dtypes.float32),
"real_valued_column6":
parsing_ops.FixedLenSequenceFeature(
- [], dtype=dtypes.int64, allow_missing=True,
- default_value=1)
+ [], dtype=dtypes.int64, allow_missing=True, default_value=1)
},
config)
@@ -1104,8 +1094,8 @@ class FeatureColumnTest(test.TestCase):
# This will initialize the crossed column weights from provided checkpoint
# and return a [4, 1] tensor which is same as weights variable. Since we
# won't modify weights, this should be same as 'saved_col_weights'.
- _, col_weights, _ = (feature_column_ops.weighted_sum_from_feature_columns(
- {
+ _, col_weights, _ = (
+ feature_column_ops.weighted_sum_from_feature_columns({
sparse_col_1.name: input_tensor,
sparse_col_2.name: input_tensor
}, [crossed_col_initialized], 1))
diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py
index a3521b4109..7240b0de14 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Dataset utilities and synthetic/reference datasets."""
from __future__ import absolute_import
@@ -46,11 +45,12 @@ DATASETS = {
# List of all synthetic datasets
SYNTHETIC = {
- # All of these will return ['data', 'target'] -> base.Dataset
- 'circles': synthetic.circles,
- 'spirals': synthetic.spirals
+ # All of these will return ['data', 'target'] -> base.Dataset
+ 'circles': synthetic.circles,
+ 'spirals': synthetic.spirals
}
+
def load_dataset(name, size='small', test_with_fake_data=False):
"""Loads dataset by name.
@@ -83,23 +83,28 @@ def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs):
seed: int or None, seed for noise
Returns:
- Shuffled features and labels for given synthetic dataset of type `base.Dataset`
+ Shuffled features and labels for given synthetic dataset of type
+ `base.Dataset`
Raises:
ValueError: Raised if `name` not found
Note:
- - This is a generic synthetic data generator - individual generators might have more parameters!
+ - This is a generic synthetic data generator - individual generators might
+ have more parameters!
See documentation for individual parameters
- - Note that the `noise` parameter uses `numpy.random.normal` and depends on `numpy`'s seed
+ - Note that the `noise` parameter uses `numpy.random.normal` and depends on
+ `numpy`'s seed
TODO:
- Support multiclass datasets
- - Need shuffling routine. Currently synthetic datasets are reshuffled to avoid train/test correlation,
+ - Need shuffling routine. Currently synthetic datasets are reshuffled to
+ avoid train/test correlation,
but that hurts reprodusability
"""
# seed = kwargs.pop('seed', None)
if name not in SYNTHETIC:
raise ValueError('Synthetic dataset not found or not implemeted: %s' % name)
else:
- return SYNTHETIC[name](n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs)
+ return SYNTHETIC[name](
+ n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs)
diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py
index 907dc0f3df..649996c49c 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Synthetic dataset generators."""
from __future__ import absolute_import
@@ -23,18 +22,27 @@ import numpy as np
from tensorflow.contrib.learn.python.learn.datasets.base import Dataset
-def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args, **kwargs):
+
+def circles(n_samples=100,
+ noise=None,
+ seed=None,
+ factor=0.8,
+ n_classes=2,
+ *args,
+ **kwargs):
"""Create circles separated by some value
Args:
n_samples: int, number of datapoints to generate
noise: float or None, standard deviation of the Gaussian noise added
seed: int or None, seed for the noise
- factor: float, size factor of the inner circles with respect to the outer ones
+ factor: float, size factor of the inner circles with respect to the outer
+ ones
n_classes: int, number of classes to generate
Returns:
- Shuffled features and labels for 'circles' synthetic dataset of type `base.Dataset`
+ Shuffled features and labels for 'circles' synthetic dataset of type
+ `base.Dataset`
Note:
The multi-class support might not work as expected if `noise` is enabled
@@ -54,7 +62,7 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args
if seed is not None:
np.random.seed(seed)
# Algo: 1) Generate initial circle, 2) For ever class generate a smaller radius circle
- linspace = np.linspace(0, 2*np.pi, n_samples // n_classes)
+ linspace = np.linspace(0, 2 * np.pi, n_samples // n_classes)
circ_x = np.empty(0, dtype=np.int32)
circ_y = np.empty(0, dtype=np.int32)
base_cos = np.cos(linspace)
@@ -66,12 +74,12 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args
circ_y = np.append(circ_y, base_sin)
base_cos *= factor
base_sin *= factor
- y = np.append(y, label*np.ones(n_samples // n_classes, dtype=np.int32))
+ y = np.append(y, label * np.ones(n_samples // n_classes, dtype=np.int32))
# Add more points if n_samples is not divisible by n_classes (unbalanced!)
extras = n_samples % n_classes
- circ_x = np.append(circ_x, np.cos(np.random.rand(extras)*2*np.pi))
- circ_y = np.append(circ_y, np.sin(np.random.rand(extras)*2*np.pi))
+ circ_x = np.append(circ_x, np.cos(np.random.rand(extras) * 2 * np.pi))
+ circ_y = np.append(circ_y, np.sin(np.random.rand(extras) * 2 * np.pi))
y = np.append(y, np.zeros(extras, dtype=np.int32))
# Reshape the features/labels
@@ -85,10 +93,13 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args
return Dataset(data=X[indices], target=y[indices])
-def spirals(n_samples=100, noise=None, seed=None,
- mode = 'archimedes',
- n_loops = 2,
- *args, **kwargs):
+def spirals(n_samples=100,
+ noise=None,
+ seed=None,
+ mode='archimedes',
+ n_loops=2,
+ *args,
+ **kwargs):
"""Create spirals
Currently only binary classification is supported for spiral generation
@@ -104,7 +115,8 @@ def spirals(n_samples=100, noise=None, seed=None,
'fermat': a spiral with branch distances decreasing (sqrt)
Returns:
- Shuffled features and labels for 'spirals' synthetic dataset of type `base.Dataset`
+ Shuffled features and labels for 'spirals' synthetic dataset of type
+ `base.Dataset`
Raises:
ValueError: If the generation `mode` is not valid
@@ -112,34 +124,35 @@ def spirals(n_samples=100, noise=None, seed=None,
TODO:
- Generation of unbalanced data
"""
- n_classes = 2 # I am not sure how to make it multiclass
+ n_classes = 2 # I am not sure how to make it multiclass
_modes = {
- 'archimedes': _archimedes_spiral,
- 'bernoulli': _bernoulli_spiral,
- 'fermat': _fermat_spiral
+ 'archimedes': _archimedes_spiral,
+ 'bernoulli': _bernoulli_spiral,
+ 'fermat': _fermat_spiral
}
if mode is None or mode not in _modes:
- raise ValueError("Cannot generate spiral with mode %s"%mode)
+ raise ValueError('Cannot generate spiral with mode %s' % mode)
if seed is not None:
np.random.seed(seed)
- linspace = np.linspace(0, 2*n_loops*np.pi, n_samples // n_classes)
+ linspace = np.linspace(0, 2 * n_loops * np.pi, n_samples // n_classes)
spir_x = np.empty(0, dtype=np.int32)
spir_y = np.empty(0, dtype=np.int32)
y = np.empty(0, dtype=np.int32)
for label in range(n_classes):
- base_cos, base_sin = _modes[mode](linspace, label*np.pi, *args, **kwargs)
+ base_cos, base_sin = _modes[mode](linspace, label * np.pi, *args, **kwargs)
spir_x = np.append(spir_x, base_cos)
spir_y = np.append(spir_y, base_sin)
- y = np.append(y, label*np.ones(n_samples // n_classes, dtype=np.int32))
+ y = np.append(y, label * np.ones(n_samples // n_classes, dtype=np.int32))
# Add more points if n_samples is not divisible by n_classes (unbalanced!)
extras = n_samples % n_classes
if extras > 0:
- x_exrta, y_extra = _modes[mode](np.random.rand(extras)*2*np.pi, *args, **kwargs)
+ x_exrta, y_extra = _modes[mode](np.random.rand(extras) * 2 * np.pi, *args,
+ **kwargs)
spir_x = np.append(spir_x, x_extra)
spir_y = np.append(spir_y, y_extra)
y = np.append(y, np.zeros(extras, dtype=np.int32))
@@ -162,7 +175,8 @@ def _archimedes_spiral(theta, theta_offset=0., *args, **kwargs):
theta: array-like, angles from polar coordinates to be converted
theta_offset: float, angle offset in radians (2*pi = 0)
"""
- x, y = theta*np.cos(theta + theta_offset), theta*np.sin(theta + theta_offset)
+ x, y = theta * np.cos(theta + theta_offset), theta * np.sin(
+ theta + theta_offset)
x_norm = np.max(np.abs(x))
y_norm = np.max(np.abs(y))
x, y = x / x_norm, y / y_norm
@@ -181,7 +195,8 @@ def _bernoulli_spiral(theta, theta_offset=0., *args, **kwargs):
"""
exp_scale = kwargs.pop('exp_scale', 0.1)
- x, y = np.exp(exp_scale*theta)*np.cos(theta + theta_offset), np.exp(exp_scale*theta)*np.sin(theta + theta_offset)
+ x, y = np.exp(exp_scale * theta) * np.cos(theta + theta_offset), np.exp(
+ exp_scale * theta) * np.sin(theta + theta_offset)
x_norm = np.max(np.abs(x))
y_norm = np.max(np.abs(y))
x, y = x / x_norm, y / y_norm
@@ -195,7 +210,8 @@ def _fermat_spiral(theta, theta_offset=0., *args, **kwargs):
theta: array-like, angles from polar coordinates to be converted
theta_offset: float, angle offset in radians (2*pi = 0)
"""
- x, y = np.sqrt(theta)*np.cos(theta + theta_offset), np.sqrt(theta)*np.sin(theta + theta_offset)
+ x, y = np.sqrt(theta) * np.cos(theta + theta_offset), np.sqrt(theta) * np.sin(
+ theta + theta_offset)
x_norm = np.max(np.abs(x))
y_norm = np.max(np.abs(y))
x, y = x / x_norm, y / y_norm
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 50c74add86..8d59fe66d9 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Base Estimator class."""
from __future__ import absolute_import
@@ -76,7 +75,6 @@ from tensorflow.python.util import compat
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
-
AS_ITERABLE_DATE = '2016-09-15'
AS_ITERABLE_INSTRUCTIONS = (
'The default behavior of predict() is changing. The default value for\n'
@@ -223,8 +221,11 @@ def _get_replica_device_setter(config):
if config.num_ps_replicas > 0:
return device_setter.replica_device_setter(
- ps_tasks=config.num_ps_replicas, worker_device=worker_device,
- merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec)
+ ps_tasks=config.num_ps_replicas,
+ worker_device=worker_device,
+ merge_devices=True,
+ ps_ops=ps_ops,
+ cluster=config.cluster_spec)
else:
return None
@@ -284,10 +285,10 @@ def _make_metrics_ops(metrics, features, labels, predictions):
raise ValueError('Invalid metric for {}. It returned a tuple with '
'len {}, expected 2.'.format(name, len(name)))
if not isinstance(predictions, dict):
- raise ValueError(
- 'Metrics passed provide (name, prediction), '
- 'but predictions are not dict. '
- 'Metrics: %s, Predictions: %s.' % (metrics, predictions))
+ raise ValueError('Metrics passed provide (name, prediction), '
+ 'but predictions are not dict. '
+ 'Metrics: %s, Predictions: %s.' % (metrics,
+ predictions))
# Here are two options: labels are single Tensor or a dict.
if isinstance(labels, dict) and name[1] in labels:
# If labels are dict and the prediction name is in it, apply metric.
@@ -298,10 +299,10 @@ def _make_metrics_ops(metrics, features, labels, predictions):
else:
# Single head metrics.
if isinstance(predictions, dict):
- raise ValueError(
- 'Metrics passed provide only name, no prediction, '
- 'but predictions are dict. '
- 'Metrics: %s, Labels: %s.' % (metrics, labels_tensor_or_dict))
+ raise ValueError('Metrics passed provide only name, no prediction, '
+ 'but predictions are dict. '
+ 'Metrics: %s, Labels: %s.' % (metrics,
+ labels_tensor_or_dict))
result[name] = metric(predictions, labels_tensor_or_dict)
return result
@@ -369,9 +370,8 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step):
logging.info(
'Summary for np.ndarray is not visible in Tensorboard by default. '
'Consider using a Tensorboard plugin for visualization (see '
- 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md ' # pylint:disable=line-too-long
- 'for more information).'
- )
+ 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'
+ ' for more information).')
else:
logging.warn(
'Skipping summary for %s, must be a float, np.float32, np.int64, '
@@ -385,8 +385,8 @@ GraphRewriteSpec = collections.namedtuple('GraphRewriteSpec',
['tags', 'transforms'])
-class BaseEstimator(
- sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
+class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
+ trainable.Trainable):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
Users should not instantiate or subclass this class. Instead, use an
@@ -428,7 +428,7 @@ class BaseEstimator(
# necessary.
# pylint: disable=g-doc-exception
raise ValueError(
- "model_dir are set both in constructor and RunConfig, but with "
+ 'model_dir are set both in constructor and RunConfig, but with '
"different values. In constructor: '{}', in RunConfig: "
"'{}' ".format(model_dir, self._config.model_dir))
# pylint: enable=g-doc-exception
@@ -457,12 +457,16 @@ class BaseEstimator(
# TODO(wicke): make RunConfig immutable, and then return it without a copy.
return copy.deepcopy(self._config)
- @deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
- ('y', None), ('batch_size', None)
- )
- def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
- monitors=None, max_steps=None):
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
+ ('x', None), ('y', None), ('batch_size', None))
+ def fit(self,
+ x=None,
+ y=None,
+ input_fn=None,
+ steps=None,
+ batch_size=None,
+ monitors=None,
+ max_steps=None):
# pylint: disable=g-doc-args,g-doc-return-or-yield
"""See `Trainable`.
@@ -494,13 +498,15 @@ class BaseEstimator(
logging.info('Loss for final step: %s.', loss)
return self
- @deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
- ('y', None), ('batch_size', None)
- )
- def partial_fit(
- self, x=None, y=None, input_fn=None, steps=1, batch_size=None,
- monitors=None):
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
+ ('x', None), ('y', None), ('batch_size', None))
+ def partial_fit(self,
+ x=None,
+ y=None,
+ input_fn=None,
+ steps=1,
+ batch_size=None,
+ monitors=None):
"""Incremental fit on a batch of samples.
This method is expected to be called several times consecutively
@@ -536,13 +542,16 @@ class BaseEstimator(
"""
logging.warning('The current implementation of partial_fit is not optimized'
' for use in a loop. Consider using fit() instead.')
- return self.fit(x=x, y=y, input_fn=input_fn, steps=steps,
- batch_size=batch_size, monitors=monitors)
+ return self.fit(
+ x=x,
+ y=y,
+ input_fn=input_fn,
+ steps=steps,
+ batch_size=batch_size,
+ monitors=monitors)
- @deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
- ('y', None), ('batch_size', None)
- )
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
+ ('x', None), ('y', None), ('batch_size', None))
def evaluate(self,
x=None,
y=None,
@@ -584,13 +593,14 @@ class BaseEstimator(
eval_results.update({'global_step': global_step})
return eval_results
- @deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
- ('batch_size', None), ('as_iterable', True)
- )
- def predict(
- self, x=None, input_fn=None, batch_size=None, outputs=None,
- as_iterable=True):
+ @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
+ ('x', None), ('batch_size', None), ('as_iterable', True))
+ def predict(self,
+ x=None,
+ input_fn=None,
+ batch_size=None,
+ outputs=None,
+ as_iterable=True):
"""Returns predictions for given features.
Args:
@@ -651,16 +661,17 @@ class BaseEstimator(
return self._model_dir
@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.')
- def export(self,
- export_dir,
- input_fn=export._default_input_fn, # pylint: disable=protected-access
- input_feature_key=None,
- use_deprecated_input_fn=True,
- signature_fn=None,
- prediction_key=None,
- default_batch_size=1,
- exports_to_keep=None,
- checkpoint_path=None):
+ def export(
+ self,
+ export_dir,
+ input_fn=export._default_input_fn, # pylint: disable=protected-access
+ input_feature_key=None,
+ use_deprecated_input_fn=True,
+ signature_fn=None,
+ prediction_key=None,
+ default_batch_size=1,
+ exports_to_keep=None,
+ checkpoint_path=None):
"""Exports inference graph into given dir.
Args:
@@ -798,8 +809,8 @@ class BaseEstimator(
logging.debug('Setting feature info to %s.', str(self._features_info))
if labels is not None:
if self._labels_info is not None:
- logging.debug('Given labels: %s, required signatures: %s.',
- str(labels), str(self._labels_info))
+ logging.debug('Given labels: %s, required signatures: %s.', str(labels),
+ str(self._labels_info))
if not tensor_signature.tensors_compatible(labels, self._labels_info):
raise ValueError('Labels are incompatible with given information. '
'Given labels: %s, required signatures: %s.' %
@@ -850,13 +861,13 @@ class BaseEstimator(
if not checkpoint_path:
latest_path = saver.latest_checkpoint(self._model_dir)
if not latest_path:
- raise NotFittedError("Couldn't find trained model at %s."
- % self._model_dir)
+ raise NotFittedError(
+ "Couldn't find trained model at %s." % self._model_dir)
checkpoint_path = latest_path
# Setup output directory.
- eval_dir = os.path.join(self._model_dir, 'eval' if not name else
- 'eval_' + name)
+ eval_dir = os.path.join(self._model_dir, 'eval'
+ if not name else 'eval_' + name)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -879,8 +890,7 @@ class BaseEstimator(
'Use steps=None if intended.')
if steps:
hooks.append(
- evaluation.StopAfterNEvalsHook(
- steps, log_progress=log_progress))
+ evaluation.StopAfterNEvalsHook(steps, log_progress=log_progress))
global_step_key = 'global_step'
while global_step_key in eval_dict:
@@ -916,8 +926,8 @@ class BaseEstimator(
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
- raise NotFittedError("Couldn't find trained model at %s."
- % self._model_dir)
+ raise NotFittedError(
+ "Couldn't find trained model at %s." % self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -979,7 +989,8 @@ class BaseEstimator(
existing_keys = predictions.keys()
predictions = {
key: value
- for key, value in six.iteritems(predictions) if key in outputs
+ for key, value in six.iteritems(predictions)
+ if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
@@ -1045,8 +1056,7 @@ class BaseEstimator(
chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
- config=self._session_config
- ) as mon_sess:
+ config=self._session_config) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
@@ -1137,8 +1147,7 @@ class Estimator(BaseEstimator):
if params is not None and 'params' not in model_fn_args:
raise ValueError('Estimator\'s model_fn (%s) does not have a params '
'argument, but params (%s) were passed to the '
- 'Estimator\'s constructor.' %
- (model_fn, params))
+ 'Estimator\'s constructor.' % (model_fn, params))
if params is None and 'params' in model_fn_args:
logging.warning('Estimator\'s model_fn (%s) includes params '
'argument, but params are not passed to Estimator.',
@@ -1192,8 +1201,9 @@ class Estimator(BaseEstimator):
# Custom metrics should overwrite defaults.
if metrics:
- model_fn_ops.eval_metric_ops.update(_make_metrics_ops(
- metrics, features, labels, model_fn_ops.predictions))
+ model_fn_ops.eval_metric_ops.update(
+ _make_metrics_ops(metrics, features, labels,
+ model_fn_ops.predictions))
return model_fn_ops
@@ -1238,8 +1248,8 @@ class Estimator(BaseEstimator):
Raises:
ValueError: if `metrics` don't match `labels`.
"""
- model_fn_ops = self._call_model_fn(
- features, labels, model_fn_lib.ModeKeys.EVAL, metrics)
+ model_fn_ops = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.EVAL, metrics)
if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops:
model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = (
@@ -1263,14 +1273,16 @@ class Estimator(BaseEstimator):
self._labels_info)
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER)
- def export_savedmodel(
- self, export_dir_base, serving_input_fn,
- default_output_alternative_key=None,
- assets_extra=None,
- as_text=False,
- checkpoint_path=None,
- graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),),
- strip_default_attrs=False):
+ def export_savedmodel(self,
+ export_dir_base,
+ serving_input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None,
+ graph_rewrite_specs=(GraphRewriteSpec(
+ (tag_constants.SERVING,), ()),),
+ strip_default_attrs=False):
# pylint: disable=line-too-long
"""Exports inference graph as a SavedModel into given dir.
@@ -1297,7 +1309,8 @@ class Estimator(BaseEstimator):
default serving tag ("serve") and no rewriting.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ [Stripping Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
The string path to the exported directory.
@@ -1313,8 +1326,8 @@ class Estimator(BaseEstimator):
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
- raise NotFittedError("Couldn't find trained model at %s."
- % self._model_dir)
+ raise NotFittedError(
+ "Couldn't find trained model at %s." % self._model_dir)
export_dir = saved_model_export_utils.get_timestamped_export_dir(
export_dir_base)
@@ -1348,10 +1361,10 @@ class Estimator(BaseEstimator):
saved_model_export_utils.get_output_alternatives(
model_fn_ops, default_output_alternative_key))
- init_op = control_flow_ops.group(
- variables.local_variables_initializer(),
- resources.initialize_resources(resources.shared_resources()),
- lookup_ops.tables_initializer())
+ init_op = control_flow_ops.group(variables.local_variables_initializer(),
+ resources.initialize_resources(
+ resources.shared_resources()),
+ lookup_ops.tables_initializer())
# Build the SignatureDefs from all pairs of input and output alternatives
signature_def_map = saved_model_export_utils.build_all_signature_defs(
@@ -1381,10 +1394,10 @@ class Estimator(BaseEstimator):
# TODO(soergel): switch to main_op or otherwise update when dust settles
builder.add_meta_graph_and_variables(
- session, untransformed_tags,
+ session,
+ untransformed_tags,
signature_def_map=signature_def_map,
- assets_collection=ops.get_collection(
- ops.GraphKeys.ASSET_FILEPATHS),
+ assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=init_op,
strip_default_attrs=strip_default_attrs)
@@ -1395,12 +1408,16 @@ class Estimator(BaseEstimator):
if graph_rewrite_specs[1:]:
# Prepare the input_names and output_names needed for the
# meta_graph_transform call below.
- input_names = [tensor.name
- for input_dict in input_alternatives.values()
- for tensor in input_dict.values()]
- output_names = [tensor.name
- for output_alternative in output_alternatives.values()
- for tensor in output_alternative[1].values()]
+ input_names = [
+ tensor.name
+ for input_dict in input_alternatives.values()
+ for tensor in input_dict.values()
+ ]
+ output_names = [
+ tensor.name
+ for output_alternative in output_alternatives.values()
+ for tensor in output_alternative[1].values()
+ ]
# Write the additional MetaGraphDefs
for graph_rewrite_spec in graph_rewrite_specs[1:]:
@@ -1419,11 +1436,11 @@ class Estimator(BaseEstimator):
# Add the extra assets
if assets_extra:
- assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
- compat.as_bytes('assets.extra'))
+ assets_extra_path = os.path.join(
+ compat.as_bytes(temp_export_dir), compat.as_bytes('assets.extra'))
for dest_relative, source in assets_extra.items():
- dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
- compat.as_bytes(dest_relative))
+ dest_absolute = os.path.join(
+ compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative))
dest_path = os.path.dirname(dest_absolute)
gfile.MakeDirs(dest_path)
gfile.Copy(source, dest_absolute)
@@ -1443,25 +1460,36 @@ class SKCompat(sklearn.BaseEstimator):
def fit(self, x, y, batch_size=128, steps=None, max_steps=None,
monitors=None):
- input_fn, feed_fn = _get_input_fn(x, y, input_fn=None, feed_fn=None,
- batch_size=batch_size, shuffle=True,
- epochs=None)
+ input_fn, feed_fn = _get_input_fn(
+ x,
+ y,
+ input_fn=None,
+ feed_fn=None,
+ batch_size=batch_size,
+ shuffle=True,
+ epochs=None)
all_monitors = []
if feed_fn:
all_monitors = [basic_session_run_hooks.FeedFnHook(feed_fn)]
if monitors:
all_monitors.extend(monitors)
- self._estimator.fit(input_fn=input_fn,
- steps=steps,
- max_steps=max_steps,
- monitors=all_monitors)
+ self._estimator.fit(
+ input_fn=input_fn,
+ steps=steps,
+ max_steps=max_steps,
+ monitors=all_monitors)
return self
def score(self, x, y, batch_size=128, steps=None, metrics=None, name=None):
- input_fn, feed_fn = _get_input_fn(x, y, input_fn=None,
- feed_fn=None, batch_size=batch_size,
- shuffle=False, epochs=1)
+ input_fn, feed_fn = _get_input_fn(
+ x,
+ y,
+ input_fn=None,
+ feed_fn=None,
+ batch_size=batch_size,
+ shuffle=False,
+ epochs=1)
if metrics is not None and not isinstance(metrics, dict):
raise ValueError('Metrics argument should be None or dict. '
'Got %s.' % metrics)
@@ -1477,8 +1505,13 @@ class SKCompat(sklearn.BaseEstimator):
def predict(self, x, batch_size=128, outputs=None):
input_fn, feed_fn = _get_input_fn(
- x, None, input_fn=None, feed_fn=None, batch_size=batch_size,
- shuffle=False, epochs=1)
+ x,
+ None,
+ input_fn=None,
+ feed_fn=None,
+ batch_size=batch_size,
+ shuffle=False,
+ epochs=1)
results = list(
self._estimator._infer_model(
input_fn=input_fn,
@@ -1489,7 +1522,6 @@ class SKCompat(sklearn.BaseEstimator):
if not isinstance(results[0], dict):
return np.concatenate([output for output in results], axis=0)
return {
- key: np.concatenate(
- [output[key] for output in results], axis=0)
+ key: np.concatenate([output[key] for output in results], axis=0)
for key in results[0]
}
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 5f682838b7..d81a534b79 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -111,8 +111,8 @@ def boston_eval_fn():
constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM])
labels = array_ops.reshape(
constant_op.constant(boston.target), [n_examples, 1])
- return array_ops.concat([features, features], 0), array_ops.concat(
- [labels, labels], 0)
+ return array_ops.concat([features, features],
+ 0), array_ops.concat([labels, labels], 0)
def extract(data, key):
@@ -147,7 +147,10 @@ def linear_model_fn(features, labels, mode):
(_, features), = features.items()
prediction, loss = (models.linear_regression_zero_init(features, labels))
train_op = optimizers.optimize_loss(
- loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1)
+ loss,
+ training_util.get_global_step(),
+ optimizer='Adagrad',
+ learning_rate=0.1)
return prediction, loss, train_op
@@ -157,7 +160,10 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode):
model_fn.ModeKeys.INFER)
prediction, loss = (models.linear_regression_zero_init(features, labels))
train_op = optimizers.optimize_loss(
- loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1)
+ loss,
+ training_util.get_global_step(),
+ optimizer='Adagrad',
+ learning_rate=0.1)
return model_fn.ModelFnOps(
mode=mode, predictions=prediction, loss=loss, train_op=train_op)
@@ -168,7 +174,10 @@ def logistic_model_no_mode_fn(features, labels):
labels = array_ops.one_hot(labels, 3, 1, 0)
prediction, loss = (models.logistic_regression_zero_init(features, labels))
train_op = optimizers.optimize_loss(
- loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1)
+ loss,
+ training_util.get_global_step(),
+ optimizer='Adagrad',
+ learning_rate=0.1)
return {
'class': math_ops.argmax(prediction, 1),
'prob': prediction
@@ -184,14 +193,12 @@ def _build_estimator_for_export_tests(tmpdir):
def _input_fn():
iris = base.load_iris()
return {
- 'feature': constant_op.constant(
- iris.data, dtype=dtypes.float32)
+ 'feature': constant_op.constant(iris.data, dtype=dtypes.float32)
}, constant_op.constant(
iris.target, shape=[150], dtype=dtypes.int32)
feature_columns = [
- feature_column_lib.real_valued_column(
- 'feature', dimension=4)
+ feature_column_lib.real_valued_column('feature', dimension=4)
]
est = linear.LinearRegressor(feature_columns)
@@ -291,8 +298,8 @@ class CheckCallsMonitor(monitors_lib.BaseMonitor):
self.begin_calls == self.expect_calls)
-def _model_fn_ops(
- expected_features, expected_labels, actual_features, actual_labels, mode):
+def _model_fn_ops(expected_features, expected_labels, actual_features,
+ actual_labels, mode):
assert_ops = tuple([
check_ops.assert_equal(
expected_features[k], actual_features[k], name='assert_%s' % k)
@@ -310,11 +317,11 @@ def _model_fn_ops(
def _make_input_fn(features, labels):
+
def _input_fn():
- return {
- k: constant_op.constant(v)
- for k, v in six.iteritems(features)
- }, constant_op.constant(labels)
+ return {k: constant_op.constant(v)
+ for k, v in six.iteritems(features)}, constant_op.constant(labels)
+
return _input_fn
@@ -369,11 +376,13 @@ class EstimatorModelFnTest(test.TestCase):
self.assertEqual(expected_params, params)
self.assertTrue(config.i_am_test)
return _model_fn_ops(features, labels, arg0, arg1, mode)
+
partial_model_fn = functools.partial(
_model_fn, foo=expected_foo, bar=expected_bar)
est = estimator.Estimator(
- model_fn=partial_model_fn, params=expected_params,
+ model_fn=partial_model_fn,
+ params=expected_params,
config=expected_config)
self.assertEqual(0, model_fn_call_count[0])
est.fit(input_fn=_make_input_fn(features, labels), steps=1)
@@ -382,7 +391,12 @@ class EstimatorModelFnTest(test.TestCase):
def testModelFnWithModelDir(self):
expected_param = {'some_param': 'some_value'}
expected_model_dir = tempfile.mkdtemp()
- def _argument_checker(features, labels, mode, params, config=None,
+
+ def _argument_checker(features,
+ labels,
+ mode,
+ params,
+ config=None,
model_dir=None):
_, _, _ = features, labels, config
self.assertEqual(model_fn.ModeKeys.TRAIN, mode)
@@ -390,9 +404,11 @@ class EstimatorModelFnTest(test.TestCase):
self.assertEqual(model_dir, expected_model_dir)
return (constant_op.constant(0.), constant_op.constant(0.),
training_util.get_global_step().assign_add(1))
- est = estimator.Estimator(model_fn=_argument_checker,
- params=expected_param,
- model_dir=expected_model_dir)
+
+ est = estimator.Estimator(
+ model_fn=_argument_checker,
+ params=expected_param,
+ model_dir=expected_model_dir)
est.fit(input_fn=boston_input_fn, steps=1)
def testInvalidModelFn_no_train_op(self):
@@ -447,8 +463,7 @@ class EstimatorModelFnTest(test.TestCase):
est.predict(input_fn=boston_input_fn)
with self.assertRaisesRegexp(ValueError, 'Missing prediction'):
est.predict(
- input_fn=functools.partial(
- boston_input_fn, num_epochs=1),
+ input_fn=functools.partial(boston_input_fn, num_epochs=1),
as_iterable=True)
def testModelFnScaffoldInTraining(self):
@@ -498,15 +513,17 @@ class EstimatorModelFnTest(test.TestCase):
self.assertTrue(self.mock_saver.restore.called)
est.predict(input_fn=input_fn)
self.assertTrue(self.mock_saver.restore.called)
+
def serving_input_fn():
- serialized_tf_example = array_ops.placeholder(dtype=dtypes.string,
- shape=[None],
- name='input_example_tensor')
+ serialized_tf_example = array_ops.placeholder(
+ dtype=dtypes.string, shape=[None], name='input_example_tensor')
features, labels = input_fn()
- return input_fn_utils.InputFnOps(
- features, labels, {'examples': serialized_tf_example})
+ return input_fn_utils.InputFnOps(features, labels, {
+ 'examples': serialized_tf_example
+ })
- est.export_savedmodel(os.path.join(est.model_dir, 'export'), serving_input_fn)
+ est.export_savedmodel(
+ os.path.join(est.model_dir, 'export'), serving_input_fn)
self.assertTrue(self.mock_saver.restore.called)
@@ -550,33 +567,28 @@ class EstimatorTest(test.TestCase):
def testRunConfigModelDir(self):
config = run_config.RunConfig(model_dir='test_dir')
- est = estimator.Estimator(model_fn=linear_model_fn,
- config=config)
+ est = estimator.Estimator(model_fn=linear_model_fn, config=config)
self.assertEqual('test_dir', est.config.model_dir)
self.assertEqual('test_dir', est.model_dir)
def testModelDirAndRunConfigModelDir(self):
config = run_config.RunConfig(model_dir='test_dir')
- est = estimator.Estimator(model_fn=linear_model_fn,
- config=config,
- model_dir='test_dir')
+ est = estimator.Estimator(
+ model_fn=linear_model_fn, config=config, model_dir='test_dir')
self.assertEqual('test_dir', est.config.model_dir)
with self.assertRaisesRegexp(
- ValueError,
- 'model_dir are set both in constructor and RunConfig, '
+ ValueError, 'model_dir are set both in constructor and RunConfig, '
'but with different'):
- estimator.Estimator(model_fn=linear_model_fn,
- config=config,
- model_dir='different_dir')
+ estimator.Estimator(
+ model_fn=linear_model_fn, config=config, model_dir='different_dir')
def testModelDirIsCopiedToRunConfig(self):
config = run_config.RunConfig()
self.assertIsNone(config.model_dir)
- est = estimator.Estimator(model_fn=linear_model_fn,
- model_dir='test_dir',
- config=config)
+ est = estimator.Estimator(
+ model_fn=linear_model_fn, model_dir='test_dir', config=config)
self.assertEqual('test_dir', est.config.model_dir)
self.assertEqual('test_dir', est.model_dir)
@@ -656,25 +668,27 @@ class EstimatorTest(test.TestCase):
boston = base.load_boston()
output_dir = tempfile.mkdtemp()
est = estimator.SKCompat(
- estimator.Estimator(
- model_fn=linear_model_fn, model_dir=output_dir))
+ estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir))
float64_labels = boston.target.astype(np.float64)
est.fit(x=boston.data, y=float64_labels, steps=50)
scores = est.score(
x=boston.data,
y=float64_labels,
- metrics={'MSE': metric_ops.streaming_mean_squared_error})
+ metrics={
+ 'MSE': metric_ops.streaming_mean_squared_error
+ })
del est
# Create another estimator object with the same output dir.
est2 = estimator.SKCompat(
- estimator.Estimator(
- model_fn=linear_model_fn, model_dir=output_dir))
+ estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir))
# Check we can evaluate and predict.
scores2 = est2.score(
x=boston.data,
y=float64_labels,
- metrics={'MSE': metric_ops.streaming_mean_squared_error})
+ metrics={
+ 'MSE': metric_ops.streaming_mean_squared_error
+ })
self.assertAllClose(scores['MSE'], scores2['MSE'])
predictions = np.array(list(est2.predict(x=boston.data)))
other_score = _sklearn.mean_squared_error(predictions, float64_labels)
@@ -685,14 +699,15 @@ class EstimatorTest(test.TestCase):
scores3 = est2.score(
x=boston.data,
y=float64_labels,
- metrics={'MSE': metric_ops.streaming_mean_squared_error})
+ metrics={
+ 'MSE': metric_ops.streaming_mean_squared_error
+ })
self.assertLess(scores3['MSE'], scores['MSE'])
def test_checkpoint_contains_relative_paths(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
- model_dir=tmpdir,
- model_fn=linear_model_fn_with_model_fn_ops)
+ model_dir=tmpdir, model_fn=linear_model_fn_with_model_fn_ops)
est.fit(input_fn=boston_input_fn, steps=5)
checkpoint_file_content = file_io.read_file_to_string(
@@ -700,22 +715,20 @@ class EstimatorTest(test.TestCase):
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
- self.assertAllEqual(
- ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+ self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'],
+ ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
tmpdir = tempfile.mkdtemp()
model_dir1 = os.path.join(tmpdir, 'model_dir1')
est1 = estimator.Estimator(
- model_dir=model_dir1,
- model_fn=linear_model_fn_with_model_fn_ops)
+ model_dir=model_dir1, model_fn=linear_model_fn_with_model_fn_ops)
est1.fit(input_fn=boston_input_fn, steps=5)
model_dir2 = os.path.join(tmpdir, 'model_dir2')
os.renames(model_dir1, model_dir2)
est2 = estimator.Estimator(
- model_dir=model_dir2,
- model_fn=linear_model_fn_with_model_fn_ops)
+ model_dir=model_dir2, model_fn=linear_model_fn_with_model_fn_ops)
self.assertEqual(5, est2.get_variable_value('global_step'))
est2.fit(input_fn=boston_input_fn, steps=5)
self.assertEqual(10, est2.get_variable_value('global_step'))
@@ -724,7 +737,9 @@ class EstimatorTest(test.TestCase):
boston = base.load_boston()
est = estimator.SKCompat(
estimator.Estimator(
- model_fn=linear_model_params_fn, params={'learning_rate': 0.01}))
+ model_fn=linear_model_params_fn, params={
+ 'learning_rate': 0.01
+ }))
est.fit(x=boston.data, y=boston.target, steps=100)
def testHooksNotChanged(self):
@@ -824,11 +839,13 @@ class EstimatorTest(test.TestCase):
def testMonitorsForFit(self):
est = estimator.Estimator(model_fn=linear_model_fn)
- est.fit(input_fn=boston_input_fn,
- steps=21,
- monitors=[CheckCallsMonitor(expect_calls=21)])
+ est.fit(
+ input_fn=boston_input_fn,
+ steps=21,
+ monitors=[CheckCallsMonitor(expect_calls=21)])
def testHooksForEvaluate(self):
+
class CheckCallHook(session_run_hook.SessionRunHook):
def __init__(self):
@@ -874,7 +891,9 @@ class EstimatorTest(test.TestCase):
est.evaluate(
input_fn=boston_input_fn,
steps=200,
- metrics={'MSE': _streaming_mean_squared_error_histogram})
+ metrics={
+ 'MSE': _streaming_mean_squared_error_histogram
+ })
events = util_test.latest_events(est.model_dir + '/eval')
output_values = {}
for e in events:
@@ -903,7 +922,9 @@ class EstimatorTest(test.TestCase):
est.evaluate(
input_fn=boston_input_fn,
steps=200,
- metrics={'PMT': _streaming_precition_mean_tensor})
+ metrics={
+ 'PMT': _streaming_precition_mean_tensor
+ })
events = util_test.latest_events(est.model_dir + '/eval')
output_values = {}
for e in events:
@@ -956,8 +977,8 @@ class EstimatorTest(test.TestCase):
self.assertTrue(
gfile.Exists(
os.path.join(
- compat.as_bytes(export_dir), compat.as_bytes(
- 'saved_model.pb'))))
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
self.assertTrue(
gfile.Exists(
os.path.join(
@@ -1017,11 +1038,11 @@ class EstimatorTest(test.TestCase):
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
self.assertTrue('linear/linear/feature/matmul' in graph_ops)
- self.assertItemsEqual(
- ['bogus_lookup', 'feature'],
- [compat.as_str_any(x) for x in graph.get_collection(
- constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)])
-
+ self.assertItemsEqual(['bogus_lookup', 'feature'], [
+ compat.as_str_any(x)
+ for x in graph.get_collection(
+ constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)
+ ])
# cleanup
gfile.DeleteRecursively(tmpdir)
@@ -1039,8 +1060,8 @@ class EstimatorTest(test.TestCase):
self.assertTrue(
gfile.Exists(
os.path.join(
- compat.as_bytes(export_dir), compat.as_bytes(
- 'saved_model.pb'))))
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
self.assertTrue(
gfile.Exists(
os.path.join(
@@ -1083,19 +1104,22 @@ class EstimatorTest(test.TestCase):
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
export_dir = est.export_savedmodel(
- export_dir_base, serving_input_fn, assets_extra=assets_extra,
+ export_dir_base,
+ serving_input_fn,
+ assets_extra=assets_extra,
graph_rewrite_specs=[
estimator.GraphRewriteSpec(['tag_1'], []),
estimator.GraphRewriteSpec(['tag_2', 'tag_3'],
- ['strip_unused_nodes'])])
+ ['strip_unused_nodes'])
+ ])
self.assertTrue(gfile.Exists(export_dir_base))
self.assertTrue(gfile.Exists(export_dir))
self.assertTrue(
gfile.Exists(
os.path.join(
- compat.as_bytes(export_dir), compat.as_bytes(
- 'saved_model.pb'))))
+ compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
self.assertTrue(
gfile.Exists(
os.path.join(
@@ -1208,18 +1232,15 @@ class InferRealValuedColumnsTest(test.TestCase):
self.assertEqual(1, len(feature_columns))
feature_column = feature_columns[0]
self.assertEqual('', feature_column.name)
- self.assertEqual(
- {
- '':
- parsing_ops.FixedLenFeature(
- shape=expected_shape, dtype=expected_dtype)
- },
- feature_column.config)
+ self.assertEqual({
+ '':
+ parsing_ops.FixedLenFeature(
+ shape=expected_shape, dtype=expected_dtype)
+ }, feature_column.config)
def testInt32Input(self):
feature_columns = estimator.infer_real_valued_columns_from_input(
- np.ones(
- shape=[7, 8], dtype=np.int32))
+ np.ones(shape=[7, 8], dtype=np.int32))
self._assert_single_feature_column([8], dtypes.int32, feature_columns)
def testInt32InputFn(self):
@@ -1229,8 +1250,7 @@ class InferRealValuedColumnsTest(test.TestCase):
def testInt64Input(self):
feature_columns = estimator.infer_real_valued_columns_from_input(
- np.ones(
- shape=[7, 8], dtype=np.int64))
+ np.ones(shape=[7, 8], dtype=np.int64))
self._assert_single_feature_column([8], dtypes.int64, feature_columns)
def testInt64InputFn(self):
@@ -1240,8 +1260,7 @@ class InferRealValuedColumnsTest(test.TestCase):
def testFloat32Input(self):
feature_columns = estimator.infer_real_valued_columns_from_input(
- np.ones(
- shape=[7, 8], dtype=np.float32))
+ np.ones(shape=[7, 8], dtype=np.float32))
self._assert_single_feature_column([8], dtypes.float32, feature_columns)
def testFloat32InputFn(self):
@@ -1251,8 +1270,7 @@ class InferRealValuedColumnsTest(test.TestCase):
def testFloat64Input(self):
feature_columns = estimator.infer_real_valued_columns_from_input(
- np.ones(
- shape=[7, 8], dtype=np.float64))
+ np.ones(shape=[7, 8], dtype=np.float64))
self._assert_single_feature_column([8], dtypes.float64, feature_columns)
def testFloat64InputFn(self):
@@ -1271,8 +1289,8 @@ class InferRealValuedColumnsTest(test.TestCase):
ValueError, 'on integer or non floating types are not supported'):
# pylint: disable=g-long-lambda
estimator.infer_real_valued_columns_from_input_fn(
- lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool),
- None))
+ lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool), None)
+ )
def testStringInput(self):
with self.assertRaisesRegexp(
@@ -1309,8 +1327,9 @@ class ReplicaDeviceSetterTest(test.TestCase):
def testVariablesAreOnPs(self):
tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}}
- with test.mock.patch.dict('os.environ',
- {'TF_CONFIG': json.dumps(tf_config)}):
+ with test.mock.patch.dict('os.environ', {
+ 'TF_CONFIG': json.dumps(tf_config)
+ }):
config = run_config.RunConfig()
with ops.device(estimator._get_replica_device_setter(config)):
@@ -1337,14 +1356,14 @@ class ReplicaDeviceSetterTest(test.TestCase):
def testMutableHashTableIsOnPs(self):
tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}}
- with test.mock.patch.dict('os.environ',
- {'TF_CONFIG': json.dumps(tf_config)}):
+ with test.mock.patch.dict('os.environ', {
+ 'TF_CONFIG': json.dumps(tf_config)
+ }):
config = run_config.RunConfig()
with ops.device(estimator._get_replica_device_setter(config)):
default_val = constant_op.constant([-1, -1], dtypes.int64)
- table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
- default_val)
+ table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
input_string = constant_op.constant(['brain', 'salad', 'tank'])
output = table.lookup(input_string)
self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device)
@@ -1354,8 +1373,7 @@ class ReplicaDeviceSetterTest(test.TestCase):
with ops.device(
estimator._get_replica_device_setter(run_config.RunConfig())):
default_val = constant_op.constant([-1, -1], dtypes.int64)
- table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
- default_val)
+ table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
input_string = constant_op.constant(['brain', 'salad', 'tank'])
output = table.lookup(input_string)
self.assertDeviceEqual('', table._table_ref.device)
@@ -1371,8 +1389,9 @@ class ReplicaDeviceSetterTest(test.TestCase):
'index': 3
}
}
- with test.mock.patch.dict('os.environ',
- {'TF_CONFIG': json.dumps(tf_config)}):
+ with test.mock.patch.dict('os.environ', {
+ 'TF_CONFIG': json.dumps(tf_config)
+ }):
config = run_config.RunConfig()
with ops.device(estimator._get_replica_device_setter(config)):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
index 8131e0fde6..2113fae394 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py
@@ -72,9 +72,11 @@ class FeatureEngineeringFunctionTest(test.TestCase):
# predictions = transformed_x (9)
self.assertEqual(9., prediction)
metrics = estimator.evaluate(
- input_fn=input_fn, steps=1,
- metrics={"label":
- metric_spec.MetricSpec(lambda predictions, labels: labels)})
+ input_fn=input_fn,
+ steps=1,
+ metrics={
+ "label": metric_spec.MetricSpec(lambda predictions, labels: labels)
+ })
# labels = transformed_y (99)
self.assertEqual(99., metrics["label"])
@@ -82,10 +84,10 @@ class FeatureEngineeringFunctionTest(test.TestCase):
def input_fn():
return {
- "x": constant_op.constant(["9."])
- }, {
- "y": constant_op.constant(["99."])
- }
+ "x": constant_op.constant(["9."])
+ }, {
+ "y": constant_op.constant(["99."])
+ }
def feature_engineering_fn(features, labels):
# Github #12205: raise a TypeError if called twice.
@@ -104,15 +106,17 @@ class FeatureEngineeringFunctionTest(test.TestCase):
return predictions, loss, update_global_step
estimator = estimator_lib.Estimator(
- model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
+ model_fn=model_fn, feature_engineering_fn=feature_engineering_fn)
estimator.fit(input_fn=input_fn, steps=1)
prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True))
# predictions = transformed_x (9)
self.assertEqual(9., prediction)
metrics = estimator.evaluate(
- input_fn=input_fn, steps=1,
- metrics={"label":
- metric_spec.MetricSpec(lambda predictions, labels: labels)})
+ input_fn=input_fn,
+ steps=1,
+ metrics={
+ "label": metric_spec.MetricSpec(lambda predictions, labels: labels)
+ })
# labels = transformed_y (99)
self.assertEqual(99., metrics["label"])
@@ -150,12 +154,10 @@ class FeatureEngineeringFunctionTest(test.TestCase):
# predictions = x
prediction_with_fe_fn = next(
- estimator_with_fe_fn.predict(
- input_fn=input_fn, as_iterable=True))
+ estimator_with_fe_fn.predict(input_fn=input_fn, as_iterable=True))
self.assertEqual(9., prediction_with_fe_fn)
prediction_without_fe_fn = next(
- estimator_without_fe_fn.predict(
- input_fn=input_fn, as_iterable=True))
+ estimator_without_fe_fn.predict(input_fn=input_fn, as_iterable=True))
self.assertEqual(1., prediction_without_fe_fn)
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 3e0b1ad21a..0948dee7e2 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Monitors instrument the training process.
@@get_default_monitors
@@ -151,8 +150,8 @@ class BaseMonitor(object):
ValueError: if we've not begun an epoch, or `epoch` number does not match.
"""
if self._current_epoch != epoch:
- raise ValueError(
- "epoch_end expected %s but got %s.", self._current_epoch, epoch)
+ raise ValueError("epoch_end expected %s but got %s.", self._current_epoch,
+ epoch)
self._current_epoch = None
def step_begin(self, step):
@@ -171,8 +170,8 @@ class BaseMonitor(object):
ValueError: if we've already begun a step, or `step` < 0, or
`step` > `max_steps`.
"""
- if (step < 0) or (
- (self._max_steps is not None) and (step > self._max_steps)):
+ if (step < 0) or ((self._max_steps is not None) and
+ (step > self._max_steps)):
raise ValueError("Invalid step %s." % step)
self._current_step = step
return []
@@ -203,8 +202,8 @@ class BaseMonitor(object):
ValueError: if we've not begun a step, or `step` number does not match.
"""
if self._current_step != step:
- raise ValueError(
- "step_end expected %s but got %s.", self._current_step, step)
+ raise ValueError("step_end expected %s but got %s.", self._current_step,
+ step)
self._current_step = None
return False
@@ -253,6 +252,7 @@ class EveryN(BaseMonitor):
treatment.
"""
+
# TODO(ipolosukhin): Add also every n seconds.
def __init__(self, every_n_steps=100, first_n_steps=1):
@@ -475,8 +475,8 @@ class LoggingTrainable(EveryN):
super(LoggingTrainable, self).every_n_step_begin(step)
# Get a list of trainable variables at the beginning of every N steps.
# We cannot get this in __init__ because train_op has not been generated.
- trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=self._scope)
+ trainables = ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES, scope=self._scope)
self._names = {}
for var in trainables:
self._names[var.name] = var.value().name
@@ -561,12 +561,19 @@ class ValidationMonitor(EveryN):
provided.
"""
- def __init__(self, x=None, y=None, input_fn=None, batch_size=None,
+ def __init__(self,
+ x=None,
+ y=None,
+ input_fn=None,
+ batch_size=None,
eval_steps=None,
- every_n_steps=100, metrics=None, hooks=None,
+ every_n_steps=100,
+ metrics=None,
+ hooks=None,
early_stopping_rounds=None,
early_stopping_metric="loss",
- early_stopping_metric_minimize=True, name=None):
+ early_stopping_metric_minimize=True,
+ name=None):
"""Initializes a ValidationMonitor.
Args:
@@ -597,8 +604,8 @@ class ValidationMonitor(EveryN):
Raises:
ValueError: If both x and input_fn are provided.
"""
- super(ValidationMonitor, self).__init__(every_n_steps=every_n_steps,
- first_n_steps=-1)
+ super(ValidationMonitor, self).__init__(
+ every_n_steps=every_n_steps, first_n_steps=-1)
# TODO(mdan): Checks like this are already done by evaluate.
if x is None and input_fn is None:
raise ValueError("Either x or input_fn should be provided.")
@@ -654,20 +661,27 @@ class ValidationMonitor(EveryN):
def _evaluate_estimator(self):
if isinstance(self._estimator, core_estimator.Estimator):
- if any((x is not None for x in
- [self.x, self.y, self.batch_size, self.metrics])):
+ if any((x is not None
+ for x in [self.x, self.y, self.batch_size, self.metrics])):
raise ValueError(
"tf.estimator.Estimator does not support following "
"arguments: x, y, batch_size, metrics. Should set as `None` "
"in ValidationMonitor")
return self._estimator.evaluate(
- input_fn=self.input_fn, steps=self.eval_steps, hooks=self.hooks,
+ input_fn=self.input_fn,
+ steps=self.eval_steps,
+ hooks=self.hooks,
name=self.name)
else:
return self._estimator.evaluate(
- x=self.x, y=self.y, input_fn=self.input_fn,
- batch_size=self.batch_size, steps=self.eval_steps,
- metrics=self.metrics, hooks=self.hooks, name=self.name)
+ x=self.x,
+ y=self.y,
+ input_fn=self.input_fn,
+ batch_size=self.batch_size,
+ steps=self.eval_steps,
+ metrics=self.metrics,
+ hooks=self.hooks,
+ name=self.name)
def every_n_step_end(self, step, outputs):
super(ValidationMonitor, self).every_n_step_end(step, outputs)
@@ -700,8 +714,9 @@ class ValidationMonitor(EveryN):
# Early stopping logic.
if self.early_stopping_rounds is not None:
if self.early_stopping_metric not in validation_outputs:
- raise ValueError("Metric %s missing from outputs %s." % (
- self.early_stopping_metric, set(validation_outputs.keys())))
+ raise ValueError("Metric %s missing from outputs %s." %
+ (self.early_stopping_metric,
+ set(validation_outputs.keys())))
current_value = validation_outputs[self.early_stopping_metric]
if (self._best_value is None or (self.early_stopping_metric_minimize and
(current_value < self._best_value)) or
@@ -712,9 +727,9 @@ class ValidationMonitor(EveryN):
self._best_value_step = step
stop_now = (step - self._best_value_step >= self.early_stopping_rounds)
if stop_now:
- logging.info("Stopping. Best step: {} with {} = {}."
- .format(self._best_value_step,
- self.early_stopping_metric, self._best_value))
+ logging.info("Stopping. Best step: {} with {} = {}.".format(
+ self._best_value_step, self.early_stopping_metric,
+ self._best_value))
self._early_stopped = True
return True
return False
@@ -763,8 +778,11 @@ class CaptureVariable(EveryN):
self._var_values[step] = _extract_output(outputs, self._var_name)
-def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100,
- output_dir=None, summary_writer=None):
+def get_default_monitors(loss_op=None,
+ summary_op=None,
+ save_summary_steps=100,
+ output_dir=None,
+ summary_writer=None):
"""Returns a default set of typically-used monitors.
Args:
@@ -782,9 +800,12 @@ def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100,
if loss_op is not None:
monitors.append(PrintTensor(tensor_names={"loss": loss_op.name}))
if summary_op is not None:
- monitors.append(SummarySaver(summary_op, save_steps=save_summary_steps,
- output_dir=output_dir,
- summary_writer=summary_writer))
+ monitors.append(
+ SummarySaver(
+ summary_op,
+ save_steps=save_summary_steps,
+ output_dir=output_dir,
+ summary_writer=summary_writer))
return monitors
@@ -794,8 +815,10 @@ class GraphDump(BaseMonitor):
Note, this is very expensive, prefer `PrintTensor` in production.
"""
- IGNORE_OPS = ["Const", "Assign", "Identity", "Placeholder",
- "RandomUniform", "Cast", "RestoreSlice"]
+ IGNORE_OPS = [
+ "Const", "Assign", "Identity", "Placeholder", "RandomUniform", "Cast",
+ "RestoreSlice"
+ ]
def __init__(self, ignore_ops=None):
"""Initializes GraphDump monitor.
@@ -881,8 +904,8 @@ class ExportMonitor(EveryN):
"""Monitor that exports Estimator every N steps."""
@deprecation.deprecated("2017-03-25",
- "ExportMonitor is deprecated. Please pass an "
- "ExportStrategy to Experiment instead.")
+ "ExportMonitor is deprecated. Please pass an "
+ "ExportStrategy to Experiment instead.")
def __init__(self,
every_n_steps,
export_dir,
@@ -1088,8 +1111,7 @@ class CheckpointSaver(BaseMonitor):
class StepCounter(EveryN):
"""Steps per second monitor."""
- def __init__(self, every_n_steps=100, output_dir=None,
- summary_writer=None):
+ def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None):
super(StepCounter, self).__init__(every_n_steps=every_n_steps)
self._summary_tag = "global_step/sec"
self._last_reported_step = None
@@ -1101,7 +1123,8 @@ class StepCounter(EveryN):
def set_estimator(self, estimator):
super(StepCounter, self).set_estimator(estimator)
if self._summary_writer is None:
- self._summary_writer = core_summary.FileWriterCache.get(estimator.model_dir)
+ self._summary_writer = core_summary.FileWriterCache.get(
+ estimator.model_dir)
def every_n_step_end(self, current_step, outputs):
current_time = time.time()
@@ -1109,8 +1132,9 @@ class StepCounter(EveryN):
added_steps = current_step - self._last_reported_step
elapsed_time = current_time - self._last_reported_time
steps_per_sec = added_steps / elapsed_time
- summary = Summary(value=[Summary.Value(tag=self._summary_tag,
- simple_value=steps_per_sec)])
+ summary = Summary(value=[
+ Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
+ ])
self._summary_writer.add_summary(summary, current_step)
self._last_reported_step = current_step
self._last_reported_time = current_time
diff --git a/tensorflow/contrib/learn/python/learn/utils/export_test.py b/tensorflow/contrib/learn/python/learn/utils/export_test.py
index 95070ada3b..9bfb1fc952 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export_test.py
@@ -50,6 +50,7 @@ def _training_input_fn():
class ExportTest(test.TestCase):
+
def _get_default_signature(self, export_meta_filename):
""" Gets the default signature from the export.meta file. """
with session.Session():
@@ -69,18 +70,18 @@ class ExportTest(test.TestCase):
# Only the written checkpoints are exported.
self.assertTrue(
saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')),
- 'Exported checkpoint expected but not found: %s' %
- os.path.join(export_dir, '00000001', 'export'))
+ 'Exported checkpoint expected but not found: %s' % os.path.join(
+ export_dir, '00000001', 'export'))
self.assertTrue(
saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')),
- 'Exported checkpoint expected but not found: %s' %
- os.path.join(export_dir, '00000010', 'export'))
+ 'Exported checkpoint expected but not found: %s' % os.path.join(
+ export_dir, '00000010', 'export'))
self.assertEquals(
six.b(os.path.join(export_dir, '00000010')),
export_monitor.last_export_dir)
# Validate the signature
signature = self._get_default_signature(
- os.path.join(export_dir, '00000010', 'export.meta'))
+ os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField(expected_signature))
def testExportMonitor_EstimatorProvidesSignature(self):
@@ -116,8 +117,7 @@ class ExportTest(test.TestCase):
def _serving_input_fn():
return {
_X_KEY:
- random_ops.random_uniform(
- shape=(1,), minval=0.0, maxval=1000.0)
+ random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
}, None
input_feature_key = 'my_example_key'
@@ -160,8 +160,7 @@ class ExportTest(test.TestCase):
input_feature_key:
None,
_X_KEY:
- random_ops.random_uniform(
- shape=(1,), minval=0.0, maxval=1000.0)
+ random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
}, None
monitor = learn.monitors.ExportMonitor(
@@ -182,8 +181,7 @@ class ExportTest(test.TestCase):
def _serving_input_fn():
return {
input_feature_key:
- array_ops.placeholder(
- dtype=dtypes.string, shape=(1,))
+ array_ops.placeholder(dtype=dtypes.string, shape=(1,))
}, None
monitor = learn.monitors.ExportMonitor(
@@ -204,11 +202,9 @@ class ExportTest(test.TestCase):
def _serving_input_fn():
return {
input_feature_key:
- array_ops.placeholder(
- dtype=dtypes.string, shape=(1,)),
+ array_ops.placeholder(dtype=dtypes.string, shape=(1,)),
_X_KEY:
- random_ops.random_uniform(
- shape=(1,), minval=0.0, maxval=1000.0)
+ random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
}, None
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
@@ -227,8 +223,8 @@ class ExportTest(test.TestCase):
def _regression_signature(examples, unused_features, predictions):
signatures = {}
- signatures['regression'] = (exporter.regression_signature(examples,
- predictions))
+ signatures['regression'] = (
+ exporter.regression_signature(examples, predictions))
return signatures['regression'], signatures
random.seed(42)
@@ -248,10 +244,10 @@ class ExportTest(test.TestCase):
with self.assertRaises(errors.NotFoundError):
saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
self.assertTrue(
- saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
+ saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
# Validate the signature
signature = self._get_default_signature(
- os.path.join(export_dir, '00000010', 'export.meta'))
+ os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField('regression_signature'))
diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
index 76cfd88e1d..e7d091e18a 100644
--- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
@@ -34,12 +34,13 @@ def _create_parser(base_dir):
# create a simple parser that pulls the export_version from the directory.
def parser(path):
# Modify the path object for RegEx match for Windows Paths
- if os.name == 'nt':
- match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$",
- compat.as_str_any(path.path).replace('\\','/'))
+ if os.name == "nt":
+ match = re.match(
+ "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$",
+ compat.as_str_any(path.path).replace("\\", "/"))
else:
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
- compat.as_str_any(path.path))
+ compat.as_str_any(path.path))
if not match:
return None
return path._replace(export_version=int(match.group(1)))
@@ -63,7 +64,9 @@ class GcTest(test_util.TensorFlowTestCase):
def testModExportVersion(self):
paths = [
- gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 4),
+ gc.Path("/foo", 5),
+ gc.Path("/foo", 6),
gc.Path("/foo", 9)
]
mod = gc.mod_export_version(2)
@@ -73,14 +76,21 @@ class GcTest(test_util.TensorFlowTestCase):
def testOneOfEveryNExportVersions(self):
paths = [
- gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3),
- gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7),
- gc.Path("/foo", 8), gc.Path("/foo", 33)
+ gc.Path("/foo", 0),
+ gc.Path("/foo", 1),
+ gc.Path("/foo", 3),
+ gc.Path("/foo", 5),
+ gc.Path("/foo", 6),
+ gc.Path("/foo", 7),
+ gc.Path("/foo", 8),
+ gc.Path("/foo", 33)
]
one_of = gc.one_of_every_n_export_versions(3)
self.assertEqual(
one_of(paths), [
- gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8),
+ gc.Path("/foo", 3),
+ gc.Path("/foo", 6),
+ gc.Path("/foo", 8),
gc.Path("/foo", 33)
])
@@ -98,13 +108,19 @@ class GcTest(test_util.TensorFlowTestCase):
f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
self.assertEqual(
f(paths), [
- gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6),
- gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)
+ gc.Path("/foo", 0),
+ gc.Path("/foo", 3),
+ gc.Path("/foo", 6),
+ gc.Path("/foo", 7),
+ gc.Path("/foo", 8),
+ gc.Path("/foo", 9)
])
def testNegation(self):
paths = [
- gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 4),
+ gc.Path("/foo", 5),
+ gc.Path("/foo", 6),
gc.Path("/foo", 9)
]
mod = gc.negation(gc.mod_export_version(2))
@@ -121,8 +137,7 @@ class GcTest(test_util.TensorFlowTestCase):
gfile.MakeDirs(os.path.join(base_dir, "ignore"))
self.assertEqual(
- gc.get_paths(base_dir, _create_parser(base_dir)),
- [
+ gc.get_paths(base_dir, _create_parser(base_dir)), [
gc.Path(os.path.join(base_dir, "0"), 0),
gc.Path(os.path.join(base_dir, "1"), 1),
gc.Path(os.path.join(base_dir, "2"), 2)
@@ -131,10 +146,10 @@ class GcTest(test_util.TensorFlowTestCase):
def testMixedStrTypes(self):
temp_dir = compat.as_bytes(test.get_temp_dir())
- for sub_dir in ['str', b'bytes', u'unicode']:
+ for sub_dir in ["str", b"bytes", u"unicode"]:
base_dir = os.path.join(
- (temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()),
- sub_dir)
+ (temp_dir
+ if isinstance(sub_dir, bytes) else temp_dir.decode()), sub_dir)
self.assertFalse(gfile.Exists(base_dir))
gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
gc.get_paths(base_dir, _create_parser(base_dir))
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 0b48ef4741..8338fde8ac 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -167,8 +167,6 @@ typedef struct {
} TfLiteLSTMParams;
typedef struct {
- int new_height;
- int new_width;
} TfLiteResizeBilinearParams;
typedef struct {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index f993fd6a00..fc58978964 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -1504,7 +1504,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
<< "*\n"
<< "* If you would like to carry on with the slow code, compile\n"
<< "* with this preprocessor token defined:\n"
- << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
+ << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n"
<< "*\n"
<< "* The right thing to do, if you care about performance, is to add\n"
<< "* a new DepthwiseConv kernel to tfmini to cover your case.\n"
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index 1cf30ecff9..bfdfba00f5 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -44,6 +44,22 @@ inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
return nullptr;
}
+// Determines whether tensor is constant.
+inline bool IsConstantTensor(TfLiteTensor* tensor) {
+ return tensor->allocation_type == kTfLiteMmapRo;
+}
+
+// Determines whether tensor is dynamic. Note that a tensor can be non-const and
+// not dynamic. This function specificially checks for a dynamic tensor.
+inline bool IsDynamicTensor(TfLiteTensor* tensor) {
+ return tensor->allocation_type == kTfLiteDynamic;
+}
+
+// Sets tensor to dynamic.
+inline void SetTensorToDynamic(TfLiteTensor* tensor) {
+ tensor->allocation_type = kTfLiteDynamic;
+}
+
// Calculates the multiplication factor for a quantized convolution (or
// quantized depthwise convolution) involving the given tensors. Returns an
// error if the scales of the tensors are not compatible.
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 569bf0fe8f..4003ed10df 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -51,17 +51,14 @@ struct PadContext {
// paddings data is present.
TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
PadContext* op_context) {
- // TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
- TF_LITE_ENSURE_EQ(context, op_context->dims, 4);
-
// Ensures the paddings array is dims x 2.
TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0),
op_context->dims);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2);
// Determines the size of the output tensor.
- const TfLiteIntArray* input_size = op_context->input->dims;
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context->dims);
+ TfLiteIntArray* input_size = op_context->input->dims;
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
const int32* paddings_data = GetTensorData<int32>(op_context->paddings);
for (int idx = 0; idx < op_context->dims; ++idx) {
@@ -85,11 +82,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
PadContext op_context(context, node);
TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
- // TODO(nupurgarg): Create wrapper functions for dynamic tensor logic.
+ // TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
+ TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
+
// Exit early if paddings is a non-const tensor. Set output tensor to
// dynamic so output size can be determined in Eval.
- if (op_context.paddings->allocation_type != kTfLiteMmapRo) {
- op_context.output->allocation_type = kTfLiteDynamic;
+ if (!IsConstantTensor(op_context.paddings)) {
+ SetTensorToDynamic(op_context.output);
return kTfLiteOk;
}
return ResizeOutputTensor(context, &op_context);
@@ -100,7 +99,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
PadContext op_context(context, node);
// Resize the output tensor if the output tensor is dynamic.
- if (op_context.output->allocation_type == kTfLiteDynamic) {
+ if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
TfLiteTensorRealloc(op_context.output->bytes, op_context.output);
}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 303a10af03..415d984ad8 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -30,17 +30,6 @@ limitations under the License.
namespace tflite {
-namespace {
-inline const tflite::Model* VerifyAndGetModel(const void* buf, size_t len) {
- ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
- if (VerifyModelBuffer(verifier)) {
- return ::tflite::GetModel(buf);
- } else {
- return nullptr;
- }
-}
-} // namespace
-
const char* kEmptyTensorName = "";
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
@@ -82,7 +71,7 @@ FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file,
}
if (!allocation_->valid() || !CheckModelIdentifier()) return;
- model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes());
+ model_ = ::tflite::GetModel(allocation_->base());
}
bool FlatBufferModel::CheckModelIdentifier() const {
@@ -103,7 +92,7 @@ FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes,
allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter);
if (!allocation_->valid()) return;
- model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes());
+ model_ = ::tflite::GetModel(allocation_->base());
}
FlatBufferModel::FlatBufferModel(const Model* model,
@@ -476,6 +465,11 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RESIZE_BILINEAR: {
+ auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_ResizeBilinearOptions()) {
+ }
+ builtin_data = reinterpret_cast<void*>(params);
break;
}
case BuiltinOperator_PAD: {
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index 5330c8f594..66f22fd66a 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
-#include <string>
#include "tensorflow/contrib/lite/model.h"
@@ -247,14 +246,6 @@ TEST(BasicFlatBufferModel, TestNullErrorReporter) {
ASSERT_NE(interpreter->Invoke(), kTfLiteOk);
}
-// Test what happens if we cannot bind any of the ops.
-TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) {
- std::string corrupted_data = "123";
- auto model = FlatBufferModel::BuildFromBuffer(corrupted_data.c_str(),
- corrupted_data.length());
- ASSERT_FALSE(model);
-}
-
// Test that loading model directly from a Model flatbuffer works.
TEST(BasicFlatBufferModel, TestBuildFromModel) {
TestErrorReporter reporter;
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index c04a73a2bf..c04a73a2bf 100644..100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 041e248790..6fc7e5e3fd 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -160,6 +160,7 @@ cc_library(
],
deps = [
# Placeholder for internal file dependency.
+ "@protobuf_archive//:protobuf_headers",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index 5961d30bf5..49cc1fc2aa 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -158,9 +158,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name,
LOG(FATAL)
<< "A RNN state array, " << array_name << ", still does not "
<< "have a known data type after all graph transformations have "
- << "run. That's mostly a toco bug --- sorry. For now, you can "
- << "work around this issue by adding manually_create:true in the "
- << "--rnn_state description of this RNN state.";
+ << "run.";
}
}
LOG(FATAL) << "An array, " << array_name << ", still does not "
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 8004a1a37a..b97a4720a7 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -208,6 +208,7 @@ struct ParsedModelFlags {
Arg<bool> dump_graphviz_video = Arg<bool>(false);
Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
+ Arg<string> arrays_extra_info_file;
};
// Flags that describe the operation you would like to do (what conversion
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 790b3443ce..4e2dec15a5 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -148,6 +148,12 @@ bool ParseModelFlagsFromCommandLineFlags(
"ranging from 32 to 127. This is disallowed by default so as to "
"catch common copy-and-paste issues where invisible unicode "
"characters are unwittingly added to these strings."),
+ Flag(
+ "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
+ parsed_flags.arrays_extra_info_file.default_value(),
+ "Path to an optional file containing a serialized ArraysExtraInfo "
+ "proto allowing to pass extra information about arrays not specified "
+ "in the input model file, such as extra MinMax information."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -327,9 +333,6 @@ void ReadModelFlagsFromCommandLineFlags(
CHECK(absl::SimpleAtoi(value, &size));
CHECK_GT(size, 0);
rnn_state_proto->set_size(size);
- } else if (key == "manually_create") {
- CHECK_EQ(absl::AsciiStrToLower(value), "true");
- rnn_state_proto->set_manually_create(true);
} else {
LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
}
@@ -368,6 +371,15 @@ void ReadModelFlagsFromCommandLineFlags(
parsed_model_flags.allow_nonascii_arrays.value());
model_flags->set_allow_nonexistent_arrays(
parsed_model_flags.allow_nonexistent_arrays.value());
+
+ if (parsed_model_flags.arrays_extra_info_file.specified()) {
+ string arrays_extra_info_file_contents;
+ port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
+ &arrays_extra_info_file_contents,
+ port::file::Defaults());
+ ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
+ model_flags->mutable_arrays_extra_info());
+ }
}
ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto
index 13fea29a07..e4b39b34e8 100644
--- a/tensorflow/contrib/lite/toco/model_flags.proto
+++ b/tensorflow/contrib/lite/toco/model_flags.proto
@@ -81,19 +81,26 @@ message RnnState {
optional string state_array = 1;
optional string back_edge_source_array = 2;
optional bool discardable = 5;
- // TODO(benoitjacob): drop the 'size' field. Should be redundant with
- // --input_shapes and shapes propagation.
+ // size allows to specify a 1-D shape for the RNN state array.
+ // Will be expanded with 1's to fit the model.
+ // TODO(benoitjacob): should allow a generic, explicit shape.
optional int32 size = 3;
- // TODO(benoitjacob): manually_create is a temporary hack:
- // due to discrepancies between the current toco dims tracking and
- // TensorFlow shapes, for some models we need to manually create RNN state
- // arrays with a specified shape.
- // Maybe we should actually implement back-edges as operators of their own,
- // which would remove the need for much special-casing, including here,
- // we could probably consistently let PropagateFixedSizes handle state
- // arrays.
- // TODO(benoitjacob): should really drop manually_create now.
- optional bool manually_create = 4;
+}
+
+// An ArraysExtraInfo message stores a collection of additional Information
+// about arrays in a model, complementing the information in the model itself.
+// It is intentionally a separate message so that it may be serialized and
+// passed separately from the model. See --arrays_extra_info_file.
+//
+// A typical use case is to manually specify MinMax for specific arrays in a
+// model that does not already contain such MinMax information.
+message ArraysExtraInfo {
+ message Entry {
+ optional string name = 1;
+ optional float min = 2;
+ optional float max = 3;
+ }
+ repeated Entry entries = 1;
}
// ModelFlags encodes properties of a model that, depending on the file
@@ -117,7 +124,7 @@ message RnnState {
// optional int32 input_dims = 11 [ default = 4];
// repeated int32 input_shape = 13;
//
-// Next ID to USE: 18.
+// Next ID to USE: 19.
message ModelFlags {
// Information about the input arrays, i.e. the arrays from which input
// activations will be read.
@@ -160,4 +167,8 @@ message ModelFlags {
// catch common copy-and-paste issues where invisible unicode
// characters are unwittingly added to these strings.
optional bool allow_nonascii_arrays = 17;
+
+ // If set, this ArraysExtraInfo allows to pass extra information about arrays
+ // not specified in the input model file, such as extra MinMax information.
+ optional ArraysExtraInfo arrays_extra_info = 18;
}
diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h
index 0572848cb5..4be3b5a0bf 100644
--- a/tensorflow/contrib/lite/toco/toco_port.h
+++ b/tensorflow/contrib/lite/toco/toco_port.h
@@ -19,6 +19,7 @@ limitations under the License.
// can build and use on google internal environments and on OSX.
#include <string>
+#include "google/protobuf/text_format.h"
#include "tensorflow/contrib/lite/toco/format_port.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/platform.h"
@@ -75,6 +76,26 @@ void CopyToBuffer(const ::Cord& src, char* dest);
#endif // PLATFORM_GOOGLE
void CopyToBuffer(const string& src, char* dest);
} // namespace port
+
+inline bool ParseFromStringOverload(const std::string& in,
+ TFLITE_PROTO_NS::Message* proto) {
+ return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
+}
+
+template <typename Proto>
+bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
+ Proto* proto) {
+ if (proto->ParseFromString(input_file_contents)) {
+ return true;
+ }
+
+ if (ParseFromStringOverload(input_file_contents, proto)) {
+ return true;
+ }
+
+ return false;
+}
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_PORT_H_
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 720c33777d..727df1cc76 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -193,6 +193,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
SetFinalDataTypeOnInputs(toco_flags, model);
+ UseArraysExtraInfo(model);
// Remove unused ops before performing any other optimizations. This is to
// stop optimizations from crossing the input/output boundaries. For example
@@ -232,6 +233,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
transformations.Add(new ResolveConstantConcatenation);
RunGraphTransformations(model, "general graph transformations",
transformations);
+
if (quantize_output) {
RunGraphTransformations(model, "pre-quantization graph transformations",
{new HardcodeMinMax, new DropFakeQuant});
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 6577bb7781..187c426a5b 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -961,7 +961,9 @@ void CheckModelCounts(const Model& model) {
void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
std::vector<int>* out_dims) {
CHECK(out_dims->empty());
- if (num_dims == 1) {
+ if (num_dims == 0) {
+ return;
+ } else if (num_dims == 1) {
CHECK_EQ(batch, 1);
*out_dims = {depth};
} else if (num_dims == 2) {
@@ -993,13 +995,13 @@ void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
if (array.has_shape()) {
num_dims = array.shape().dimensions_count();
}
- std::vector<int> dims;
- MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
CHECK(array.data_type == ArrayDataType::kFloat ||
array.data_type == ArrayDataType::kNone);
array.data_type = ArrayDataType::kFloat;
- if (!array.has_shape()) {
+ if (!array.has_shape() && num_dims >= 0) {
Shape* shape = array.mutable_shape();
+ std::vector<int> dims;
+ MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
*shape->mutable_dims() = dims;
}
}
@@ -1188,9 +1190,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
}
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
- if (!rnn_state.manually_create()) {
- continue;
- }
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
model);
}
@@ -1204,6 +1203,9 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
model->flags.set_allow_nonexistent_arrays(
model_flags.allow_nonexistent_arrays());
+
+ CHECK(!model->flags.has_arrays_extra_info());
+ *model->flags.mutable_arrays_extra_info() = model_flags.arrays_extra_info();
}
void CheckIsReadyForQuantization(const Model& model) {
@@ -1715,4 +1717,15 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) {
}
}
+void UseArraysExtraInfo(Model* model) {
+ for (const auto& entry : model->flags.arrays_extra_info().entries()) {
+ QCHECK(model->HasArray(entry.name()))
+ << "ArraysExtraInfo refers to non-existent array name: "
+ << entry.name();
+ auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax();
+ minmax.min = entry.min();
+ minmax.max = entry.max();
+ }
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 5986d63649..2ac51c7e5b 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -23,7 +23,6 @@ limitations under the License.
#include <string>
#include <vector>
-#include "google/protobuf/text_format.h"
#include "tensorflow/core/platform/logging.h"
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/src/google/protobuf/text_format.h"
@@ -84,25 +83,6 @@ void DumpGraphvizVideoFrame(const Model& model);
void LogDump(int log_level, const string& message, const Model& model);
void LogSummary(int log_level, const string& message, const Model& model);
-inline bool ParseFromStringOverload(const std::string& in,
- TFLITE_PROTO_NS::Message* proto) {
- return TFLITE_PROTO_NS::TextFormat::ParseFromString(in, proto);
-}
-
-template <typename Proto>
-bool ParseFromStringEitherTextOrBinary(const std::string& input_file_contents,
- Proto* proto) {
- if (proto->ParseFromString(input_file_contents)) {
- return true;
- }
-
- if (ParseFromStringOverload(input_file_contents, proto)) {
- return true;
- }
-
- return false;
-}
-
// TODO(b/36075966): Clean up when dims superseded by array shape.
void ExtendShape(Shape* shape, int new_shape_size);
@@ -298,6 +278,8 @@ void CheckFinalDataTypesSatisfied(const Model& model);
ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type);
+void UseArraysExtraInfo(Model* model);
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index 20df905270..1bffcfb987 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -93,3 +93,28 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)
+
+cc_library(
+ name = "verifier",
+ srcs = ["verifier.cc"],
+ hdrs = ["verifier.h"],
+ deps = [
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "verifier_test",
+ size = "small",
+ srcs = ["verifier_test.cc"],
+ deps = [
+ ":verifier",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc
new file mode 100644
index 0000000000..95a0895379
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/verifier.cc
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/verifier.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+namespace {
+
+const Model* VerifyFlatbufferAndGetModel(const void* buf, size_t len) {
+ ::flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
+ if (VerifyModelBuffer(verifier)) {
+ return ::tflite::GetModel(buf);
+ } else {
+ return nullptr;
+ }
+}
+
+} // namespace
+
+bool Verify(const void* buf, size_t len) {
+ const Model* model = VerifyFlatbufferAndGetModel(buf, len);
+ if (model == nullptr) {
+ return false;
+ }
+
+ return model->version() == TFLITE_SCHEMA_VERSION;
+}
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h
new file mode 100644
index 0000000000..03e1f22b7e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/verifier.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_
+
+#include <stdio.h>
+
+namespace tflite {
+
+// Verifies the integrity of a Tensorflow Lite flatbuffer model file.
+// Currently, it verifies:
+// * The file is following a legit flatbuffer schema.
+// * The model is in supported version.
+bool Verify(const void* buf, size_t len);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_VERIFIER_H_
diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc
new file mode 100644
index 0000000000..0481a55a78
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/verifier_test.cc
@@ -0,0 +1,136 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/tools/verifier.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/util.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+using flatbuffers::FlatBufferBuilder;
+using flatbuffers::Offset;
+using flatbuffers::Vector;
+
+// Class that abstracts the list of buffers at the end of the TF Lite structure
+class DeferredBufferWriter {
+ public:
+ DeferredBufferWriter() {
+ data_.push_back({}); // sentinel empty buffer.
+ }
+
+ Offset<Vector<Offset<Buffer>>> BuildBuffers(FlatBufferBuilder *builder) {
+ std::vector<Offset<Buffer>> buffer_vector;
+ for (const auto &vec : data_) {
+ auto data_buffer = builder->CreateVector(vec.data(), vec.size());
+ buffer_vector.push_back(tflite::CreateBuffer(*builder, data_buffer));
+ }
+ return builder->CreateVector(buffer_vector);
+ }
+
+ // Registers a buffer index and takes ownership of the data to write to it.
+ int Record(std::vector<uint8_t> data) {
+ int buffer_index = data_.size();
+ data_.emplace_back(std::move(data));
+ return buffer_index;
+ }
+
+ private:
+ std::vector<std::vector<unsigned char>> data_;
+};
+
+TEST(VerifyModel, TestEmptyModel) {
+ FlatBufferBuilder builder;
+ auto model = CreateModel(builder, /*version=*/TFLITE_SCHEMA_VERSION,
+ /*operator_codes=*/0, /*subgraphs=*/0,
+ /*description=*/0, /*buffers=*/0);
+ ::tflite::FinishModelBuffer(builder, model);
+
+ ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+}
+
+TEST(VerifyModel, TestSimpleModel) {
+ FlatBufferBuilder builder;
+ auto inputs = builder.CreateVector<int32_t>({0});
+ auto outputs = builder.CreateVector<int32_t>({1});
+ auto operator_codes = builder.CreateVector(std::vector<Offset<OperatorCode>>{
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "test")});
+ auto operators =
+ builder.CreateVector(std::vector<Offset<Operator>>{CreateOperator(
+ builder, /*opcode_index=*/0,
+ /*inputs=*/builder.CreateVector<int32_t>({0}),
+ /*outputs=*/builder.CreateVector<int32_t>({1}), BuiltinOptions_NONE,
+ /*builtin_options=*/0,
+ /*custom_options=*/0, ::tflite::CustomOptionsFormat_FLEXBUFFERS)});
+ std::vector<int> shape;
+ auto tensors = builder.CreateVector(std::vector<Offset<Tensor>>{
+ CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/0,
+ "input", /*quantization=*/0),
+ CreateTensorDirect(builder, &shape, TensorType_INT32, /*buffer=*/0,
+ "output", /*quantization=*/0)});
+ auto subgraph = std::vector<Offset<SubGraph>>(
+ {CreateSubGraph(builder, tensors, inputs, outputs, operators,
+ builder.CreateString("Main"))});
+
+ auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, operator_codes,
+ builder.CreateVector(subgraph),
+ builder.CreateString("SmartReply"), /*buffers=*/0);
+
+ ::tflite::FinishModelBuffer(builder, model);
+ ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+}
+
+TEST(VerifyModel, TestCorruptedData) {
+ string model = "123";
+ ASSERT_FALSE(Verify(model.data(), model.size()));
+}
+
+TEST(VerifyModel, TestUnsupportedVersion) {
+ FlatBufferBuilder builder;
+ auto model = CreateModel(builder, /*version=*/1, /*operator_codes=*/0,
+ /*subgraphs=*/0, /*description=*/0, /*buffers=*/0);
+ ::tflite::FinishModelBuffer(builder, model);
+ ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize()));
+}
+
+TEST(VerifyModel, TestRandomModificationIsNotAllowed) {
+ FlatBufferBuilder builder;
+ auto model = CreateModel(builder, /*version=*/TFLITE_SCHEMA_VERSION,
+ /*operator_codes=*/0,
+ /*subgraphs=*/0, /*description=*/0, /*buffers=*/0);
+ ::tflite::FinishModelBuffer(builder, model);
+
+ string model_content(reinterpret_cast<char *>(builder.GetBufferPointer()),
+ builder.GetSize());
+ for (int i = 0; i < model_content.size(); i++) {
+ model_content[i] = (model_content[i] + 137) % 255;
+ EXPECT_FALSE(Verify(model_content.data(), model_content.size()))
+ << "Fail at position: " << i;
+ }
+}
+
+// TODO(yichengfan): make up malicious files to test with.
+
+} // namespace tflite
+
+int main(int argc, char **argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index c3de1c4c62..55946c128b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -339,9 +339,9 @@ def streaming_mean_tensor(values,
name=name)
-@deprecated(
- None, 'Please switch to tf.metrics.accuracy. Note that the order of the '
- 'labels and predictions arguments has been switched.')
+@deprecated(None,
+ 'Please switch to tf.metrics.accuracy. Note that the order of the '
+ 'labels and predictions arguments has been switched.')
def streaming_accuracy(predictions,
labels,
weights=None,
@@ -936,8 +936,9 @@ def streaming_curve_points(labels=None,
if curve != 'ROC' and curve != 'PR':
raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
kepsilon = _EPSILON # to account for floating point imprecisions
- thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
- for i in range(num_thresholds - 2)]
+ thresholds = [
+ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
+ ]
thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
values, update_ops = _streaming_confusion_matrix_at_thresholds(
@@ -973,9 +974,8 @@ def streaming_curve_points(labels=None,
return points, update_op
-@deprecated(
- None, 'Please switch to tf.metrics.auc. Note that the order of the '
- 'labels and predictions arguments has been switched.')
+@deprecated(None, 'Please switch to tf.metrics.auc. Note that the order of the '
+ 'labels and predictions arguments has been switched.')
def streaming_auc(predictions,
labels,
weights=None,
@@ -1105,8 +1105,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
# For conformance, set precision to 1 when the number of positive
# classifications is 0.
y_axis_values = array_ops.where(
- math_ops.greater(splits, 0),
- math_ops.truediv(true_positives, splits),
+ math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits),
array_ops.ones_like(true_positives, dtype=dtypes.float64))
# Calculate trapezoid areas.
@@ -1119,9 +1118,8 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
# exception seems excessive) so we return 0, otherwise we finish computing.
return control_flow_ops.cond(
math_ops.logical_or(
- math_ops.equal(total_positive, 0),
- math_ops.equal(total_positive, size)
- ),
+ math_ops.equal(total_positive, 0), math_ops.equal(
+ total_positive, size)),
true_fn=lambda: array_ops.constant(0, dtypes.float64),
false_fn=continue_computing_dynamic_auc)
@@ -1185,10 +1183,10 @@ def streaming_dynamic_auc(labels,
array_ops.ones_like(labels, dtypes.int64),
message='labels must be 0 or 1, at least one is >1')
]):
- preds_accum, update_preds = streaming_concat(predictions,
- name='concat_preds')
- labels_accum, update_labels = streaming_concat(labels,
- name='concat_labels')
+ preds_accum, update_preds = streaming_concat(
+ predictions, name='concat_preds')
+ labels_accum, update_labels = streaming_concat(
+ labels, name='concat_labels')
update_op = control_flow_ops.group(update_labels, update_preds)
auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
if updates_collections:
@@ -1571,9 +1569,9 @@ def streaming_precision_at_thresholds(predictions,
name=name)
-@deprecated(
- None, 'Please switch to tf.metrics.recall_at_thresholds. Note that the '
- 'order of the labels and predictions arguments has been switched.')
+@deprecated(None,
+ 'Please switch to tf.metrics.recall_at_thresholds. Note that the '
+ 'order of the labels and predictions arguments has been switched.')
def streaming_recall_at_thresholds(predictions,
labels,
thresholds,
@@ -3299,8 +3297,13 @@ def count(values,
return count_, update_op
-def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
- metrics_collections=None, updates_collections=None, name=None):
+def cohen_kappa(labels,
+ predictions_idx,
+ num_classes,
+ weights=None,
+ metrics_collections=None,
+ updates_collections=None,
+ name=None):
"""Calculates Cohen's kappa.
[Cohen's kappa](https://en.wikipedia.org/wiki/Cohen's_kappa) is a statistic
@@ -3367,14 +3370,15 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
labels = array_ops.squeeze(labels, axis=[-1])
predictions_idx, labels, weights = (
metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
- predictions=predictions_idx, labels=labels, weights=weights))
+ predictions=predictions_idx,
+ labels=labels,
+ weights=weights))
predictions_idx.get_shape().assert_is_compatible_with(labels.get_shape())
- stat_dtype = (dtypes.int64
- if weights is None or weights.dtype.is_integer
- else dtypes.float32)
- po = metrics_impl.metric_variable(
- (num_classes,), stat_dtype, name='po')
+ stat_dtype = (
+ dtypes.int64
+ if weights is None or weights.dtype.is_integer else dtypes.float32)
+ po = metrics_impl.metric_variable((num_classes,), stat_dtype, name='po')
pe_row = metrics_impl.metric_variable(
(num_classes,), stat_dtype, name='pe_row')
pe_col = metrics_impl.metric_variable(
@@ -3382,9 +3386,12 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
# Table of the counts of agreement:
counts_in_table = confusion_matrix.confusion_matrix(
- labels, predictions_idx,
- num_classes=num_classes, weights=weights,
- dtype=stat_dtype, name="counts_in_table")
+ labels,
+ predictions_idx,
+ num_classes=num_classes,
+ weights=weights,
+ dtype=stat_dtype,
+ name='counts_in_table')
po_t = array_ops.diag_part(counts_in_table)
pe_row_t = math_ops.reduce_sum(counts_in_table, axis=0)
@@ -3404,12 +3411,14 @@ def cohen_kappa(labels, predictions_idx, num_classes, weights=None,
math_ops.to_double(total))
# kappa = (po - pe) / (N - pe)
k = metrics_impl._safe_scalar_div( # pylint: disable=protected-access
- po_sum - pe_sum, total - pe_sum, name=name)
+ po_sum - pe_sum,
+ total - pe_sum,
+ name=name)
return k
kappa = _calculate_k(po, pe_row, pe_col, name='value')
- update_op = _calculate_k(update_po, update_pe_row, update_pe_col,
- name='update_op')
+ update_op = _calculate_k(
+ update_po, update_pe_row, update_pe_col, name='update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, kappa)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 89aa29f711..e067f08bab 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -46,8 +46,7 @@ def _enqueue_vector(sess, queue, values, shape=None):
shape = (1, len(values))
dtype = queue.dtypes[0]
sess.run(
- queue.enqueue(constant_op.constant(
- values, dtype=dtype, shape=shape)))
+ queue.enqueue(constant_op.constant(values, dtype=dtype, shape=shape)))
def _binary_2d_label_to_sparse_value(labels):
@@ -79,8 +78,8 @@ def _binary_2d_label_to_sparse_value(labels):
batch += 1
shape = [len(labels), len(labels[0])]
return sparse_tensor.SparseTensorValue(
- np.array(indices, np.int64),
- np.array(values, np.int64), np.array(shape, np.int64))
+ np.array(indices, np.int64), np.array(values, np.int64),
+ np.array(shape, np.int64))
def _binary_2d_label_to_sparse(labels):
@@ -125,8 +124,8 @@ def _binary_3d_label_to_sparse_value(labels):
assert label == 0
shape = [len(labels), len(labels[0]), len(labels[0][0])]
return sparse_tensor.SparseTensorValue(
- np.array(indices, np.int64),
- np.array(values, np.int64), np.array(shape, np.int64))
+ np.array(indices, np.int64), np.array(values, np.int64),
+ np.array(shape, np.int64))
def _binary_3d_label_to_sparse(labels):
@@ -669,20 +668,18 @@ class StreamingTruePositivesTest(test.TestCase):
for expand_predictions in [True, False]:
for expand_labels in [True, False]:
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_predictions:
predictions = array_ops.expand_dims(predictions, 2)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_labels:
labels = array_ops.expand_dims(labels, 2)
- tp, tp_update_op = metrics.streaming_true_positives(predictions,
- labels)
+ tp, tp_update_op = metrics.streaming_true_positives(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -692,14 +689,12 @@ class StreamingTruePositivesTest(test.TestCase):
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels, weights=37.0)
@@ -717,28 +712,25 @@ class StreamingFalseNegativesTest(test.TestCase):
ops.reset_default_graph()
def testVars(self):
- metrics.streaming_false_negatives((0, 1, 0),
- (0, 1, 1))
+ metrics.streaming_false_negatives((0, 1, 0), (0, 1, 1))
_assert_metric_variables(self, ('false_negatives/count:0',))
def testUnweighted(self):
for expand_predictions in [True, False]:
for expand_labels in [True, False]:
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_predictions:
predictions = array_ops.expand_dims(predictions, 2)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_labels:
labels = array_ops.expand_dims(labels, 2)
- fn, fn_update_op = metrics.streaming_false_negatives(predictions,
- labels)
+ fn, fn_update_op = metrics.streaming_false_negatives(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -748,14 +740,12 @@ class StreamingFalseNegativesTest(test.TestCase):
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels, weights=((3.0,), (5.0,), (7.0,)))
@@ -773,28 +763,25 @@ class StreamingFalsePositivesTest(test.TestCase):
ops.reset_default_graph()
def testVars(self):
- metrics.streaming_false_positives((0, 1, 0),
- (0, 1, 1))
+ metrics.streaming_false_positives((0, 1, 0), (0, 1, 1))
_assert_metric_variables(self, ('false_positives/count:0',))
def testUnweighted(self):
for expand_predictions in [True, False]:
for expand_labels in [True, False]:
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_predictions:
predictions = array_ops.expand_dims(predictions, 2)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_labels:
labels = array_ops.expand_dims(labels, 2)
- fp, fp_update_op = metrics.streaming_false_positives(predictions,
- labels)
+ fp, fp_update_op = metrics.streaming_false_positives(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -804,20 +791,17 @@ class StreamingFalsePositivesTest(test.TestCase):
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
fp, fp_update_op = metrics.streaming_false_positives(
predictions,
labels,
- weights=((1.0, 2.0, 3.0, 5.0),
- (7.0, 11.0, 13.0, 17.0),
- (19.0, 23.0, 29.0, 31.0)))
+ weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
+ 29.0, 31.0)))
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -833,28 +817,25 @@ class StreamingTrueNegativesTest(test.TestCase):
ops.reset_default_graph()
def testVars(self):
- metrics.streaming_true_negatives((0, 1, 0),
- (0, 1, 1))
+ metrics.streaming_true_negatives((0, 1, 0), (0, 1, 1))
_assert_metric_variables(self, ('true_negatives/count:0',))
def testUnweighted(self):
for expand_predictions in [True, False]:
for expand_labels in [True, False]:
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_predictions:
predictions = array_ops.expand_dims(predictions, 2)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
if expand_labels:
labels = array_ops.expand_dims(labels, 2)
- tn, tn_update_op = metrics.streaming_true_negatives(predictions,
- labels)
+ tn, tn_update_op = metrics.streaming_true_negatives(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -864,14 +845,12 @@ class StreamingTrueNegativesTest(test.TestCase):
def testWeighted(self):
for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
- predictions = math_ops.cast(constant_op.constant(
- ((1, 0, 1, 0),
- (0, 1, 1, 1),
- (0, 0, 0, 0))), dtype=dtype)
- labels = math_ops.cast(constant_op.constant(
- ((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0))), dtype=dtype)
+ predictions = math_ops.cast(
+ constant_op.constant(((1, 0, 1, 0), (0, 1, 1, 1), (0, 0, 0, 0))),
+ dtype=dtype)
+ labels = math_ops.cast(
+ constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0))),
+ dtype=dtype)
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),))
@@ -894,12 +873,9 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
_assert_metric_variables(self, ('true_positives:0',))
def testUnweighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
@@ -910,12 +886,9 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
self.assertAllEqual((3, 1, 0), tp.eval())
def testWeighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85))
@@ -937,16 +910,14 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
(0.0, 1.0, 0.0), (0, 1, 1), thresholds=(
0.15,
0.5,
- 0.85,))
+ 0.85,
+ ))
_assert_metric_variables(self, ('false_negatives:0',))
def testUnweighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
@@ -957,12 +928,9 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
self.assertAllEqual((0, 2, 3), fn.eval())
def testWeighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
predictions,
labels,
@@ -988,12 +956,9 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
_assert_metric_variables(self, ('false_positives:0',))
def testUnweighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
@@ -1004,18 +969,14 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
self.assertAllEqual((7, 4, 2), fp.eval())
def testWeighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
predictions,
labels,
- weights=((1.0, 2.0, 3.0, 5.0),
- (7.0, 11.0, 13.0, 17.0),
- (19.0, 23.0, 29.0, 31.0)),
+ weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
+ 29.0, 31.0)),
thresholds=(0.15, 0.5, 0.85))
with self.test_session() as sess:
@@ -1037,12 +998,9 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
_assert_metric_variables(self, ('true_negatives:0',))
def testUnweighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
@@ -1053,12 +1011,9 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
self.assertAllEqual((2, 5, 7), tn.eval())
def testWeighted(self):
- predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
- (0.2, 0.9, 0.7, 0.6),
- (0.1, 0.2, 0.4, 0.3)))
- labels = constant_op.constant(((0, 1, 1, 0),
- (1, 0, 0, 0),
- (0, 0, 0, 0)))
+ predictions = constant_op.constant(
+ ((0.9, 0.2, 0.8, 0.1), (0.2, 0.9, 0.7, 0.6), (0.1, 0.2, 0.4, 0.3)))
+ labels = constant_op.constant(((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0)))
tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
predictions,
labels,
@@ -1393,8 +1348,7 @@ class StreamingFPRTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1413,8 +1367,7 @@ class StreamingFPRTest(test.TestCase):
predictions = constant_op.constant(np_inputs)
labels = constant_op.constant(np_inputs)
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1424,8 +1377,7 @@ class StreamingFPRTest(test.TestCase):
def testSomeCorrect(self):
predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1467,8 +1419,7 @@ class StreamingFPRTest(test.TestCase):
predictions = constant_op.constant(np_inputs)
labels = constant_op.constant(1 - np_inputs)
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1478,8 +1429,7 @@ class StreamingFPRTest(test.TestCase):
def testZeroFalsePositivesAndTrueNegativesGivesZeroFPR(self):
predictions = array_ops.ones((1, 4))
labels = array_ops.ones((1, 4))
- fpr, update_op = metrics.streaming_false_positive_rate(
- predictions, labels)
+ fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1521,8 +1471,7 @@ class StreamingFNRTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1541,8 +1490,7 @@ class StreamingFNRTest(test.TestCase):
predictions = constant_op.constant(np_inputs)
labels = constant_op.constant(np_inputs)
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1552,8 +1500,7 @@ class StreamingFNRTest(test.TestCase):
def testSomeCorrect(self):
predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1595,8 +1542,7 @@ class StreamingFNRTest(test.TestCase):
predictions = constant_op.constant(np_inputs)
labels = constant_op.constant(1 - np_inputs)
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1606,8 +1552,7 @@ class StreamingFNRTest(test.TestCase):
def testZeroFalseNegativesAndTruePositivesGivesZeroFNR(self):
predictions = array_ops.zeros((1, 4))
labels = array_ops.zeros((1, 4))
- fnr, update_op = metrics.streaming_false_negative_rate(
- predictions, labels)
+ fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -1944,16 +1889,17 @@ class StreamingAUCTest(test.TestCase):
enqueue_ops[i].append(x_queue.enqueue(x_batches[i, :]))
return x_queue.dequeue()
- for weights in (None, np.ones(num_samples), np.random.exponential(
- scale=1.0, size=num_samples)):
+ for weights in (None, np.ones(num_samples),
+ np.random.exponential(scale=1.0, size=num_samples)):
expected_auc = _np_auc(predictions, labels, weights)
with self.test_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
tf_labels = _enqueue_as_batches(labels, enqueue_ops)
- tf_weights = (_enqueue_as_batches(weights, enqueue_ops) if
- weights is not None else None)
+ tf_weights = (
+ _enqueue_as_batches(weights, enqueue_ops)
+ if weights is not None else None)
for i in range(num_batches):
sess.run(enqueue_ops[i])
@@ -1985,17 +1931,18 @@ class StreamingDynamicAUCTest(test.TestCase):
def testUnknownCurve(self):
with self.assertRaisesRegexp(
ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'):
- metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)),
- predictions=array_ops.ones((10, 1)),
- curve='TEST_CURVE')
+ metrics.streaming_dynamic_auc(
+ labels=array_ops.ones((10, 1)),
+ predictions=array_ops.ones((10, 1)),
+ curve='TEST_CURVE')
def testVars(self):
metrics.streaming_dynamic_auc(
labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1)))
- _assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0',
- 'dynamic_auc/concat_labels/size:0',
- 'dynamic_auc/concat_preds/array:0',
- 'dynamic_auc/concat_preds/size:0'])
+ _assert_metric_variables(self, [
+ 'dynamic_auc/concat_labels/array:0', 'dynamic_auc/concat_labels/size:0',
+ 'dynamic_auc/concat_preds/array:0', 'dynamic_auc/concat_preds/size:0'
+ ])
def testMetricsCollection(self):
my_collection_name = '__metrics__'
@@ -2049,8 +1996,8 @@ class StreamingDynamicAUCTest(test.TestCase):
def testNonZeroOnePredictions(self):
with self.test_session() as sess:
- predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5],
- dtype=dtypes_lib.float32)
+ predictions = constant_op.constant(
+ [2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32)
labels = constant_op.constant([1, 0, 1, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
sess.run(variables.local_variables_initializer())
@@ -2122,9 +2069,10 @@ class StreamingDynamicAUCTest(test.TestCase):
num_batches = 100
labels = np.array([])
predictions = np.array([])
- tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- dtype=dtypes_lib.int32)
+ tf_labels = variables.Variable(
+ array_ops.ones(batch_size, dtypes_lib.int32),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.int32)
tf_predictions = variables.Variable(
array_ops.ones(batch_size),
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -2195,8 +2143,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
gotten_result: A PrecisionRecallData object.
"""
gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()}
- self.assertItemsEqual(
- list(expected_dict.keys()), list(gotten_dict.keys()))
+ self.assertItemsEqual(list(expected_dict.keys()), list(gotten_dict.keys()))
for key, expected_values in expected_dict.items():
self.assertAllClose(expected_values, gotten_dict[key])
@@ -2261,60 +2208,65 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
sess.run(update_op)
# Then verify idempotency.
- initial_result = {k: value.eval().tolist() for k, value in
- result._asdict().items()}
+ initial_result = {
+ k: value.eval().tolist()
+ for k, value in result._asdict().items()
+ }
for _ in range(3):
self._testResultsEqual(initial_result, result)
def testAllTruePositives(self):
- self._testCase([[1]], [[True]], {
- 'tp': [1, 1, 1],
- 'fp': [0, 0, 0],
- 'tn': [0, 0, 0],
- 'fn': [0, 0, 0],
- 'precision': [1.0, 1.0, 1.0],
- 'recall': [1.0, 1.0, 1.0],
- 'thresholds': [0.0, 0.5, 1.0],
- })
+ self._testCase(
+ [[1]], [[True]], {
+ 'tp': [1, 1, 1],
+ 'fp': [0, 0, 0],
+ 'tn': [0, 0, 0],
+ 'fn': [0, 0, 0],
+ 'precision': [1.0, 1.0, 1.0],
+ 'recall': [1.0, 1.0, 1.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ })
def testAllTrueNegatives(self):
- self._testCase([[0]], [[False]], {
- 'tp': [0, 0, 0],
- 'fp': [1, 0, 0],
- 'tn': [0, 1, 1],
- 'fn': [0, 0, 0],
- 'precision': [0.0, 0.0, 0.0],
- 'recall': [0.0, 0.0, 0.0],
- 'thresholds': [0.0, 0.5, 1.0],
- })
+ self._testCase(
+ [[0]], [[False]], {
+ 'tp': [0, 0, 0],
+ 'fp': [1, 0, 0],
+ 'tn': [0, 1, 1],
+ 'fn': [0, 0, 0],
+ 'precision': [0.0, 0.0, 0.0],
+ 'recall': [0.0, 0.0, 0.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ })
def testAllFalsePositives(self):
- self._testCase([[1]], [[False]], {
- 'tp': [0, 0, 0],
- 'fp': [1, 1, 1],
- 'tn': [0, 0, 0],
- 'fn': [0, 0, 0],
- 'precision': [0.0, 0.0, 0.0],
- 'recall': [0.0, 0.0, 0.0],
- 'thresholds': [0.0, 0.5, 1.0],
- })
+ self._testCase(
+ [[1]], [[False]], {
+ 'tp': [0, 0, 0],
+ 'fp': [1, 1, 1],
+ 'tn': [0, 0, 0],
+ 'fn': [0, 0, 0],
+ 'precision': [0.0, 0.0, 0.0],
+ 'recall': [0.0, 0.0, 0.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ })
def testAllFalseNegatives(self):
- self._testCase([[0]], [[True]], {
- 'tp': [1, 0, 0],
- 'fp': [0, 0, 0],
- 'tn': [0, 0, 0],
- 'fn': [0, 1, 1],
- 'precision': [1.0, 0.0, 0.0],
- 'recall': [1.0, 0.0, 0.0],
- 'thresholds': [0.0, 0.5, 1.0],
- })
+ self._testCase(
+ [[0]], [[True]], {
+ 'tp': [1, 0, 0],
+ 'fp': [0, 0, 0],
+ 'tn': [0, 0, 0],
+ 'fn': [0, 1, 1],
+ 'precision': [1.0, 0.0, 0.0],
+ 'recall': [1.0, 0.0, 0.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ })
def testManyValues(self):
self._testCase(
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
- [[True, False, False, True, True, True]],
- {
+ [[True, False, False, True, True, True]], {
'tp': [4, 3, 0],
'fp': [2, 0, 0],
'tn': [0, 2, 2],
@@ -2327,8 +2279,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
def testManyValuesWithWeights(self):
self._testCase(
[[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
- [[True, False, False, True, True, True]],
- {
+ [[True, False, False, True, True, True]], {
'tp': [1.5, 1.5, 0.0],
'fp': [2.5, 0.0, 0.0],
'tn': [0.0, 2.5, 2.5],
@@ -2644,11 +2595,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -2672,11 +2622,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2690,11 +2639,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
thresholds = [0.5]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2709,11 +2657,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2779,11 +2726,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
thresholds = [-1.0, 2.0] # lower/higher than any values
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
prec_low = prec[0]
prec_high = prec[1]
@@ -2803,11 +2749,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
- prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
- labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ predictions, labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ predictions, labels, thresholds)
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@@ -2872,12 +2817,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
tf_predictions = predictions_queue.dequeue()
tf_labels = labels_queue.dequeue()
- prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions,
- tf_labels,
- thresholds)
- rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions,
- tf_labels,
- thresholds)
+ prec, prec_op = metrics.streaming_precision_at_thresholds(
+ tf_predictions, tf_labels, thresholds)
+ rec, rec_op = metrics.streaming_recall_at_thresholds(
+ tf_predictions, tf_labels, thresholds)
sess.run(variables.local_variables_initializer())
for _ in range(int(num_samples / batch_size)):
@@ -2921,8 +2864,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
labels=array_ops.ones((10, 1)),
thresholds=[0, 0.5, 1.0],
updates_collections=[my_collection_name])
- self.assertListEqual(
- ops.get_collection(my_collection_name), [update_op])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_uniform(
@@ -3271,8 +3213,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
labels=array_ops.ones((10, 1)),
thresholds=[0, 0.5, 1.0],
updates_collections=[my_collection_name])
- self.assertListEqual(
- ops.get_collection(my_collection_name), [update_op])
+ self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_uniform(
@@ -3492,8 +3433,7 @@ class StreamingRecallAtKTest(test.TestCase):
def testVars(self):
metrics.streaming_recall_at_k(
predictions=array_ops.ones((self._batch_size, self._num_classes)),
- labels=array_ops.ones(
- (self._batch_size,), dtype=dtypes_lib.int32),
+ labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32),
k=1)
_assert_metric_variables(self,
('recall_at_1/count:0', 'recall_at_1/total:0'))
@@ -3502,8 +3442,7 @@ class StreamingRecallAtKTest(test.TestCase):
my_collection_name = '__metrics__'
mean, _ = metrics.streaming_recall_at_k(
predictions=array_ops.ones((self._batch_size, self._num_classes)),
- labels=array_ops.ones(
- (self._batch_size,), dtype=dtypes_lib.int32),
+ labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32),
k=1,
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [mean])
@@ -3512,8 +3451,7 @@ class StreamingRecallAtKTest(test.TestCase):
my_collection_name = '__updates__'
_, update_op = metrics.streaming_recall_at_k(
predictions=array_ops.ones((self._batch_size, self._num_classes)),
- labels=array_ops.ones(
- (self._batch_size,), dtype=dtypes_lib.int32),
+ labels=array_ops.ones((self._batch_size,), dtype=dtypes_lib.int32),
k=1,
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
@@ -3715,9 +3653,17 @@ class StreamingSparsePrecisionTest(test.TestCase):
# top_k_predictions has rank < 2.
top_k_predictions = [9, 4, 6, 2, 0]
sp_labels = sparse_tensor.SparseTensorValue(
- indices=np.array([[0,], [1,], [2,]], np.int64),
+ indices=np.array([[
+ 0,
+ ], [
+ 1,
+ ], [
+ 2,
+ ]], np.int64),
values=np.array([2, 7, 8], np.int64),
- dense_shape=np.array([10,], np.int64))
+ dense_shape=np.array([
+ 10,
+ ], np.int64))
with self.assertRaises(ValueError):
precision, _ = metrics.streaming_sparse_precision_at_top_k(
@@ -3774,8 +3720,9 @@ class StreamingSparsePrecisionTest(test.TestCase):
# average of the 2 examples.
labels = np.array([labels_ex1, labels_ex2], dtype=np.int64)
predictions = (predictions_ex1, predictions_ex2)
- streaming_precision = [(ex1 + ex2) / 2
- for ex1, ex2 in zip(precision_ex1, precision_ex2)]
+ streaming_precision = [
+ (ex1 + ex2) / 2 for ex1, ex2 in zip(precision_ex1, precision_ex2)
+ ]
streaming_average_precision = [
(ex1 + ex2) / 2
for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
@@ -3835,29 +3782,29 @@ class StreamingSparsePrecisionTest(test.TestCase):
(predictions_top_k_ex1[:k],), labels, expected=avg_precision_ex1[i])
def test_average_precision_at_top_k_static_shape_check(self):
- predictions_top_k = array_ops.placeholder(shape=(2, None),
- dtype=dtypes_lib.int64)
+ predictions_top_k = array_ops.placeholder(
+ shape=(2, None), dtype=dtypes_lib.int64)
labels = np.array(((1,), (2,)), dtype=np.int64)
# Fails due to non-static predictions_idx shape.
with self.assertRaises(ValueError):
- metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k,
- labels)
+ metric_ops.streaming_sparse_average_precision_at_top_k(
+ predictions_top_k, labels)
predictions_top_k = (2, 1)
# Fails since rank of predictions_idx is less than one.
with self.assertRaises(ValueError):
- metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k,
- labels)
+ metric_ops.streaming_sparse_average_precision_at_top_k(
+ predictions_top_k, labels)
predictions_top_k = ((2,), (1,))
# Valid static shape.
- metric_ops.streaming_sparse_average_precision_at_top_k(predictions_top_k,
- labels)
+ metric_ops.streaming_sparse_average_precision_at_top_k(
+ predictions_top_k, labels)
def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -3871,8 +3818,8 @@ class StreamingSparsePrecisionTest(test.TestCase):
def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -3971,8 +3918,8 @@ class StreamingSparsePrecisionTest(test.TestCase):
[5, 7, 2, 9, 6],
]
sp_labels = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
- [1, 3]],
+ indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1,
+ 3]],
# values -1 and 10 are outside the [0, n_classes) range and are ignored.
values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
dense_shape=[2, 4])
@@ -4324,8 +4271,8 @@ class StreamingSparseRecallTest(test.TestCase):
def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
# Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of
@@ -4340,8 +4287,8 @@ class StreamingSparseRecallTest(test.TestCase):
def test_one_label_at_k1_no_predictions(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -4354,8 +4301,8 @@ class StreamingSparseRecallTest(test.TestCase):
def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -4374,8 +4321,8 @@ class StreamingSparseRecallTest(test.TestCase):
def test_one_label_at_k1_weighted(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value(
- [[0, 0, 0, 1], [0, 0, 1, 0]])
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
for labels in (sparse_labels, dense_labels):
@@ -4647,8 +4594,8 @@ class StreamingSparseRecallTest(test.TestCase):
[5, 7, 2, 9, 6],
]
sp_labels = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
- [1, 3]],
+ indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], [1,
+ 3]],
# values -1 and 10 are outside the [0, n_classes) range.
values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
dense_shape=[2, 4])
@@ -4661,10 +4608,7 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2,
class_id=2)
self._test_sparse_recall_at_top_k(
- sp_labels,
- top_k_predictions,
- expected=2.0 / 2,
- class_id=2)
+ sp_labels, top_k_predictions, expected=2.0 / 2, class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -4674,10 +4618,7 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=5)
self._test_sparse_recall_at_top_k(
- sp_labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=5)
+ sp_labels, top_k_predictions, expected=1.0 / 1, class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@@ -4687,10 +4628,7 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.0 / 1,
class_id=7)
self._test_sparse_recall_at_top_k(
- sp_labels,
- top_k_predictions,
- expected=0.0 / 1,
- class_id=7)
+ sp_labels, top_k_predictions, expected=0.0 / 1, class_id=7)
# All classes: 8 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
@@ -4740,10 +4678,8 @@ class StreamingSparseRecallTest(test.TestCase):
[9, 4, 6, 2, 0],
]]
sparse_labels = _binary_3d_label_to_sparse_value(
- [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
- [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
- [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
+ [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
+ [[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
dense_labels = np.array(
[[[2, 7, 8], [1, 2, 5]], [
[1, 2, 5],
@@ -4771,10 +4707,8 @@ class StreamingSparseRecallTest(test.TestCase):
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
- [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
- [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
- [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
+ [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
+ [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
# Class 2: 4 labels, all correct.
self._test_streaming_sparse_recall_at_k(
@@ -4813,10 +4747,8 @@ class StreamingSparseRecallTest(test.TestCase):
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
- [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
- [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
- [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
+ [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
+ [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
for class_id in xrange(10):
self._test_streaming_sparse_recall_at_k(
@@ -4867,10 +4799,8 @@ class StreamingSparseRecallTest(test.TestCase):
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
- [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
- [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
- [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
+ [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
+ [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
@@ -4963,10 +4893,8 @@ class StreamingSparseRecallTest(test.TestCase):
weights=[[0, 1], [0, 1]])
def test_sparse_tensor_value(self):
- predictions = [[0.1, 0.3, 0.2, 0.4],
- [0.1, 0.2, 0.3, 0.4]]
- labels = [[0, 0, 1, 0],
- [0, 0, 0, 1]]
+ predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
+ labels = [[0, 0, 1, 0], [0, 0, 0, 1]]
expected_recall = 0.5
with self.test_session():
_, recall = metrics.streaming_sparse_recall_at_k(
@@ -5009,8 +4937,8 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_normal((10, 3), seed=1)
labels = random_ops.random_normal((10, 3), seed=2)
- error, update_op = metrics.streaming_mean_absolute_error(predictions,
- labels)
+ error, update_op = metrics.streaming_mean_absolute_error(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5031,8 +4959,8 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
[1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
- error, update_op = metrics.streaming_mean_absolute_error(predictions,
- labels, weights)
+ error, update_op = metrics.streaming_mean_absolute_error(
+ predictions, labels, weights)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5075,8 +5003,8 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
predictions = random_ops.random_normal((10, 3), seed=1)
labels = random_ops.random_normal((10, 3), seed=2)
normalizer = random_ops.random_normal((10, 3), seed=3)
- error, update_op = metrics.streaming_mean_relative_error(predictions,
- labels, normalizer)
+ error, update_op = metrics.streaming_mean_relative_error(
+ predictions, labels, normalizer)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5200,8 +5128,8 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
[1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
- error, update_op = metrics.streaming_mean_squared_error(predictions, labels,
- weights)
+ error, update_op = metrics.streaming_mean_squared_error(
+ predictions, labels, weights)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5224,8 +5152,8 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
_enqueue_vector(sess, labels_queue, [2, 4, 6])
labels = labels_queue.dequeue()
- error, update_op = metrics.streaming_mean_squared_error(predictions,
- labels)
+ error, update_op = metrics.streaming_mean_squared_error(
+ predictions, labels)
sess.run(variables.local_variables_initializer())
sess.run(update_op)
@@ -5292,10 +5220,10 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
_enqueue_vector(sess, labels_queue, [2, 4, 6])
labels = labels_queue.dequeue()
- mae, ma_update_op = metrics.streaming_mean_absolute_error(predictions,
- labels)
- mse, ms_update_op = metrics.streaming_mean_squared_error(predictions,
- labels)
+ mae, ma_update_op = metrics.streaming_mean_absolute_error(
+ predictions, labels)
+ mse, ms_update_op = metrics.streaming_mean_squared_error(
+ predictions, labels)
sess.run(variables.local_variables_initializer())
sess.run([ma_update_op, ms_update_op])
@@ -5336,8 +5264,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_normal((10, 3), seed=1)
labels = random_ops.random_normal((10, 3), seed=2)
- error, update_op = metrics.streaming_root_mean_squared_error(predictions,
- labels)
+ error, update_op = metrics.streaming_root_mean_squared_error(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5357,8 +5285,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
0.0, shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
- rmse, update_op = metrics.streaming_root_mean_squared_error(predictions,
- labels)
+ rmse, update_op = metrics.streaming_root_mean_squared_error(
+ predictions, labels)
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
@@ -5372,8 +5300,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
labels = constant_op.constant(
[1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)
- rmse, update_op = metrics.streaming_root_mean_squared_error(predictions,
- labels)
+ rmse, update_op = metrics.streaming_root_mean_squared_error(
+ predictions, labels)
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(math.sqrt(6), update_op.eval(), 5)
@@ -5387,9 +5315,8 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
[1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
- rmse, update_op = metrics.streaming_root_mean_squared_error(predictions,
- labels,
- weights)
+ rmse, update_op = metrics.streaming_root_mean_squared_error(
+ predictions, labels, weights)
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(math.sqrt(13), sess.run(update_op))
@@ -5404,8 +5331,8 @@ class StreamingCovarianceTest(test.TestCase):
def testVars(self):
metrics.streaming_covariance(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]))
_assert_metric_variables(self, (
'covariance/comoment:0',
@@ -5417,8 +5344,8 @@ class StreamingCovarianceTest(test.TestCase):
def testMetricsCollection(self):
my_collection_name = '__metrics__'
cov, _ = metrics.streaming_covariance(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]),
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [cov])
@@ -5426,8 +5353,8 @@ class StreamingCovarianceTest(test.TestCase):
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metrics.streaming_covariance(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]),
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
@@ -5487,9 +5414,8 @@ class StreamingCovarianceTest(test.TestCase):
cov, update_op = metrics.streaming_covariance(
predictions, labels, weights=weights)
- expected_cov = np.cov([2, 4, 6, 8],
- [1, 3, 2, 7],
- fweights=[0, 1, 3, 1])[0, 1]
+ expected_cov = np.cov(
+ [2, 4, 6, 8], [1, 3, 2, 7], fweights=[0, 1, 3, 1])[0, 1]
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(expected_cov, sess.run(update_op))
self.assertAlmostEqual(expected_cov, cov.eval())
@@ -5514,17 +5440,18 @@ class StreamingCovarianceTest(test.TestCase):
predictions_t: predictions[stride * i:stride * (i + 1)],
labels_t: labels[stride * i:stride * (i + 1)]
}
- self.assertEqual(np.isnan(prev_expected_cov),
- np.isnan(sess.run(cov, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(prev_expected_cov),
+ np.isnan(sess.run(cov, feed_dict=feed_dict)))
if not np.isnan(prev_expected_cov):
- self.assertAlmostEqual(
- prev_expected_cov, sess.run(cov, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(prev_expected_cov,
+ sess.run(cov, feed_dict=feed_dict), 5)
expected_cov = np.cov(predictions[:stride * (i + 1)],
labels[:stride * (i + 1)])[0, 1]
- self.assertAlmostEqual(
- expected_cov, sess.run(update_op, feed_dict=feed_dict), 5)
- self.assertAlmostEqual(
- expected_cov, sess.run(cov, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_cov,
+ sess.run(update_op, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_cov, sess.run(cov, feed_dict=feed_dict),
+ 5)
prev_expected_cov = expected_cov
def testMultiUpdateWithErrorAndWeights(self):
@@ -5552,18 +5479,20 @@ class StreamingCovarianceTest(test.TestCase):
labels_t: labels[stride * i:stride * (i + 1)],
weights_t: weights[stride * i:stride * (i + 1)]
}
- self.assertEqual(np.isnan(prev_expected_cov),
- np.isnan(sess.run(cov, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(prev_expected_cov),
+ np.isnan(sess.run(cov, feed_dict=feed_dict)))
if not np.isnan(prev_expected_cov):
- self.assertAlmostEqual(
- prev_expected_cov, sess.run(cov, feed_dict=feed_dict), 5)
- expected_cov = np.cov(predictions[:stride * (i + 1)],
- labels[:stride * (i + 1)],
- fweights=weights[:stride * (i + 1)])[0, 1]
- self.assertAlmostEqual(
- expected_cov, sess.run(update_op, feed_dict=feed_dict), 5)
- self.assertAlmostEqual(
- expected_cov, sess.run(cov, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(prev_expected_cov,
+ sess.run(cov, feed_dict=feed_dict), 5)
+ expected_cov = np.cov(
+ predictions[:stride * (i + 1)],
+ labels[:stride * (i + 1)],
+ fweights=weights[:stride * (i + 1)])[0, 1]
+ self.assertAlmostEqual(expected_cov,
+ sess.run(update_op, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_cov, sess.run(cov, feed_dict=feed_dict),
+ 5)
prev_expected_cov = expected_cov
@@ -5574,8 +5503,8 @@ class StreamingPearsonRTest(test.TestCase):
def testVars(self):
metrics.streaming_pearson_correlation(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]))
_assert_metric_variables(self, (
'pearson_r/covariance/comoment:0',
@@ -5595,8 +5524,8 @@ class StreamingPearsonRTest(test.TestCase):
def testMetricsCollection(self):
my_collection_name = '__metrics__'
pearson_r, _ = metrics.streaming_pearson_correlation(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]),
metrics_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [pearson_r])
@@ -5604,8 +5533,8 @@ class StreamingPearsonRTest(test.TestCase):
def testUpdatesCollection(self):
my_collection_name = '__updates__'
_, update_op = metrics.streaming_pearson_correlation(
- predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones(
- [10, 10]),
+ predictions=math_ops.to_float(math_ops.range(10)) +
+ array_ops.ones([10, 10]),
labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10]),
updates_collections=[my_collection_name])
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
@@ -5613,8 +5542,8 @@ class StreamingPearsonRTest(test.TestCase):
def testValueTensorIsIdempotent(self):
labels = random_ops.random_normal((10, 3), seed=2)
predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5
- pearson_r, update_op = metrics.streaming_pearson_correlation(predictions,
- labels)
+ pearson_r, update_op = metrics.streaming_pearson_correlation(
+ predictions, labels)
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
@@ -5633,8 +5562,8 @@ class StreamingPearsonRTest(test.TestCase):
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
- pearson_r, update_op = metrics.streaming_pearson_correlation(predictions,
- labels)
+ pearson_r, update_op = metrics.streaming_pearson_correlation(
+ predictions, labels)
expected_r = np.corrcoef(np.arange(10), np.arange(10))[0, 1]
sess.run(variables.local_variables_initializer())
@@ -5648,8 +5577,8 @@ class StreamingPearsonRTest(test.TestCase):
labels = constant_op.constant(
[1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)
- pearson_r, update_op = metrics.streaming_pearson_correlation(predictions,
- labels)
+ pearson_r, update_op = metrics.streaming_pearson_correlation(
+ predictions, labels)
expected_r = np.corrcoef([2, 4, 6], [1, 3, 2])[0, 1]
sess.run(variables.local_variables_initializer())
@@ -5698,17 +5627,18 @@ class StreamingPearsonRTest(test.TestCase):
predictions_t: predictions[stride * i:stride * (i + 1)],
labels_t: labels[stride * i:stride * (i + 1)]
}
- self.assertEqual(np.isnan(prev_expected_r),
- np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(prev_expected_r),
+ np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
if not np.isnan(prev_expected_r):
- self.assertAlmostEqual(
- prev_expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(prev_expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
expected_r = np.corrcoef(predictions[:stride * (i + 1)],
labels[:stride * (i + 1)])[0, 1]
- self.assertAlmostEqual(
- expected_r, sess.run(update_op, feed_dict=feed_dict), 5)
- self.assertAlmostEqual(
- expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(update_op, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndWeights(self):
@@ -5736,19 +5666,21 @@ class StreamingPearsonRTest(test.TestCase):
labels_t: labels[stride * i:stride * (i + 1)],
weights_t: weights[stride * i:stride * (i + 1)]
}
- self.assertEqual(np.isnan(prev_expected_r),
- np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(prev_expected_r),
+ np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
if not np.isnan(prev_expected_r):
- self.assertAlmostEqual(
- prev_expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
- cmat = np.cov(predictions[:stride * (i + 1)],
- labels[:stride * (i + 1)],
- fweights=weights[:stride * (i + 1)])
+ self.assertAlmostEqual(prev_expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
+ cmat = np.cov(
+ predictions[:stride * (i + 1)],
+ labels[:stride * (i + 1)],
+ fweights=weights[:stride * (i + 1)])
expected_r = cmat[0, 1] / np.sqrt(cmat[0, 0] * cmat[1, 1])
- self.assertAlmostEqual(
- expected_r, sess.run(update_op, feed_dict=feed_dict), 5)
- self.assertAlmostEqual(
- expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(update_op, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndSingletonBatches(self):
@@ -5758,7 +5690,7 @@ class StreamingPearsonRTest(test.TestCase):
predictions = np.random.randn(n)
labels = 0.5 * predictions + np.random.randn(n)
stride = 10
- weights = (np.arange(n).reshape(n//stride, stride) % stride == 0)
+ weights = (np.arange(n).reshape(n // stride, stride) % stride == 0)
for row in weights:
np.random.shuffle(row)
# Now, weights is one-hot by row - one item per batch has non-zero weight.
@@ -5778,19 +5710,20 @@ class StreamingPearsonRTest(test.TestCase):
labels_t: labels[stride * i:stride * (i + 1)],
weights_t: weights[stride * i:stride * (i + 1)]
}
- cmat = np.cov(predictions[:stride * (i + 1)],
- labels[:stride * (i + 1)],
- fweights=weights[:stride * (i + 1)])
+ cmat = np.cov(
+ predictions[:stride * (i + 1)],
+ labels[:stride * (i + 1)],
+ fweights=weights[:stride * (i + 1)])
expected_r = cmat[0, 1] / np.sqrt(cmat[0, 0] * cmat[1, 1])
actual_r = sess.run(update_op, feed_dict=feed_dict)
self.assertEqual(np.isnan(expected_r), np.isnan(actual_r))
- self.assertEqual(np.isnan(expected_r),
- np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
+ self.assertEqual(
+ np.isnan(expected_r),
+ np.isnan(sess.run(pearson_r, feed_dict=feed_dict)))
if not np.isnan(expected_r):
- self.assertAlmostEqual(
- expected_r, actual_r, 5)
- self.assertAlmostEqual(
- expected_r, sess.run(pearson_r, feed_dict=feed_dict), 5)
+ self.assertAlmostEqual(expected_r, actual_r, 5)
+ self.assertAlmostEqual(expected_r,
+ sess.run(pearson_r, feed_dict=feed_dict), 5)
class StreamingMeanCosineDistanceTest(test.TestCase):
@@ -6191,20 +6124,14 @@ class StreamingMeanIOUTest(test.TestCase):
self.assertAlmostEqual(desired_output, miou.eval())
def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
- predictions = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[5]), constant_op.constant(
- 1, shape=[5])
- ],
- 0)
- labels = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[3]), constant_op.constant(
- 1, shape=[7])
- ],
- 0)
+ predictions = array_ops.concat([
+ constant_op.constant(0, shape=[5]),
+ constant_op.constant(1, shape=[5])
+ ], 0)
+ labels = array_ops.concat([
+ constant_op.constant(0, shape=[3]),
+ constant_op.constant(1, shape=[7])
+ ], 0)
num_classes = 2
with self.test_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
@@ -6238,29 +6165,20 @@ class StreamingMeanIOUTest(test.TestCase):
self.assertEqual(0., miou.eval())
def testResultsWithSomeMissing(self):
- predictions = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[5]), constant_op.constant(
- 1, shape=[5])
- ],
- 0)
- labels = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[3]), constant_op.constant(
- 1, shape=[7])
- ],
- 0)
+ predictions = array_ops.concat([
+ constant_op.constant(0, shape=[5]),
+ constant_op.constant(1, shape=[5])
+ ], 0)
+ labels = array_ops.concat([
+ constant_op.constant(0, shape=[3]),
+ constant_op.constant(1, shape=[7])
+ ], 0)
num_classes = 2
- weights = array_ops.concat(
- [
- constant_op.constant(
- 0, shape=[1]), constant_op.constant(
- 1, shape=[8]), constant_op.constant(
- 0, shape=[1])
- ],
- 0)
+ weights = array_ops.concat([
+ constant_op.constant(0, shape=[1]),
+ constant_op.constant(1, shape=[8]),
+ constant_op.constant(0, shape=[1])
+ ], 0)
with self.test_session() as sess:
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes, weights=weights)
@@ -6270,56 +6188,45 @@ class StreamingMeanIOUTest(test.TestCase):
self.assertAlmostEqual(desired_miou, miou.eval())
def testMissingClassInLabels(self):
- labels = constant_op.constant([
- [[0, 0, 1, 1, 0, 0],
- [1, 0, 0, 0, 0, 1]],
- [[1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0]]])
- predictions = constant_op.constant([
- [[0, 0, 2, 1, 1, 0],
- [0, 1, 2, 2, 0, 1]],
- [[0, 0, 2, 1, 1, 1],
- [1, 1, 2, 0, 0, 0]]])
+ labels = constant_op.constant([[[0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 0, 1]],
+ [[1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]]])
+ predictions = constant_op.constant(
+ [[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1],
+ [1, 1, 2, 0, 0, 0]]])
num_classes = 3
with self.test_session() as sess:
- miou, update_op = metrics.streaming_mean_iou(
- predictions, labels, num_classes)
+ miou, update_op = metrics.streaming_mean_iou(predictions, labels,
+ num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval())
- self.assertAlmostEqual(
- 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)),
- miou.eval())
+ self.assertAlmostEqual(1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 /
+ (0 + 5 + 0)), miou.eval())
def testMissingClassOverallSmall(self):
labels = constant_op.constant([0])
predictions = constant_op.constant([0])
num_classes = 2
with self.test_session() as sess:
- miou, update_op = metrics.streaming_mean_iou(
- predictions, labels, num_classes)
+ miou, update_op = metrics.streaming_mean_iou(predictions, labels,
+ num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[1, 0], [0, 0]], update_op.eval())
self.assertAlmostEqual(1, miou.eval())
def testMissingClassOverallLarge(self):
- labels = constant_op.constant([
- [[0, 0, 1, 1, 0, 0],
- [1, 0, 0, 0, 0, 1]],
- [[1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0]]])
- predictions = constant_op.constant([
- [[0, 0, 1, 1, 0, 0],
- [1, 1, 0, 0, 1, 1]],
- [[0, 0, 0, 1, 1, 1],
- [1, 1, 1, 0, 0, 0]]])
+ labels = constant_op.constant([[[0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 0, 1]],
+ [[1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0]]])
+ predictions = constant_op.constant(
+ [[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1],
+ [1, 1, 1, 0, 0, 0]]])
num_classes = 3
with self.test_session() as sess:
- miou, update_op = metrics.streaming_mean_iou(
- predictions, labels, num_classes)
+ miou, update_op = metrics.streaming_mean_iou(predictions, labels,
+ num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval())
- self.assertAlmostEqual(
- 1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval())
+ self.assertAlmostEqual(1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)),
+ miou.eval())
class StreamingConcatTest(test.TestCase):
@@ -6683,7 +6590,8 @@ class CohenKappaTest(test.TestCase):
_assert_metric_variables(self, (
'cohen_kappa/po:0',
'cohen_kappa/pe_row:0',
- 'cohen_kappa/pe_col:0',))
+ 'cohen_kappa/pe_col:0',
+ ))
def testMetricsCollection(self):
my_collection_name = '__metrics__'
@@ -6705,9 +6613,9 @@ class CohenKappaTest(test.TestCase):
def testValueTensorIsIdempotent(self):
predictions = random_ops.random_uniform(
- (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1)
+ (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
- (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
+ (10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
with self.test_session() as sess:
@@ -6723,10 +6631,7 @@ class CohenKappaTest(test.TestCase):
self.assertAlmostEqual(initial_kappa, kappa.eval(), 5)
def testBasic(self):
- confusion_matrix = np.array([
- [9, 3, 1],
- [4, 8, 2],
- [2, 1, 6]])
+ confusion_matrix = np.array([[9, 3, 1], [4, 8, 2], [2, 1, 6]])
# overall total = 36
# po = [9, 8, 6], sum(po) = 23
# pe_row = [15, 12, 9], pe_col = [13, 14, 9], so pe = [5.42, 4.67, 2.25]
@@ -6738,8 +6643,10 @@ class CohenKappaTest(test.TestCase):
labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
dtypes = [dtypes_lib.int16, dtypes_lib.int32, dtypes_lib.int64]
- shapes = [(len(labels,)), # 1-dim
- (len(labels), 1)] # 2-dim
+ shapes = [
+ (len(labels,)), # 1-dim
+ (len(labels), 1)
+ ] # 2-dim
weights = [None, np.ones_like(labels)]
for dtype in dtypes:
@@ -6795,10 +6702,7 @@ class CohenKappaTest(test.TestCase):
self.assertAlmostEqual(expect, kappa.eval(), 5)
def testWeighted(self):
- confusion_matrix = np.array([
- [9, 3, 1],
- [4, 8, 2],
- [2, 1, 6]])
+ confusion_matrix = np.array([[9, 3, 1], [4, 8, 2], [2, 1, 6]])
labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
num_samples = np.sum(confusion_matrix, dtype=np.int32)
weights = (np.arange(0, num_samples) % 5) / 5.0
@@ -6809,31 +6713,26 @@ class CohenKappaTest(test.TestCase):
with self.test_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
- kappa, update_op = metrics.cohen_kappa(labels, predictions, 4,
- weights=weights)
+ kappa, update_op = metrics.cohen_kappa(
+ labels, predictions, 4, weights=weights)
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(expect, sess.run(update_op), 5)
self.assertAlmostEqual(expect, kappa.eval(), 5)
def testWithMultipleUpdates(self):
- confusion_matrix = np.array([
- [90, 30, 10, 20],
- [40, 80, 20, 30],
- [20, 10, 60, 35],
- [15, 25, 30, 25]])
+ confusion_matrix = np.array([[90, 30, 10, 20], [40, 80, 20, 30],
+ [20, 10, 60, 35], [15, 25, 30, 25]])
labels, predictions = self._confusion_matrix_to_samples(confusion_matrix)
num_samples = np.sum(confusion_matrix, dtype=np.int32)
weights = (np.arange(0, num_samples) % 5) / 5.0
num_classes = confusion_matrix.shape[0]
batch_size = num_samples // 10
- predictions_t = array_ops.placeholder(dtypes_lib.float32,
- shape=(batch_size,))
- labels_t = array_ops.placeholder(dtypes_lib.int32,
- shape=(batch_size,))
- weights_t = array_ops.placeholder(dtypes_lib.float32,
- shape=(batch_size,))
+ predictions_t = array_ops.placeholder(
+ dtypes_lib.float32, shape=(batch_size,))
+ labels_t = array_ops.placeholder(dtypes_lib.int32, shape=(batch_size,))
+ weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,))
kappa, update_op = metrics.cohen_kappa(
labels_t, predictions_t, num_classes, weights=weights_t)
with self.test_session() as sess:
@@ -6841,10 +6740,13 @@ class CohenKappaTest(test.TestCase):
for idx in range(0, num_samples, batch_size):
batch_start, batch_end = idx, idx + batch_size
- sess.run(update_op,
- feed_dict={labels_t: labels[batch_start:batch_end],
- predictions_t: predictions[batch_start:batch_end],
- weights_t: weights[batch_start:batch_end]})
+ sess.run(
+ update_op,
+ feed_dict={
+ labels_t: labels[batch_start:batch_end],
+ predictions_t: predictions[batch_start:batch_end],
+ weights_t: weights[batch_start:batch_end]
+ })
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(
# labels_np, predictions_np, sample_weight=weights_np)
expect = 0.289965397924
@@ -6862,7 +6764,8 @@ class CohenKappaTest(test.TestCase):
with self.assertRaises(ValueError):
metrics.cohen_kappa(invalid_labels, predictions, 3)
- invalid_predictions = array_ops.placeholder(dtypes_lib.float32, shape=(4, 2))
+ invalid_predictions = array_ops.placeholder(
+ dtypes_lib.float32, shape=(4, 2))
labels = array_ops.placeholder(dtypes_lib.int32, shape=(4, 1))
with self.assertRaises(ValueError):
metrics.cohen_kappa(labels, invalid_predictions, 3)
diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py
index 0d1de869f6..73dd56398c 100644
--- a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py
+++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py
@@ -54,10 +54,10 @@ BATCH_SIZE = 128
DATA_DIR = '/tmp/cifar10_data'
# Constants describing the training process.
-MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
-NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
+MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
+NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
-INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
+INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
# If a model is trained with multiple GPUs, prefix all Op names with tower_name
# to differentiate the operations. Note that this prefix is removed from the
@@ -82,8 +82,7 @@ def _activation_summary(x):
# session. This helps the clarity of presentation on tensorboard.
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
tf.summary.histogram(tensor_name + '/activations', x)
- tf.summary.scalar(tensor_name + '/sparsity',
- tf.nn.zero_fraction(x))
+ tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
def _variable_on_cpu(name, shape, initializer):
@@ -120,10 +119,9 @@ def _variable_with_weight_decay(name, shape, stddev, wd):
Variable Tensor
"""
dtype = tf.float32
- var = _variable_on_cpu(
- name,
- shape,
- tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
+ var = _variable_on_cpu(name, shape,
+ tf.truncated_normal_initializer(
+ stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
@@ -188,10 +186,8 @@ def inference(images):
# Note that the masks are applied only to the weight tensors
# conv1
with tf.variable_scope('conv1') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[5, 5, 3, 64],
- stddev=5e-2,
- wd=0.0)
+ kernel = _variable_with_weight_decay(
+ 'weights', shape=[5, 5, 3, 64], stddev=5e-2, wd=0.0)
conv = tf.nn.conv2d(
images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
@@ -201,18 +197,20 @@ def inference(images):
_activation_summary(conv1)
# pool1
- pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
- padding='SAME', name='pool1')
+ pool1 = tf.nn.max_pool(
+ conv1,
+ ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1],
+ padding='SAME',
+ name='pool1')
# norm1
- norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
- name='norm1')
+ norm1 = tf.nn.lrn(
+ pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
# conv2
with tf.variable_scope('conv2') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[5, 5, 64, 64],
- stddev=5e-2,
- wd=0.0)
+ kernel = _variable_with_weight_decay(
+ 'weights', shape=[5, 5, 64, 64], stddev=5e-2, wd=0.0)
conv = tf.nn.conv2d(
norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
@@ -221,19 +219,23 @@ def inference(images):
_activation_summary(conv2)
# norm2
- norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
- name='norm2')
+ norm2 = tf.nn.lrn(
+ conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
# pool2
- pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
- strides=[1, 2, 2, 1], padding='SAME', name='pool2')
+ pool2 = tf.nn.max_pool(
+ norm2,
+ ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1],
+ padding='SAME',
+ name='pool2')
# local3
with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [BATCH_SIZE, -1])
dim = reshape.get_shape()[1].value
- weights = _variable_with_weight_decay('weights', shape=[dim, 384],
- stddev=0.04, wd=0.004)
+ weights = _variable_with_weight_decay(
+ 'weights', shape=[dim, 384], stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
local3 = tf.nn.relu(
tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases,
@@ -242,8 +244,8 @@ def inference(images):
# local4
with tf.variable_scope('local4') as scope:
- weights = _variable_with_weight_decay('weights', shape=[384, 192],
- stddev=0.04, wd=0.004)
+ weights = _variable_with_weight_decay(
+ 'weights', shape=[384, 192], stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
local4 = tf.nn.relu(
tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases,
@@ -255,8 +257,8 @@ def inference(images):
# tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
# and performs the softmax internally for efficiency.
with tf.variable_scope('softmax_linear') as scope:
- weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
- stddev=1/192.0, wd=0.0)
+ weights = _variable_with_weight_decay(
+ 'weights', [192, NUM_CLASSES], stddev=1 / 192.0, wd=0.0)
biases = _variable_on_cpu('biases', [NUM_CLASSES],
tf.constant_initializer(0.0))
softmax_linear = tf.add(
@@ -337,11 +339,12 @@ def train(total_loss, global_step):
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
# Decay the learning rate exponentially based on the number of steps.
- lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
- global_step,
- decay_steps,
- LEARNING_RATE_DECAY_FACTOR,
- staircase=True)
+ lr = tf.train.exponential_decay(
+ INITIAL_LEARNING_RATE,
+ global_step,
+ decay_steps,
+ LEARNING_RATE_DECAY_FACTOR,
+ staircase=True)
tf.summary.scalar('learning_rate', lr)
# Generate moving averages of all losses and associated summaries.
@@ -365,8 +368,8 @@ def train(total_loss, global_step):
tf.summary.histogram(var.op.name + '/gradients', grad)
# Track the moving averages of all trainable variables.
- variable_averages = tf.train.ExponentialMovingAverage(
- MOVING_AVERAGE_DECAY, global_step)
+ variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,
+ global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
@@ -383,10 +386,13 @@ def maybe_download_and_extract():
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
+
def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
- float(count * block_size) / float(total_size) * 100.0))
+ sys.stdout.write('\r>> Downloading %s %.1f%%' %
+ (filename,
+ float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
+
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index 6132cba1f5..716ee9cdf7 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Wrapper optimizer for Elastic Average SGD """
from __future__ import absolute_import
from __future__ import division
@@ -78,23 +77,24 @@ class ElasticAverageCustomGetter(object):
def __call__(self, getter, name, trainable, collections, *args, **kwargs):
if trainable:
with ops.device(self._worker_device):
- local_var = getter(name, trainable=True,
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- *args, **kwargs)
+ local_var = getter(
+ name,
+ trainable=True,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
global_center_variable = variable_scope.variable(
- name='%s/%s' %
- (GLOBAL_VARIABLE_NAME,
- name),
- initial_value=local_var.initialized_value(),
- trainable=False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
+ initial_value=local_var.initialized_value(),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
with ops.device(self._worker_device):
local_center_variable = variable_scope.variable(
- name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
- initial_value=local_var.initialized_value(),
- trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
+ initial_value=local_var.initialized_value(),
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
self._local_map[local_var] = local_center_variable
self._global_map[local_var] = global_center_variable
@@ -117,16 +117,15 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
# Default value as paper described
BETA = 0.9
- def __init__(
- self,
- opt,
- num_worker,
- ea_custom_getter,
- communication_period=10,
- moving_rate=None,
- rho=None,
- use_locking=True,
- name="ElasticAverageOptimizer"):
+ def __init__(self,
+ opt,
+ num_worker,
+ ea_custom_getter,
+ communication_period=10,
+ moving_rate=None,
+ rho=None,
+ use_locking=True,
+ name='ElasticAverageOptimizer'):
"""Construct a new gradient descent optimizer.
Args:
@@ -160,13 +159,15 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
self._rho = rho
self._local_step = variable_scope.get_variable(
- initializer=0,
- trainable=False,
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- name="local_step")
+ initializer=0,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ name='local_step')
self._opt._prepare()
- def compute_gradients(self, loss, var_list=None,
+ def compute_gradients(self,
+ loss,
+ var_list=None,
gate_gradients=optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
@@ -204,16 +205,18 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
if not var_list:
var_list = variables.trainable_variables()
- elastic_difference = [math_ops.subtract(v, lv) for v, lv in zip(
- variables.trainable_variables(),
- [self._local_map[var] for var in var_list])]
+ elastic_difference = [
+ math_ops.subtract(v, lv)
+ for v, lv in zip(variables.trainable_variables(),
+ [self._local_map[var] for var in var_list])
+ ]
distance_loss = self._rho * math_ops.add_n(
- [gen_nn_ops.l2_loss(ed) for ed in elastic_difference])
+ [gen_nn_ops.l2_loss(ed) for ed in elastic_difference])
total_loss = loss + distance_loss
- return self._opt.compute_gradients(total_loss, var_list,
- gate_gradients, aggregation_method,
+ return self._opt.compute_gradients(total_loss, var_list, gate_gradients,
+ aggregation_method,
colocate_gradients_with_ops, grad_loss)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
@@ -241,7 +244,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
apply_updates = self._opt.apply_gradients(grads_and_vars)
with ops.control_dependencies([apply_updates]):
local_update = state_ops.assign_add(
- self._local_step, 1, name='local_step_update').op
+ self._local_step, 1, name='local_step_update').op
# update global variables.
def _Update_global_variables():
@@ -259,12 +262,16 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
differences.append(math_ops.subtract(v, lv))
for lvar, diff in zip(local_vars, differences):
with ops.device(lvar.device):
- update_ops.append(state_ops.assign_sub(lvar, math_ops.multiply(
- self._moving_rate, diff)))
+ update_ops.append(
+ state_ops.assign_sub(lvar,
+ math_ops.multiply(self._moving_rate,
+ diff)))
for var, diff in zip(global_center_vars, differences):
with ops.device(var.device):
- update_ops.append(state_ops.assign_add(var, math_ops.multiply(
- self._moving_rate, diff)))
+ update_ops.append(
+ state_ops.assign_add(var,
+ math_ops.multiply(self._moving_rate,
+ diff)))
if global_step:
with ops.colocate_with(global_step):
update_ops.append(state_ops.assign_add(global_step, 1))
@@ -272,10 +279,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
return variable_update
with ops.control_dependencies([local_update]):
- condition = math_ops.equal(math_ops.mod(
- self._local_step, self._period), 0)
+ condition = math_ops.equal(
+ math_ops.mod(self._local_step, self._period), 0)
conditional_update = control_flow_ops.cond(
- condition, _Update_global_variables, control_flow_ops.no_op)
+ condition, _Update_global_variables, control_flow_ops.no_op)
return conditional_update
def get_init_op(self, task_index):
@@ -285,10 +292,12 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
def _Add_sync_queues_and_barrier(enqueue_after_list):
"""Adds ops to enqueu on all worker queues"""
sync_queues = [
- data_flow_ops.FIFOQueue(self._num_worker, [dtypes.bool], shapes=[[]],
- shared_name='%s%s' % (
- 'variable_init_sync_queue', i)) for i in
- range(self._num_worker)]
+ data_flow_ops.FIFOQueue(
+ self._num_worker, [dtypes.bool],
+ shapes=[[]],
+ shared_name='%s%s' % ('variable_init_sync_queue', i))
+ for i in range(self._num_worker)
+ ]
queue_ops = []
# For each other worker, add an entry in a queue
token = constant_op.constant(False)
@@ -299,7 +308,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
else:
queue_ops.append(q.enqueue(token))
queue_ops.append(
- sync_queues[task_index].dequeue_many(len(sync_queues) - 1))
+ sync_queues[task_index].dequeue_many(len(sync_queues) - 1))
return control_flow_ops.group(*queue_ops)
init_ops = []
@@ -307,11 +316,10 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
global_center_vars = [self._global_map[var] for var in local_vars]
local_center_vars = [self._local_map[var] for var in local_vars]
if not (local_vars and global_center_vars and local_center_vars):
- raise ValueError(
- 'The lists of local_variables, global_center_variables, '
- 'local_center_variables should not be empty ')
- for lvar, gc_var, lc_var in zip(
- local_vars, global_center_vars, local_center_vars):
+ raise ValueError('The lists of local_variables, global_center_variables, '
+ 'local_center_variables should not be empty ')
+ for lvar, gc_var, lc_var in zip(local_vars, global_center_vars,
+ local_center_vars):
init_ops.append(state_ops.assign(lvar, gc_var))
init_ops.append(state_ops.assign(lc_var, gc_var))
@@ -325,6 +333,7 @@ class ElasticAverageOptimizer(optimizer.Optimizer):
class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
+
def __init__(self, ea_optimizer, is_chief, task_index):
"""Creates hook to handle ElasticAverageOptimizer initialization ops.
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
index 446e91018d..37539b9599 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
@@ -38,20 +38,20 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc"):
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {
- "worker": ["localhost:%s" % port for port in worker_ports],
- "ps": ["localhost:%s" % port for port in ps_ports]
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
}
cs = server_lib.ClusterSpec(cluster_dict)
workers = [
- server_lib.Server(
- cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
- for ix in range(num_workers)
+ server_lib.Server(
+ cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_workers)
]
ps_servers = [
- server_lib.Server(
- cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
- for ix in range(num_ps)
+ server_lib.Server(
+ cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_ps)
]
return cluster_dict, workers, ps_servers
@@ -68,15 +68,14 @@ def _get_workers(num_workers, period, workers, moving_rate):
is_chief = (worker_id == 0)
with graph.as_default():
worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
- ea_coustom = ElasticAverageCustomGetter(
- worker_device=worker_device)
- with variable_scope.variable_scope('',
- custom_getter=ea_coustom), ops.device(
- device_setter.replica_device_setter(worker_device=worker_device,
- ps_device="/job:ps/task:0/cpu:0",
- ps_tasks=1)):
- global_step = variables.Variable(0, name='global_step',
- trainable=False)
+ ea_coustom = ElasticAverageCustomGetter(worker_device=worker_device)
+ with variable_scope.variable_scope(
+ "", custom_getter=ea_coustom), ops.device(
+ device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=1)):
+ global_step = variables.Variable(0, name="global_step", trainable=False)
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
@@ -86,21 +85,19 @@ def _get_workers(num_workers, period, workers, moving_rate):
sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
opt = ElasticAverageOptimizer(
- opt=sgd_opt,
- num_worker=num_workers,
- moving_rate=moving_rate,
- communication_period=period,
- ea_custom_getter=ea_coustom
- )
+ opt=sgd_opt,
+ num_worker=num_workers,
+ moving_rate=moving_rate,
+ communication_period=period,
+ ea_custom_getter=ea_coustom)
train_op = [
- opt.apply_gradients(
- ([grads_0, var_0],
- [grads_1, var_1]), global_step)
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ global_step)
]
easgd_hook = opt.make_session_run_hook(is_chief, worker_id)
# Creates MonitoredSession
- sess = training.MonitoredTrainingSession(workers[worker_id].target,
- hooks=[easgd_hook])
+ sess = training.MonitoredTrainingSession(
+ workers[worker_id].target, hooks=[easgd_hook])
sessions.append(sess)
graphs.append(graph)
@@ -110,6 +107,7 @@ def _get_workers(num_workers, period, workers, moving_rate):
class ElasticAverageOptimizerTest(test.TestCase):
+
def _run(self, train_op, sess):
sess.run(train_op)
@@ -117,15 +115,14 @@ class ElasticAverageOptimizerTest(test.TestCase):
num_workers = 1
communication_period = 2
num_ps = 1
- cluster, workers, _ = create_local_cluster(num_workers=num_workers,
- num_ps=num_ps)
+ cluster, workers, _ = create_local_cluster(
+ num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(num_workers,
- communication_period,
- workers, 1.0)
+ sessions, graphs, train_ops = _get_workers(
+ num_workers, communication_period, workers, 1.0)
- var_0 = graphs[0].get_tensor_by_name('v0:0')
- var_1 = graphs[0].get_tensor_by_name('v1:0')
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
global_step = training_util.get_global_step(graphs[0])
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
@@ -166,18 +163,17 @@ class ElasticAverageOptimizerTest(test.TestCase):
num_workers = 2
communication_period = 1
num_ps = 2
- cluster, workers, _ = create_local_cluster(num_workers=num_workers,
- num_ps=num_ps)
+ cluster, workers, _ = create_local_cluster(
+ num_workers=num_workers, num_ps=num_ps)
- sessions, graphs, train_ops = _get_workers(num_workers,
- communication_period,
- workers, 0.5)
+ sessions, graphs, train_ops = _get_workers(
+ num_workers, communication_period, workers, 0.5)
- var_0 = graphs[0].get_tensor_by_name('v0:0')
- var_1 = graphs[0].get_tensor_by_name('v1:0')
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
- var_0_1 = graphs[1].get_tensor_by_name('v0:0')
- var_1_1 = graphs[1].get_tensor_by_name('v1:0')
+ var_0_1 = graphs[1].get_tensor_by_name("v0:0")
+ var_1_1 = graphs[1].get_tensor_by_name("v1:0")
var_0_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v0:0")
var_1_g = graphs[0].get_tensor_by_name(GLOBAL_VARIABLE_NAME + "/v1:0")
@@ -201,25 +197,24 @@ class ElasticAverageOptimizerTest(test.TestCase):
def testPS2TasksWithClusterSpecClass(self):
cluster_spec = server_lib.ClusterSpec({
- "ps": ["ps0:2222", "ps1:2222"],
- "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+ "ps": ["ps0:2222", "ps1:2222"],
+ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
})
- ea_coustom = ElasticAverageCustomGetter(
- worker_device="/job:worker/task:0")
+ ea_coustom = ElasticAverageCustomGetter(worker_device="/job:worker/task:0")
from tensorflow.python.training import device_setter
with ops.device(
device_setter.replica_device_setter(cluster=cluster_spec,
worker_device="/job:worker/task:0",
ps_device="/job:ps")), \
- variable_scope.variable_scope('', custom_getter=ea_coustom):
+ variable_scope.variable_scope("", custom_getter=ea_coustom):
v = variable_scope.get_variable(initializer=[1, 2], name="v")
- w = variable_scope.get_variable(initializer=[2, 1], name='w')
- v_g, w_g = ea_coustom._global_map[v],ea_coustom._global_map[w]
+ w = variable_scope.get_variable(initializer=[2, 1], name="w")
+ v_g, w_g = ea_coustom._global_map[v], ea_coustom._global_map[w]
self.assertDeviceEqual("/job:worker/task:0", v.device)
self.assertDeviceEqual("job:ps/task:0", v_g.device)
self.assertDeviceEqual("/job:worker/task:0", w.device)
self.assertDeviceEqual("job:ps/task:1", w_g.device)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py
index e8443e718d..578d9424b2 100644
--- a/tensorflow/contrib/predictor/predictor_factories_test.py
+++ b/tensorflow/contrib/predictor/predictor_factories_test.py
@@ -50,8 +50,8 @@ class PredictorFactoriesTest(test.TestCase):
def testFromContribEstimator(self):
estimator = testing_common.get_arithmetic_estimator(core=False)
input_fn = testing_common.get_arithmetic_input_fn(core=False)
- predictor_factories.from_contrib_estimator(estimator, input_fn,
- output_alternative_key='sum')
+ predictor_factories.from_contrib_estimator(
+ estimator, input_fn, output_alternative_key='sum')
def testFromContribEstimatorWithCoreEstimatorRaises(self):
estimator = testing_common.get_arithmetic_estimator(core=True)
diff --git a/tensorflow/contrib/py2tf/converters/break_canonicalization.py b/tensorflow/contrib/py2tf/converters/break_canonicalization.py
index ef58573445..2ae65e3007 100644
--- a/tensorflow/contrib/py2tf/converters/break_canonicalization.py
+++ b/tensorflow/contrib/py2tf/converters/break_canonicalization.py
@@ -33,31 +33,25 @@ class BreakCanonicalizationTransformer(gast.NodeTransformer):
self.break_uses = []
def _create_break_check(self):
-
- def template(var_name):
- (not var_name) # pylint:disable=pointless-statement
-
- expr, = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ template = """
+ (not var_name)
+ """
+ expr, = templates.replace(template, var_name=self.break_uses[-1][1])
return expr.value
def _create_break_trigger(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = True
-
- block = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ """
+ block = templates.replace(template, var_name=self.break_uses[-1][1])
block.append(gast.Continue())
return block
def _create_break_init(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = False
-
- assign, = templates.replace(
- template, var_name=gast.Name(self.break_uses[-1][1], None, None))
+ """
+ assign, = templates.replace(template, var_name=self.break_uses[-1][1])
return assign
# TODO(mdan): Surely the transformer supports this better?
diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py
index b80c96c97a..7f6b64a34c 100644
--- a/tensorflow/contrib/py2tf/converters/builtin_functions.py
+++ b/tensorflow/contrib/py2tf/converters/builtin_functions.py
@@ -29,10 +29,9 @@ class BuiltinFunctionTransformer(gast.NodeTransformer):
# TODO(mdan): Bring print_functions in here.
def _convert_len(self, node):
-
- def template(args):
- tf.shape(args)[0] # pylint:disable=undefined-variable,expression-not-assigned
-
+ template = """
+ tf.shape(args)[0]
+ """
new_call = templates.replace(template, args=node.args)[0].value
return new_call
diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py
index df071f596f..0aae030450 100644
--- a/tensorflow/contrib/py2tf/converters/call_trees.py
+++ b/tensorflow/contrib/py2tf/converters/call_trees.py
@@ -151,7 +151,7 @@ class CallTreeTransformer(gast.NodeTransformer):
else:
new_name = self.namer.compiled_function_name(
'__'.join(target_fqn), live_object=target_obj)
- node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
+ node.func = gast.Name(new_name, gast.Load(), None)
return node
def _rename_member_function_of_known_type(self, node):
@@ -184,26 +184,17 @@ class CallTreeTransformer(gast.NodeTransformer):
def _wrap_to_py_func_no_return(self, node):
args_scope = anno.getanno(node, 'args_scope')
# TODO(mdan): Properly handle varargs, kwargs, etc.
- args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)
-
- # pylint:disable=undefined-variable,unused-argument,function-redefined
-
- def template(call, wrapper, args):
-
+ template = """
def wrapper(args):
call(args)
return 1
-
tf.py_func(wrapper, [args], [tf.int64])
-
- # pylint:enable=undefined-variable,unused-argument,function-redefined
-
- wrapper_name = self.namer.compiled_function_name(node.func.id)
+ """
wrapper_def, call_expr = templates.replace(
template,
call=node.func,
- wrapper=gast.Name(wrapper_name, gast.Load(), None),
- args=args)
+ wrapper=self.namer.compiled_function_name(node.func.id),
+ args=tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used))
anno.setanno(call_expr.value, 'args_scope', args_scope)
# TODO(mdan): Rename this annotation to 'graph_ready'
anno.setanno(wrapper_def, 'skip_processing', True)
diff --git a/tensorflow/contrib/py2tf/converters/continue_canonicalization.py b/tensorflow/contrib/py2tf/converters/continue_canonicalization.py
index 7f8ace77a8..486f0f6509 100644
--- a/tensorflow/contrib/py2tf/converters/continue_canonicalization.py
+++ b/tensorflow/contrib/py2tf/converters/continue_canonicalization.py
@@ -33,32 +33,28 @@ class ContinueCanonicalizationTransformer(gast.NodeTransformer):
self.continuation_uses = []
def _create_continuation_check(self):
-
- def template(var_name):
+ template = """
if not var_name:
pass
-
- cond, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ """
+ cond, = templates.replace(template, var_name=self.continuation_uses[-1][1])
cond.body = []
return cond
def _create_continuation_trigger(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = True
-
+ """
assign, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ template, var_name=self.continuation_uses[-1][1])
return assign
def _create_continuation_init(self):
-
- def template(var_name): # pylint:disable=unused-argument
+ template = """
var_name = False
-
+ """
assign, = templates.replace(
- template, var_name=gast.Name(self.continuation_uses[-1][1], None, None))
+ template, var_name=self.continuation_uses[-1][1])
return assign
def _visit_and_reindent_if_necessary(self, nodes):
diff --git a/tensorflow/contrib/py2tf/converters/control_flow.py b/tensorflow/contrib/py2tf/converters/control_flow.py
index 8ebd9ad93d..a40c7b28f7 100644
--- a/tensorflow/contrib/py2tf/converters/control_flow.py
+++ b/tensorflow/contrib/py2tf/converters/control_flow.py
@@ -75,29 +75,6 @@ class ControlFlowTransformer(gast.NodeTransformer):
raise ValueError(
'The else branch creates new symbols that the if branch does not.')
- def template( # pylint:disable=missing-docstring
- test,
- body_name,
- body,
- orelse_name,
- orelse,
- aliased,
- aliases, # pylint:disable=unused-argument
- aliased_results,
- results): # pylint:disable=unused-argument
-
- def body_name(): # pylint:disable=function-redefined
- aliases, = aliased, # pylint:disable=unused-variable
- body # pylint:disable=pointless-statement
- return (aliased_results,)
-
- def orelse_name(): # pylint:disable=function-redefined
- aliases, = aliased, # pylint:disable=unused-variable
- orelse # pylint:disable=pointless-statement
- return (aliased_results,)
-
- results = tf.cond(test, body_name, orelse_name) # pylint:disable=undefined-variable
-
all_modified = tuple(body_scope.modified | orelse_scope.modified)
all_referenced = body_scope.referenced | orelse_scope.referenced
@@ -107,10 +84,10 @@ class ControlFlowTransformer(gast.NodeTransformer):
need_alias = (
(body_scope.modified | orelse_scope.modified) -
(body_scope.created | orelse_scope.created))
- aliased = tuple(need_alias)
- aliases = tuple(
- self.namer.new_symbol(s, all_referenced) for s in aliased)
- alias_map = dict(zip(aliased, aliases))
+ aliased_orig_names = tuple(need_alias)
+ aliased_new_names = tuple(
+ self.namer.new_symbol(s, all_referenced) for s in aliased_orig_names)
+ alias_map = dict(zip(aliased_orig_names, aliased_new_names))
node_body = node.body
node_body = [SymbolRenamer(alias_map).visit(n) for n in node_body]
node_orelse = node.orelse
@@ -122,20 +99,29 @@ class ControlFlowTransformer(gast.NodeTransformer):
results = gast.Tuple(
tuple(gast.Name(s, None, None) for s in all_modified), None)
+ template = """
+ def body_name():
+ aliased_new_names, = aliased_orig_names,
+ body
+ return (all_results,)
+ def orelse_name():
+ aliased_new_names, = aliased_orig_names,
+ orelse
+ return (all_results,)
+ results = tf.cond(test, body_name, orelse_name)
+ """
+ body_name = self.namer.new_symbol('if_true', all_referenced)
return templates.replace(
template,
test=node.test,
- body_name=gast.Name(
- self.namer.new_symbol('if_true', all_referenced), None, None),
+ body_name=body_name,
body=node_body,
- orelse_name=gast.Name(
- self.namer.new_symbol('if_false', all_referenced), None, None),
+ orelse_name=self.namer.new_symbol('if_false', all_referenced),
orelse=node_orelse,
- aliased=tuple(gast.Name(s, None, None) for s in aliased),
- aliases=tuple(gast.Name(s, None, None) for s in aliases),
- aliased_results=tuple(
- gast.Name(alias_map[s] if s in aliased else s, None, None)
- for s in all_modified),
+ aliased_orig_names=tuple(aliased_orig_names),
+ aliased_new_names=tuple(aliased_new_names),
+ all_results=tuple(alias_map[s] if s in aliased_orig_names else s
+ for s in all_modified),
results=results)
def visit_While(self, node):
@@ -144,38 +130,28 @@ class ControlFlowTransformer(gast.NodeTransformer):
body_scope = anno.getanno(node, 'body_scope')
body_closure = tuple(body_scope.modified - body_scope.created)
- def template(
- state, # pylint:disable=unused-argument
- state_ast_tuple, # pylint:disable=unused-argument
- test_name,
- test, # pylint:disable=unused-argument
- body_name,
- body):
-
- def test_name(state): # pylint:disable=function-redefined,unused-argument
- return test
-
- def body_name(state): # pylint:disable=function-redefined,unused-argument
- body # pylint:disable=pointless-statement
- return state,
-
- state_ast_tuple = tf.while_loop(test_name, body_name, [state]) # pylint:disable=undefined-variable
-
- test_name = self.namer.new_symbol('loop_test', body_scope.referenced)
- body_name = self.namer.new_symbol('loop_body', body_scope.referenced)
if len(body_closure) == 1:
- state = gast.Name(body_closure[0], None, None)
+ state = body_closure[0]
state_ast_tuple = state
else:
- state = tuple(gast.Name(n, None, None) for n in body_closure)
- state_ast_tuple = gast.Tuple(state, None)
+ state = tuple(body_closure)
+ state_ast_tuple = gast.Tuple(
+ tuple(gast.Name(n, None, None) for n in state), None)
+ template = """
+ def test_name(state):
+ return test
+ def body_name(state):
+ body
+ return state,
+ state_ast_tuple = tf.while_loop(test_name, body_name, [state])
+ """
node = templates.replace(
template,
state=state,
state_ast_tuple=state_ast_tuple,
- test_name=gast.Name(test_name, gast.Load(), None),
+ test_name=self.namer.new_symbol('loop_test', body_scope.referenced),
test=node.test,
- body_name=gast.Name(body_name, gast.Load(), None),
+ body_name=self.namer.new_symbol('loop_body', body_scope.referenced),
body=node.body)
return node
diff --git a/tensorflow/contrib/py2tf/converters/for_canonicalization.py b/tensorflow/contrib/py2tf/converters/for_canonicalization.py
index 52360789cd..c284689b90 100644
--- a/tensorflow/contrib/py2tf/converters/for_canonicalization.py
+++ b/tensorflow/contrib/py2tf/converters/for_canonicalization.py
@@ -42,46 +42,40 @@ class ForLoopCanonicalizationTransformer(gast.NodeTransformer):
# Or maybe we should replace range with tf.range?
if anno.hasanno(node, 'extra_cond'):
-
- def template(loop_iter, target, body, i, n, extra_cond): # pylint:disable=unused-argument
+ template = """
i = 0
- n = len(loop_iter) # pylint:disable=undefined-variable
+ n = len(loop_iter)
while i < n and extra_cond:
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
target = loop_iter[i]
- body # pylint:disable=pointless-statement
+ body
i += 1
-
+ """
return templates.replace(
template,
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=gast.Name(
- self.namer.new_symbol('i', body_scope.referenced), None, None),
- n=gast.Name(
- self.namer.new_symbol('n', body_scope.referenced), None, None),
+ i=self.namer.new_symbol('i', body_scope.referenced),
+ n=self.namer.new_symbol('n', body_scope.referenced),
extra_cond=anno.getanno(node, 'extra_cond'))
else:
-
- def template(loop_iter, target, body, i, n): # pylint:disable=unused-argument
+ template = """
i = 0
- n = len(loop_iter) # pylint:disable=undefined-variable
+ n = len(loop_iter)
while i < n:
# TODO(mdan): Use TensorListFromTensor(loop_iter) here.
target = loop_iter[i]
body # pylint:disable=pointless-statement
i += 1
-
+ """
return templates.replace(
template,
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=gast.Name(
- self.namer.new_symbol('i', body_scope.referenced), None, None),
- n=gast.Name(
- self.namer.new_symbol('n', body_scope.referenced), None, None))
+ i=self.namer.new_symbol('i', body_scope.referenced),
+ n=self.namer.new_symbol('n', body_scope.referenced))
def visit_Continue(self, node):
assert False, 'continue statement should be desugared at this point'
diff --git a/tensorflow/contrib/py2tf/converters/side_effect_guards.py b/tensorflow/contrib/py2tf/converters/side_effect_guards.py
index 83d0720b6b..4df723989d 100644
--- a/tensorflow/contrib/py2tf/converters/side_effect_guards.py
+++ b/tensorflow/contrib/py2tf/converters/side_effect_guards.py
@@ -96,12 +96,10 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
return node
def _gate_symbols(self, guard_statement, guarded_args):
-
- def template(args): # pylint:disable=unused-argument
- (args,) = (tf.identity(a) for a in (args,)) # pylint:disable=undefined-variable
-
- guards = templates.replace(
- template, args=tuple(gast.Name(a, None, None) for a in guarded_args))
+ template = """
+ (args,) = (tf.identity(a) for a in (args,))
+ """
+ guards = templates.replace(template, args=tuple(guarded_args))
guard_statement.body.extend(guards)
return guard_statement
@@ -112,29 +110,25 @@ class SideEffectGuardTransformer(gast.NodeTransformer):
# opt.minimize(loss)
# or:
# tf.py_func(...)
-
args_scope = anno.getanno(node.value, 'args_scope')
temp_name = self.namer.new_symbol('temp', args_scope.parent.referenced)
# TODO(mdan): Unsafe reference modification!
args_scope.mark_write(temp_name)
-
- def template(call, temp_result):
+ template = """
temp_result = call
if temp_result is not None:
if not isinstance(temp_result, (list, tuple)):
temp_result = (temp_result,)
- ctx = tf.control_dependencies(temp_result) # pylint:disable=undefined-variable
+ ctx = tf.control_dependencies(temp_result)
else:
ctx = contextmanager(lambda: (yield))()
with ctx:
# TODO(mdan): Also insert ops to re-fetch if variables are involved.
pass # Will be removed below.
-
- # TODO(mdan): This is brittle. Reorganize this mechanism.
+ """
+ # TODO(mdan): This is brittle. Reorganize the mechanism.
statements = templates.replace(
- template,
- call=node.value,
- temp_result=gast.Name(temp_name, None, None))
+ template, call=node.value, temp_result=temp_name)
control_deps_guard = statements[-1]
control_deps_guard.body = []
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py
index 4fadc793e6..77c5fbe02a 100644
--- a/tensorflow/contrib/py2tf/pyct/templates.py
+++ b/tensorflow/contrib/py2tf/pyct/templates.py
@@ -80,37 +80,46 @@ class ReplaceTransformer(gast.NodeTransformer):
return node
+def _strings_to_names(n):
+ if isinstance(n, str):
+ # Note: the node will receive the ctx value from the template, see
+ # ReplaceTransformer.visit_Name.
+ return gast.Name(id=n, ctx=None, annotation=None)
+ if isinstance(n, list):
+ return [_strings_to_names(e) for e in n]
+ if isinstance(n, tuple):
+ return tuple(_strings_to_names(e) for e in n)
+ return n
+
+
def replace(template, **replacements):
"""Replace placeholders in a Python template.
+ AST Name and Tuple nodes always receive the context that inferred from
+ the template. However, when replacing more complex nodes (that can potentially
+ contain Name children), then the caller is responsible for setting the
+ appropriate context.
+
Args:
- template: A function to be used as a template. Any placeholder is expected
- to also be a function argument.
+ template: A string representing Python code. Any symbol name can be used
+ that appears in the template code can be used as placeholder.
**replacements: A mapping from placeholder names to (lists of) AST nodes
- that these placeholders will be replaced by.
+ that these placeholders will be replaced by. String values are also
+ supported as a shorthand for AST Name nodes with the respective ID.
Returns:
- body: An AST node or list of AST nodes with the replacements made. If the
- template was a function, a list will be returned. If the template was a
- node, the same node will be returned. If the template was a string, an
- AST node will be returned (a `Module` node in the case of a multi-line
- string, an `Expr` node otherwise).
+ An AST node or list of AST nodes with the replacements made. If the
+ template was a function, a list will be returned. If the template was a
+ node, the same node will be returned. If the template was a string, an
+ AST node will be returned (a `Module` node in the case of a multi-line
+ string, an `Expr` node otherwise).
Raises:
- ValueError: If a function is used as a template and an incorrect set of
- replacements was passed.
+ ValueError: if the arguments are incorrect.
"""
- tree = parser.parse_object(template).body[0]
- placeholders = set(arg.id for arg in tree.args.args)
- tree.args.args = []
- if tree.args.vararg:
- placeholders.add(tree.args.vararg)
- tree.args.vararg = None
- if set(replacements.keys()) != placeholders:
- raise ValueError(
- 'too many or few replacements. replacements: %s; placeholders: %s' %
- (replacements.keys(), placeholders))
-
- # Perform the replacement, stripping the function into which the template was
- # wrapped.
+ if not isinstance(template, str):
+ raise ValueError('Expected string template, got %s' % type(template))
+ tree = parser.parse_str(template)
+ for k in replacements:
+ replacements[k] = _strings_to_names(replacements[k])
return ReplaceTransformer(replacements).visit(tree).body
diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py
index 2ad8b9317b..1143131283 100644
--- a/tensorflow/contrib/py2tf/pyct/templates_test.py
+++ b/tensorflow/contrib/py2tf/pyct/templates_test.py
@@ -28,46 +28,42 @@ from tensorflow.python.platform import test
class TemplatesTest(test.TestCase):
def test_replace_variable(self):
- def template(a): # pylint:disable=unused-argument
- def test_fn(a): # pylint:disable=unused-variable
+ template = """
+ def test_fn(a):
a += 1
a = 2 * a + 1
- return b # pylint:disable=undefined-variable
+ return b
+ """
- node = templates.replace(
- template, a=gast.Name('b', gast.Load(), None))[0]
+ node = templates.replace(template, a='b')[0]
result = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_replace_function_name(self):
- def template(fname): # pylint:disable=unused-argument
- def fname(a): # pylint:disable=function-redefined
+ template = """
+ def fname(a):
a += 1
a = 2 * a + 1
return a
+ """
- node = templates.replace(
- template, fname=gast.Name('test_fn', gast.Load(), None))[0]
+ node = templates.replace(template, fname='test_fn')[0]
result = compiler.ast_to_object(node)
self.assertEquals(7, result.test_fn(2))
def test_code_block(self):
- def template(block): # pylint:disable=unused-argument
- def test_fn(a): # pylint:disable=unused-variable
- block # pylint:disable=pointless-statement
+ template = """
+ def test_fn(a):
+ block
return a
+ """
node = templates.replace(
template,
block=[
- gast.Assign(
- [
- gast.Name('a', gast.Store(), None)
- ],
- gast.BinOp(
- gast.Name('a', gast.Load(), None),
- gast.Add(),
- gast.Num(1))),
+ gast.Assign([
+ gast.Name('a', None, None)
+ ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
] * 2)[0]
result = compiler.ast_to_object(node)
self.assertEquals(3, result.test_fn(1))
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
index 8e6870fadd..501cddb8c8 100644
--- a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
@@ -34,9 +34,9 @@ namespace functor {
__global__ void ReduceSliceDeviceKernel##reduceop( \
Cuda3DLaunchConfig config, Index indices_width, Index bound, \
const T begin, const Index *indices, const T *input, T *out) { \
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) { \
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) { \
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) { \
Index outidx = x * config.virtual_thread_count.y * \
config.virtual_thread_count.z + \
y * config.virtual_thread_count.z + z; \
@@ -68,8 +68,9 @@ namespace functor {
if (sizex * sizey * sizez == 0) { \
return; \
} \
- Cuda3DLaunchConfig config = GetCuda3DLaunchConfig(sizex, sizey, sizez, d,\
- ReduceSliceDeviceKernel##reduceop<T, Index>, 0, 0); \
+ Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
+ sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
+ 0, 0); \
\
ReduceSliceDeviceKernel##reduceop<T, Index> \
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index ef3722ee41..6d8f786223 100644
--- a/tensorflow/contrib/seq2seq/python/ops/helper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -73,6 +73,14 @@ class Helper(object):
raise NotImplementedError("batch_size has not been implemented")
@abc.abstractproperty
+ def input_shape(self):
+ """Shape of each input element in batch.
+
+ Returns a `TensorShape`.
+ """
+ raise NotImplementedError("input_shape has not been implemented")
+
+ @abc.abstractproperty
def sample_ids_shape(self):
"""Shape of tensor returned by `sample`, excluding the batch dimension.
@@ -127,6 +135,7 @@ class CustomHelper(Helper):
self._sample_fn = sample_fn
self._next_inputs_fn = next_inputs_fn
self._batch_size = None
+ self._input_shape = None
self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or [])
self._sample_ids_dtype = sample_ids_dtype or dtypes.int32
@@ -149,6 +158,8 @@ class CustomHelper(Helper):
(finished, next_inputs) = self._initialize_fn()
if self._batch_size is None:
self._batch_size = array_ops.size(finished)
+ if self._input_shape is None:
+ self._input_shape = next_inputs.shape[1:]
return (finished, next_inputs)
def sample(self, time, outputs, state, name=None):
@@ -184,6 +195,7 @@ class TrainingHelper(Helper):
"""
with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
+ self._inputs = inputs
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)
@@ -199,12 +211,17 @@ class TrainingHelper(Helper):
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
self._batch_size = array_ops.size(sequence_length)
+ self._input_shape = inputs.shape[2:]
@property
def batch_size(self):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])
@@ -212,6 +229,14 @@ class TrainingHelper(Helper):
def sample_ids_dtype(self):
return dtypes.int32
+ @property
+ def inputs(self):
+ return self._inputs
+
+ @property
+ def sequence_length(self):
+ return self._sequence_length
+
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
@@ -516,12 +541,17 @@ class GreedyEmbeddingHelper(Helper):
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._start_inputs = self._embedding_fn(self._start_tokens)
+ self._input_shape = self._start_inputs.shape[1:]
@property
def batch_size(self):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])
@@ -632,6 +662,8 @@ class InferenceHelper(Helper):
self._sample_dtype = sample_dtype
self._next_inputs_fn = next_inputs_fn
self._batch_size = array_ops.shape(start_inputs)[0]
+ self._input_shape = start_inputs.shape[1:]
+
self._start_inputs = ops.convert_to_tensor(
start_inputs, name="start_inputs")
@@ -640,6 +672,10 @@ class InferenceHelper(Helper):
return self._batch_size
@property
+ def input_shape(self):
+ return self._input_shape
+
+ @property
def sample_ids_shape(self):
return self._sample_shape
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
index c8b4e472c9..360e7dbe75 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
@@ -105,8 +105,8 @@ class SparsemaxLossTest(test.TestCase):
tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
np_loss = self._np_sparsemax_loss(z, q).astype(dtype)
- self.assertAllCloseAccordingToType(np_loss, tf_loss_out,
- half_atol=1e-2, half_rtol=5e-3)
+ self.assertAllCloseAccordingToType(
+ np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3)
self.assertShapeEqual(np_loss, tf_loss_op)
def _test_constant_add(self, dtype, random, use_gpu):
@@ -116,17 +116,17 @@ class SparsemaxLossTest(test.TestCase):
q = np.zeros((test_obs, 10))
q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
- _, tf_loss_zpc = self._tf_sparsemax_loss(
- z + c, q, dtype, use_gpu
- )
+ _, tf_loss_zpc = self._tf_sparsemax_loss(z + c, q, dtype, use_gpu)
- _, tf_loss_z = self._tf_sparsemax_loss(
- z, q, dtype, use_gpu
- )
+ _, tf_loss_z = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
- self.assertAllCloseAccordingToType(tf_loss_zpc, tf_loss_z,
- float_atol=5e-6, float_rtol=5e-6,
- half_atol=1e-2, half_rtol=1e-2)
+ self.assertAllCloseAccordingToType(
+ tf_loss_zpc,
+ tf_loss_z,
+ float_atol=5e-6,
+ float_rtol=5e-6,
+ half_atol=1e-2,
+ half_rtol=1e-2)
def _test_sparsemax_loss_positive(self, dtype, random, use_gpu):
"""check sparsemax-loss proposition 4"""
@@ -170,10 +170,7 @@ class SparsemaxLossTest(test.TestCase):
with self.test_session(use_gpu=use_gpu):
err = gradient_checker.compute_gradient_error(
- logits, z.shape,
- loss_op, (test_obs, ),
- x_init_value=z, delta=1e-9
- )
+ logits, z.shape, loss_op, (test_obs,), x_init_value=z, delta=1e-9)
self.assertLess(err, 1e-4)
@@ -192,8 +189,8 @@ class SparsemaxLossTest(test.TestCase):
tf_grad = loss_grad_op.eval()
np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype)
- self.assertAllCloseAccordingToType(np_grad, tf_grad,
- half_atol=1e-2, half_rtol=5e-3)
+ self.assertAllCloseAccordingToType(
+ np_grad, tf_grad, half_atol=1e-2, half_rtol=5e-3)
self.assertShapeEqual(np_grad, loss_grad_op)
def _test_dtype(self, dtype):
@@ -220,5 +217,6 @@ class SparsemaxLossTest(test.TestCase):
def testDouble(self):
self._test_dtype('float64')
-if __name__ == "__main__":
+
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
index 82d36ee9cb..259e62bd86 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -83,8 +83,8 @@ class SparsemaxTest(test.TestCase):
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
p_sparemax = self._np_sparsemax(z).astype(dtype)
- self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out,
- half_atol=5e-3)
+ self.assertAllCloseAccordingToType(
+ p_sparemax, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
@@ -111,9 +111,8 @@ class SparsemaxTest(test.TestCase):
p_expected = np.zeros((test_obs, 10), dtype=dtype)
p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1
- tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
- (1 / epsilon) * z, dtype, use_gpu
- )
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax((1 / epsilon) * z,
+ dtype, use_gpu)
self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out)
self.assertShapeEqual(p_expected, tf_sparsemax_op)
@@ -123,16 +122,12 @@ class SparsemaxTest(test.TestCase):
z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype)
- _, tf_sparsemax_zpc = self._tf_sparsemax(
- z + c, dtype, use_gpu
- )
+ _, tf_sparsemax_zpc = self._tf_sparsemax(z + c, dtype, use_gpu)
- _, tf_sparsemax_z = self._tf_sparsemax(
- z, dtype, use_gpu
- )
+ _, tf_sparsemax_z = self._tf_sparsemax(z, dtype, use_gpu)
- self.assertAllCloseAccordingToType(tf_sparsemax_zpc, tf_sparsemax_z,
- half_atol=5e-3)
+ self.assertAllCloseAccordingToType(
+ tf_sparsemax_zpc, tf_sparsemax_z, half_atol=5e-3)
def _test_permutation(self, dtype, random, use_gpu):
"""check sparsemax proposition 3"""
@@ -143,12 +138,11 @@ class SparsemaxTest(test.TestCase):
per = random.permutation(10)
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
- z[i, per].reshape(1, -1), dtype, use_gpu
- )
+ z[i, per].reshape(1, -1), dtype, use_gpu)
p_expected = p[i, per].reshape(1, -1)
- self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out,
- half_atol=5e-3)
+ self.assertAllCloseAccordingToType(
+ p_expected, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_expected, tf_sparsemax_op)
def _test_diffrence(self, dtype, random, use_gpu):
@@ -166,18 +160,14 @@ class SparsemaxTest(test.TestCase):
continue
self.assertTrue(
- 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
- "0 <= %.10f <= %.10f" % (
- p[val, j] - p[val, i], z[val, j] - z[val, i] + etol
- )
- )
+ 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
+ '0 <= %.10f <= %.10f' % (p[val, j] - p[val, i],
+ z[val, j] - z[val, i] + etol))
def _test_two_dimentional(self, dtype, random, use_gpu):
"""check two dimentation sparsemax case"""
t = np.linspace(-2, 2, test_obs, dtype=dtype)
- z = np.vstack([
- t, np.zeros(test_obs, dtype=dtype)
- ]).T
+ z = np.vstack([t, np.zeros(test_obs, dtype=dtype)]).T
tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
@@ -196,10 +186,7 @@ class SparsemaxTest(test.TestCase):
with self.test_session(use_gpu=use_gpu):
err = gradient_checker.compute_gradient_error(
- logits, z.shape,
- sparsemax_op, z.shape,
- x_init_value=z, delta=1e-9
- )
+ logits, z.shape, sparsemax_op, z.shape, x_init_value=z, delta=1e-9)
self.assertLess(err, 1e-4)
@@ -248,5 +235,6 @@ class SparsemaxTest(test.TestCase):
def testDouble(self):
self._test_dtype('float64')
-if __name__ == "__main__":
+
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
index f938d08c84..02c0fc687f 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
@@ -316,7 +316,7 @@ class DenseClassificationGrowStats : public ClassificationStats {
void PackToProto(FertileSlot* slot) const override;
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
- LeafStat* right_stats) const;
+ LeafStat* right_stats) const override;
protected:
void ClassificationAddSplitStats() override {
@@ -383,7 +383,7 @@ class SparseClassificationGrowStats : public ClassificationStats {
void PackToProto(FertileSlot* slot) const override;
void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
- LeafStat* right_stats) const;
+ LeafStat* right_stats) const override;
protected:
void ClassificationAddSplitStats() override {
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
index 14cb19d36f..bf0fb92450 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
@@ -21,8 +21,6 @@ namespace tensorflow {
namespace tensorforest {
namespace {
-const int32 SPARSE_DEFAULT = 0;
-
bool DecideInequalityTest(const decision_trees::InequalityTest& test,
float value) {
float bias = test.threshold().float_value();
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
new file mode 100644
index 0000000000..28f571e1f0
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -0,0 +1,45 @@
+# Description:
+# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow.
+# APIs are meant to change over time.
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load(
+ "@local_config_tensorrt//:build_defs.bzl",
+ "if_tensorrt",
+)
+
+tf_cuda_cc_test(
+ name = "tensorrt_test_cc",
+ size = "small",
+ srcs = ["tensorrt_test.cc"],
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ] + if_tensorrt([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/tensorrt/tensorrt_test.cc b/tensorflow/contrib/tensorrt/tensorrt_test.cc
new file mode 100644
index 0000000000..e11522ea5b
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/tensorrt_test.cc
@@ -0,0 +1,159 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "cuda/include/cuda.h"
+#include "cuda/include/cuda_runtime_api.h"
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace {
+
+class Logger : public nvinfer1::ILogger {
+ public:
+ void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
+ switch (severity) {
+ case Severity::kINFO:
+ LOG(INFO) << msg;
+ break;
+ case Severity::kWARNING:
+ LOG(WARNING) << msg;
+ break;
+ case Severity::kINTERNAL_ERROR:
+ case Severity::kERROR:
+ LOG(ERROR) << msg;
+ break;
+ default:
+ break;
+ }
+ }
+};
+
+class ScopedWeights {
+ public:
+ ScopedWeights(float value) : value_(value) {
+ w.type = nvinfer1::DataType::kFLOAT;
+ w.values = &value_;
+ w.count = 1;
+ }
+ const nvinfer1::Weights& get() { return w; }
+
+ private:
+ float value_;
+ nvinfer1::Weights w;
+};
+
+const char* kInputTensor = "input";
+const char* kOutputTensor = "output";
+
+// Creates a network to compute y=2x+3.
+nvinfer1::IHostMemory* CreateNetwork() {
+ Logger logger;
+ nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
+ ScopedWeights weights(2.0);
+ ScopedWeights bias(3.0);
+
+ nvinfer1::INetworkDefinition* network = builder->createNetwork();
+ // Add the input.
+ auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
+ nvinfer1::DimsCHW{1, 1, 1});
+ EXPECT_NE(input, nullptr);
+ // Add the hidden layer.
+ auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
+ EXPECT_NE(layer, nullptr);
+ // Mark the output.
+ auto output = layer->getOutput(0);
+ output->setName(kOutputTensor);
+ network->markOutput(*output);
+ // Build the engine
+ builder->setMaxBatchSize(1);
+ builder->setMaxWorkspaceSize(1 << 10);
+ auto engine = builder->buildCudaEngine(*network);
+ EXPECT_NE(engine, nullptr);
+ // Serialize the engine to create a model, then close everything.
+ nvinfer1::IHostMemory* model = engine->serialize();
+ network->destroy();
+ engine->destroy();
+ builder->destroy();
+ return model;
+}
+
+// Executes the network.
+void Execute(nvinfer1::IExecutionContext& context, const float* input,
+ float* output) {
+ const nvinfer1::ICudaEngine& engine = context.getEngine();
+
+ // We have two bindings: input and output.
+ ASSERT_EQ(engine.getNbBindings(), 2);
+ const int input_index = engine.getBindingIndex(kInputTensor);
+ const int output_index = engine.getBindingIndex(kOutputTensor);
+
+ // Create GPU buffers and a stream
+ void* buffers[2];
+ ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float)));
+ ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float)));
+ cudaStream_t stream;
+ ASSERT_EQ(0, cudaStreamCreate(&stream));
+
+ // Copy the input to the GPU, execute the network, and copy the output back.
+ //
+ // Note that since the host buffer was not created as pinned memory, these
+ // async copies are turned into sync copies. So the following synchronization
+ // could be removed.
+ ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
+ cudaMemcpyHostToDevice, stream));
+ context.enqueue(1, buffers, stream, nullptr);
+ ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
+ cudaMemcpyDeviceToHost, stream));
+ cudaStreamSynchronize(stream);
+
+ // Release the stream and the buffers
+ cudaStreamDestroy(stream);
+ ASSERT_EQ(0, cudaFree(buffers[input_index]));
+ ASSERT_EQ(0, cudaFree(buffers[output_index]));
+}
+
+TEST(TensorrtTest, BasicFunctions) {
+ // Create the network model.
+ nvinfer1::IHostMemory* model = CreateNetwork();
+ // Use the model to create an engine and then an execution context.
+ Logger logger;
+ nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
+ nvinfer1::ICudaEngine* engine =
+ runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
+ model->destroy();
+ nvinfer1::IExecutionContext* context = engine->createExecutionContext();
+
+ // Execute the network.
+ float input = 1234;
+ float output;
+ Execute(*context, &input, &output);
+ EXPECT_EQ(output, input * 2 + 3);
+
+ // Destroy the engine.
+ context->destroy();
+ engine->destroy();
+ runtime->destroy();
+}
+
+} // namespace
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 1cded9f8cf..7373d0e17c 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -47,12 +47,14 @@ string GetCurrentTimeStampAsString() {
return s;
}
-ProfileResponse Profile(const string& service_addr, int duration_ms) {
+ProfileResponse Profile(const string& service_addr, int duration_ms,
+ const ProfileOptions& opts) {
ProfileRequest request;
request.set_duration_ms(duration_ms);
request.set_max_events(kMaxEvents);
request.add_tools("input_pipeline");
request.add_tools("overview_page");
+ *request.mutable_opts() = opts;
std::cout << "Limiting the number of trace events to " << kMaxEvents
<< std::endl;
::grpc::ClientContext context;
@@ -76,6 +78,7 @@ int main(int argc, char** argv) {
tensorflow::string FLAGS_service_addr;
tensorflow::string FLAGS_logdir;
int FLAGS_duration_ms = 2000;
+ bool FLAGS_include_dataset_ops = true;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("service_addr", &FLAGS_service_addr,
"Address of TPU profiler service e.g. localhost:8466"),
@@ -83,6 +86,8 @@ int main(int argc, char** argv) {
"Path of TensorBoard log directory e.g. /tmp/tb_log"),
tensorflow::Flag("duration_ms", &FLAGS_duration_ms,
"Duration of tracing in ms. Default is 2000ms."),
+ tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops,
+ "Set to false to profile longer TPU device traces."),
};
std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION
@@ -97,8 +102,10 @@ int main(int argc, char** argv) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
int duration_ms = FLAGS_duration_ms;
+ tensorflow::ProfileOptions opts;
+ opts.set_include_dataset_ops(FLAGS_include_dataset_ops);
tensorflow::ProfileResponse response =
- tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms);
+ tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms, opts);
// Use the current timestamp as the run name.
tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString();
TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index 5440bbbfdd..2094294baa 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -61,6 +61,11 @@ message OpMetricsResult {
message OpMetricsDbResult {
// A bunch of OpMetricsResults.
repeated OpMetricsResult metrics_db = 1;
+ // The total host infeed-enqueue duration in picoseconds.
+ optional uint64 total_host_infeed_enq_duration_ps = 2;
+ // The total of the difference between the start times of two
+ // consecutive infeed-enqueues (per host) in picoseconds.
+ optional uint64 total_host_infeed_enq_start_timestamp_ps_diff = 3;
}
// Result proto for StepInfo.
diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
index bf30d2ce09..f3f3302ceb 100644
--- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
+++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
@@ -13,6 +13,14 @@ service TPUProfiler {
}
}
+message ProfileOptions {
+ // We don't collect the dataset ops by default for better trace-viewer
+ // scalability. The caller can mannually set this field to include the ops.
+ bool include_dataset_ops = 1;
+
+ // next-field: 2
+}
+
message ProfileRequest {
// In future, the caller will be able to customize when profiling starts and
// stops. For now, it collects `duration_ms` milliseconds worth of data.
@@ -25,10 +33,13 @@ message ProfileRequest {
// required profiling tools name such as "input_pipeline_analyzer" etc
repeated string tools = 3;
+ // Optional profiling options that control how a TF session will be profiled.
+ ProfileOptions opts = 4;
+
// In future, the caller will indicate which TF session is being profiled, and
// only data relating to that program will be returned. For now, we assume
// all activity during the profiling period is relevant.
- // next-field: 4
+ // next-field: 5
}
message ProfileToolData {
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index a49a3dcf29..1c970655d0 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -47,7 +47,7 @@ if platform.system() != "Windows":
# types are supported.
_SUPPORTED_INFEED_DTYPES = set([
- dtypes.int32, dtypes.bfloat16, dtypes.float32
+ dtypes.bool, dtypes.int32, dtypes.bfloat16, dtypes.float32
])
def infeed_dequeue(dtype, shape, name=None):
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index 2a0ef0e6b3..dbdbb08a82 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -53,7 +53,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
sp_tensor1 = sparse_tensor.SparseTensor(
array_ops.constant(ind1, dtypes.int64),
array_ops.constant(val1, dtypes.int64),
- array_ops.constant(shape1, dtypes.int64))
+ array_ops.placeholder_with_default(shape1, shape=[2]))
ind2 = np.array([
[0, 0, 1],
[0, 1, 0],
@@ -68,7 +68,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
sp_tensor2 = sparse_tensor.SparseTensor(
array_ops.constant(ind2, dtypes.int64),
array_ops.constant(val2, dtypes.int64),
- array_ops.constant(shape2, dtypes.int64))
+ array_ops.placeholder_with_default(shape2, shape=[3]))
sp_tensor3 = sparse_tensor.SparseTensor(
array_ops.constant([[1, 9], [2, 2], [2, 10]], dtypes.int64),
array_ops.constant([7, 15, 2], dtypes.int64),
@@ -320,6 +320,18 @@ class BatchSequencesWithStatesTest(test.TestCase):
def testNotAMultiple(self):
num_unroll = 3 # Not a divisor of value_length -
# so padding would have been necessary.
+
+ # Use placeholder_with_default in sequences to make sure we get runtime
+ # error instead of shape inference error
+ sequences = {
+ "seq1": array_ops.placeholder_with_default(self.sequences["seq1"],
+ shape=(None, 5)),
+ "seq2": array_ops.placeholder_with_default(self.sequences["seq2"],
+ shape=(None, 4, 2)),
+ "seq3": self.sequences["seq3"],
+ "seq4": self.sequences["seq4"],
+ }
+
with self.test_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
".*should be a multiple of: 3, but saw "
@@ -330,7 +342,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
with coord.stop_on_exception():
next_batch = sqss.batch_sequences_with_states(
input_key=self.key,
- input_sequences=self.sequences,
+ input_sequences=sequences,
input_context=self.context,
input_length=3,
initial_states=self.initial_states,
@@ -493,6 +505,18 @@ class BatchSequencesWithStatesTest(test.TestCase):
expected_seq4_batch2=expected_seq4_batch2)
+class BatchSequencesWithStatesTestWithCApi(BatchSequencesWithStatesTest):
+
+ def setUp(self):
+ self._prev_value = ops._USE_C_API
+ ops._USE_C_API = True
+ super(BatchSequencesWithStatesTestWithCApi, self).setUp()
+
+ def tearDown(self):
+ super(BatchSequencesWithStatesTestWithCApi, self).tearDown()
+ ops._USE_C_API = self._prev_value
+
+
class PaddingTest(test.TestCase):
def testPaddingInvalidLengths(self):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 94973a0e52..29c515121e 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1896,6 +1896,13 @@ cc_library(
],
)
+tf_cuda_library(
+ name = "cuda_device_functions",
+ hdrs = ["util/cuda_device_functions.h"],
+ visibility = ["//visibility:public"],
+ deps = [":framework_lite"],
+)
+
# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
cc_library(
name = "protos_cc",
diff --git a/tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt
index a72f2bfe5f..118d0e2178 100644
--- a/tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_FusedResizeAndPadConv2D.pbtxt
@@ -30,9 +30,8 @@ END
attr {
name: "resize_align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1),
-which exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
attr {
diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt
index 6b3ba72e53..a08ed710b7 100644
--- a/tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_QuantizedResizeBilinear.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize quantized `images` to `size` using quantized bilinear interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt
index 6dc321a544..317ad263cc 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeArea.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize `images` to `size` using area interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt
index 06e645e3ee..d4f8233d25 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeBicubic.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize `images` to `size` using bicubic interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt
index bf5201d82e..eeb0680ab8 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeBicubicGrad.pbtxt
@@ -25,9 +25,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale grads by (orig_height - 1) / (height - 1), which
-exactly aligns the 4 corners of grads and original_image. If false, rescale by
-orig_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and grad tensors are
+aligned. Defaults to false.
END
}
summary: "Computes the gradient of bicubic interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt
index 0768e437fa..0673baa703 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeBilinear.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize `images` to `size` using bilinear interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt
index fba64203c2..9a1a5fb69a 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeBilinearGrad.pbtxt
@@ -25,9 +25,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale grads by (orig_height - 1) / (height - 1), which
-exactly aligns the 4 corners of grads and original_image. If false, rescale by
-orig_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and grad tensors are
+aligned. Defaults to false.
END
}
summary: "Computes the gradient of bilinear interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt
index a74db4c9dc..e6f8dc1941 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighbor.pbtxt
@@ -23,9 +23,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale input by (new_height - 1) / (height - 1), which
-exactly aligns the 4 corners of images and resized images. If false, rescale
-by new_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and output tensors are
+aligned, preserving the values at the corner pixels. Defaults to false.
END
}
summary: "Resize `images` to `size` using nearest neighbor interpolation."
diff --git a/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt
index 4ef1547eb4..8d52ca8334 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResizeNearestNeighborGrad.pbtxt
@@ -24,9 +24,8 @@ END
attr {
name: "align_corners"
description: <<END
-If true, rescale grads by (orig_height - 1) / (height - 1), which
-exactly aligns the 4 corners of grads and original_image. If false, rescale by
-orig_height / height. Treat similarly the width dimension.
+If true, the centers of the 4 corner pixels of the input and grad tensors are
+aligned. Defaults to false.
END
}
summary: "Computes the gradient of nearest neighbor interpolation."
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
index e78b6ab5d9..870bbb141b 100644
--- a/tensorflow/core/framework/op_gen_lib.cc
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -266,35 +266,6 @@ static void StringReplace(const string& from, const string& to, string* s) {
*s = str_util::Join(split, to.c_str());
}
-static void RenameInDocs(const string& from, const string& to, OpDef* op_def) {
- const string from_quoted = strings::StrCat("`", from, "`");
- const string to_quoted = strings::StrCat("`", to, "`");
- for (int i = 0; i < op_def->input_arg_size(); ++i) {
- if (!op_def->input_arg(i).description().empty()) {
- StringReplace(from_quoted, to_quoted,
- op_def->mutable_input_arg(i)->mutable_description());
- }
- }
- for (int i = 0; i < op_def->output_arg_size(); ++i) {
- if (!op_def->output_arg(i).description().empty()) {
- StringReplace(from_quoted, to_quoted,
- op_def->mutable_output_arg(i)->mutable_description());
- }
- }
- for (int i = 0; i < op_def->attr_size(); ++i) {
- if (!op_def->attr(i).description().empty()) {
- StringReplace(from_quoted, to_quoted,
- op_def->mutable_attr(i)->mutable_description());
- }
- }
- if (!op_def->summary().empty()) {
- StringReplace(from_quoted, to_quoted, op_def->mutable_summary());
- }
- if (!op_def->description().empty()) {
- StringReplace(from_quoted, to_quoted, op_def->mutable_description());
- }
-}
-
static void RenameInDocs(const string& from, const string& to,
ApiDef* api_def) {
const string from_quoted = strings::StrCat("`", from, "`");
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index aee3a0afbc..16bf5c256f 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -943,13 +943,6 @@ Status FindKernelRegistration(const DeviceType& device_type,
return Status::OK();
}
-Status FindKernelRegistration(const DeviceType& device_type, const Node& node,
- const KernelRegistration** reg,
- bool* was_attr_mismatch) {
- return FindKernelRegistration(device_type, node.def(), reg,
- was_attr_mismatch);
-}
-
} // namespace
// TODO(irving): Change const NodeDef& to const Node&
diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc
index b1e6cf64e8..4118f14f8b 100644
--- a/tensorflow/core/graph/costmodel.cc
+++ b/tensorflow/core/graph/costmodel.cc
@@ -57,10 +57,10 @@ void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) {
const int local_id = cm.Id(n);
const int global_id = Id(n);
if (local_id < 0 || global_id < 0) continue;
- Ensure(global_id);
+ int num_slots = cm.slot_bytes_[local_id].size();
+ Ensure(global_id, num_slots);
count_[global_id] += cm.count_[local_id];
time_[global_id] += cm.time_[local_id];
- int num_slots = cm.slot_bytes_[local_id].size();
if (num_slots > 0) {
if (slot_bytes_[global_id].empty()) {
slot_bytes_[global_id].resize(num_slots);
@@ -78,11 +78,11 @@ void CostModel::MergeFromGlobal(const CostModel& cm) {
CHECK(is_global_);
CHECK_EQ(true, cm.is_global());
const int num_nodes = cm.count_.size();
- Ensure(num_nodes);
- for (int i = 0; i < num_nodes; ++i) {
+ for (int i = num_nodes - 1; i >= 0; --i) {
count_[i] += cm.count_[i];
time_[i] += cm.time_[i];
int num_slots = cm.slot_bytes_[i].size();
+ Ensure(i, num_slots);
if (num_slots > 0) {
if (slot_bytes_[i].empty()) {
slot_bytes_[i].resize(num_slots);
@@ -106,7 +106,7 @@ void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
// copy/send/recv nodes, feed/fetch, etc.
if (iter == map.end()) continue;
int32 global_id = iter->second;
- Ensure(global_id);
+ Ensure(global_id, ns.output_size());
int64 elapsed_micros = ns.op_end_rel_micros() - ns.op_start_rel_micros();
count_[global_id]++;
time_[global_id] += elapsed_micros;
@@ -122,7 +122,7 @@ void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
}
}
-void CostModel::Ensure(int id) {
+void CostModel::Ensure(int id, int num_outputs) {
if (slot_bytes_.size() <= static_cast<size_t>(id)) {
slot_bytes_.resize(id + 1);
count_.resize(id + 1);
@@ -131,25 +131,37 @@ void CostModel::Ensure(int id) {
max_exec_time_.resize(id + 1);
output_port_alloc_ids_.resize(id + 1);
}
+ if (num_outputs > 0) {
+ auto perslot = &slot_bytes_[id];
+ auto output_port_alloc_ids = &output_port_alloc_ids_[id];
+ auto max_mem_usage = &max_mem_usage_[id];
+
+ CHECK_LE(perslot->size(), num_outputs);
+ DCHECK_EQ(output_port_alloc_ids->size(), perslot->size());
+ DCHECK_EQ(max_mem_usage->output_port_mem.size(), perslot->size());
+ DCHECK_EQ(max_mem_usage->output_port_shape.size(), perslot->size());
+ DCHECK_EQ(max_mem_usage->output_port_type.size(), perslot->size());
+
+ perslot->resize(num_outputs, Bytes(-1));
+ output_port_alloc_ids->resize(num_outputs, -1);
+ max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1));
+ max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_);
+ max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID);
+ }
}
void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
const int id = Id(node);
if (id < 0) return;
- Ensure(id);
+ // Do not resize the number of slots before checking its existing number of
+ // slots.
+ Ensure(id, 0);
auto perslot = &slot_bytes_[id];
- auto max_mem_usage = &max_mem_usage_[id];
- auto output_port_alloc_ids = &output_port_alloc_ids_[id];
if (!perslot->empty()) {
CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node="
<< node->name();
- } else {
- perslot->resize(num_outputs, Bytes(-1));
- output_port_alloc_ids->resize(num_outputs, -1);
- max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1));
- max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_);
- max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID);
}
+ Ensure(id, num_outputs);
}
void CostModel::RecordCount(const Node* node, int count) {
@@ -198,7 +210,7 @@ void CostModel::RecordTime(const Node* node, Microseconds time) {
const int id = Id(node);
if (id < 0) return;
DCHECK(node->IsOp()) << node->DebugString();
- Ensure(id);
+ Ensure(id, node->num_outputs());
time_[id] += time;
}
@@ -240,7 +252,10 @@ void CostModel::RecordMaxMemorySize(const Node* node, int output_slot,
const DataType& dtype) {
const int id = Id(node);
if (id < 0) return;
- Ensure(id);
+ CHECK_LT(output_slot, node->num_outputs())
+ << "Unexpected output slot for node " << node->DebugString() << ". Got "
+ << output_slot << " but its num_outputs is " << node->num_outputs();
+ Ensure(id, node->num_outputs());
auto& current_max = max_mem_usage_[id].output_port_mem[output_slot];
// If the memory allocator doesn't track memory usage, let's infer a lower
// bound from the tensor shape and its data type.
@@ -316,7 +331,7 @@ void CostModel::RecordMemoryStats(const Node* node,
void CostModel::RecordMaxExecutionTime(const Node* node, Microseconds time) {
const int id = Id(node);
if (id < 0) return;
- Ensure(id);
+ Ensure(id, node->num_outputs());
max_exec_time_[id] = std::max(max_exec_time_[id], time);
}
@@ -332,7 +347,7 @@ void CostModel::RecordAllocationId(const Node* node, int output_slot,
int64 alloc_id) {
const int id = Id(node);
if (id < 0) return;
- Ensure(id);
+ Ensure(id, node->num_outputs());
output_port_alloc_ids_[id][output_slot] = alloc_id;
}
diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h
index 081eb2ff4c..c60a946c2c 100644
--- a/tensorflow/core/graph/costmodel.h
+++ b/tensorflow/core/graph/costmodel.h
@@ -183,8 +183,8 @@ class CostModel {
const bool is_global_;
- // Resizes vectors so that they are large enough for "id".
- void Ensure(int id);
+ // Resizes vectors so that they are large enough for "id" and id's outputs.
+ void Ensure(int id, int num_outputs);
// Nodes and Edges whose count is < this value
// get type/byte estimates of 0.
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
index 852e69737b..b7eaf8dc63 100644
--- a/tensorflow/core/grappler/costs/cost_estimator.h
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -85,10 +85,7 @@ struct Costs {
typedef NanoSeconds Duration;
// Overall cost of running the graph; latency.
- // Mean
Duration execution_time;
- Duration min_execution_time;
- Duration max_execution_time;
// Computation cost of running the graph.
Duration compute_time;
diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
index 8fd1801863..ea4320687a 100644
--- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc
@@ -117,8 +117,6 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
LOG(ERROR) << "Failed to measure graph performance: "
<< status.error_message();
costs->execution_time = Costs::Duration::max();
- costs->max_execution_time = Costs::Duration::max();
- costs->min_execution_time = 0;
return status;
}
@@ -126,8 +124,6 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph,
// to filter out outliers.
RobustStats stats(times);
costs->execution_time = Costs::Duration(stats.mean());
- costs->max_execution_time = Costs::Duration(stats.hi());
- costs->min_execution_time = Costs::Duration(stats.lo());
return Status::OK();
}
diff --git a/tensorflow/core/grappler/costs/op_performance_data.proto b/tensorflow/core/grappler/costs/op_performance_data.proto
index 1d623b8db8..37f9ebd6a1 100644
--- a/tensorflow/core/grappler/costs/op_performance_data.proto
+++ b/tensorflow/core/grappler/costs/op_performance_data.proto
@@ -58,11 +58,18 @@ message LogNormalDistribution {
double sigma = 2;
}
+message SessionInfo {
+ int64 intra_op_parallelism = 1;
+}
+
// Performance data for tensorflow operations
message OpPerformance {
// The op
OpInfo op = 1;
+ // Information about the session configs.
+ SessionInfo session_info = 12;
+
// The node name (optional). Makes it easier to associate the performance data
// with a specific graph node.
string node = 5;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 8ccc51f545..9db6d46266 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -139,8 +139,8 @@ class FIFOManager : public ReadyNodeManager {
public:
FIFOManager() : ReadyNodeManager() {}
~FIFOManager() override {}
- virtual void Init(
- const std::unordered_map<const NodeDef*, NodeState>* node_state) {}
+ void Init(const std::unordered_map<const NodeDef*, NodeState>* node_state)
+ override {}
void AddNode(const NodeDef* node) override { nodes_.push_back(node); }
const NodeDef* GetCurrNode() override {
CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node";
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index 149f6fc735..2f8549cf39 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -134,6 +134,7 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
const NodeDef* node = name_to_node[NodeName(root)];
if (!node) {
*ill_formed = true;
+ VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
return {};
}
queue.push_back(node);
@@ -153,6 +154,7 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
for (const string& input : node->input()) {
const NodeDef* in = name_to_node[NodeName(input)];
if (!in) {
+ VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
*ill_formed = true;
return {};
}
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 50e6ba4a64..735d78e7ee 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -2076,6 +2076,7 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item,
const TuningConfig& config, GraphDef* output) {
auto status = graph_properties.AnnotateOutputShapes(output);
if (!status.ok()) {
+ VLOG(1) << "Annotate shape return status: " << status.ToString();
*output = item.graph;
return status;
}
@@ -2100,6 +2101,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphProperties graph_properties(item);
auto status = graph_properties.InferStatically(false);
if (!status.ok()) {
+ VLOG(1) << "Infer shape return status: " << status.ToString();
*output = item.graph;
return status;
}
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index 42f3db1d79..2ca194a77f 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -173,19 +173,13 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
// Accumulate the results in the shared memory into the first element.
// No syncthreads is needed since this is only in the same warp.
int32 thread_index = threadIdx.x;
- if (thread_index < 16) {
- s_data[thread_index] += s_data[thread_index + 16];
- __syncwarp(0xFFFF);
- if (thread_index < 8) s_data[thread_index] += s_data[thread_index + 8];
- __syncwarp(0xFF);
- if (thread_index < 4) s_data[thread_index] += s_data[thread_index + 4];
- __syncwarp(0xF);
- if (thread_index < 2) s_data[thread_index] += s_data[thread_index + 2];
- __syncwarp(0x3);
+ if (thread_index < 32) {
+ AccT data = s_data[thread_index];
+ for (int32 delta = warpSize / 2; delta > 0; delta /= 2) {
+ data += CudaShuffleXorSync(kCudaWarpAll, data, delta);
+ }
if (thread_index == 0) {
- T val = T(s_data[0] + s_data[1]);
- // The first thread writes out the accumulated result to global location.
- CudaAtomicAdd(bias_backprop + bias_index, val);
+ CudaAtomicAdd(bias_backprop + bias_index, T(data));
}
}
}
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 903aac5d68..5493e33532 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -34,6 +34,7 @@ limitations under the License.
namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
using Eigen::GpuDevice;
// Returns whether depthwise convolution forward or backward input pass can be
@@ -1028,7 +1029,7 @@ __device__ __forceinline__ T WarpSumReduce(T val) {
int zeros = sub_warp * kWidth;
unsigned mask = ((1UL << kWidth) - 1) << zeros;
for (int delta = kWidth / 2; delta > 0; delta /= 2) {
- val += CudaShuffleXor(mask, val, delta);
+ val += CudaShuffleXorSync(mask, val, delta);
}
return val;
}
@@ -1145,7 +1146,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
- unsigned active_threads = CudaBallot(CUDA_WARP_ALL, depth_in_range);
+ unsigned active_threads = CudaBallotSync(kCudaWarpAll, depth_in_range);
if (depth_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1159,7 +1160,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16; delta >= kBlockSlices; delta /= 2) {
- val += CudaShuffleDown(active_threads, val, delta);
+ val += CudaShuffleXorSync(active_threads, val, delta);
}
if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) {
*accum_ptr = val;
@@ -1399,7 +1400,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
// Note: the condition to reach this is uniform across the entire block.
__syncthreads();
- unsigned active_threads = CudaBallot(CUDA_WARP_ALL, slice_in_range);
+ unsigned active_threads = CudaBallotSync(kCudaWarpAll, slice_in_range);
if (slice_in_range) {
const T* const out_ptr = inout_offset + output;
@@ -1413,10 +1414,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall(
T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset];
// Warp-accumulate pixels of the same depth and write to accumulator.
for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) {
- val += CudaShuffleDown(active_threads, val, delta);
+ val += CudaShuffleXorSync(active_threads, val, delta);
}
if (!(thread_idx & 32 / kBlockSlices - 1)) {
- *accum_ptr = val;
+ *accum_ptr = val; // kBlockSlices threads per warp.
}
++shared_offset;
accum_ptr += accum_increment;
diff --git a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
index 31f74671ca..a3c21edc15 100644
--- a/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_gpu.cu.cc
@@ -55,6 +55,27 @@ struct LeftUpdate<T, scatter_nd_op::UpdateOp::SUB> {
}
};
+// Specializations for std::complex, updating real and imaginary part
+// individually. Even though this is not an atomic op anymore, it is safe
+// because there is only one type of op per kernel.
+template <typename T>
+struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
+ std::complex<T>* out, const std::complex<T>& val) {
+ T* ptr = reinterpret_cast<T*>(out);
+ CudaAtomicAdd(ptr, val.real());
+ CudaAtomicAdd(ptr, val.imag());
+ }
+};
+
+template <typename T>
+struct LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::SUB> {
+ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void operator()(
+ std::complex<T>* out, const std::complex<T>& val) {
+ LeftUpdate<std::complex<T>, scatter_nd_op::UpdateOp::ADD>()(out, -val);
+ }
+};
+
} // namespace
template <typename T, typename Index, scatter_nd_op::UpdateOp op, int IXDIM>
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index 41cbece1d6..d317a8d33d 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -42,11 +42,16 @@ class CreateSummaryFileWriterOp : public OpKernel {
const int32 flush_millis = tmp->scalar<int32>()();
OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
const string filename_suffix = tmp->scalar<string>()();
- SummaryWriterInterface* s;
- OP_REQUIRES_OK(ctx,
- CreateSummaryFileWriter(max_queue, flush_millis, logdir,
- filename_suffix, ctx->env(), &s));
- OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
+
+ SummaryWriterInterface* s = nullptr;
+ OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
+ ctx, HandleFromInput(ctx, 0), &s,
+ [max_queue, flush_millis, logdir, filename_suffix,
+ ctx](SummaryWriterInterface** s) {
+ return CreateSummaryFileWriter(
+ max_queue, flush_millis, logdir,
+ filename_suffix, ctx->env(), s);
+ }));
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
@@ -66,17 +71,23 @@ class CreateSummaryDbWriterOp : public OpKernel {
const string run_name = tmp->scalar<string>()();
OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
const string user_name = tmp->scalar<string>()();
- SummaryWriterInterface* s;
- Sqlite* db;
- OP_REQUIRES_OK(ctx, Sqlite::Open(db_uri,
- SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
- &db));
- core::ScopedUnref unref(db);
- OP_REQUIRES_OK(ctx, SetupTensorboardSqliteDb(db));
+
+ SummaryWriterInterface* s = nullptr;
OP_REQUIRES_OK(
- ctx, CreateSummaryDbWriter(db, experiment_name,
- run_name, user_name, ctx->env(), &s));
- OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
+ ctx,
+ LookupOrCreateResource<SummaryWriterInterface>(
+ ctx, HandleFromInput(ctx, 0), &s,
+ [db_uri, experiment_name, run_name, user_name,
+ ctx](SummaryWriterInterface** s) {
+ Sqlite* db;
+ TF_RETURN_IF_ERROR(Sqlite::Open(
+ db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db));
+ core::ScopedUnref unref(db);
+ TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db));
+ TF_RETURN_IF_ERROR(CreateSummaryDbWriter(
+ db, experiment_name, run_name, user_name, ctx->env(), s));
+ return Status::OK();
+ }));
}
};
REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
@@ -267,8 +278,6 @@ class WriteAudioSummaryOp : public OpKernel {
private:
int max_outputs_;
- bool has_sample_rate_attr_;
- float sample_rate_attr_;
};
REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU),
WriteAudioSummaryOp);
diff --git a/tensorflow/core/kernels/svd_op_gpu.cu.cc b/tensorflow/core/kernels/svd_op_gpu.cu.cc
index dedc2da60b..8c3a58b108 100644
--- a/tensorflow/core/kernels/svd_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/svd_op_gpu.cu.cc
@@ -63,8 +63,8 @@ __global__ void ComputeValueOfVKernel(Cuda2DLaunchConfig config, int64 m,
int64 ldu, const Scalar* M,
const Scalar* U, const Scalar* S,
Scalar* V) {
- CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
+ CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
CudaAtomicAdd(V + batch, v);
}
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 12c27c7984..4f946fb3ca 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -171,29 +171,10 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
return Status::OK();
}
-Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
- ShapeHandle handle;
- DimensionHandle unused_handle;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
- for (int i = 1; i < c->num_inputs(); ++i) {
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
- }
- for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->Scalar());
- }
- return Status::OK();
-}
-
Status TwoElementOutput(InferenceContext* c) {
c->set_output(0, c->Vector(2));
return Status::OK();
}
-
-Status ScalarOutput(InferenceContext* c) {
- c->set_output(0, c->Scalar());
- return Status::OK();
-}
} // namespace
REGISTER_OP("RandomShuffleQueue")
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 7484ebb078..ef2ac267cc 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -25,42 +25,6 @@ using shape_inference::ShapeHandle;
namespace {
-const char kDecodeJpegCommonDocStr[] = R"doc(
-The attr `channels` indicates the desired number of color channels for the
-decoded image.
-
-Accepted values are:
-
-* 0: Use the number of channels in the JPEG-encoded image.
-* 1: output a grayscale image.
-* 3: output an RGB image.
-
-If needed, the JPEG-encoded image is transformed to match the requested number
-of color channels.
-
-The attr `ratio` allows downscaling the image by an integer factor during
-decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
-downscaling the image later.
-
-)doc";
-
-const char kDecodeJpegCommonParamsDocStr[] = R"doc(
-channels: Number of color channels for the decoded image.
-ratio: Downscaling ratio.
-fancy_upscaling: If true use a slower but nicer upscaling of the
- chroma planes (yuv420/422 only).
-try_recover_truncated: If true try to recover an image from truncated input.
-acceptable_fraction: The minimum required fraction of lines before a truncated
- input is accepted.
-dct_method: string specifying a hint about the algorithm used for
- decompression. Defaults to "" which maps to a system-specific
- default. Currently valid values are ["INTEGER_FAST",
- "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
- jpeg library changes to a version that does not have that specific
- option.)
-image: 3-D with shape `[height, width, channels]`..
-)doc";
-
// Sets output[0] to shape [batch_dim,height,width,channel_dim], where
// height and width come from the size_tensor.
Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index e8d03877c9..6ce9595fb6 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -22,48 +22,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-const char kAddSignCommonDocStr[] = R"doc(
-Update '*var' according to the AddSign update.
-
-m_t <- beta1 * m_{t-1} + (1 - beta1) * g
-update <- (alpha + sign_decay * sign(g) *sign(m)) * g
-variable <- variable - lr_t * update
-
-var: Should be from a Variable().
-m: Should be from a Variable().
-lr: Scaling factor. Must be a scalar.
-sign_decay: Must be a scalar.
-alpha: Must be a scalar.
-beta: Must be a scalar.
-grad: The gradient.
-)doc";
-
-const char kPowerSignCommonDocStr[] = R"doc(
-Update '*var' according to the AddSign update.
-
-m_t <- beta1 * m_{t-1} + (1 - beta1) * g
-update <- exp(logbase * sign_decay * sign(g) * sign(m_t)) * g
-variable <- variable - lr_t * update
-
-var: Should be from a Variable().
-m: Should be from a Variable().
-lr: Scaling factor. Must be a scalar.
-logbase: Must be a scalar.
-sign_decay: Must be a scalar.
-beta: Must be a scalar.
-grad: The gradient.
-)doc";
-
-const char kOutDocStr[] = R"doc(
-out: Same as "var".
-)doc";
-
-const char kLockDocStr[] = R"doc(
-use_locking: If `True`, updating of the var and m tensors is
- protected by a lock; otherwise the behavior is undefined, but may exhibit less
- contention.
-)doc";
-
static ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) {
auto* handle_data = c->input_handle_shapes_and_types(input);
if (handle_data != nullptr && !handle_data->empty() &&
diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.h b/tensorflow/core/profiler/internal/tfprof_timeline.h
index 4428ab571f..651ad3f0c1 100644
--- a/tensorflow/core/profiler/internal/tfprof_timeline.h
+++ b/tensorflow/core/profiler/internal/tfprof_timeline.h
@@ -178,7 +178,6 @@ class Timeline {
int64 step_;
const string outfile_;
int64 next_pid_ = 0;
- int64 allocator_pid_ = -1;
MemoryTracker mem_tracker_;
ChromeTraceFormatter chrome_formatter_;
std::map<string, int64> device_pids_;
diff --git a/tensorflow/core/profiler/internal/tfprof_utils.cc b/tensorflow/core/profiler/internal/tfprof_utils.cc
index 2813bb46fa..7712ebd926 100644
--- a/tensorflow/core/profiler/internal/tfprof_utils.cc
+++ b/tensorflow/core/profiler/internal/tfprof_utils.cc
@@ -355,9 +355,6 @@ static const char* const kOpTypes =
static const char* const kScope =
"scope: The nodes in the model graph are organized by their names, which "
"is hierarchical like filesystem.";
-static const char* const kGraph =
- "graph: The nodes in the model graph are organized by their operation "
- "input and output.";
static const char* const kCode =
"code: When python trace is available, the nodes are python lines and "
"their are organized by the python call stack.";
diff --git a/tensorflow/core/util/cuda_device_functions.h b/tensorflow/core/util/cuda_device_functions.h
new file mode 100644
index 0000000000..f787687f66
--- /dev/null
+++ b/tensorflow/core/util/cuda_device_functions.h
@@ -0,0 +1,499 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
+#define TENSORFLOW_CORE_UTIL_CUDA_DEVICE_FUNCTIONS_H_
+
+/**
+ * Wrappers and helpers for CUDA device code.
+ *
+ * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
+ * backwards compatibility, see go/volta-porting for details.
+ * Provides atomic operations on types that aren't natively supported.
+ */
+
+#if GOOGLE_CUDA
+
+#include <algorithm>
+#include <complex>
+#include "cuda/include/cuda.h"
+#include "cuda/include/device_functions.h"
+#include "tensorflow/core/platform/types.h"
+
+#if CUDA_VERSION >= 7050
+#include "cuda/include/cuda_fp16.h"
+#endif // CUDA_VERSION >= 7050
+
+namespace tensorflow {
+
+namespace detail {
+
+// Helper for range-based for loop using 'delta' increments.
+// Usage: see CudaGridRange?() functions below.
+template <typename T>
+class CudaGridRange {
+ struct Iterator {
+ __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
+ __device__ T operator*() const { return index_; }
+ __device__ Iterator& operator++() {
+ index_ += delta_;
+ return *this;
+ }
+ __device__ bool operator!=(const Iterator& other) const {
+ bool greater = index_ > other.index_;
+ bool less = index_ < other.index_;
+ // Anything past an end iterator (delta_ == 0) is equal.
+ // In range-based for loops, this optimizes to 'return less'.
+ if (!other.delta_) {
+ return less;
+ }
+ if (!delta_) {
+ return greater;
+ }
+ return less || greater;
+ }
+
+ private:
+ T index_;
+ const T delta_;
+ };
+
+ public:
+ __device__ CudaGridRange(T begin, T delta, T end)
+ : begin_(begin), delta_(delta), end_(end) {}
+
+ __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
+ __device__ Iterator end() const { return Iterator{end_, 0}; }
+
+ private:
+ T begin_;
+ T delta_;
+ T end_;
+};
+
+} // namespace detail
+
+// Helper to visit indices in the range 0 <= i < count, using the x-coordinate
+// of the global thread index. That is, each index i is visited by all threads
+// with the same x-coordinate.
+// Usage: for(int i : CudaGridRangeX(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeX(T count) {
+ return detail::CudaGridRange<T>(blockIdx.x * blockDim.x + threadIdx.x,
+ gridDim.x * blockDim.x, count);
+}
+
+// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
+// Usage: for(int i : CudaGridRangeY(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeY(T count) {
+ return detail::CudaGridRange<T>(blockIdx.y * blockDim.y + threadIdx.y,
+ gridDim.y * blockDim.y, count);
+}
+
+// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
+// Usage: for(int i : CudaGridRangeZ(count)) { visit(i); }
+template <typename T>
+__device__ detail::CudaGridRange<T> CudaGridRangeZ(T count) {
+ return detail::CudaGridRange<T>(blockIdx.z * blockDim.z + threadIdx.z,
+ gridDim.z * blockDim.z, count);
+}
+
+// Mask for all 32 threads in a warp.
+const unsigned kCudaWarpAll = 0xffffffff;
+
+// Returns the warp lane ID of the calling thread
+__device__ inline unsigned CudaLaneId() {
+ unsigned int lane_id;
+ asm("mov.u32 %0, %%laneid;" : "=r"(lane_id));
+ return lane_id;
+}
+
+namespace detail {
+// Returns true if mask is a valid parameter for __shfl*sync to return a well
+// defined value, assuming the calling lane will read from src_lane as part of
+// the shuffle operation.
+//
+// Specifically, returns true iff mask has the calling lane bit and the src_lane
+// bit set, and the src_lane calls this function with the same mask value
+// (required for the two threads to wait for each other).
+//
+// On Volta, for some invalid masks, this function hangs or returns false
+// positives, because the implementation shuffles with the same mask that
+// we are validating. Run on Pascal if you suspect that the mask is incorrect.
+__device__ inline bool CudaValidateShuffleSyncMask(unsigned mask,
+ unsigned src_lane) {
+ unsigned src_dst_mask = 1u << CudaLaneId() | 1u << src_lane;
+#if CUDA_VERSION >= 9000
+ unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane);
+#else
+ unsigned src_lane_mask = __shfl(mask, src_lane);
+#endif
+ return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask;
+}
+
+// Returns the actual source lane for shuffle.
+__device__ inline unsigned CudaShuffleGetSrcLane(int src_lane, int width) {
+ int lane_id = CudaLaneId();
+ int lane_base = lane_id & ~width + 1;
+ int lane_offset = src_lane & width - 1;
+ return lane_base + lane_offset;
+}
+
+// Returns the source lane for shuffle up.
+__device__ inline unsigned CudaShuffleUpGetSrcLane(unsigned delta, int width) {
+ unsigned lane_id = CudaLaneId();
+ if ((lane_id & width - 1) < delta) {
+ return lane_id;
+ }
+ return lane_id - delta;
+}
+
+// Returns the source lane for shuffle down.
+__device__ inline unsigned CudaShuffleDownGetSrcLane(unsigned delta,
+ int width) {
+ unsigned lane_id = CudaLaneId();
+ if ((lane_id & width - 1) + delta >= width) {
+ return lane_id;
+ }
+ return lane_id + delta;
+}
+
+// Returns the source lane for shuffle xor.
+__device__ inline unsigned CudaShuffleXorGetSrcLane(int lane_mask, int width) {
+ int lane_id = CudaLaneId();
+ int src_lane = lane_id ^ lane_mask;
+ if (src_lane > (lane_id | width - 1)) {
+ return lane_id;
+ }
+ return src_lane;
+}
+} // namespace detail
+
+// For all *_sync wrappers below, it is illegal to synchronize threads from
+// different program locations, because that is not supported before sm_70.
+// In other words, all threads in 'mask' must call the functions in convergence.
+// Code that requires sm_70 (and CUDA 9) may use the intrinsic directly.
+//
+// It is also illegal to shuffle with a mask that produces an undefined result
+// for any of the threads. Specifically, all source threads of the shuffle
+// must have their corresponding bit in 'mask' set.
+
+// Wrapper for __syncwarp. No-op for CUDA 8 and earlier.
+__device__ inline void CudaSyncWarp(unsigned mask = kCudaWarpAll) {
+ assert(mask & 1u << CudaLaneId());
+#if CUDA_VERSION >= 9000
+ __syncwarp(mask);
+#endif
+}
+
+// Wrapper for __ballot_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+__device__ inline unsigned CudaBallotSync(unsigned mask, int pred) {
+ assert(mask & 1u << CudaLaneId());
+#if CUDA_VERSION >= 9000
+ return __ballot_sync(mask, pred);
+#else
+ return __ballot(pred) & mask; // Apply mask to match __ballot_sync's spec.
+#endif
+}
+
+// Wrapper for __any_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+__device__ inline int CudaAnySync(unsigned mask, int pred) {
+ assert(mask & 1u << CudaLaneId());
+#if CUDA_VERSION >= 9000
+ return __any_sync(mask, pred);
+#else
+ return __any(pred);
+#endif
+}
+
+// Wrapper for __all_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+__device__ inline int CudaAllSync(unsigned mask, int pred) {
+ assert(mask & 1u << CudaLaneId());
+#if CUDA_VERSION >= 9000
+ return __all_sync(mask, pred);
+#else
+ return __all(pred);
+#endif
+}
+
+// Wrapper for __shfl_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+template <typename T>
+__device__ T CudaShuffleSync(unsigned mask, T value, int src_lane,
+ int width = warpSize) {
+ assert(!(width & width - 1));
+ assert(detail::CudaValidateShuffleSyncMask(
+ mask, detail::CudaShuffleGetSrcLane(src_lane, width)));
+#if CUDA_VERSION >= 9000
+ return __shfl_sync(mask, value, src_lane, width);
+#else
+ return __shfl(value, src_lane, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleSync(unsigned mask, double value,
+ int src_lane, int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleSync(mask, hi, src_lane, width);
+ lo = CudaShuffleSync(mask, lo, src_lane, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+template <typename T>
+__device__ inline T CudaShuffleUpSync(unsigned mask, T value, unsigned delta,
+ int width = warpSize) {
+ assert(!(width & width - 1));
+ assert(detail::CudaValidateShuffleSyncMask(
+ mask, detail::CudaShuffleUpGetSrcLane(delta, width)));
+#if CUDA_VERSION >= 9000
+ return __shfl_up_sync(mask, value, delta, width);
+#else
+ return __shfl_up(value, delta, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleUpSync(unsigned mask, double value,
+ unsigned delta,
+ int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleUpSync(mask, hi, delta, width);
+ lo = CudaShuffleUpSync(mask, lo, delta, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_down_sync. All threads in 'mask' must call this function
+// in convergence, see comment above for details.
+template <typename T>
+__device__ inline T CudaShuffleDownSync(unsigned mask, T value, unsigned delta,
+ int width = warpSize) {
+ assert(!(width & width - 1));
+ assert(detail::CudaValidateShuffleSyncMask(
+ mask, detail::CudaShuffleDownGetSrcLane(delta, width)));
+#if CUDA_VERSION >= 9000
+ return __shfl_down_sync(mask, value, delta, width);
+#else
+ return __shfl_down(value, delta, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleDownSync(unsigned mask, double value,
+ unsigned delta,
+ int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleDownSync(mask, hi, delta, width);
+ lo = CudaShuffleDownSync(mask, lo, delta, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in
+// convergence, see comment above for details.
+template <typename T>
+__device__ T CudaShuffleXorSync(unsigned mask, T value, int lane_mask,
+ int width = warpSize) {
+ assert(!(width & width - 1));
+ assert(detail::CudaValidateShuffleSyncMask(
+ mask, detail::CudaShuffleXorGetSrcLane(lane_mask, width)));
+#if CUDA_VERSION >= 9000
+ return __shfl_xor_sync(mask, value, lane_mask, width);
+#else
+ return __shfl_xor(value, lane_mask, width);
+#endif
+}
+
+// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
+// instead of float for lo and hi (which is incorrect with ftz, for example).
+// See b/69446944.
+__device__ inline double CudaShuffleXorSync(unsigned mask, double value,
+ int lane_mask,
+ int width = warpSize) {
+ unsigned lo, hi;
+ asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
+ hi = CudaShuffleXorSync(mask, hi, lane_mask, width);
+ lo = CudaShuffleXorSync(mask, lo, lane_mask, width);
+ asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
+ return value;
+}
+
+// Wrapper for __ldg.
+template <typename T>
+__host__ __device__ T CudaLdg(const T* address) {
+#if __CUDA_ARCH__ >= 350
+ return __ldg(address);
+#else
+ return *address;
+#endif
+}
+
+__host__ __device__ inline bool CudaLdg(const bool* address) {
+ return CudaLdg(reinterpret_cast<const char*>(address)) != 0;
+}
+
+__host__ __device__ inline std::complex<float> CudaLdg(
+ const std::complex<float>* address) {
+#if __CUDA_ARCH__ >= 350
+ float2 mem = __ldg(reinterpret_cast<const float2*>(address));
+ return std::complex<float>(mem.x, mem.y);
+#else
+ return *address;
+#endif
+}
+
+__host__ __device__ inline std::complex<double> CudaLdg(
+ const std::complex<double>* address) {
+#if __CUDA_ARCH__ >= 350
+ double2 mem = __ldg(reinterpret_cast<const double2*>(address));
+ return std::complex<double>(mem.x, mem.y);
+#else
+ return *address;
+#endif
+}
+
+// Zeroes count elements starting at ptr using all threads of a 1-D grid.
+// Note: this function does not synchronize, and therefore the memory range is
+// not guaranteed to be zero until the next kernel launch.
+template <typename T>
+__global__ void SetZero(const int count, T* ptr) {
+ // Check that the grid is one dimensional and index doesn't overflow.
+ assert(blockDim.y == 1 && blockDim.z == 1);
+ assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x);
+ for (int i : CudaGridRangeX(count)) {
+ ptr[i] = T(0);
+ }
+}
+
+namespace detail {
+// Helper function for atomic accumulation implemented as CAS.
+template <typename T, typename F>
+__device__ T CudaAtomicCasHelper(T* ptr, F accumulate) {
+ T old = *ptr;
+ T assumed;
+ do {
+ assumed = old;
+ old = atomicCAS(ptr, assumed, accumulate(assumed));
+ } while (assumed != old);
+ return old;
+}
+
+// Overload for floating point (using integer comparison to handle NaN
+// correctly).
+template <typename F>
+__device__ float CudaAtomicCasHelper(float* ptr, F accumulate) {
+ return __float_as_int(
+ CudaAtomicCasHelper(reinterpret_cast<int32*>(ptr), [accumulate](int32 a) {
+ return __float_as_int(accumulate(__int_as_float(a)));
+ }));
+}
+template <typename F>
+__device__ double CudaAtomicCasHelper(double* ptr, F accumulate) {
+ return __longlong_as_double(CudaAtomicCasHelper(
+ reinterpret_cast<tensorflow::uint64*>(ptr),
+ [accumulate](tensorflow::uint64 a) {
+ return __double_as_longlong(accumulate(__longlong_as_double(a)));
+ }));
+}
+
+template <typename From, typename To>
+using ToTypeIfConvertible =
+ typename std::enable_if<std::is_convertible<From, To>::value, To>::type;
+
+} // namespace detail
+
+// CUDA provides atomic ops, but not for all types. We provide wrappers
+// for some ops and provide implementation for all reasonable types.
+
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicAdd(T* ptr, U value) {
+ return atomicAdd(ptr, value);
+}
+#if __CUDA_ARCH__ < 600
+__device__ inline double CudaAtomicAdd(double* ptr, double value) {
+ return detail::CudaAtomicCasHelper(ptr,
+ [value](double a) { return a + value; });
+}
+#elif __clang__
+// Clang cannot compile __nvvm_atom_add_gen_d builtin yet, use inline PTX.
+// see https://reviews.llvm.org/D39638
+__device__ inline double CudaAtomicAdd(double* ptr, double value) {
+ double result;
+ asm volatile("atom.add.f64 %0, [%1], %2;"
+ : "=d"(result)
+ : "l"(ptr), "d"(value)
+ : "memory");
+ return result;
+}
+#endif
+
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicSub(T* ptr, U value) {
+ return atomicSub(ptr, value);
+}
+// Specializations of substraction which add the negative value.
+__device__ inline float CudaAtomicSub(float* ptr, float value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+__device__ inline double CudaAtomicSub(double* ptr, double value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+__device__ inline tensorflow::uint64 CudaAtomicSub(tensorflow::uint64* ptr,
+ tensorflow::uint64 value) {
+ return CudaAtomicAdd(ptr, -value);
+}
+
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMax(T* ptr, U value) {
+ return atomicMax(ptr, value);
+}
+#if __CUDA_ARCH__ < 320
+__device__ inline tensorflow::uint64 CudaAtomicMax(tensorflow::uint64* ptr,
+ tensorflow::uint64 value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](tensorflow::uint64 a) { return max(a, value); });
+}
+#endif
+
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicMul(T* ptr, U value) {
+ return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a * value; });
+}
+template <typename T, typename U>
+__device__ detail::ToTypeIfConvertible<U, T> CudaAtomicDiv(T* ptr, U value) {
+ return detail::CudaAtomicCasHelper(ptr, [value](T a) { return a / value; });
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 3e32ec7973..18a4c008f1 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -18,299 +18,133 @@ limitations under the License.
#if GOOGLE_CUDA
-#include <algorithm>
+#include "tensorflow/core/util/cuda_device_functions.h"
+#include "tensorflow/core/util/cuda_launch_config.h"
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "cuda/include/cuda.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/stream_executor.h"
-#include "tensorflow/core/platform/types.h"
+// Deprecated, use 'for(int i : CudaGridRangeX(n))' instead.
+#define CUDA_1D_KERNEL_LOOP(i, n) \
+ for (int i : ::tensorflow::CudaGridRangeX<int>(n))
+// Deprecated, use 'for(int i : CudaGridRange?(n))' instead.
+#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
+ for (int i : ::tensorflow::CudaGridRange##axis<int>(n))
-// Mask for all 32 threads in a warp.
-#define CUDA_WARP_ALL 0xFFFFFFFF
-
-#if defined(CUDA_VERSION) && CUDA_VERSION < 9000
-// CUDA 9.0 introduces a new, light-weight barrier synchronization primitive
-// that operates at the warp-scope. This is required to ensure visibility of
-// reads/writes among threads that can make indepenent progress on Volta.
-// For previous CUDA versions these synchronizations not necessary, and we
-// define an empty function as a convenience for backward compatibility.
-__device__ inline void __syncwarp(unsigned mask = CUDA_WARP_ALL) {}
-
-// CUDA 9.0 deprecates the warp-intrinsic functions (shfl, ballot, etc.) in
-// favor of synchronizing versions. These ensure that all warp lanes specified
-// in mask execute the intrinsic in convergence. Here we provide legacy mappings
-// to the less-verbose routines provided in previous versions of CUDA.
-#define __ballot_sync(mask, predicate) __ballot(predicate)
-#define __shfl_sync(mask, val, srcLane, width) __shfl(val, srcLane, width)
-#define __shfl_down_sync(mask, val, delta, width) __shfl_down(val, delta, width)
-#define __shfl_up_sync(mask, val, delta, width) __shfl_up(val, delta, width)
-#define __shfl_xor_sync(mask, val, laneMask, width) \
- __shfl_xor(val, laneMask, width)
-#endif
-
-// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
-// GetCuda3DLaunchConfig:
-//
-// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
-// version uses heuristics without any knowledge of the device kernel, the other
-// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
-// launch parameters that maximize occupancy. Currently, only the maximum
-// occupancy version of GetCuda3DLaunchConfig is available.
-//
-// For large number of work elements, the convention is that each kernel would
-// iterate through its assigned range. The return value of GetCudaLaunchConfig
-// is struct CudaLaunchConfig, which contains all the information needed for the
-// kernel launch, including: virtual number of threads, the number of threads
-// per block and number of threads per block used inside <<< >>> of a kernel
-// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
-// as CudaLaunchConfig. The only difference is the dimension. The macros
-// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
-//
-/* Sample code:
-
-__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
- CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
- do_your_job_here;
- }
+namespace tensorflow {
+__host__ __device__ inline tensorflow::bfloat16 CudaLdg(
+ const tensorflow::bfloat16* address) {
+ tensorflow::bfloat16 return_value;
+ return_value.value = CudaLdg(reinterpret_cast<const uint16_t*>(address));
+ return return_value;
}
-__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
- do_your_job_here;
- }
- }
+template <typename T>
+__host__ __device__ inline T ldg(const T* ptr) {
+ return CudaLdg(ptr);
}
-__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
- do_your_job_here;
- }
- }
- }
+template <typename T>
+__host__ __device__ inline const T& tf_min(const T& x, const T& y) {
+ return x < y ? x : y;
}
-void MyDriverFunc(const GPUDevice &d) {
- // use heuristics
- CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
- MyKernel1D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
- Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
- MyKernel2D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
- Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
- MyKernel3D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
-
- // maximize occupancy
- CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
- MyKernel1D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
- Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
- MyKernel1D, 0, 0);
- MyKernel2D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
- Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
- MyKernel1D, 0, 0);
- MyKernel3D <<<config.block_count,
- config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
+template <typename T>
+__host__ __device__ inline const T& tf_max(const T& x, const T& y) {
+ return x < y ? y : x;
}
-// See the test for this for more example:
-//
-https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
-
-*/
-
-#define CUDA_1D_KERNEL_LOOP(i, n) \
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
- i += blockDim.x * gridDim.x)
-
-#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \
- for (int i = blockIdx.axis * blockDim.axis + threadIdx.axis; i < n.axis; \
- i += blockDim.axis * gridDim.axis)
-
-#define DIV_UP(a, b) (((a) + (b)-1) / (b))
-
-namespace tensorflow {
-
-typedef Eigen::GpuDevice GPUDevice;
-
-struct CudaLaunchConfig {
- // Logical number of thread that works on the elements. If each logical
- // thread works on exactly a single element, this is the same as the working
- // element count.
- int virtual_thread_count = -1;
- // Number of threads per block.
- int thread_per_block = -1;
- // Number of blocks for Cuda kernel launch.
- int block_count = -1;
-};
-
-// Calculate the Cuda launch config we should use for a kernel launch.
-// This is assuming the kernel is quite simple and will largely be
-// memory-limited.
-// REQUIRES: work_element_count > 0.
-inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
- const GPUDevice& d) {
- CHECK_GT(work_element_count, 0);
- CudaLaunchConfig config;
- const int virtual_thread_count = work_element_count;
- const int physical_thread_count = std::min(
- d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
- virtual_thread_count);
- const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
- const int block_count =
- std::min(DIV_UP(physical_thread_count, thread_per_block),
- d.getNumCudaMultiProcessors());
-
- config.virtual_thread_count = virtual_thread_count;
- config.thread_per_block = thread_per_block;
- config.block_count = block_count;
- return config;
+// Overloads of the above functions for float and double.
+__host__ __device__ inline float tf_min(float x, float y) {
+ return fminf(x, y);
}
-
-// Calculate the Cuda launch config we should use for a kernel launch. This
-// variant takes the resource limits of func into account to maximize occupancy.
-// REQUIRES: work_element_count > 0.
-template <typename DeviceFunc>
-inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
- const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size,
- int block_size_limit) {
- CHECK_GT(work_element_count, 0);
- CudaLaunchConfig config;
- int block_count = 0;
- int thread_per_block = 0;
-
- cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
- &block_count, &thread_per_block, func, dynamic_shared_memory_size,
- block_size_limit);
- CHECK_EQ(err, cudaSuccess);
-
- block_count =
- std::min(block_count, DIV_UP(work_element_count, thread_per_block));
-
- config.virtual_thread_count = work_element_count;
- config.thread_per_block = thread_per_block;
- config.block_count = block_count;
- return config;
+__host__ __device__ inline double tf_min(double x, double y) {
+ return fmin(x, y);
+}
+__host__ __device__ inline float tf_max(float x, float y) {
+ return fmaxf(x, y);
+}
+__host__ __device__ inline double tf_max(double x, double y) {
+ return fmax(x, y);
}
-struct Cuda2DLaunchConfig {
- dim3 virtual_thread_count = dim3(0, 0, 0);
- dim3 thread_per_block = dim3(0, 0, 0);
- dim3 block_count = dim3(0, 0, 0);
-};
-
-inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
- const GPUDevice& d) {
- Cuda2DLaunchConfig config;
-
- if (xdim <= 0 || ydim <= 0) {
- return config;
- }
-
- const int kThreadsPerBlock = 256;
- int block_cols = std::min(xdim, kThreadsPerBlock);
- // ok to round down here and just do more loops in the kernel
- int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
-
- const int physical_thread_count =
- d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
-
- const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
-
- config.virtual_thread_count = dim3(xdim, ydim, 1);
- config.thread_per_block = dim3(block_cols, block_rows, 1);
-
- int grid_x = std::min(DIV_UP(xdim, block_cols), max_blocks);
+__device__ inline Eigen::half CudaShuffleSync(unsigned mask, Eigen::half value,
+ int src_lane,
+ int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleSync(mask, static_cast<uint16>(value), src_lane, width));
+}
- config.block_count = dim3(
- grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
- return config;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleUpSync(
+ unsigned mask, Eigen::half value, int delta, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleUpSync(mask, static_cast<uint16>(value), delta, width));
}
-// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
-// This variant takes the resource limits of func into account to maximize
-// occupancy.
-using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDownSync(
+ unsigned mask, Eigen::half value, int delta, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleDownSync(mask, static_cast<uint16>(value), delta, width));
+}
-template <typename DeviceFunc>
-inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
- int xdim, int ydim, int zdim, const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size, int block_size_limit) {
- Cuda3DLaunchConfig config;
+__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
+ unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) {
+ return Eigen::half(
+ CudaShuffleXorSync(mask, static_cast<uint16>(value), lane_mask, width));
+}
- if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
- return config;
+namespace detail {
+// Overload of above function for half. Note that we don't have
+// atomicCAS() for anything less than 32 bits, so we need to include the
+// other 16 bits in the operation.
+//
+// This version is going to be very slow
+// under high concurrency, since most threads will be spinning on failing
+// their compare-and-swap tests. (The fact that we get false sharing on the
+// neighboring fp16 makes this even worse.) If you are doing a large reduction,
+// you are much better off with doing the intermediate steps in fp32 and then
+// switching to fp16 as late as you can in the calculations.
+//
+// Note: Assumes little endian.
+template <typename F>
+__device__ Eigen::half CudaAtomicCasHelper(Eigen::half* ptr, F accumulate) {
+#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__)
+ static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian");
+#endif
+ namespace half_impl = Eigen::half_impl;
+ intptr_t intptr = reinterpret_cast<intptr_t>(ptr);
+ assert(!(intptr & 0x1)); // should be 2-aligned.
+ if (intptr & 0x2) {
+ // The half is in the second part of the uint32 (upper 16 bits).
+ uint32* address = reinterpret_cast<uint32*>(intptr - 2);
+ uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
+ unsigned short high = static_cast<unsigned short>(arg >> 16);
+ Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(high));
+ return (static_cast<uint32>(acc.x) << 16) | (arg & 0xffff);
+ });
+ return half_impl::raw_uint16_to_half(static_cast<uint16>(result >> 16));
+ } else {
+ // The half is in the first part of the uint32 (lower 16 bits).
+ uint32* address = reinterpret_cast<uint32*>(intptr);
+ uint32 result = CudaAtomicCasHelper(address, [accumulate](uint32 arg) {
+ unsigned short low = static_cast<unsigned short>(arg & 0xffff);
+ Eigen::half acc = accumulate(half_impl::raw_uint16_to_half(low));
+ return (arg & 0xffff0000) | static_cast<uint32>(acc.x);
+ });
+ return half_impl::raw_uint16_to_half(static_cast<uint16>(result & 0xffff));
}
-
- int dev;
- cudaGetDevice(&dev);
- cudaDeviceProp deviceProp;
- cudaGetDeviceProperties(&deviceProp, dev);
- int xthreadlimit = deviceProp.maxThreadsDim[0];
- int ythreadlimit = deviceProp.maxThreadsDim[1];
- int zthreadlimit = deviceProp.maxThreadsDim[2];
- int xgridlimit = deviceProp.maxGridSize[0];
- int ygridlimit = deviceProp.maxGridSize[1];
- int zgridlimit = deviceProp.maxGridSize[2];
-
- int block_count = 0;
- int thread_per_block = 0;
- cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
- &block_count, &thread_per_block, func, dynamic_shared_memory_size,
- block_size_limit);
- CHECK_EQ(err, cudaSuccess);
-
-#define MIN3(a, b, c) std::min((a), std::min((b), (c)))
- int threadsx = MIN3(xdim, thread_per_block, xthreadlimit);
- int threadsy =
- MIN3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
- int threadsz =
- MIN3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
- zthreadlimit);
-
- int blocksx = MIN3(block_count, DIV_UP(xdim, threadsx), xgridlimit);
- int blocksy =
- MIN3(DIV_UP(block_count, blocksx), DIV_UP(ydim, threadsy), ygridlimit);
- int blocksz = MIN3(DIV_UP(block_count, (blocksx * blocksy)),
- DIV_UP(zdim, threadsz), zgridlimit);
-#undef MIN3
-
- config.virtual_thread_count = dim3(xdim, ydim, zdim);
- config.thread_per_block = dim3(threadsx, threadsy, threadsz);
- config.block_count = dim3(blocksx, blocksy, blocksz);
- return config;
}
+} // namespace detail
-template <typename DeviceFunc>
-inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
- int xdim, int ydim, const GPUDevice& d, DeviceFunc func,
- size_t dynamic_shared_memory_size, int block_size_limit) {
- return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
- dynamic_shared_memory_size, block_size_limit);
+__device__ inline Eigen::half CudaAtomicAdd(Eigen::half* ptr,
+ Eigen::half value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](Eigen::half a) { return a + value; });
}
-
-// Returns a raw reference to the current cuda stream. Required by a
-// number of kernel calls (for which StreamInterface* does not work), i.e.
-// CUB and certain cublas primitives.
-inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
- const cudaStream_t* ptr = CHECK_NOTNULL(
- reinterpret_cast<const cudaStream_t*>(context->op_device_context()
- ->stream()
- ->implementation()
- ->CudaStreamMemberHack()));
- return *ptr;
+__device__ inline Eigen::half CudaAtomicSub(Eigen::half* ptr,
+ Eigen::half value) {
+ return detail::CudaAtomicCasHelper(
+ ptr, [value](Eigen::half a) { return a - value; });
}
namespace cuda_helper {
-
template <typename IntType>
__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
IntType* orig = first;
@@ -330,495 +164,8 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
return first - orig;
}
-
} // namespace cuda_helper
-
-template <typename T>
-__device__ __host__ inline T ldg(const T* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return __ldg(address);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline std::complex<float> ldg(
- const std::complex<float>* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- float2 mem = __ldg(reinterpret_cast<const float2*>(address));
- return std::complex<float>(mem.x, mem.y);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline std::complex<double> ldg(
- const std::complex<double>* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- double2 mem = __ldg(reinterpret_cast<const double2*>(address));
- return std::complex<double>(mem.x, mem.y);
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline Eigen::half ldg(const Eigen::half* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return Eigen::half_impl::raw_uint16_to_half(
- __ldg(reinterpret_cast<const uint16_t*>(address)));
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline tensorflow::bfloat16 ldg(
- const tensorflow::bfloat16* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- tensorflow::bfloat16 return_value;
- asm volatile("ld.global.nc.u16 %0, [%1];"
- : "=h"(return_value.value)
- : "l"(address));
- return return_value;
-#else
- return *address;
-#endif
-}
-
-template <>
-__device__ __host__ inline bool ldg(const bool* address) {
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
- return *reinterpret_cast<const bool*>(
- __ldg(reinterpret_cast<const char*>(address)));
-#else
- return *address;
-#endif
-}
-
-// CUDA provides atomic ops, but not for all types. We provide wrappers
-// for some ops and provide implementation for all reasonable types.
-#define CUDA_ATOMIC_WRAPPER(op, T) \
- __device__ __forceinline__ T CudaAtomic##op(T* address, T val)
-
-#define USE_CUDA_ATOMIC(op, T) \
- CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
-
-// For atomicAdd.
-USE_CUDA_ATOMIC(Add, int32);
-USE_CUDA_ATOMIC(Add, uint32);
-USE_CUDA_ATOMIC(Add, uint64);
-USE_CUDA_ATOMIC(Add, float);
-
-// For atomicMax.
-USE_CUDA_ATOMIC(Max, int32);
-USE_CUDA_ATOMIC(Max, uint32);
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
-USE_CUDA_ATOMIC(Max, uint64);
-#else
-// The uint64 overload of atomicMax() is only available for __CUDA_ARCH__ >=
-// 350. If not satisfied, we provide a custom implementation using atomicCAS().
-CUDA_ATOMIC_WRAPPER(Max, uint64) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed, max(val, assumed));
- } while (assumed != old);
-
- return old;
-}
-#endif
-
-// Custom implementation of atomicAdd for double.
-// This implementation is copied from CUDA manual.
-CUDA_ATOMIC_WRAPPER(Add, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(val + __longlong_as_double(assumed)));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- return __longlong_as_double(old);
-}
-
-// Custom implementation of atomicAdd for std::complex<float>.
-// This implementation performs to atomic additions on the components.
-CUDA_ATOMIC_WRAPPER(Add, std::complex<float>) {
-#if defined(__CUDA_ARCH__)
-#if __CUDA_ARCH__ >= 350
- float2* addr_as_float2 = reinterpret_cast<float2*>(address);
- float2* val_as_float2 = reinterpret_cast<float2*>(&val);
- CudaAtomicAdd(&(addr_as_float2->x), val_as_float2->x);
- CudaAtomicAdd(&(addr_as_float2->y), val_as_float2->y);
-#else
- static_assert(sizeof(std::complex<float>) == 2 * sizeof(float),
- "Unable to compile CudaAtomicAdd for complex64 because "
- "sizeof(complex64) != 2*sizeof(float32)");
- float* addr_as_float = reinterpret_cast<float*>(address);
- float* val_as_float = reinterpret_cast<float*>(&val);
- CudaAtomicAdd(addr_as_float, *val_as_float);
- CudaAtomicAdd(addr_as_float + 1, *(val_as_float + 1));
-#endif
-#endif
- return *address;
-}
-
-// Custom implementation of atomicAdd for std::complex<double>.
-// This implementation performs to atomic additions on the components
-// using the double atomic wrapper above.
-CUDA_ATOMIC_WRAPPER(Add, complex128) {
-#if defined(__CUDA_ARCH__)
-#if __CUDA_ARCH__ >= 350
- double2* addr_as_double2 = reinterpret_cast<double2*>(address);
- double2* val_as_double2 = reinterpret_cast<double2*>(&val);
- CudaAtomicAdd(&(addr_as_double2->x), val_as_double2->x);
- CudaAtomicAdd(&(addr_as_double2->y), val_as_double2->y);
-#else
- static_assert(sizeof(std::complex<double>) == 2 * sizeof(double),
- "Unable to compile CudaAtomicAdd for complex128 because "
- "sizeof(complex128) != 2*sizeof(float64)");
- double* addr_as_double = reinterpret_cast<double*>(address);
- double* val_as_double = reinterpret_cast<double*>(&val);
- CudaAtomicAdd(addr_as_double, *val_as_double);
- CudaAtomicAdd(addr_as_double + 1, *(val_as_double + 1));
-#endif
-#endif
- return *address;
-}
-
-// Helper functions for CudaAtomicAdd(half*, half), below.
-//
-// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2()
-// for a more efficient implementation, assuming that adding -0.0
-// will never harm the neighboring value. In this version, we take special
-// care to guarantee the bits of the untouched value are unchanged.
-inline __device__ uint32 add_to_low_half(uint32 val, float x) {
- Eigen::half low_half;
- low_half.x = static_cast<uint16>(val & 0xffffu);
- low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x);
- return (val & 0xffff0000u) | low_half.x;
-}
-
-inline __device__ uint32 add_to_high_half(uint32 val, float x) {
- Eigen::half high_half;
- high_half.x = static_cast<uint16>(val >> 16);
- high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x);
- return (val & 0xffffu) | (high_half.x << 16);
-}
-
-// Custom implementation of atomicAdd for half. Note that we don't have
-// atomicCAS() for anything less than 32 bits, so we need to include the
-// other 16 bits in the operation.
-//
-// Unlike the other atomic adds, this version is going to be very slow
-// under high concurrency, since most threads will be spinning on failing
-// their compare-and-swap tests. (The fact that we get false sharing on the
-// neighboring fp16 makes this even worse.) If you are doing a large reduction,
-// you are much better off with doing the intermediate steps in fp32 and then
-// switching to fp16 as late as you can in the calculations.
-//
-// Note: Assumes little endian.
-CUDA_ATOMIC_WRAPPER(Add, Eigen::half) {
- float val_as_float(val);
- intptr_t address_int = reinterpret_cast<intptr_t>(address);
- if ((address_int & 0x2) == 0) {
- // The half is in the first part of the uint32 (lower 16 bits).
- uint32* address_as_uint32 = reinterpret_cast<uint32*>(address);
- assert(((intptr_t)address_as_uint32 & 0x3) == 0);
- uint32 old = *address_as_uint32, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_uint32, assumed,
- add_to_low_half(assumed, val_as_float));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- Eigen::half ret;
- ret.x = old & 0xffffu;
- return ret;
- } else {
- // The half is in the second part of the uint32 (upper 16 bits).
- uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2);
- assert(((intptr_t)address_as_uint32 & 0x3) == 0);
- uint32 old = *address_as_uint32, assumed;
-
- do {
- assumed = old;
- old = atomicCAS(address_as_uint32, assumed,
- add_to_high_half(assumed, val_as_float));
-
- // Note: uses integer comparison to avoid hang in case of NaN
- } while (assumed != old);
-
- Eigen::half ret;
- ret.x = old >> 16;
- return ret;
- }
-}
-
-template <typename T>
-__global__ void SetZero(const int nthreads, T* bottom_diff) {
- CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); }
-}
-
-// For atomicSub.
-
-// Custom implementation for sub by just negating the value.
-#define WRAPPED_ATOMIC_SUB(T) \
- CUDA_ATOMIC_WRAPPER(Sub, T) { return CudaAtomicAdd(address, -val); }
-
-WRAPPED_ATOMIC_SUB(uint64);
-WRAPPED_ATOMIC_SUB(int32);
-WRAPPED_ATOMIC_SUB(uint32);
-WRAPPED_ATOMIC_SUB(Eigen::half);
-WRAPPED_ATOMIC_SUB(float);
-WRAPPED_ATOMIC_SUB(double);
-
-CUDA_ATOMIC_WRAPPER(Sub, complex64) {
- const std::complex<float> Tneg(-val.real(), -val.imag());
- return CudaAtomicAdd(address, Tneg);
-}
-
-CUDA_ATOMIC_WRAPPER(Sub, complex128) {
- const std::complex<double> Tneg(-val.real(), -val.imag());
- return CudaAtomicAdd(address, Tneg);
-}
-
-#undef WRAPPED_ATOMIC_SUB
-
-// For atomicMul.
-CUDA_ATOMIC_WRAPPER(Mul, int32) {
- int32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, uint32) {
- uint32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, uint64) {
- uint64 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, val * assumed);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, float) {
- int32* address_as_int = reinterpret_cast<int32*>(address);
- int32 old = *address_as_int, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_int, assumed,
- __float_as_int(val * __int_as_float(assumed)));
- } while (assumed != old);
- return __int_as_float(old);
-}
-
-CUDA_ATOMIC_WRAPPER(Mul, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(val * __longlong_as_double(assumed)));
- } while (assumed != old);
- return __longlong_as_double(old);
-}
-
-// For atomicDiv.
-CUDA_ATOMIC_WRAPPER(Div, int32) {
- int32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, uint32) {
- uint32 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, uint64) {
- uint64 old = *address, assumed;
- do {
- assumed = old;
- old = atomicCAS(address, assumed, assumed / val);
- } while (assumed != old);
- return old;
-}
-
-CUDA_ATOMIC_WRAPPER(Div, float) {
- int32* address_as_int = reinterpret_cast<int32*>(address);
- int32 old = *address_as_int, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_int, assumed,
- __float_as_int(__int_as_float(assumed) / val));
- } while (assumed != old);
- return __int_as_float(old);
-}
-
-CUDA_ATOMIC_WRAPPER(Div, double) {
- uint64* address_as_ull = reinterpret_cast<uint64*>(address);
- uint64 old = *address_as_ull, assumed;
- do {
- assumed = old;
- old = atomicCAS(address_as_ull, assumed,
- __double_as_longlong(__longlong_as_double(assumed) / val));
- } while (assumed != old);
- return __longlong_as_double(old);
-}
-
-#undef USE_CUDA_ATOMIC
-#undef CUDA_ATOMIC_WRAPPER
-
-template <typename T>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_min(const T& x, const T& y) {
- return x > y ? y : x;
-}
-
-template <typename T>
-EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T tf_max(const T& x, const T& y) {
- return x < y ? y : x;
-}
-
-__device__ EIGEN_ALWAYS_INLINE unsigned CudaBallot(unsigned mask,
- int predicate) {
- return __ballot_sync(mask, predicate);
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffle(unsigned mask, T value,
- int srcLane,
- int width = warpSize) {
- return __shfl_sync(mask, value, srcLane, width);
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffle(unsigned mask, double value,
- int srcLane,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_sync(mask, hi, srcLane, width);
- lo = __shfl_sync(mask, lo, srcLane, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleUp(unsigned mask, T value,
- int delta,
- int width = warpSize) {
- return __shfl_up_sync(mask, value, delta, width);
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleUp(unsigned mask, double value,
- int delta,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_up_sync(mask, hi, delta, width);
- lo = __shfl_up_sync(mask, lo, delta, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleDown(unsigned mask, T value,
- int delta,
- int width = warpSize) {
- return __shfl_down_sync(mask, value, delta, width);
-}
-
-__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleDown(
- unsigned mask, Eigen::half value, int delta, int width = warpSize) {
- return Eigen::half(
- __shfl_down_sync(mask, static_cast<uint16>(value), delta, width));
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleDown(unsigned mask,
- double value, int delta,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_down_sync(mask, hi, delta, width);
- lo = __shfl_down_sync(mask, lo, delta, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
-template <typename T>
-__device__ EIGEN_ALWAYS_INLINE T CudaShuffleXor(unsigned mask, T value,
- int laneMask,
- int width = warpSize) {
- return __shfl_xor_sync(mask, value, laneMask, width);
-}
-
-__device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXor(
- unsigned mask, Eigen::half value, int laneMask, int width = warpSize) {
- return Eigen::half(
- __shfl_xor_sync(mask, static_cast<uint16>(value), laneMask, width));
-}
-
-// Variant of the (undocumented) version from the CUDA SDK, but using unsigned
-// instead of float for lo and hi (which is incorrect with ftz, for example).
-// A bug has been filed with NVIDIA and will be fixed in the next CUDA release.
-// TODO(csigg): remove when the bug is fixed in the next CUDA release.
-__device__ EIGEN_ALWAYS_INLINE double CudaShuffleXor(unsigned mask,
- double value, int laneMask,
- int width = warpSize) {
- unsigned lo, hi;
- asm volatile("mov.b64 {%0,%1}, %2;" : "=r"(lo), "=r"(hi) : "d"(value));
- hi = __shfl_xor_sync(mask, hi, laneMask, width);
- lo = __shfl_xor_sync(mask, lo, laneMask, width);
- asm volatile("mov.b64 %0, {%1,%2};" : "=d"(value) : "r"(lo), "r"(hi));
- return value;
-}
-
} // namespace tensorflow
-#undef DIV_UP
-
#endif // GOOGLE_CUDA
-
#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
diff --git a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
index 6991554eff..bd4c356ea0 100644
--- a/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
+++ b/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
@@ -52,11 +52,11 @@ __global__ void Count1D(CudaLaunchConfig config, int bufsize, int* outbuf) {
}
}
__global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
@@ -66,15 +66,15 @@ __global__ void Count2D(Cuda2DLaunchConfig config, int bufsize, int* outbuf) {
}
}
__global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
- CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count.x, X) {
if (x < 0) { // x might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count.y, Y) {
if (y < 0) { // y might overflow when testing extreme case
break;
}
- CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count.z, Z) {
if (z < 0) { // z might overflow when testing extreme case
break;
}
@@ -87,6 +87,44 @@ __global__ void Count3D(Cuda3DLaunchConfig config, int bufsize, int* outbuf) {
}
}
+__global__ void CudaShuffleGetSrcLaneTest(unsigned* failure_count) {
+ unsigned lane_id = CudaLaneId();
+ for (int width = warpSize; width > 1; width /= 2) {
+ auto check_result = [&](const char* op_name, int param, unsigned actual,
+ unsigned expected) {
+ if (actual != expected) {
+ printf("Cuda%sGetSrcLane(%d, %d) for lane %d returned %d, not %d\n",
+ op_name, param, width, lane_id, actual, expected);
+ CudaAtomicAdd(failure_count, 1);
+ }
+ };
+ for (int src_lane = -warpSize; src_lane <= warpSize; ++src_lane) {
+ unsigned actual_lane = detail::CudaShuffleGetSrcLane(src_lane, width);
+ unsigned expect_lane =
+ CudaShuffleSync(kCudaWarpAll, lane_id, src_lane, width);
+ check_result("Shuffle", src_lane, actual_lane, expect_lane);
+ }
+ for (unsigned delta = 0; delta <= warpSize; ++delta) {
+ unsigned actual_lane = detail::CudaShuffleUpGetSrcLane(delta, width);
+ unsigned expect_lane =
+ CudaShuffleUpSync(kCudaWarpAll, lane_id, delta, width);
+ check_result("ShuffleUp", delta, actual_lane, expect_lane);
+ }
+ for (unsigned delta = 0; delta <= warpSize; ++delta) {
+ unsigned actual_lane = detail::CudaShuffleDownGetSrcLane(delta, width);
+ unsigned expect_lane =
+ CudaShuffleDownSync(kCudaWarpAll, lane_id, delta, width);
+ check_result("ShuffleDown", delta, actual_lane, expect_lane);
+ }
+ for (int lane_lane = warpSize; lane_lane > 0; lane_lane /= 2) {
+ unsigned actual_lane = detail::CudaShuffleXorGetSrcLane(lane_lane, width);
+ unsigned expect_lane =
+ CudaShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width);
+ check_result("ShuffleXor", lane_lane, actual_lane, expect_lane);
+ }
+ }
+}
+
} // namespace
class CudaLaunchConfigTest : public ::testing::Test {
@@ -94,7 +132,7 @@ class CudaLaunchConfigTest : public ::testing::Test {
const int bufsize = 1024;
int* outbuf = nullptr;
Eigen::CudaStreamDevice stream;
- GPUDevice d = GPUDevice(&stream);
+ Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
virtual void SetUp() {
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
@@ -229,6 +267,16 @@ TEST_F(CudaLaunchConfigTest, GetCuda3DLaunchConfig) {
#undef TEST_LAUNCH_PARAMETER
}
+TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) {
+ unsigned* failure_count;
+ ASSERT_EQ(cudaMallocManaged(&failure_count, sizeof(unsigned)), cudaSuccess);
+ *failure_count = 0;
+ CudaShuffleGetSrcLaneTest<<<1, 32>>>(failure_count);
+ ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess);
+ ASSERT_EQ(*failure_count, 0);
+ cudaFree(failure_count);
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h
new file mode 100644
index 0000000000..3ea33ee6cf
--- /dev/null
+++ b/tensorflow/core/util/cuda_launch_config.h
@@ -0,0 +1,284 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
+#define TENSORFLOW_CORE_UTIL_CUDA_LAUNCH_CONFIG_H_
+
+#if GOOGLE_CUDA
+
+#include <algorithm>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "cuda/include/cuda.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/platform/types.h"
+
+// Usage of GetCudaLaunchConfig, GetCuda2DLaunchConfig, and
+// GetCuda3DLaunchConfig:
+//
+// There are two versions of GetCudaLaunchConfig and GetCuda2DLaunchConfig, one
+// version uses heuristics without any knowledge of the device kernel, the other
+// version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
+// launch parameters that maximize occupancy. Currently, only the maximum
+// occupancy version of GetCuda3DLaunchConfig is available.
+//
+// For large number of work elements, the convention is that each kernel would
+// iterate through its assigned range. The return value of GetCudaLaunchConfig
+// is struct CudaLaunchConfig, which contains all the information needed for the
+// kernel launch, including: virtual number of threads, the number of threads
+// per block and number of threads per block used inside <<< >>> of a kernel
+// launch. GetCuda2DLaunchConfig and GetCuda3DLaunchConfig does the same thing
+// as CudaLaunchConfig. The only difference is the dimension. The macros
+// CUDA_1D_KERNEL_LOOP and CUDA_AXIS_KERNEL_LOOP might be used to do inner loop.
+//
+/* Sample code:
+
+__global__ void MyKernel1D(CudaLaunchConfig config, other_args...) {
+ CUDA_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
+ do_your_job_here;
+ }
+}
+
+__global__ void MyKernel2D(Cuda2DLaunchConfig config, other_args...) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ do_your_job_here;
+ }
+ }
+}
+
+__global__ void MyKernel3D(Cuda3DLaunchConfig config, other_args...) {
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
+ do_your_job_here;
+ }
+ }
+ }
+}
+
+void MyDriverFunc(const Eigen::GpuDevice &d) {
+ // use heuristics
+ CudaLaunchConfig cfg1 = GetCudaLaunchConfig(10240, d);
+ MyKernel1D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
+ Cuda2DLaunchConfig cfg2 = GetCuda2DLaunchConfig(10240, 10240, d);
+ MyKernel2D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
+ Cuda3DLaunchConfig cfg3 = GetCuda3DLaunchConfig(4096, 4096, 100, d);
+ MyKernel3D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
+
+ // maximize occupancy
+ CudaLaunchConfig cfg4 = GetCudaLaunchConfig(10240, d, MyKernel1D, 0, 0 );
+ MyKernel1D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
+ Cuda2DLaunchConfig cfg5 = GetCuda2DLaunchConfig(10240, 10240, d,
+ MyKernel1D, 0, 0);
+ MyKernel2D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
+ Cuda3DLaunchConfig cfg6 = GetCuda3DLaunchConfig(4096, 4096, 100, d,
+ MyKernel1D, 0, 0);
+ MyKernel3D <<<config.block_count,
+ config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
+}
+
+// See the test for this for more example:
+//
+https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/cuda_kernel_helper_test.cu.cc
+
+*/
+
+namespace tensorflow {
+
+inline int DivUp(int a, int b) { return (a + b - 1) / b; }
+
+struct CudaLaunchConfig {
+ // Logical number of thread that works on the elements. If each logical
+ // thread works on exactly a single element, this is the same as the working
+ // element count.
+ int virtual_thread_count = -1;
+ // Number of threads per block.
+ int thread_per_block = -1;
+ // Number of blocks for Cuda kernel launch.
+ int block_count = -1;
+};
+
+// Calculate the Cuda launch config we should use for a kernel launch.
+// This is assuming the kernel is quite simple and will largely be
+// memory-limited.
+// REQUIRES: work_element_count > 0.
+inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
+ const Eigen::GpuDevice& d) {
+ CHECK_GT(work_element_count, 0);
+ CudaLaunchConfig config;
+ const int virtual_thread_count = work_element_count;
+ const int physical_thread_count = std::min(
+ d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
+ virtual_thread_count);
+ const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
+ const int block_count =
+ std::min(DivUp(physical_thread_count, thread_per_block),
+ d.getNumCudaMultiProcessors());
+
+ config.virtual_thread_count = virtual_thread_count;
+ config.thread_per_block = thread_per_block;
+ config.block_count = block_count;
+ return config;
+}
+
+// Calculate the Cuda launch config we should use for a kernel launch. This
+// variant takes the resource limits of func into account to maximize occupancy.
+// REQUIRES: work_element_count > 0.
+template <typename DeviceFunc>
+inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
+ const Eigen::GpuDevice& d,
+ DeviceFunc func,
+ size_t dynamic_shared_memory_size,
+ int block_size_limit) {
+ CHECK_GT(work_element_count, 0);
+ CudaLaunchConfig config;
+ int block_count = 0;
+ int thread_per_block = 0;
+
+ cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
+ &block_count, &thread_per_block, func, dynamic_shared_memory_size,
+ block_size_limit);
+ CHECK_EQ(err, cudaSuccess);
+
+ block_count =
+ std::min(block_count, DivUp(work_element_count, thread_per_block));
+
+ config.virtual_thread_count = work_element_count;
+ config.thread_per_block = thread_per_block;
+ config.block_count = block_count;
+ return config;
+}
+
+struct Cuda2DLaunchConfig {
+ dim3 virtual_thread_count = dim3(0, 0, 0);
+ dim3 thread_per_block = dim3(0, 0, 0);
+ dim3 block_count = dim3(0, 0, 0);
+};
+
+inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
+ const Eigen::GpuDevice& d) {
+ Cuda2DLaunchConfig config;
+
+ if (xdim <= 0 || ydim <= 0) {
+ return config;
+ }
+
+ const int kThreadsPerBlock = 256;
+ int block_cols = std::min(xdim, kThreadsPerBlock);
+ // ok to round down here and just do more loops in the kernel
+ int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
+
+ const int physical_thread_count =
+ d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor();
+
+ const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
+
+ config.virtual_thread_count = dim3(xdim, ydim, 1);
+ config.thread_per_block = dim3(block_cols, block_rows, 1);
+
+ int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
+
+ config.block_count = dim3(
+ grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
+ return config;
+}
+
+// Calculate the Cuda 2D and 3D launch config we should use for a kernel launch.
+// This variant takes the resource limits of func into account to maximize
+// occupancy.
+using Cuda3DLaunchConfig = Cuda2DLaunchConfig;
+
+template <typename DeviceFunc>
+inline Cuda3DLaunchConfig GetCuda3DLaunchConfig(
+ int xdim, int ydim, int zdim, const Eigen::GpuDevice& d, DeviceFunc func,
+ size_t dynamic_shared_memory_size, int block_size_limit) {
+ Cuda3DLaunchConfig config;
+
+ if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
+ return config;
+ }
+
+ int dev;
+ cudaGetDevice(&dev);
+ cudaDeviceProp deviceProp;
+ cudaGetDeviceProperties(&deviceProp, dev);
+ int xthreadlimit = deviceProp.maxThreadsDim[0];
+ int ythreadlimit = deviceProp.maxThreadsDim[1];
+ int zthreadlimit = deviceProp.maxThreadsDim[2];
+ int xgridlimit = deviceProp.maxGridSize[0];
+ int ygridlimit = deviceProp.maxGridSize[1];
+ int zgridlimit = deviceProp.maxGridSize[2];
+
+ int block_count = 0;
+ int thread_per_block = 0;
+ cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
+ &block_count, &thread_per_block, func, dynamic_shared_memory_size,
+ block_size_limit);
+ CHECK_EQ(err, cudaSuccess);
+
+ auto min3 = [](int a, int b, int c) { return std::min(a, std::min(b, c)); };
+
+ int threadsx = min3(xdim, thread_per_block, xthreadlimit);
+ int threadsy =
+ min3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit);
+ int threadsz =
+ min3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
+ zthreadlimit);
+
+ int blocksx = min3(block_count, DivUp(xdim, threadsx), xgridlimit);
+ int blocksy =
+ min3(DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit);
+ int blocksz = min3(DivUp(block_count, (blocksx * blocksy)),
+ DivUp(zdim, threadsz), zgridlimit);
+
+ config.virtual_thread_count = dim3(xdim, ydim, zdim);
+ config.thread_per_block = dim3(threadsx, threadsy, threadsz);
+ config.block_count = dim3(blocksx, blocksy, blocksz);
+ return config;
+}
+
+template <typename DeviceFunc>
+inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(
+ int xdim, int ydim, const Eigen::GpuDevice& d, DeviceFunc func,
+ size_t dynamic_shared_memory_size, int block_size_limit) {
+ return GetCuda3DLaunchConfig(xdim, ydim, 1, d, func,
+ dynamic_shared_memory_size, block_size_limit);
+}
+
+// Returns a raw reference to the current cuda stream. Required by a
+// number of kernel calls (for which StreamInterface* does not work), i.e.
+// CUB and certain cublas primitives.
+inline const cudaStream_t& GetCudaStream(OpKernelContext* context) {
+ const cudaStream_t* ptr = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+ return *ptr;
+}
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
index fa4c1c0da5..461fb1c517 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Train and Eval the MNIST network.
This version is like fully_connected_feed.py but uses data converted
@@ -65,6 +64,7 @@ def decode(serialized_example):
return image, label
+
def augment(image, label):
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
@@ -72,12 +72,14 @@ def augment(image, label):
# into a vector, we don't bother.
return image, label
+
def normalize(image, label):
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
return image, label
+
def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times.
@@ -98,9 +100,10 @@ def inputs(train, batch_size, num_epochs):
over the dataset once. On the other hand there is no special initialization
required.
"""
- if not num_epochs: num_epochs = None
- filename = os.path.join(FLAGS.train_dir,
- TRAIN_FILE if train else VALIDATION_FILE)
+ if not num_epochs:
+ num_epochs = None
+ filename = os.path.join(FLAGS.train_dir, TRAIN_FILE
+ if train else VALIDATION_FILE)
with tf.name_scope('input'):
# TFRecordDataset opens a protobuf and reads entries line by line
@@ -127,13 +130,11 @@ def run_training():
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Input images and labels.
- image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size,
- num_epochs=FLAGS.num_epochs)
+ image_batch, label_batch = inputs(
+ train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs)
# Build a Graph that computes predictions from the inference model.
- logits = mnist.inference(image_batch,
- FLAGS.hidden1,
- FLAGS.hidden2)
+ logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2)
# Add to the Graph the loss calculation.
loss = mnist.loss(logits, label_batch)
@@ -152,7 +153,7 @@ def run_training():
sess.run(init_op)
try:
step = 0
- while True: #train until OutOfRangeError
+ while True: #train until OutOfRangeError
start_time = time.time()
# Run one step of the model. The return values are
@@ -168,10 +169,12 @@ def run_training():
# Print an overview fairly often.
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
- duration))
+ duration))
step += 1
except tf.errors.OutOfRangeError:
- print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs,
+ step))
+
def main(_):
run_training()
@@ -183,37 +186,27 @@ if __name__ == '__main__':
'--learning_rate',
type=float,
default=0.01,
- help='Initial learning rate.'
- )
+ help='Initial learning rate.')
parser.add_argument(
'--num_epochs',
type=int,
default=2,
- help='Number of epochs to run trainer.'
- )
+ help='Number of epochs to run trainer.')
parser.add_argument(
'--hidden1',
type=int,
default=128,
- help='Number of units in hidden layer 1.'
- )
+ help='Number of units in hidden layer 1.')
parser.add_argument(
'--hidden2',
type=int,
default=32,
- help='Number of units in hidden layer 2.'
- )
- parser.add_argument(
- '--batch_size',
- type=int,
- default=100,
- help='Batch size.'
- )
+ help='Number of units in hidden layer 2.')
+ parser.add_argument('--batch_size', type=int, default=100, help='Batch size.')
parser.add_argument(
'--train_dir',
type=str,
default='/tmp/data',
- help='Directory with the training data.'
- )
+ help='Directory with the training data.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/label_image/label_image.py b/tensorflow/examples/label_image/label_image.py
index d62b73384c..1c1bd57d71 100644
--- a/tensorflow/examples/label_image/label_image.py
+++ b/tensorflow/examples/label_image/label_image.py
@@ -23,6 +23,7 @@ import sys
import numpy as np
import tensorflow as tf
+
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
@@ -34,22 +35,26 @@ def load_graph(model_file):
return graph
-def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
- input_mean=0, input_std=255):
+
+def read_tensor_from_image_file(file_name,
+ input_height=299,
+ input_width=299,
+ input_mean=0,
+ input_std=255):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
if file_name.endswith(".png"):
- image_reader = tf.image.decode_png(file_reader, channels = 3,
- name='png_reader')
+ image_reader = tf.image.decode_png(
+ file_reader, channels=3, name="png_reader")
elif file_name.endswith(".gif"):
- image_reader = tf.squeeze(tf.image.decode_gif(file_reader,
- name='gif_reader'))
+ image_reader = tf.squeeze(
+ tf.image.decode_gif(file_reader, name="gif_reader"))
elif file_name.endswith(".bmp"):
- image_reader = tf.image.decode_bmp(file_reader, name='bmp_reader')
+ image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader")
else:
- image_reader = tf.image.decode_jpeg(file_reader, channels = 3,
- name='jpeg_reader')
+ image_reader = tf.image.decode_jpeg(
+ file_reader, channels=3, name="jpeg_reader")
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
@@ -59,6 +64,7 @@ def read_tensor_from_image_file(file_name, input_height=299, input_width=299,
return result
+
def load_labels(label_file):
label = []
proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
@@ -66,6 +72,7 @@ def load_labels(label_file):
label.append(l.rstrip())
return label
+
if __name__ == "__main__":
file_name = "tensorflow/examples/label_image/data/grace_hopper.jpg"
model_file = \
@@ -110,11 +117,12 @@ if __name__ == "__main__":
output_layer = args.output_layer
graph = load_graph(model_file)
- t = read_tensor_from_image_file(file_name,
- input_height=input_height,
- input_width=input_width,
- input_mean=input_mean,
- input_std=input_std)
+ t = read_tensor_from_image_file(
+ file_name,
+ input_height=input_height,
+ input_width=input_width,
+ input_mean=input_mean,
+ input_std=input_std)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
@@ -122,8 +130,9 @@ if __name__ == "__main__":
output_operation = graph.get_operation_by_name(output_name)
with tf.Session(graph=graph) as sess:
- results = sess.run(output_operation.outputs[0],
- {input_operation.outputs[0]: t})
+ results = sess.run(output_operation.outputs[0], {
+ input_operation.outputs[0]: t
+ })
results = np.squeeze(results)
top_k = results.argsort()[-5:][::-1]
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 1481a4d035..e6f94396b8 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""A client interface for TensorFlow."""
from __future__ import absolute_import
@@ -71,8 +70,9 @@ def _get_indexed_slices_value_from_fetches(fetched_vals):
def _get_feeds_for_indexed_slices(feed, feed_val):
- return list(zip([feed.values, feed.indices] if feed.dense_shape is None else
- [feed.values, feed.indices, feed.dense_shape], feed_val))
+ return list(
+ zip([feed.values, feed.indices] if feed.dense_shape is None else
+ [feed.values, feed.indices, feed.dense_shape], feed_val))
# List of extensions supported to convert run arguments into actual fetches and
@@ -124,6 +124,7 @@ _REGISTERED_EXPANSIONS = [
lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
lambda feed, feed_val: [(feed, feed_val)],
lambda feed: [feed])]
+
# pylint: enable=g-long-lambda
@@ -132,8 +133,11 @@ def _convert_to_numpy_obj(numpy_dtype, obj):
return numpy_dtype(obj) if numpy_dtype is not object else str(obj)
-def register_session_run_conversion_functions(tensor_type, fetch_function,
- feed_function=None, feed_function_for_partial_run=None):
+def register_session_run_conversion_functions(
+ tensor_type,
+ fetch_function,
+ feed_function=None,
+ feed_function_for_partial_run=None):
"""Register fetch and feed conversion functions for `tf.Session.run()`.
This function registers a triple of conversion functions for fetching and/or
@@ -174,11 +178,11 @@ def register_session_run_conversion_functions(tensor_type, fetch_function,
"""
for conversion_function in _REGISTERED_EXPANSIONS:
if issubclass(conversion_function[0], tensor_type):
- raise ValueError(
- '%s has already been registered so ignore it.', tensor_type)
+ raise ValueError('%s has already been registered so ignore it.',
+ tensor_type)
return
- _REGISTERED_EXPANSIONS.insert(0,
- (tensor_type, fetch_function, feed_function, feed_function_for_partial_run))
+ _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
+ feed_function_for_partial_run))
class _FetchMapper(object):
@@ -233,8 +237,8 @@ class _FetchMapper(object):
An instance of a subclass of `_FetchMapper` that handles the shape.
"""
if fetch is None:
- raise TypeError('Fetch argument %r has invalid type %r' %
- (fetch, type(fetch)))
+ raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
+ type(fetch)))
elif isinstance(fetch, (list, tuple)):
# NOTE(touts): This is also the code path for namedtuples.
return _ListFetchMapper(fetch)
@@ -247,8 +251,8 @@ class _FetchMapper(object):
fetches, contraction_fn = fetch_fn(fetch)
return _ElementFetchMapper(fetches, contraction_fn)
# Did not find anything.
- raise TypeError('Fetch argument %r has invalid type %r' %
- (fetch, type(fetch)))
+ raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
+ type(fetch)))
class _ElementFetchMapper(_FetchMapper):
@@ -277,8 +281,8 @@ class _ElementFetchMapper(_FetchMapper):
fetch, allow_tensor=True, allow_operation=True))
except TypeError as e:
raise TypeError('Fetch argument %r has invalid type %r, '
- 'must be a string or Tensor. (%s)'
- % (fetch, type(fetch), str(e)))
+ 'must be a string or Tensor. (%s)' %
+ (fetch, type(fetch), str(e)))
except ValueError as e:
raise ValueError('Fetch argument %r cannot be interpreted as a '
'Tensor. (%s)' % (fetch, str(e)))
@@ -376,8 +380,9 @@ class _DictFetchMapper(_FetchMapper):
"""
self._fetch_type = type(fetches)
self._keys = fetches.keys()
- self._mappers = [_FetchMapper.for_fetch(fetch)
- for fetch in fetches.values()]
+ self._mappers = [
+ _FetchMapper.for_fetch(fetch) for fetch in fetches.values()
+ ]
self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
def unique_fetches(self):
@@ -401,6 +406,7 @@ class _FetchHandler(object):
result structure matching the user-provided structure for fetches, but
containing the corresponding results.
"""
+
# TODO(touts): Make this class also take care of destructuring the feed
# dict instead of doing it in the callers.
@@ -551,8 +557,11 @@ class _DeviceAttributes(object):
return self._memory_limit_bytes
def __repr__(self):
- return '_DeviceAttributes(%s, %s, %d)' % (self.name, self.device_type,
- self.memory_limit_bytes,)
+ return '_DeviceAttributes(%s, %s, %d)' % (
+ self.name,
+ self.device_type,
+ self.memory_limit_bytes,
+ )
class BaseSession(SessionInterface):
@@ -601,8 +610,8 @@ class BaseSession(SessionInterface):
if config is not None:
if not isinstance(config, config_pb2.ConfigProto):
- raise TypeError('config must be a tf.ConfigProto, but got %s'
- % type(config))
+ raise TypeError(
+ 'config must be a tf.ConfigProto, but got %s' % type(config))
self._config = config
self._add_shapes = config.graph_options.infer_shapes
else:
@@ -976,8 +985,8 @@ class BaseSession(SessionInterface):
for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS:
if isinstance(feed, tensor_type):
return feed_fn(feed)
- raise TypeError('Feed argument %r has invalid type %r'
- % (feed, type(feed)))
+ raise TypeError('Feed argument %r has invalid type %r' % (feed,
+ type(feed)))
# Check session.
if self._closed:
@@ -998,8 +1007,8 @@ class BaseSession(SessionInterface):
for feed in feeds:
for subfeed in _feed_fn(feed):
try:
- subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
- allow_operation=False)
+ subfeed_t = self.graph.as_graph_element(
+ subfeed, allow_tensor=True, allow_operation=False)
if self._created_with_new_api:
# pylint: disable=protected-access
feed_list.append(subfeed_t._as_tf_output())
@@ -1007,8 +1016,7 @@ class BaseSession(SessionInterface):
else:
feed_list.append(compat.as_bytes(subfeed_t.name))
except Exception as e:
- e.message = ('Cannot interpret feed_list key as Tensor: '
- + e.message)
+ e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message)
e.args = (e.message,)
raise e
@@ -1041,12 +1049,13 @@ class BaseSession(SessionInterface):
def _run(self, handle, fetches, feed_dict, options, run_metadata):
"""Perform either run or partial_run, depending the presence of `handle`."""
+
def _feed_fn(feed, feed_val):
for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
if isinstance(feed, tensor_type):
return feed_fn(feed, feed_val)
- raise TypeError('Feed argument %r has invalid type %r'
- % (feed, type(feed)))
+ raise TypeError('Feed argument %r has invalid type %r' % (feed,
+ type(feed)))
# Check session.
if self._closed:
@@ -1066,11 +1075,11 @@ class BaseSession(SessionInterface):
for feed, feed_val in feed_dict.items():
for subfeed, subfeed_val in _feed_fn(feed, feed_val):
try:
- subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True,
- allow_operation=False)
+ subfeed_t = self.graph.as_graph_element(
+ subfeed, allow_tensor=True, allow_operation=False)
except Exception as e:
- raise TypeError('Cannot interpret feed_dict key as Tensor: '
- + e.args[0])
+ raise TypeError(
+ 'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
if isinstance(subfeed_val, ops.Tensor):
raise TypeError('The value of a feed cannot be a tf.Tensor object. '
@@ -1081,10 +1090,9 @@ class BaseSession(SessionInterface):
if isinstance(subfeed_val, int) and _convert_to_numpy_obj(
subfeed_dtype, subfeed_val) != subfeed_val:
raise TypeError(
- 'Type of feed value ' + str(subfeed_val) + ' with type ' +
- str(type(subfeed_val)) +
- ' is not compatible with Tensor type ' +
- str(subfeed_dtype) +
+ 'Type of feed value ' + str(subfeed_val) + ' with type ' + str(
+ type(subfeed_val)) +
+ ' is not compatible with Tensor type ' + str(subfeed_dtype) +
'. Try explicitly setting the type of the feed tensor'
' to a larger type (e.g. int64).')
@@ -1098,10 +1106,10 @@ class BaseSession(SessionInterface):
if (not is_tensor_handle_feed and
not subfeed_t.get_shape().is_compatible_with(np_val.shape)):
- raise ValueError(
- 'Cannot feed value of shape %r for Tensor %r, '
- 'which has shape %r'
- % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
+ raise ValueError('Cannot feed value of shape %r for Tensor %r, '
+ 'which has shape %r' %
+ (np_val.shape, subfeed_t.name,
+ str(subfeed_t.get_shape())))
if not self.graph.is_feedable(subfeed_t):
raise ValueError('Tensor %s may not be fed.' % subfeed_t)
@@ -1130,10 +1138,7 @@ class BaseSession(SessionInterface):
results = []
return fetch_handler.build_results(self, results)
- def make_callable(self,
- fetches,
- feed_list=None,
- accept_options=False):
+ def make_callable(self, fetches, feed_list=None, accept_options=False):
"""Returns a Python callable that runs a particular step.
The returned callable will take `len(feed_list)` arguments whose types
@@ -1176,9 +1181,12 @@ class BaseSession(SessionInterface):
# `Session._run()` so that we can convert the feeds to a list of
# strings here.
def _generic_run(*feed_args, **kwargs):
- feed_dict = {feed: feed_val
- for feed, feed_val in zip(feed_list, feed_args)}
+ feed_dict = {
+ feed: feed_val
+ for feed, feed_val in zip(feed_list, feed_args)
+ }
return self.run(fetches, feed_dict=feed_dict, **kwargs)
+
return _generic_run
# Ensure any changes to the graph are reflected in the runtime.
@@ -1198,12 +1206,11 @@ class BaseSession(SessionInterface):
fetch_list = _name_list(fetch_handler.fetches())
target_list = _name_list(fetch_handler.targets())
- def _callable_template_with_options_and_metadata(
- fetch_list,
- target_list,
- fetch_handler,
- options=None,
- run_metadata=None):
+ def _callable_template_with_options_and_metadata(fetch_list,
+ target_list,
+ fetch_handler,
+ options=None,
+ run_metadata=None):
"""Template callable that accepts RunOptions and RunMetadata."""
options_ptr = tf_session.TF_NewBufferFromString(
compat.as_bytes(options.SerializeToString())) if options else None
@@ -1215,9 +1222,9 @@ class BaseSession(SessionInterface):
self._session, options_ptr, {}, fetch_list, target_list,
run_metadata_ptr, status)
else:
- results = tf_session.TF_Run(
- self._session, options_ptr, {}, fetch_list, target_list, status,
- run_metadata_ptr)
+ results = tf_session.TF_Run(self._session, options_ptr, {},
+ fetch_list, target_list, status,
+ run_metadata_ptr)
if fetch_handler:
results = fetch_handler.build_results(self, results)
else:
@@ -1233,37 +1240,40 @@ class BaseSession(SessionInterface):
return results
if accept_options:
- return functools.partial(
- _callable_template_with_options_and_metadata, fetch_list,
- target_list, fetch_handler)
+ return functools.partial(_callable_template_with_options_and_metadata,
+ fetch_list, target_list, fetch_handler)
elif isinstance(fetches, ops.Operation):
# Special case for fetching a single operation, because the
# function will have no return value.
assert not fetch_list
assert len(target_list) == 1
+
def _single_operation_run():
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
- tf_session.TF_SessionRun_wrapper(
- self._session, None, {}, [], target_list, None, status)
+ tf_session.TF_SessionRun_wrapper(self._session, None, {}, [],
+ target_list, None, status)
else:
- tf_session.TF_Run(
- self._session, None, {}, [], target_list, status, None)
+ tf_session.TF_Run(self._session, None, {}, [], target_list, status,
+ None)
+
return _single_operation_run
elif isinstance(fetches, ops.Tensor):
# Special case for fetching a single tensor, because the
# function can return the result of `TF_Run()` directly.
assert len(fetch_list) == 1
assert not target_list
+
def _single_tensor_run():
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
results = tf_session.TF_SessionRun_wrapper(
self._session, None, {}, fetch_list, [], None, status)
else:
- results = tf_session.TF_Run(
- self._session, None, {}, fetch_list, [], status, None)
+ results = tf_session.TF_Run(self._session, None, {}, fetch_list, [],
+ status, None)
return results[0]
+
return _single_tensor_run
else:
# In all other cases, we must use `fetch_handler` to build the
@@ -1274,16 +1284,17 @@ class BaseSession(SessionInterface):
results = tf_session.TF_SessionRun_wrapper(
self._session, None, {}, fetch_list, target_list, None, status)
else:
- results = tf_session.TF_Run(
- self._session, None, {}, fetch_list, target_list, status, None)
+ results = tf_session.TF_Run(self._session, None, {}, fetch_list,
+ target_list, status, None)
return fetch_handler.build_results(self, results)
+
return _fetch_handler_run
# Captures the name of a node in an error status.
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
- def _do_run(self, handle, target_list, fetch_list, feed_dict,
- options, run_metadata):
+ def _do_run(self, handle, target_list, fetch_list, feed_dict, options,
+ run_metadata):
"""Runs a step based on the given fetches and feeds.
Args:
@@ -1320,13 +1331,12 @@ class BaseSession(SessionInterface):
self._extend_graph()
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
- return tf_session.TF_SessionRun_wrapper(
- session, options, feed_dict, fetch_list, target_list,
- run_metadata, status)
+ return tf_session.TF_SessionRun_wrapper(session, options, feed_dict,
+ fetch_list, target_list,
+ run_metadata, status)
else:
- return tf_session.TF_Run(session, options,
- feed_dict, fetch_list, target_list,
- status, run_metadata)
+ return tf_session.TF_Run(session, options, feed_dict, fetch_list,
+ target_list, status, run_metadata)
def _prun_fn(session, handle, feed_dict, fetch_list):
if target_list:
@@ -1365,20 +1375,20 @@ class BaseSession(SessionInterface):
def _extend_graph(self):
# Nothing to do if we're using the new session interface
# TODO(skyewm): remove this function altogether eventually
- if self._created_with_new_api: return
+ if self._created_with_new_api:
+ return
# Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock:
if self._graph.version > self._current_version:
# pylint: disable=protected-access
graph_def, self._current_version = self._graph._as_graph_def(
- from_version=self._current_version,
- add_shapes=self._add_shapes)
+ from_version=self._current_version, add_shapes=self._add_shapes)
# pylint: enable=protected-access
with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_ExtendGraph(
- self._session, graph_def.SerializeToString(), status)
+ tf_session.TF_ExtendGraph(self._session,
+ graph_def.SerializeToString(), status)
self._opened = True
# The threshold to run garbage collection to delete dead tensors.
@@ -1398,9 +1408,8 @@ class BaseSession(SessionInterface):
feeds = {}
fetches = []
for deleter_key, tensor_handle in enumerate(tensors_to_delete):
- holder, deleter = session_ops._get_handle_deleter(self.graph,
- deleter_key,
- tensor_handle)
+ holder, deleter = session_ops._get_handle_deleter(
+ self.graph, deleter_key, tensor_handle)
feeds[holder] = tensor_handle
fetches.append(deleter)
self.run(fetches, feed_dict=feeds)
@@ -1471,7 +1480,8 @@ class Session(BaseSession):
sess.run(...)
```
- The [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
+ The
+ [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
protocol buffer exposes various configuration options for a
session. For example, to create a session that uses soft constraints
for device placement, and log the resulting placement decisions,
@@ -1502,7 +1512,8 @@ class Session(BaseSession):
@{$distributed$Distributed TensorFlow}
for more examples.
graph: (Optional.) The `Graph` to be launched (described above).
- config: (Optional.) A [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
+ config: (Optional.) A
+ [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
protocol buffer with configuration options for the session.
"""
@@ -1526,8 +1537,8 @@ class Session(BaseSession):
def __exit__(self, exec_type, exec_value, exec_tb):
if exec_type is errors.OpError:
logging.error('Session closing due to OpError: %s', (exec_value,))
- self._default_session_context_manager.__exit__(
- exec_type, exec_value, exec_tb)
+ self._default_session_context_manager.__exit__(exec_type, exec_value,
+ exec_tb)
self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb)
self._default_session_context_manager = None
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index c579fba339..768a5db88a 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Tests for tensorflow.python.client.session.Session."""
from __future__ import absolute_import
from __future__ import division
@@ -57,7 +56,6 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
-
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
@@ -95,14 +93,18 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(arr, copy_val)
# Test without feed.
copy_val = copy.eval()
- self.assertAllEqual(np.asarray([[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]],
- dtype=np.float32), copy_val)
+ self.assertAllEqual(
+ np.asarray(
+ [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32),
+ copy_val)
def testManyCPUs(self):
# TODO(keveman): Implement ListDevices and test for the number of
# devices returned by ListDevices.
with session.Session(
- config=config_pb2.ConfigProto(device_count={'CPU': 2})):
+ config=config_pb2.ConfigProto(device_count={
+ 'CPU': 2
+ })):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
@@ -161,20 +163,23 @@ class SessionTest(test_util.TensorFlowTestCase):
def exc_predicate(e):
return (e.op is None and e.node_def is None and
e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+
with self.assertRaisesOpError(exc_predicate):
# Run with a bogus handle.
s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
def testOpConstructionErrorPayload(self):
- if ops._USE_C_API: return # No shape registration for 'ConstructionFails'
+ if ops._USE_C_API:
+ return # No shape registration for 'ConstructionFails'
with session.Session():
failing_op = ops.get_default_graph().create_op(
'ConstructionFails', [], [], name='f')
def exc_predicate(e):
- return (e.op == failing_op
- and e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+ return (e.op == failing_op and
+ e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+
with self.assertRaisesOpError(exc_predicate):
failing_op.run()
@@ -191,9 +196,9 @@ class SessionTest(test_util.TensorFlowTestCase):
# pylint: enable=protected-access
def exc_predicate(e):
- return (e.op == c.op
- and e.op._original_op == b.op
- and e.op._original_op._original_op == a.op)
+ return (e.op == c.op and e.op._original_op == b.op and
+ e.op._original_op._original_op == a.op)
+
with self.assertRaisesOpError(exc_predicate):
c.eval()
@@ -341,8 +346,12 @@ class SessionTest(test_util.TensorFlowTestCase):
b = control_flow_ops.no_op() # An op, not a tensor.
c = constant_op.constant(c_val)
# List of lists, tuples, namedtuple, and dict
- res = sess.run([[a, b, c], (a, b, c), ABC(a=a, b=b, c=c),
- {'a': a.name, 'c': c, 'b': b}])
+ res = sess.run([[a, b, c], (a, b, c),
+ ABC(a=a, b=b, c=c), {
+ 'a': a.name,
+ 'c': c,
+ 'b': b
+ }])
self.assertTrue(isinstance(res, list))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res[0], list))
@@ -365,8 +374,11 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(b_val, res[3]['b'])
self.assertEqual(c_val, res[3]['c'])
# Tuple of lists, tuples, namedtuple, and dict
- res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c),
- {'a': a, 'c': c, 'b': b}))
+ res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, c=c), {
+ 'a': a,
+ 'c': c,
+ 'b': b
+ }))
self.assertTrue(isinstance(res, tuple))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res[0], list))
@@ -389,10 +401,16 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(b_val, res[3]['b'])
self.assertEqual(c_val, res[3]['c'])
# Namedtuple of lists, tuples, namedtuples, and dict
- res = sess.run(DEFG(d=[a, b, c],
- e=(a, b, c),
- f=ABC(a=a.name, b=b, c=c),
- g={'a': a, 'c': c, 'b': b}))
+ res = sess.run(
+ DEFG(
+ d=[a, b, c],
+ e=(a, b, c),
+ f=ABC(a=a.name, b=b, c=c),
+ g={
+ 'a': a,
+ 'c': c,
+ 'b': b
+ }))
self.assertTrue(isinstance(res, DEFG))
self.assertTrue(isinstance(res.d, list))
self.assertEqual(3, len(res.d))
@@ -414,10 +432,16 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(b_val, res.g['b'])
self.assertEqual(c_val, res.g['c'])
# Dict of lists, tuples, namedtuples, and dict
- res = sess.run({'d': [a, b, c],
- 'e': (a, b, c),
- 'f': ABC(a=a, b=b, c=c),
- 'g': {'a': a.name, 'c': c, 'b': b}})
+ res = sess.run({
+ 'd': [a, b, c],
+ 'e': (a, b, c),
+ 'f': ABC(a=a, b=b, c=c),
+ 'g': {
+ 'a': a.name,
+ 'c': c,
+ 'b': b
+ }
+ })
self.assertTrue(isinstance(res, dict))
self.assertEqual(4, len(res))
self.assertTrue(isinstance(res['d'], list))
@@ -516,8 +540,7 @@ class SessionTest(test_util.TensorFlowTestCase):
values = np.array([1.0, 2.0]).astype(np.float32)
shape = np.array([7, 9, 2]).astype(np.int64)
sp = sparse_tensor.SparseTensor(
- constant_op.constant(indices),
- constant_op.constant(values),
+ constant_op.constant(indices), constant_op.constant(values),
constant_op.constant(shape))
# Single fetch, use as tuple
sp_out = s.run(sp)
@@ -587,14 +610,17 @@ class SessionTest(test_util.TensorFlowTestCase):
sp = sparse_tensor.SparseTensor(
array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
array_ops.placeholder(dtype=np.float32, shape=(2,)),
- array_ops.placeholder(dtype=np.int64, shape=(3,)),)
+ array_ops.placeholder(dtype=np.int64, shape=(3,)),
+ )
sp_indices = array_ops.identity(sp.indices)
sp_values = array_ops.identity(sp.values)
sp_shape = array_ops.identity(sp.dense_shape)
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: (indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
@@ -605,20 +631,23 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(sp_out.dense_shape, shape)
# Feed with SparseTensorValue
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape],
- {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue, fetch SparseTensorValue
- sp2_out = s.run(
- sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ sp2_out = s.run(sp2, {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(sp2_out.indices, indices)
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.dense_shape, shape)
# Feed SparseTensorValue and fetch sp directly.
- sp_out = s.run(
- sp, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ sp_out = s.run(sp, {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(sp_out.indices, indices)
self.assertAllEqual(sp_out.values, values)
self.assertAllEqual(sp_out.dense_shape, shape)
@@ -635,20 +664,24 @@ class SessionTest(test_util.TensorFlowTestCase):
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: (indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape],
- {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue, fetch SparseTensorValue
- sp2_out = s.run(
- sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ sp2_out = s.run(sp2, {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(sp2_out.indices, indices)
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.dense_shape, shape)
@@ -666,20 +699,24 @@ class SessionTest(test_util.TensorFlowTestCase):
sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: (indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape],
- {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
# Feed with SparseTensorValue, fetch SparseTensorValue
- sp2_out = s.run(
- sp2, {sp: sparse_tensor.SparseTensorValue(indices, values, shape)})
+ sp2_out = s.run(sp2, {
+ sp: sparse_tensor.SparseTensorValue(indices, values, shape)
+ })
self.assertAllEqual(sp2_out.indices, indices)
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.dense_shape, shape)
@@ -689,9 +726,8 @@ class SessionTest(test_util.TensorFlowTestCase):
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
values = np.array([1.0, 2.0]).astype(np.float32)
shape = np.array([7, 9, 2]).astype(np.int64)
- sp = array_ops.sparse_placeholder(dtype=np.float32,
- shape=shape,
- name='placeholder1')
+ sp = array_ops.sparse_placeholder(
+ dtype=np.float32, shape=shape, name='placeholder1')
self.assertAllEqual(sp.dense_shape.eval(session=s), shape)
self.assertAllEqual(tensor_util.constant_value(sp.dense_shape), shape)
sp_indices = array_ops.identity(sp.indices)
@@ -699,7 +735,9 @@ class SessionTest(test_util.TensorFlowTestCase):
sp_shape = array_ops.identity(sp.dense_shape)
# Feed with tuple
indices_out, values_out, shape_out = s.run(
- [sp_indices, sp_values, sp_shape], {sp: (indices, values)})
+ [sp_indices, sp_values, sp_shape], {
+ sp: (indices, values)
+ })
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(values_out, values)
self.assertAllEqual(shape_out, shape)
@@ -745,33 +783,34 @@ class SessionTest(test_util.TensorFlowTestCase):
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
dense_shape = np.array([7, 9, 2]).astype(np.int64)
ind = ops.IndexedSlices(
- array_ops.placeholder(dtype=np.float32,
- shape=(2,)),
- array_ops.placeholder(dtype=np.int64,
- shape=(2, 3)),
- array_ops.placeholder(dtype=np.int64,
- shape=(3,)),)
+ array_ops.placeholder(dtype=np.float32, shape=(2,)),
+ array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
+ array_ops.placeholder(dtype=np.int64, shape=(3,)),
+ )
ind_values = array_ops.identity(ind.values)
ind_indices = array_ops.identity(ind.indices)
ind_dense_shape = array_ops.identity(ind.dense_shape)
ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape)
# Feed with tuple
values_out, indices_out, dense_shape_out = s.run(
- [ind_values, ind_indices, ind_dense_shape],
- {ind: (values, indices, dense_shape)})
+ [ind_values, ind_indices, ind_dense_shape], {
+ ind: (values, indices, dense_shape)
+ })
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# Feed with IndexedSlicesValue
values_out, indices_out, dense_shape_out = s.run(
- [ind_values, ind_indices, ind_dense_shape],
- {ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
+ [ind_values, ind_indices, ind_dense_shape], {
+ ind: ops.IndexedSlicesValue(values, indices, dense_shape)
+ })
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
self.assertAllEqual(dense_shape_out, dense_shape)
# Feed with IndexedSlicesValue, fetch IndexedSlicesValue
- ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices,
- dense_shape)})
+ ind2_out = s.run(ind2, {
+ ind: ops.IndexedSlicesValue(values, indices, dense_shape)
+ })
self.assertAllEqual(ind2_out.values, values)
self.assertAllEqual(ind2_out.indices, indices)
self.assertAllEqual(ind2_out.dense_shape, dense_shape)
@@ -816,28 +855,27 @@ class SessionTest(test_util.TensorFlowTestCase):
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
dense_shape = None
ind = ops.IndexedSlices(
- array_ops.placeholder(dtype=np.float32,
- shape=(2,)),
- array_ops.placeholder(dtype=np.int64,
- shape=(2, 3)),
- None)
+ array_ops.placeholder(dtype=np.float32, shape=(2,)),
+ array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None)
ind_values = array_ops.identity(ind.values)
ind_indices = array_ops.identity(ind.indices)
ind2 = ops.IndexedSlices(ind_values, ind_indices)
# Feed with tuple
- values_out, indices_out = s.run(
- [ind_values, ind_indices], {ind: (values, indices)})
+ values_out, indices_out = s.run([ind_values, ind_indices], {
+ ind: (values, indices)
+ })
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
# Feed with IndexedSlicesValue
- values_out, indices_out = s.run(
- [ind_values, ind_indices],
- {ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
+ values_out, indices_out = s.run([ind_values, ind_indices], {
+ ind: ops.IndexedSlicesValue(values, indices, dense_shape)
+ })
self.assertAllEqual(values_out, values)
self.assertAllEqual(indices_out, indices)
# Feed with IndexedSlicesValue, fetch IndexedSlicesValue
- ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices,
- dense_shape)})
+ ind2_out = s.run(ind2, {
+ ind: ops.IndexedSlicesValue(values, indices, dense_shape)
+ })
self.assertAllEqual(ind2_out.values, values)
self.assertAllEqual(ind2_out.indices, indices)
self.assertAllEqual(ind2_out.dense_shape, dense_shape)
@@ -986,8 +1024,9 @@ class SessionTest(test_util.TensorFlowTestCase):
constructed_events = [threading.Event() for _ in range(10)]
continue_event = threading.Event()
for i, constructed_event in enumerate(constructed_events):
- t = self.checkedThread(target=self._testDefaultGraphInThread,
- args=(constructed_event, continue_event, i))
+ t = self.checkedThread(
+ target=self._testDefaultGraphInThread,
+ args=(constructed_event, continue_event, i))
threads.append(t)
for t in threads:
t.start()
@@ -1006,6 +1045,7 @@ class SessionTest(test_util.TensorFlowTestCase):
ev.wait()
val = c.eval(session=sess)
self.assertEqual(val, 5.0)
+
threads = [self.checkedThread(target=run_step) for _ in range(100)]
for t in threads:
t.start()
@@ -1038,11 +1078,10 @@ class SessionTest(test_util.TensorFlowTestCase):
def testGraphDef(self):
with session.Session() as sess:
- self.assertProtoEquals(
- 'versions { producer: %d min_consumer: %d }' % (
- versions.GRAPH_DEF_VERSION,
- versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
- sess.graph_def)
+ self.assertProtoEquals('versions { producer: %d min_consumer: %d }' %
+ (versions.GRAPH_DEF_VERSION,
+ versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
+ sess.graph_def)
c = constant_op.constant(5.0, name='c')
self.assertEquals(len(sess.graph_def.node), 1)
d = constant_op.constant(6.0, name='d')
@@ -1072,6 +1111,7 @@ class SessionTest(test_util.TensorFlowTestCase):
lambda e: 'Attempted to use a closed Session.' in str(e)):
while True:
sess.run(c)
+
t = threading.Thread(target=update_thread)
t.start()
time.sleep(0.1)
@@ -1177,17 +1217,11 @@ class SessionTest(test_util.TensorFlowTestCase):
def testFeedAndFetch(self):
with session.Session() as sess:
- for dtype in [dtypes.float16,
- dtypes.float32,
- dtypes.float64,
- dtypes.int32,
- dtypes.uint8,
- dtypes.int16,
- dtypes.int8,
- dtypes.int64,
- dtypes.bool,
- dtypes.complex64,
- dtypes.complex128]:
+ for dtype in [
+ dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool,
+ dtypes.complex64, dtypes.complex128
+ ]:
for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
np_dtype = dtype.as_numpy_dtype
@@ -1206,13 +1240,19 @@ class SessionTest(test_util.TensorFlowTestCase):
np_array = np_array.astype(np_dtype)
self.assertAllEqual(np_array,
- sess.run(out_t, feed_dict={feed_t: np_array}))
+ sess.run(out_t, feed_dict={
+ feed_t: np_array
+ }))
# Check that we can also get the feed back.
self.assertAllEqual(np_array,
- sess.run(feed_t, feed_dict={feed_t: np_array}))
+ sess.run(feed_t, feed_dict={
+ feed_t: np_array
+ }))
# Also check that we can get both back.
- out_v, feed_v = sess.run([out_t, feed_t],
- feed_dict={feed_t: np_array})
+ out_v, feed_v = sess.run(
+ [out_t, feed_t], feed_dict={
+ feed_t: np_array
+ })
self.assertAllEqual(np_array, out_v)
self.assertAllEqual(np_array, feed_v)
@@ -1257,9 +1297,11 @@ class SessionTest(test_util.TensorFlowTestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
- self.assertAllClose(
- 42.0,
- tensor_runner(41.0, options=run_options, run_metadata=run_metadata))
+ self.assertAllClose(42.0,
+ tensor_runner(
+ 41.0,
+ options=run_options,
+ run_metadata=run_metadata))
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
def testFeedError(self):
@@ -1296,8 +1338,9 @@ class SessionTest(test_util.TensorFlowTestCase):
size = 1
for s in shape:
size *= s
- c_list = np.array([compat.as_bytes(str(i)) for i in xrange(size)],
- dtype=np.object).reshape(shape) if size > 0 else []
+ c_list = np.array(
+ [compat.as_bytes(str(i)) for i in xrange(size)],
+ dtype=np.object).reshape(shape) if size > 0 else []
c = constant_op.constant(c_list)
self.assertAllEqual(c.eval(), c_list)
@@ -1307,13 +1350,16 @@ class SessionTest(test_util.TensorFlowTestCase):
size = 1
for s in shape:
size *= s
- c_list = np.array([compat.as_bytes(str(i)) for i in xrange(size)],
- dtype=np.object).reshape(shape)
+ c_list = np.array(
+ [compat.as_bytes(str(i)) for i in xrange(size)],
+ dtype=np.object).reshape(shape)
feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape)
c = array_ops.identity(feed_t)
self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list)
- self.assertAllEqual(sess.run(feed_t, feed_dict={feed_t: c_list}),
- c_list)
+ self.assertAllEqual(
+ sess.run(feed_t, feed_dict={
+ feed_t: c_list
+ }), c_list)
c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list})
self.assertAllEqual(c_v, c_list)
self.assertAllEqual(feed_v, c_list)
@@ -1329,8 +1375,10 @@ class SessionTest(test_util.TensorFlowTestCase):
def testStringFeedWithUnicode(self):
with session.Session():
- c_list = [u'\n\x01\x00', u'\n\x00\x01',
- u'\u26a3 unicode', u'\U0001f60e deal with it']
+ c_list = [
+ u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode',
+ u'\U0001f60e deal with it'
+ ]
feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)])
c = array_ops.identity(feed_t)
@@ -1423,9 +1471,10 @@ class SessionTest(test_util.TensorFlowTestCase):
sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
self.assertTrue(not run_metadata.HasField('step_stats'))
- sess.run(constant_op.constant(1.0),
- options=run_options,
- run_metadata=run_metadata)
+ sess.run(
+ constant_op.constant(1.0),
+ options=run_options,
+ run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField('step_stats'))
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
@@ -1439,23 +1488,26 @@ class SessionTest(test_util.TensorFlowTestCase):
with session.Session() as sess:
# all combinations are valid
sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
- sess.run(constant_op.constant(1.0), options=None,
- run_metadata=run_metadata)
+ sess.run(
+ constant_op.constant(1.0), options=None, run_metadata=run_metadata)
self.assertTrue(not run_metadata.HasField('step_stats'))
- sess.run(constant_op.constant(1.0), options=run_options,
- run_metadata=None)
+ sess.run(
+ constant_op.constant(1.0), options=run_options, run_metadata=None)
self.assertTrue(not run_metadata.HasField('step_stats'))
- sess.run(constant_op.constant(1.0), options=run_options,
- run_metadata=run_metadata)
+ sess.run(
+ constant_op.constant(1.0),
+ options=run_options,
+ run_metadata=run_metadata)
self.assertTrue(run_metadata.HasField('step_stats'))
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
def testFeedShapeCompatibility(self):
# TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
- if ops._USE_C_API: return
+ if ops._USE_C_API:
+ return
with session.Session() as sess:
some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
@@ -1499,8 +1551,11 @@ class SessionTest(test_util.TensorFlowTestCase):
d = math_ops.multiply(c, c)
for step in xrange(120):
run_metadata = config_pb2.RunMetadata()
- sess.run(d, feed_dict={a: 1.0},
- options=run_options, run_metadata=run_metadata)
+ sess.run(
+ d,
+ feed_dict={a: 1.0},
+ options=run_options,
+ run_metadata=run_metadata)
if step == 99:
self.assertTrue(run_metadata.HasField('cost_graph'))
else:
@@ -1569,8 +1624,7 @@ class SessionTest(test_util.TensorFlowTestCase):
def testTimeoutWithShortOperations(self):
num_epochs = 5
- q = data_flow_ops.FIFOQueue(
- capacity=50, dtypes=[dtypes.int32], shapes=[()])
+ q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()])
enqueue_op = q.enqueue_many(constant_op.constant([1, 2]))
# Use a 10-second timeout, which should be longer than any
@@ -1582,7 +1636,9 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(sess.run(q.size()), num_epochs * 2)
def testRegisterFetchAndFeedConversionFunctions(self):
+
class SquaredTensor(object):
+
def __init__(self, tensor):
self.sq = math_ops.square(tensor)
@@ -1591,24 +1647,27 @@ class SessionTest(test_util.TensorFlowTestCase):
feed_fn2 = lambda feed: [feed.sq]
session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
- feed_fn1, feed_fn2)
+ feed_fn1, feed_fn2)
with self.assertRaises(ValueError):
- session.register_session_run_conversion_functions(SquaredTensor,
- fetch_fn, feed_fn1, feed_fn2)
+ session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
+ feed_fn1, feed_fn2)
with self.test_session() as sess:
np1 = np.array([1.0, 1.5, 2.0, 2.5])
np2 = np.array([3.0, 3.5, 4.0, 4.5])
squared_tensor = SquaredTensor(np2)
squared_eval = sess.run(squared_tensor)
self.assertAllClose(np2 * np2, squared_eval)
- squared_eval = sess.run(squared_tensor, feed_dict={
- squared_tensor : np1 * np1})
+ squared_eval = sess.run(
+ squared_tensor, feed_dict={
+ squared_tensor: np1 * np1
+ })
self.assertAllClose(np1 * np1, squared_eval)
partial_run = sess.partial_run_setup([squared_tensor], [])
squared_eval = sess.partial_run(partial_run, squared_tensor)
self.assertAllClose(np2 * np2, squared_eval)
def testDefaultLogDevicePlacement(self):
+
class CaptureStderr(str):
"""Class to capture stderr from C++ shared library."""
@@ -1719,6 +1778,7 @@ class SessionTest(test_util.TensorFlowTestCase):
def runTestAddFunctionToSession(self, target=''):
"""Add a function to a session after the graph has already been run."""
+
@function.Defun(dtypes.float32)
def foo(x):
return x + 1
@@ -1753,6 +1813,7 @@ class SessionTest(test_util.TensorFlowTestCase):
TypeError, 'Type of feed value 1 with type <(\w+) \'int\'> is not'):
sess.run(a, feed_dict={a: 1})
+
class GraphMutationTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -1803,8 +1864,7 @@ class GraphMutationTest(test_util.TensorFlowTestCase):
with session.Session(graph=g) as sess:
self.assertAllEqual(1.0, sess.run(b))
- b.op._set_attr('DstT',
- attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
+ b.op._set_attr('DstT', attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
with self.assertRaisesRegexp(
errors.FailedPreconditionError,
'Cast.*was changed by setting attribute after it was run'):
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 5fb389cf92..43cbde69d9 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -59,7 +59,7 @@ tf_py_test(
tf_py_test(
name = "dataset_from_generator_op_test",
- size = "small",
+ size = "medium",
srcs = ["dataset_from_generator_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6162644036..647f03351d 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -766,6 +766,9 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors,
return;
}
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
+ if (PyErr_Occurred()) {
+ return;
+ }
std::vector<tensorflow::eager::TapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 41f55b12af..c519fd557a 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -604,6 +604,7 @@ py_library(
":metric_keys",
":model_fn",
":prediction_keys",
+ ":util",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 94a5d3a342..cb9e3fc6ca 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -24,6 +24,7 @@ import collections
import six
from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export_output
@@ -371,6 +372,64 @@ def _check_logits_final_dim(logits, expected_logits_dimension):
return array_ops.identity(logits, name=scope)
+def _validate_loss_fn_args(loss_fn):
+ """Validates loss_fn arguments.
+
+ Required arguments: labels, logits.
+ Optional arguments: features.
+
+ Args:
+ loss_fn: The loss function.
+ Raises:
+ ValueError: If the signature is unexpected.
+ """
+ loss_fn_args = util.fn_args(loss_fn)
+ for required_arg in ['labels', 'logits']:
+ if required_arg not in loss_fn_args:
+ raise ValueError(
+ 'loss_fn must contain argument: {}. '
+ 'Given arguments: {}'.format(required_arg, loss_fn_args))
+ invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features']))
+ if invalid_args:
+ raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))
+
+
+def _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1):
+ """Calls loss_fn and checks the returned shape.
+
+ Args:
+ loss_fn: The loss function.
+ labels: Processed labels Tensor.
+ logits: Logits Tensor of shape [D0, D1, ... DN, logits_dimension].
+ features: Features dict.
+ expected_loss_dim: The expected last dimension of loss Tensor.
+ Returns:
+ Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].
+ """
+ loss_fn_args = util.fn_args(loss_fn)
+ kwargs = {}
+ if 'features' in loss_fn_args:
+ kwargs['features'] = features
+ with ops.name_scope(
+ None, 'call_loss_fn',
+ values=[labels, logits] + list(six.itervalues(features))):
+ unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)
+ logits_shape = array_ops.shape(logits, name='logits_shape')
+ expected_loss_shape = array_ops.concat(
+ [logits_shape[:-1], [expected_loss_dim]], axis=0,
+ name='expected_loss_shape')
+ loss_shape = array_ops.shape(unweighted_loss, name='loss_shape')
+ check_loss_shape_op = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(loss_shape, expected_loss_shape)),
+ data=[
+ 'loss_fn must return Tensor of shape '
+ '[D0, D1, ... DN, {}]. '.format(expected_loss_dim),
+ 'logits_shape: ', logits_shape, 'loss_shape: ', loss_shape],
+ name='check_loss_shape')
+ with ops.control_dependencies([check_loss_shape_op]):
+ return array_ops.identity(unweighted_loss)
+
+
def _indicator_labels_mean(labels, weights=None, name=None):
with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope:
labels = math_ops.to_float(labels, name='labels')
@@ -467,6 +526,7 @@ def _multi_class_head_with_softmax_cross_entropy_loss(
weight_column=None,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""Creates a '_Head' for multi class classification.
@@ -485,6 +545,12 @@ def _multi_class_head_with_softmax_cross_entropy_loss(
labels have shape `[batch_size, 1]`, the loss is the weighted sum over
`batch_size`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with
+ shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
+ the input labels before passing them to `loss_fn`.
+
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
@@ -499,6 +565,7 @@ def _multi_class_head_with_softmax_cross_entropy_loss(
`label_vocabulary` is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -517,11 +584,14 @@ def _multi_class_head_with_softmax_cross_entropy_loss(
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
+ if loss_fn:
+ _validate_loss_fn_args(loss_fn)
return _MultiClassHeadWithSoftmaxCrossEntropyLoss(
n_classes=n_classes,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
@@ -533,6 +603,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
weight_column=None,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
if (n_classes is None) or (n_classes <= 2):
raise ValueError('n_classes must be > 2: %s.' % n_classes)
@@ -540,6 +611,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
self._weight_column = weight_column
self._label_vocabulary = label_vocabulary
self._loss_reduction = loss_reduction
+ self._loss_fn = loss_fn
self._name = name
@property
@@ -602,10 +674,15 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
labels = _check_dense_labels_match_logits_and_reshape(
labels=labels, logits=logits, expected_labels_dimension=1)
label_ids = self._label_ids(labels)
- unweighted_loss = losses.sparse_softmax_cross_entropy(
- labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
- # Restore the squeezed dim, so unweighted_loss matches the weights shape.
- unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1)
+ if self._loss_fn:
+ unweighted_loss = _call_loss_fn(
+ loss_fn=self._loss_fn, labels=label_ids, logits=logits,
+ features=features, expected_loss_dim=1)
+ else:
+ unweighted_loss = losses.sparse_softmax_cross_entropy(
+ labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
+ # Restore the squeezed dim, so unweighted_loss matches the weights shape.
+ unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1)
weights = _get_weights_and_check_match_logits(
features=features, weight_column=self._weight_column, logits=logits)
training_loss = losses.compute_weighted_loss(
@@ -734,8 +811,12 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
- weight_column=None, thresholds=None, label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM, name=None):
+ weight_column=None,
+ thresholds=None,
+ label_vocabulary=None,
+ loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
+ name=None):
"""Creates a `_Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
@@ -755,6 +836,12 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
labels have shape `[batch_size, 1]`, the loss is the weighted sum over
`batch_size`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with
+ shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
+ the input labels before passing them to `loss_fn`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -772,6 +859,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -795,11 +883,14 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
+ if loss_fn:
+ _validate_loss_fn_args(loss_fn)
return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(
weight_column=weight_column,
thresholds=thresholds,
label_vocabulary=label_vocabulary,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
@@ -811,11 +902,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
thresholds=None,
label_vocabulary=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
self._weight_column = weight_column
self._thresholds = thresholds
self._label_vocabulary = label_vocabulary
self._loss_reduction = loss_reduction
+ self._loss_fn = loss_fn
self._name = name
@property
@@ -916,8 +1009,13 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
name='class_id_lookup').lookup(labels)
labels = math_ops.to_float(labels)
labels = _assert_range(labels, 2)
- unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
- labels=labels, logits=logits)
+ if self._loss_fn:
+ unweighted_loss = _call_loss_fn(
+ loss_fn=self._loss_fn, labels=labels, logits=logits,
+ features=features, expected_loss_dim=1)
+ else:
+ unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
+ labels=labels, logits=logits)
weights = _get_weights_and_check_match_logits(
features=features, weight_column=self._weight_column, logits=logits)
training_loss = losses.compute_weighted_loss(
@@ -1057,6 +1155,7 @@ def _regression_head_with_mean_squared_error_loss(
weight_column=None,
label_dimension=1,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""Creates a `_Head` for regression using the `mean_squared_error` loss.
@@ -1075,6 +1174,10 @@ def _regression_head_with_mean_squared_error_loss(
`[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
`[D0, D1, ... DN, label_dimension]`.
+ Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
+ `(labels, logits, features)` as arguments and returns unreduced loss with
+ shape `[D0, D1, ... DN, label_dimension]`.
+
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -1085,6 +1188,7 @@ def _regression_head_with_mean_squared_error_loss(
`[batch_size, label_dimension]`).
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
reduce training loss over batch. Defaults to `SUM`.
+ loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -1097,10 +1201,13 @@ def _regression_head_with_mean_squared_error_loss(
if (loss_reduction not in losses.Reduction.all() or
loss_reduction == losses.Reduction.NONE):
raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
+ if loss_fn:
+ _validate_loss_fn_args(loss_fn)
return _RegressionHeadWithMeanSquaredErrorLoss(
weight_column=weight_column,
label_dimension=label_dimension,
loss_reduction=loss_reduction,
+ loss_fn=loss_fn,
name=name)
@@ -1112,6 +1219,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
label_dimension,
weight_column=None,
loss_reduction=losses.Reduction.SUM,
+ loss_fn=None,
name=None):
"""`Head` for regression."""
if label_dimension < 1:
@@ -1119,6 +1227,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
self._logits_dimension = label_dimension
self._weight_column = weight_column
self._loss_reduction = loss_reduction
+ self._loss_fn = loss_fn
self._name = name
@property
@@ -1137,8 +1246,13 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
labels=labels, logits=logits,
expected_labels_dimension=self._logits_dimension)
labels = math_ops.to_float(labels)
- unweighted_loss = losses.mean_squared_error(
- labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
+ if self._loss_fn:
+ unweighted_loss = _call_loss_fn(
+ loss_fn=self._loss_fn, labels=labels, logits=logits,
+ features=features, expected_loss_dim=self._logits_dimension)
+ else:
+ unweighted_loss = losses.mean_squared_error(
+ labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
weights = _get_weights_and_check_match_logits(
features=features, weight_column=self._weight_column, logits=logits,
allow_per_logit_weights=True)
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 4e871e8f37..3a03770af4 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -111,6 +111,41 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes=3, loss_reduction=losses.Reduction.NONE)
+ def test_loss_fn_arg_labels_missing(self):
+ def _loss_fn(logits):
+ del logits # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: labels\. '
+ r'Given arguments: \(\'logits\',\)'):
+ head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_logits_missing(self):
+ def _loss_fn(labels):
+ del labels # unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: logits\. '
+ r'Given arguments: \(\'labels\',\)'):
+ head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_features_ok(self):
+ def _loss_fn(labels, logits, features):
+ del labels, logits, features # Unused
+ head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_invalid(self):
+ def _loss_fn(labels, logits, name=None):
+ del labels, logits, name # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn has unexpected args: \[\'name\'\]'):
+ head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
def test_invalid_logits_shape(self):
n_classes = 3
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)
@@ -406,6 +441,56 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)
+ def test_eval_create_loss_loss_fn(self):
+ """Tests head.create_loss for eval mode and custom loss_fn."""
+ loss = np.array([[1.], [2.]], dtype=np.float32)
+ logits_input = np.array([[-10., 10., 0.], [-15., 10., 0]], dtype=np.float32)
+ labels_input = np.array([[1], [2]], dtype=np.int64)
+ def _loss_fn(labels, logits):
+ check_labels = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(labels, labels_input)),
+ data=[labels])
+ check_logits = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(logits, logits_input)),
+ data=[logits])
+ with ops.control_dependencies([check_labels, check_logits]):
+ return constant_op.constant(loss)
+ head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits_input,
+ labels=labels_input)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(np.sum(loss), actual_training_loss.eval())
+
+ def test_eval_create_loss_loss_fn_wrong_shape(self):
+ """Tests custom loss_fn that returns Tensor of unexpected shape."""
+ loss = np.array([1., 2.], dtype=np.float32)
+ def _loss_fn(labels, logits):
+ del labels, logits # Unused
+ return constant_op.constant(loss)
+ head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
+ n_classes=3, loss_fn=_loss_fn)
+
+ logits = np.array([[-10., 10., 0.], [-15., 10., 0.]], dtype=np.float32)
+ labels = np.array([[1], [2]], dtype=np.int64)
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] '
+ r'\[logits_shape: \] \[2 3\] \[loss_shape: \] \[2\]'):
+ actual_training_loss.eval()
+
def test_eval_labels_none(self):
"""Tests that error is raised when labels is None."""
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
@@ -1204,6 +1289,41 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
loss_reduction=losses.Reduction.NONE)
+ def test_loss_fn_arg_labels_missing(self):
+ def _loss_fn(logits):
+ del logits # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: labels\. '
+ r'Given arguments: \(\'logits\',\)'):
+ head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_logits_missing(self):
+ def _loss_fn(labels):
+ del labels # unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: logits\. '
+ r'Given arguments: \(\'labels\',\)'):
+ head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_features_ok(self):
+ def _loss_fn(labels, logits, features):
+ del labels, logits, features # Unused
+ head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_invalid(self):
+ def _loss_fn(labels, logits, name=None):
+ del labels, logits, name # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn has unexpected args: \[\'name\'\]'):
+ head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
def test_invalid_logits_shape(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
self.assertEqual(1, head.logits_dimension)
@@ -1699,6 +1819,56 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())
self.assertAllClose(expected_weights, actual_weights)
+ def test_eval_create_loss_loss_fn(self):
+ """Tests head.create_loss for eval mode and custom loss_fn."""
+ loss = np.array([[1.], [2.]], dtype=np.float32)
+ logits_input = np.array([[-10.], [10.]], dtype=np.float32)
+ labels_input = np.array([[1], [0]], dtype=np.int64)
+ def _loss_fn(labels, logits):
+ check_labels = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(labels, labels_input)),
+ data=[labels])
+ check_logits = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(logits, logits_input)),
+ data=[logits])
+ with ops.control_dependencies([check_labels, check_logits]):
+ return constant_op.constant(loss)
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits_input,
+ labels=labels_input)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(np.sum(loss), actual_training_loss.eval())
+
+ def test_eval_create_loss_loss_fn_wrong_shape(self):
+ """Tests custom loss_fn that returns Tensor of unexpected shape."""
+ loss = np.array([1., 2.], dtype=np.float32)
+ def _loss_fn(labels, logits):
+ del labels, logits # Unused
+ return constant_op.constant(loss)
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_fn=_loss_fn)
+
+ logits = np.array([[-10.], [10.]], dtype=np.float32)
+ labels = np.array([[1], [0]], dtype=np.int64)
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] '
+ r'\[logits_shape: \] \[2 1\] \[loss_shape: \] \[2\]'):
+ actual_training_loss.eval()
+
def test_train_labels_none(self):
"""Tests that error is raised when labels is None."""
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
@@ -2355,6 +2525,37 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
head_lib._regression_head_with_mean_squared_error_loss(
loss_reduction=losses.Reduction.NONE)
+ def test_loss_fn_arg_labels_missing(self):
+ def _loss_fn(logits):
+ del logits # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: labels\. '
+ r'Given arguments: \(\'logits\',\)'):
+ head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_logits_missing(self):
+ def _loss_fn(labels):
+ del labels # unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn must contain argument: logits\. '
+ r'Given arguments: \(\'labels\',\)'):
+ head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_features_ok(self):
+ def _loss_fn(labels, logits, features):
+ del labels, logits, features # Unused
+ head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn)
+
+ def test_loss_fn_arg_invalid(self):
+ def _loss_fn(labels, logits, name=None):
+ del labels, logits, name # Unused
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'loss_fn has unexpected args: \[\'name\'\]'):
+ head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn)
+
def test_invalid_logits(self):
head = head_lib._regression_head_with_mean_squared_error_loss(
label_dimension=3)
@@ -2530,6 +2731,56 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
# loss = [(43-45)^2, (44-41)] = [4, 9]
self.assertAllClose(13., training_loss.eval())
+ def test_eval_create_loss_loss_fn(self):
+ """Tests head.create_loss for eval mode and custom loss_fn."""
+ loss = np.array([[0., 1.], [2., 3.]], dtype=np.float32)
+ logits_input = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32)
+ labels_input = np.array([[1., 0.], [2., -1.]], dtype=np.float32)
+ def _loss_fn(labels, logits):
+ check_labels = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(labels, labels_input)),
+ data=[labels])
+ check_logits = control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(logits, logits_input)),
+ data=[logits])
+ with ops.control_dependencies([check_labels, check_logits]):
+ return constant_op.constant(loss)
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ label_dimension=2, loss_fn=_loss_fn)
+
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits_input,
+ labels=labels_input)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ self.assertAllClose(np.sum(loss), actual_training_loss.eval())
+
+ def test_eval_create_loss_loss_fn_wrong_shape(self):
+ """Tests custom loss_fn that returns Tensor of unexpected shape."""
+ loss = np.array([[1.], [2.]], dtype=np.float32)
+ def _loss_fn(labels, logits):
+ del labels, logits # Unused
+ return constant_op.constant(loss)
+ head = head_lib._regression_head_with_mean_squared_error_loss(
+ label_dimension=2, loss_fn=_loss_fn)
+
+ logits = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32)
+ labels = np.array([[1., 0.], [2., -1.]], dtype=np.float32)
+ actual_training_loss = head.create_loss(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)[0]
+ with self.test_session():
+ _initialize_variables(self, monitored_session.Scaffold())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 2\]\. \] '
+ r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2 1\]'):
+ actual_training_loss.eval()
+
def test_eval_labels_none(self):
"""Tests that error is raised when labels is None."""
head = head_lib._regression_head_with_mean_squared_error_loss()
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index 75c0e61d47..8e5d8141a1 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -47,10 +47,9 @@ except ImportError:
def _fill_array(arr, seq, fillvalue=0):
- """
- Recursively fills padded arr with elements from seq.
- If length of seq is less than arr padded length, fillvalue used.
+ """Recursively fills padded arr with elements from seq.
+ If length of seq is less than arr padded length, fillvalue used.
Args:
arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].
seq: Non-padded list of data sampels of shape
@@ -84,28 +83,30 @@ def _pad_if_needed(batch_key_item, fillvalue=0):
Raises:
ValueError if data samples have different shapes (except last padded dim).
"""
- shapes = [seq.shape[:-1] if len(seq.shape) > 0 else -1
- for seq in batch_key_item]
+ shapes = [
+ seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item
+ ]
if not all(shapes[0] == x for x in shapes):
raise ValueError("Array shapes must match.")
- last_length = [seq.shape[-1] if len(seq.shape) > 0 else 0
- for seq in batch_key_item]
+ last_length = [
+ seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item
+ ]
if all([x == last_length[0] for x in last_length]):
return batch_key_item
batch_size = len(batch_key_item)
max_sequence_length = max(last_length)
result_batch = np.zeros(
- shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
- dtype=batch_key_item[0].dtype)
+ shape=[batch_size] + list(shapes[0]) + [max_sequence_length],
+ dtype=batch_key_item[0].dtype)
_fill_array(result_batch, batch_key_item, fillvalue)
return result_batch
-def _get_integer_indices_for_next_batch(
- batch_indices_start, batch_size, epoch_end, array_length,
- current_epoch, total_epochs):
+def _get_integer_indices_for_next_batch(batch_indices_start, batch_size,
+ epoch_end, array_length, current_epoch,
+ total_epochs):
"""Returns the integer indices for next batch.
If total epochs is not None and current epoch is the final epoch, the end
@@ -135,8 +136,9 @@ def _get_integer_indices_for_next_batch(
"Already emitted %s epochs." % current_epoch)
batch_indices_end = batch_indices_start + batch_size
- batch_indices = [j % array_length for j in
- range(batch_indices_start, batch_indices_end)]
+ batch_indices = [
+ j % array_length for j in range(batch_indices_start, batch_indices_end)
+ ]
epoch_end_indices = [i for i, x in enumerate(batch_indices) if x == epoch_end]
current_epoch += len(epoch_end_indices)
@@ -320,16 +322,20 @@ class _GeneratorFeedFn(object):
raise KeyError("key mismatch between dicts emitted by GenFun "
"Expected {} keys; got {}".format(
self._keys, data_row.keys()))
- list_dict.setdefault(self._col_placeholders[index],
- list()).append(data_row[key])
+ list_dict.setdefault(self._col_placeholders[index], list()).append(
+ data_row[key])
list_dict_size += 1
if self._pad_value is not None:
- feed_dict = {key: np.asarray(_pad_if_needed(item, self._pad_value))
- for key, item in list(list_dict.items())}
+ feed_dict = {
+ key: np.asarray(_pad_if_needed(item, self._pad_value))
+ for key, item in list(list_dict.items())
+ }
else:
- feed_dict = {key: np.asarray(item)
- for key, item in list(list_dict.items())}
+ feed_dict = {
+ key: np.asarray(item)
+ for key, item in list(list_dict.items())
+ }
return feed_dict
@@ -382,9 +388,8 @@ def _enqueue_data(data,
queue_shapes = [(), data.shape[1:]]
get_feed_fn = _ArrayFeedFn
elif isinstance(data, collections.OrderedDict):
- types = [dtypes.int64] + [
- dtypes.as_dtype(col.dtype) for col in data.values()
- ]
+ types = [dtypes.int64
+ ] + [dtypes.as_dtype(col.dtype) for col in data.values()]
queue_shapes = [()] + [col.shape[1:] for col in data.values()]
get_feed_fn = _OrderedDictNumpyFeedFn
elif isinstance(data, tp.FunctionType):
@@ -447,11 +452,11 @@ def _enqueue_data(data,
seed=seed)
elif pad_data:
min_after_dequeue = 0 # just for the summary text
- queue_shapes = list(map(
- lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
- queue_shapes))
+ queue_shapes = list(
+ map(lambda x: tuple(list(x[:-1]) + [None]) if len(x) > 0 else x,
+ queue_shapes))
queue = data_flow_ops.PaddingFIFOQueue(
- capacity, dtypes=types, shapes=queue_shapes)
+ capacity, dtypes=types, shapes=queue_shapes)
else:
min_after_dequeue = 0 # just for the summary text
queue = data_flow_ops.FIFOQueue(
@@ -470,31 +475,35 @@ def _enqueue_data(data,
if not pad_data:
feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs))
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs))
else:
feed_fns.append(
- get_feed_fn(
- placeholders,
- data,
- enqueue_size,
- random_start=shuffle,
- seed=seed_i,
- num_epochs=num_epochs,
- pad_value=pad_value))
+ get_feed_fn(
+ placeholders,
+ data,
+ enqueue_size,
+ random_start=shuffle,
+ seed=seed_i,
+ num_epochs=num_epochs,
+ pad_value=pad_value))
runner = fqr._FeedingQueueRunner( # pylint: disable=protected-access
- queue=queue, enqueue_ops=enqueue_ops, feed_fns=feed_fns)
+ queue=queue,
+ enqueue_ops=enqueue_ops,
+ feed_fns=feed_fns)
queue_runner.add_queue_runner(runner)
- full = (math_ops.cast(
- math_ops.maximum(0, queue.size() - min_after_dequeue),
- dtypes.float32) * (1. / (capacity - min_after_dequeue)))
+ full = (
+ math_ops.cast(
+ math_ops.maximum(0,
+ queue.size() - min_after_dequeue), dtypes.float32)
+ * (1. / (capacity - min_after_dequeue)))
# Note that name contains a '/' at the end so we intentionally do not place
# a '/' after %s below.
summary_name = ("queue/%sfraction_over_%d_of_%d_full" %
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 0133318456..6a7e1d0c89 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -53,6 +53,7 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import versions
@@ -1460,3 +1461,14 @@ def get_node_def_from_graph(node_name, graph_def):
if node_def.name == node_name:
return node_def
return None
+
+
+def set_producer_version(graph, producer_version):
+ """Sets graph.graph_def_versions.producer to `producer_version`."""
+ # The C API doesn't expose altering GraphDefVersions. We can indirectly set
+ # it via import_graph_def though.
+ graph_def = graph_pb2.GraphDef()
+ graph_def.versions.producer = producer_version
+ with graph.as_default():
+ importer.import_graph_def(graph_def)
+ assert graph.graph_def_versions.producer, producer_version
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py
index 146bb4311c..61dc4e2afb 100644
--- a/tensorflow/python/grappler/cost_analyzer_tool.py
+++ b/tensorflow/python/grappler/cost_analyzer_tool.py
@@ -23,18 +23,33 @@ import sys
from google.protobuf import text_format
+from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
from tensorflow.python.grappler import cost_analyzer
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
+from tensorflow.python.training import saver
def main(_):
- with gfile.GFile(FLAGS.input) as input_file:
- metagraph = meta_graph_pb2.MetaGraphDef()
- metagraph.ParseFromString(input_file.read())
+ if FLAGS.metagraphdef:
+ with gfile.GFile(FLAGS.metagraphdef) as meta_file:
+ metagraph = meta_graph_pb2.MetaGraphDef()
+ metagraph.ParseFromString(meta_file.read())
+ else:
+ with gfile.GFile(FLAGS.graphdef) as graph_file:
+ graph_def = graph_pb2.GraphDef()
+ graph_def.ParseFromString(graph_file.read())
+ importer.import_graph_def(graph_def, name="")
+ graph = ops.get_default_graph()
+ fetch = graph.get_operation_by_name(FLAGS.fetch)
+ graph.add_to_collection("train_op", fetch)
+ metagraph = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(), graph=graph)
if FLAGS.rewriter_config is not None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
@@ -49,7 +64,25 @@ def main(_):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- "--input", type=str, default=None, help="Input .meta file path.")
+ "--metagraphdef",
+ type=str,
+ default=None,
+ help="Input .meta MetaGraphDef file path.")
+ parser.add_argument(
+ "--graphdef",
+ type=str,
+ default=None,
+ help="Input .pb GraphDef file path.")
+ # Consider making flag fetch work together with flag metagraphdef. As some
+ # MetaGraphDef files don't have collection train_op.
+ parser.add_argument(
+ "--fetch",
+ type=str,
+ default=None,
+ help=
+ "The name of the fetch node. This flag is ignored if flag "
+ "metagraphdef is used."
+ )
parser.add_argument(
"--rewriter_config",
type=str,
diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i
index f0dd4483a6..1b657983a4 100644
--- a/tensorflow/python/grappler/tf_optimizer.i
+++ b/tensorflow/python/grappler/tf_optimizer.i
@@ -103,6 +103,11 @@ PyObject* TF_OptimizeGraph(
std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
+ if (!grappler_item) {
+ TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
+ return nullptr;
+ }
+
tensorflow::DeviceBase* cpu_device = nullptr;
tensorflow::GraphDef out_graph;
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config);
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 1f20b3ae0e..6125755775 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -14,10 +14,12 @@ py_library(
"_impl/keras/__init__.py",
"_impl/keras/activations.py",
"_impl/keras/applications/__init__.py",
+ "_impl/keras/applications/densenet.py",
"_impl/keras/applications/imagenet_utils.py",
"_impl/keras/applications/inception_resnet_v2.py",
"_impl/keras/applications/inception_v3.py",
"_impl/keras/applications/mobilenet.py",
+ "_impl/keras/applications/nasnet.py",
"_impl/keras/applications/resnet50.py",
"_impl/keras/applications/vgg16.py",
"_impl/keras/applications/vgg19.py",
@@ -76,9 +78,11 @@ py_library(
"_impl/keras/wrappers/scikit_learn.py",
"activations/__init__.py",
"applications/__init__.py",
+ "applications/densenet/__init__.py",
"applications/inception_resnet_v2/__init__.py",
"applications/inception_v3/__init__.py",
"applications/mobilenet/__init__.py",
+ "applications/nasnet/__init__.py",
"applications/resnet50/__init__.py",
"applications/vgg16/__init__.py",
"applications/vgg19/__init__.py",
@@ -257,6 +261,18 @@ py_test(
)
py_test(
+ name = "densenet_test",
+ size = "large",
+ srcs = ["_impl/keras/applications/densenet_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "inception_resnet_v2_test",
size = "medium",
srcs = ["_impl/keras/applications/inception_resnet_v2_test.py"],
@@ -293,6 +309,18 @@ py_test(
)
py_test(
+ name = "nasnet_test",
+ size = "large",
+ srcs = ["_impl/keras/applications/nasnet_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "resnet50_test",
size = "small",
srcs = ["_impl/keras/applications/resnet50_test.py"],
@@ -504,7 +532,7 @@ py_test(
py_test(
name = "recurrent_test",
- size = "small",
+ size = "medium",
srcs = ["_impl/keras/layers/recurrent_test.py"],
srcs_version = "PY2AND3",
deps = [
@@ -527,7 +555,7 @@ py_test(
py_test(
name = "wrappers_test",
- size = "small",
+ size = "medium",
srcs = ["_impl/keras/layers/wrappers_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py
index a70250d796..7311353932 100644
--- a/tensorflow/python/keras/_impl/keras/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/__init__.py
@@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.models import Sequential
-__version__ = '2.1.2-tf'
+__version__ = '2.1.3-tf'
diff --git a/tensorflow/python/keras/_impl/keras/activations.py b/tensorflow/python/keras/_impl/keras/activations.py
index f017d2ae85..4852b8c36a 100644
--- a/tensorflow/python/keras/_impl/keras/activations.py
+++ b/tensorflow/python/keras/_impl/keras/activations.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras built-in activation functions.
+"""Built-in activation functions.
"""
from __future__ import absolute_import
from __future__ import division
@@ -61,10 +61,12 @@ def selu(x):
x: A tensor or variable to compute the activation function for.
Returns:
- Tensor with the same shape and dtype as `x`.
+ Tensor with the same shape and dtype as `x`.
+
+ # Note
+ - To be used together with the initialization "lecun_normal".
+ - To be used together with the dropout variant "AlphaDropout".
- References:
- - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
diff --git a/tensorflow/python/keras/_impl/keras/applications/__init__.py b/tensorflow/python/keras/_impl/keras/applications/__init__.py
index c11c52b71e..206a769b37 100644
--- a/tensorflow/python/keras/_impl/keras/applications/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/applications/__init__.py
@@ -18,9 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201
from tensorflow.python.keras._impl.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3
from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet
+from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge
+from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile
from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50
from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16
from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19
diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet.py b/tensorflow/python/keras/_impl/keras/applications/densenet.py
new file mode 100644
index 0000000000..9e40d34930
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/applications/densenet.py
@@ -0,0 +1,346 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=invalid-name
+# pylint: disable=unused-import
+"""DenseNet models for Keras.
+
+# Reference paper
+
+- [Densely Connected Convolutional Networks]
+ (https://arxiv.org/abs/1608.06993) (CVPR 2017 Best Paper Award)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.applications import imagenet_utils
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Activation
+from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import BatchNormalization
+from tensorflow.python.keras._impl.keras.layers import Concatenate
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+
+
+DENSENET121_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet121_weights_tf_dim_ordering_tf_kernels.h5'
+DENSENET121_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5'
+DENSENET169_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet169_weights_tf_dim_ordering_tf_kernels.h5'
+DENSENET169_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5'
+DENSENET201_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels.h5'
+DENSENET201_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5'
+
+
+def dense_block(x, blocks, name):
+ """A dense block.
+
+ Arguments:
+ x: input tensor.
+ blocks: integer, the number of building blocks.
+ name: string, block label.
+
+ Returns:
+ output tensor for the block.
+ """
+ for i in range(blocks):
+ x = conv_block(x, 32, name=name + '_block' + str(i + 1))
+ return x
+
+
+def transition_block(x, reduction, name):
+ """A transition block.
+
+ Arguments:
+ x: input tensor.
+ reduction: float, compression rate at transition layers.
+ name: string, block label.
+
+ Returns:
+ output tensor for the block.
+ """
+ bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
+ x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')(x)
+ x = Activation('relu', name=name + '_relu')(x)
+ x = Conv2D(
+ int(K.int_shape(x)[bn_axis] * reduction),
+ 1,
+ use_bias=False,
+ name=name + '_conv')(
+ x)
+ x = AveragePooling2D(2, strides=2, name=name + '_pool')(x)
+ return x
+
+
+def conv_block(x, growth_rate, name):
+ """A building block for a dense block.
+
+ Arguments:
+ x: input tensor.
+ growth_rate: float, growth rate at dense layers.
+ name: string, block label.
+
+ Returns:
+ output tensor for the block.
+ """
+ bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
+ x1 = BatchNormalization(
+ axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(
+ x)
+ x1 = Activation('relu', name=name + '_0_relu')(x1)
+ x1 = Conv2D(4 * growth_rate, 1, use_bias=False, name=name + '_1_conv')(x1)
+ x1 = BatchNormalization(
+ axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(
+ x1)
+ x1 = Activation('relu', name=name + '_1_relu')(x1)
+ x1 = Conv2D(
+ growth_rate, 3, padding='same', use_bias=False, name=name + '_2_conv')(
+ x1)
+ x = Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
+ return x
+
+
+def DenseNet(blocks,
+ include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ """Instantiates the DenseNet architecture.
+
+ Optionally loads weights pre-trained
+ on ImageNet. Note that when using TensorFlow,
+ for best performance you should set
+ `image_data_format='channels_last'` in your Keras config
+ at ~/.keras/keras.json.
+
+ The model and the weights are compatible with
+ TensorFlow, Theano, and CNTK. The data format
+ convention used by the model is the one
+ specified in your Keras config file.
+
+ Arguments:
+ blocks: numbers of building blocks for the four dense layers.
+ include_top: whether to include the fully-connected
+ layer at the top of the network.
+ weights: one of `None` (random initialization),
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
+ input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
+ to use as image input for the model.
+ input_shape: optional shape tuple, only to be specified
+ if `include_top` is False (otherwise the input shape
+ has to be `(224, 224, 3)` (with `channels_last` data format)
+ or `(3, 224, 224)` (with `channels_first` data format).
+ It should have exactly 3 inputs channels.
+ pooling: optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model will be
+ the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: in case of invalid argument for `weights`,
+ or invalid input shape.
+ """
+ if not (weights in {'imagenet', None} or os.path.exists(weights)):
+ raise ValueError('The `weights` argument should be either '
+ '`None` (random initialization), `imagenet` '
+ '(pre-training on ImageNet), '
+ 'or the path to the weights file to be loaded.')
+
+ if weights == 'imagenet' and include_top and classes != 1000:
+ raise ValueError('If using `weights` as imagenet with `include_top`'
+ ' as true, `classes` should be 1000')
+
+ # Determine proper input shape
+ input_shape = _obtain_input_shape(
+ input_shape,
+ default_size=224,
+ min_size=221,
+ data_format=K.image_data_format(),
+ require_flatten=include_top,
+ weights=weights)
+
+ if input_tensor is None:
+ img_input = Input(shape=input_shape)
+ else:
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
+
+ bn_axis = 3 if K.image_data_format() == 'channels_last' else 1
+
+ x = ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
+ x = Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
+ x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)
+ x = Activation('relu', name='conv1/relu')(x)
+ x = ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
+ x = MaxPooling2D(3, strides=2, name='pool1')(x)
+
+ x = dense_block(x, blocks[0], name='conv2')
+ x = transition_block(x, 0.5, name='pool2')
+ x = dense_block(x, blocks[1], name='conv3')
+ x = transition_block(x, 0.5, name='pool3')
+ x = dense_block(x, blocks[2], name='conv4')
+ x = transition_block(x, 0.5, name='pool4')
+ x = dense_block(x, blocks[3], name='conv5')
+
+ x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
+
+ if include_top:
+ x = GlobalAveragePooling2D(name='avg_pool')(x)
+ x = Dense(classes, activation='softmax', name='fc1000')(x)
+ else:
+ if pooling == 'avg':
+ x = GlobalAveragePooling2D(name='avg_pool')(x)
+ elif pooling == 'max':
+ x = GlobalMaxPooling2D(name='max_pool')(x)
+
+ # Ensure that the model takes into account
+ # any potential predecessors of `input_tensor`.
+ if input_tensor is not None:
+ inputs = get_source_inputs(input_tensor)
+ else:
+ inputs = img_input
+
+ # Create model.
+ if blocks == [6, 12, 24, 16]:
+ model = Model(inputs, x, name='densenet121')
+ elif blocks == [6, 12, 32, 32]:
+ model = Model(inputs, x, name='densenet169')
+ elif blocks == [6, 12, 48, 32]:
+ model = Model(inputs, x, name='densenet201')
+ else:
+ model = Model(inputs, x, name='densenet')
+
+ # Load weights.
+ if weights == 'imagenet':
+ if include_top:
+ if blocks == [6, 12, 24, 16]:
+ weights_path = get_file(
+ 'densenet121_weights_tf_dim_ordering_tf_kernels.h5',
+ DENSENET121_WEIGHT_PATH,
+ cache_subdir='models',
+ file_hash='0962ca643bae20f9b6771cb844dca3b0')
+ elif blocks == [6, 12, 32, 32]:
+ weights_path = get_file(
+ 'densenet169_weights_tf_dim_ordering_tf_kernels.h5',
+ DENSENET169_WEIGHT_PATH,
+ cache_subdir='models',
+ file_hash='bcf9965cf5064a5f9eb6d7dc69386f43')
+ elif blocks == [6, 12, 48, 32]:
+ weights_path = get_file(
+ 'densenet201_weights_tf_dim_ordering_tf_kernels.h5',
+ DENSENET201_WEIGHT_PATH,
+ cache_subdir='models',
+ file_hash='7bb75edd58cb43163be7e0005fbe95ef')
+ else:
+ if blocks == [6, 12, 24, 16]:
+ weights_path = get_file(
+ 'densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5',
+ DENSENET121_WEIGHT_PATH_NO_TOP,
+ cache_subdir='models',
+ file_hash='4912a53fbd2a69346e7f2c0b5ec8c6d3')
+ elif blocks == [6, 12, 32, 32]:
+ weights_path = get_file(
+ 'densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5',
+ DENSENET169_WEIGHT_PATH_NO_TOP,
+ cache_subdir='models',
+ file_hash='50662582284e4cf834ce40ab4dfa58c6')
+ elif blocks == [6, 12, 48, 32]:
+ weights_path = get_file(
+ 'densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5',
+ DENSENET201_WEIGHT_PATH_NO_TOP,
+ cache_subdir='models',
+ file_hash='1c2de60ee40562448dbac34a0737e798')
+ model.load_weights(weights_path)
+ elif weights is not None:
+ model.load_weights(weights)
+
+ return model
+
+
+def DenseNet121(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ return DenseNet([6, 12, 24, 16], include_top, weights, input_tensor,
+ input_shape, pooling, classes)
+
+
+def DenseNet169(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ return DenseNet([6, 12, 32, 32], include_top, weights, input_tensor,
+ input_shape, pooling, classes)
+
+
+def DenseNet201(include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ input_shape=None,
+ pooling=None,
+ classes=1000):
+ return DenseNet([6, 12, 48, 32], include_top, weights, input_tensor,
+ input_shape, pooling, classes)
+
+
+def preprocess_input(x, data_format=None):
+ """Preprocesses a numpy array encoding a batch of images.
+
+ Arguments:
+ x: a 3D or 4D numpy array consists of RGB values within [0, 255].
+ data_format: data format of the image tensor.
+
+ Returns:
+ Preprocessed array.
+ """
+ return imagenet_utils.preprocess_input(x, data_format, mode='torch')
+
+
+setattr(DenseNet121, '__doc__', DenseNet.__doc__)
+setattr(DenseNet169, '__doc__', DenseNet.__doc__)
+setattr(DenseNet201, '__doc__', DenseNet.__doc__)
diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet_test.py b/tensorflow/python/keras/_impl/keras/applications/densenet_test.py
new file mode 100644
index 0000000000..3b92287a1e
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/applications/densenet_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for DenseNet application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.platform import test
+
+
+class DenseNet121Test(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.DenseNet121(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.DenseNet121(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 1024))
+
+ def test_with_pooling(self):
+ model = keras.applications.DenseNet121(weights=None,
+ include_top=False,
+ pooling='avg')
+ self.assertEqual(model.output_shape, (None, 1024))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet121(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet121(weights='imagenet',
+ classes=2000)
+
+
+class DenseNet169Test(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.DenseNet169(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.DenseNet169(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 1664))
+
+ def test_with_pooling(self):
+ model = keras.applications.DenseNet169(weights=None,
+ include_top=False,
+ pooling='max')
+ self.assertEqual(model.output_shape, (None, 1664))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet169(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet169(weights='imagenet',
+ classes=2000)
+
+
+class DenseNet201(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.DenseNet201(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.DenseNet201(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 1920))
+
+ def test_with_pooling(self):
+ model = keras.applications.DenseNet201(weights=None,
+ include_top=False,
+ pooling='avg')
+ self.assertEqual(model.output_shape, (None, 1920))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet201(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.DenseNet201(weights='imagenet',
+ classes=2000)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
index 63ee83cb51..f1f20f12a8 100644
--- a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
+++ b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Utilities used by models pre-trained on ImageNet.
+"""Utilities for ImageNet data preprocessing & prediction decoding.
"""
from __future__ import absolute_import
from __future__ import division
@@ -35,63 +35,92 @@ _IMAGENET_MEAN = None
def _preprocess_numpy_input(x, data_format, mode):
- """Preprocesses a image tensor as a Numpy array.
+ """Preprocesses a Numpy array encoding a batch of images.
Arguments:
- x: input Numpy, 3D or 4D.
- data_format: data format of the image tensor.
- mode: One of "caffe", "tf".
+ x: Input array, 3D or 4D.
+ data_format: Data format of the image array.
+ mode: One of "caffe", "tf" or "torch".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
+ - torch: will scale pixels between 0 and 1 and then
+ will normalize each channel with respect to the
+ ImageNet dataset.
Returns:
- Preprocessed array.
+ Preprocessed Numpy array.
"""
if mode == 'tf':
x /= 127.5
x -= 1.
return x
+ if mode == 'torch':
+ x /= 255.
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ else:
+ if data_format == 'channels_first':
+ # 'RGB'->'BGR'
+ if x.ndim == 3:
+ x = x[::-1, ...]
+ else:
+ x = x[:, ::-1, ...]
+ else:
+ # 'RGB'->'BGR'
+ x = x[..., ::-1]
+ mean = [103.939, 116.779, 123.68]
+ std = None
+
+ # Zero-center by mean pixel
if data_format == 'channels_first':
if x.ndim == 3:
- # 'RGB'->'BGR'
- x = x[::-1, ...]
- # Zero-center by mean pixel
- x[0, :, :] -= 103.939
- x[1, :, :] -= 116.779
- x[2, :, :] -= 123.68
+ x[0, :, :] -= mean[0]
+ x[1, :, :] -= mean[1]
+ x[2, :, :] -= mean[2]
+ if std is not None:
+ x[0, :, :] /= std[0]
+ x[1, :, :] /= std[1]
+ x[2, :, :] /= std[2]
else:
- x = x[:, ::-1, ...]
- x[:, 0, :, :] -= 103.939
- x[:, 1, :, :] -= 116.779
- x[:, 2, :, :] -= 123.68
+ x[:, 0, :, :] -= mean[0]
+ x[:, 1, :, :] -= mean[1]
+ x[:, 2, :, :] -= mean[2]
+ if std is not None:
+ x[:, 0, :, :] /= std[0]
+ x[:, 1, :, :] /= std[1]
+ x[:, 2, :, :] /= std[2]
else:
- # 'RGB'->'BGR'
- x = x[..., ::-1]
- # Zero-center by mean pixel
- x[..., 0] -= 103.939
- x[..., 1] -= 116.779
- x[..., 2] -= 123.68
+ x[..., 0] -= mean[0]
+ x[..., 1] -= mean[1]
+ x[..., 2] -= mean[2]
+ if std is not None:
+ x[..., 0] /= std[0]
+ x[..., 1] /= std[1]
+ x[..., 2] /= std[2]
return x
def _preprocess_symbolic_input(x, data_format, mode):
- """Preprocesses a symbolic image tensor.
+ """Preprocesses a tensor encoding a batch of images.
Arguments:
- x: symoblic tensor, 3D or 4D.
- data_format: data format of the image tensor.
- mode: One of "caffe", "tf".
+ x: Input tensor, 3D or 4D.
+ data_format: Data format of the image tensor.
+ mode: One of "caffe", "tf" or "torch".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
+ - torch: will scale pixels between 0 and 1 and then
+ will normalize each channel with respect to the
+ ImageNet dataset.
Returns:
Preprocessed tensor.
@@ -103,32 +132,42 @@ def _preprocess_symbolic_input(x, data_format, mode):
x -= 1.
return x
- if data_format == 'channels_first':
- # 'RGB'->'BGR'
- if K.ndim(x) == 3:
- x = x[::-1, ...]
- else:
- x = x[:, ::-1, ...]
+ if mode == 'torch':
+ x /= 255.
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
else:
- # 'RGB'->'BGR'
- x = x[..., ::-1]
+ if data_format == 'channels_first':
+ # 'RGB'->'BGR'
+ if K.ndim(x) == 3:
+ x = x[::-1, ...]
+ else:
+ x = x[:, ::-1, ...]
+ else:
+ # 'RGB'->'BGR'
+ x = x[..., ::-1]
+ mean = [103.939, 116.779, 123.68]
+ std = None
if _IMAGENET_MEAN is None:
- _IMAGENET_MEAN = K.constant(-np.array([103.939, 116.779, 123.68]))
+ _IMAGENET_MEAN = K.constant(-np.array(mean))
+
# Zero-center by mean pixel
if K.dtype(x) != K.dtype(_IMAGENET_MEAN):
x = K.bias_add(x, K.cast(_IMAGENET_MEAN, K.dtype(x)), data_format)
else:
x = K.bias_add(x, _IMAGENET_MEAN, data_format)
+ if std is not None:
+ x /= std
return x
def preprocess_input(x, data_format=None, mode='caffe'):
- """Preprocesses a tensor encoding a batch of images.
+ """Preprocesses a tensor or Numpy array encoding a batch of images.
Arguments:
- x: input Numpy or symoblic tensor, 3D or 4D.
- data_format: data format of the image tensor.
+ x: Input Numpy or symbolic tensor, 3D or 4D.
+ data_format: Data format of the image tensor/array.
mode: One of "caffe", "tf".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
@@ -138,10 +177,10 @@ def preprocess_input(x, data_format=None, mode='caffe'):
sample-wise.
Returns:
- Preprocessed tensor.
+ Preprocessed tensor or Numpy array.
Raises:
- ValueError: in case of incorrect data_format.
+ ValueError: In case of unknown `data_format` argument.
"""
if data_format is None:
data_format = K.image_data_format()
@@ -159,7 +198,7 @@ def decode_predictions(preds, top=5):
Arguments:
preds: Numpy tensor encoding a batch of predictions.
- top: integer, how many top-guesses to return.
+ top: Integer, how many top-guesses to return.
Returns:
A list of lists of top class prediction tuples
@@ -167,7 +206,7 @@ def decode_predictions(preds, top=5):
One list of tuples per sample in batch input.
Raises:
- ValueError: in case of invalid shape of the `pred` array
+ ValueError: In case of invalid shape of the `pred` array
(must be 2D).
"""
global CLASS_INDEX
@@ -177,10 +216,11 @@ def decode_predictions(preds, top=5):
'(i.e. a 2D array of shape (samples, 1000)). '
'Found array with shape: ' + str(preds.shape))
if CLASS_INDEX is None:
- fpath = get_file('imagenet_class_index.json',
- CLASS_INDEX_PATH,
- cache_subdir='models',
- file_hash='c2c37ea517e94d9795004a39431a14cb')
+ fpath = get_file(
+ 'imagenet_class_index.json',
+ CLASS_INDEX_PATH,
+ cache_subdir='models',
+ file_hash='c2c37ea517e94d9795004a39431a14cb')
CLASS_INDEX = json.load(open(fpath))
results = []
for pred in preds:
@@ -197,17 +237,17 @@ def _obtain_input_shape(input_shape,
data_format,
require_flatten,
weights=None):
- """Internal utility to compute/validate an ImageNet model's input shape.
+ """Internal utility to compute/validate a model's input shape.
Arguments:
- input_shape: either None (will return the default network input shape),
+ input_shape: Either None (will return the default network input shape),
or a user-provided shape to be validated.
- default_size: default input width/height for the model.
- min_size: minimum input width/height accepted by the model.
- data_format: image data format to use.
- require_flatten: whether the model is expected to
+ default_size: Default input width/height for the model.
+ min_size: Minimum input width/height accepted by the model.
+ data_format: Image data format to use.
+ require_flatten: Whether the model is expected to
be linked to a classifier via a Flatten layer.
- weights: one of `None` (random initialization)
+ weights: One of `None` (random initialization)
or 'imagenet' (pre-training on ImageNet).
If weights='imagenet' input channels must be equal to 3.
@@ -215,7 +255,7 @@ def _obtain_input_shape(input_shape,
An integer shape tuple (may include None entries).
Raises:
- ValueError: in case of invalid argument values.
+ ValueError: In case of invalid argument values.
"""
if weights != 'imagenet' and input_shape and len(input_shape) == 3:
if data_format == 'channels_first':
@@ -252,8 +292,8 @@ def _obtain_input_shape(input_shape,
'`input_shape=' + str(input_shape) + '`')
if ((input_shape[1] is not None and input_shape[1] < min_size) or
(input_shape[2] is not None and input_shape[2] < min_size)):
- raise ValueError('Input size must be at least ' + str(min_size) + 'x'
- + str(min_size) + '; got '
+ raise ValueError('Input size must be at least ' + str(min_size) +
+ 'x' + str(min_size) + '; got '
'`input_shape=' + str(input_shape) + '`')
else:
if input_shape is not None:
@@ -264,8 +304,8 @@ def _obtain_input_shape(input_shape,
'`input_shape=' + str(input_shape) + '`')
if ((input_shape[0] is not None and input_shape[0] < min_size) or
(input_shape[1] is not None and input_shape[1] < min_size)):
- raise ValueError('Input size must be at least ' + str(min_size) + 'x'
- + str(min_size) + '; got '
+ raise ValueError('Input size must be at least ' + str(min_size) +
+ 'x' + str(min_size) + '; got '
'`input_shape=' + str(input_shape) + '`')
else:
if require_flatten:
diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
index 2e73cefb6c..1dc15b5b34 100644
--- a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""Inception-ResNet V2 model for Keras.
# Reference
@@ -28,7 +30,7 @@ import os
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
@@ -43,6 +45,8 @@ from tensorflow.python.keras._impl.keras.layers import Lambda
from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
+
BASE_WEIGHT_URL = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.7/'
@@ -116,7 +120,8 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
scale: scaling factor to scale the residuals (i.e., the output of
passing `x` through an inception module) before adding them
to the shortcut branch. Let `r` be the output from the residual
- branch, the output of this block will be `x + scale * r`.
+ branch,
+ the output of this block will be `x + scale * r`.
block_type: `'block35'`, `'block17'` or `'block8'`, determines
the network structure in the residual branch.
block_idx: an `int` used for generating layer names. The Inception-ResNet
@@ -128,8 +133,7 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
will have `block_type='block35', block_idx=0`, ane the layer names
will have
a common prefix `'block35_0'`.
- activation: activation function to use at the end of the block
- (see [activations](../activations.md)).
+ activation: activation function to use at the end of the block.
When `activation=None`, no activation is applied
(i.e., "linear" activation: `a(x) = x`).
@@ -178,6 +182,7 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
x = Lambda(
lambda inputs, scale: inputs[0] + inputs[1] * scale,
+ output_shape=K.int_shape(x)[1:],
arguments={'scale': scale},
name=block_name)([x, up])
if activation is not None:
@@ -185,7 +190,7 @@ def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
return x
-def InceptionResNetV2(include_top=True, # pylint: disable=invalid-name
+def InceptionResNetV2(include_top=True,
weights='imagenet',
input_tensor=None,
input_shape=None,
@@ -211,8 +216,8 @@ def InceptionResNetV2(include_top=True, # pylint: disable=invalid-name
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
index 4424b92804..ff57116f2d 100644
--- a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""Inception V3 model for Keras.
Note that the input image format for this model is different than for
@@ -35,7 +36,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
@@ -48,6 +49,7 @@ from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels.h5'
@@ -92,7 +94,8 @@ def conv2d_bn(x,
strides=strides,
padding=padding,
use_bias=False,
- name=conv_name)(x)
+ name=conv_name)(
+ x)
x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
x = Activation('relu', name=name)(x)
return x
@@ -109,7 +112,7 @@ def InceptionV3(include_top=True,
Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
- `image_data_format="channels_last"` in your Keras config
+ `image_data_format='channels_last'` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
TensorFlow and Theano. The data format
@@ -121,15 +124,15 @@ def InceptionV3(include_top=True,
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- "imagenet" (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(299, 299, 3)` (with `channels_last` data format)
or `(3, 299, 299)` (with `channels_first` data format).
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 139.
E.g. `(150, 150, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
@@ -176,7 +179,10 @@ def InceptionV3(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
if K.image_data_format() == 'channels_first':
channel_axis = 1
@@ -389,6 +395,7 @@ def InceptionV3(include_top=True,
model.load_weights(weights_path)
elif weights is not None:
model.load_weights(weights)
+
return model
diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
index 5f97c138fc..790bf8cead 100644
--- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""MobileNet v1 models for Keras.
MobileNet is a general architecture and can be used for multiple use cases.
@@ -56,7 +58,7 @@ the 100 % MobileNet on various input sizes:
------------------------------------------------------------------------
The weights for all 16 models are obtained and translated
-from Tensorflow checkpoints found at
+from TensorFlow checkpoints found at
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md
# Reference
@@ -75,9 +77,10 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
from tensorflow.python.keras._impl.keras.layers import Conv2D
@@ -91,6 +94,7 @@ from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
+
BASE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/'
@@ -130,7 +134,7 @@ class DepthwiseConv2D(Conv2D):
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
- padding: one of `"valid"` or `"same"` (case-insensitive).
+ padding: one of `'valid'` or `'same'` (case-insensitive).
depth_multiplier: The number of depthwise convolution output channels
for each input channel.
The total number of depthwise convolution output
@@ -144,29 +148,21 @@ class DepthwiseConv2D(Conv2D):
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
- If you never set it, then it will be "channels_last".
- activation: Activation function to use
- (see [activations](../activations.md)).
+ If you never set it, then it will be 'channels_last'.
+ activation: Activation function to use.
If you don't specify anything, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
+ (ie. 'linear' activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
- depthwise_initializer: Initializer for the depthwise kernel matrix
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ depthwise_initializer: Initializer for the depthwise kernel matrix.
+ bias_initializer: Initializer for the bias vector.
depthwise_regularizer: Regularizer function applied to
- the depthwise kernel matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the depthwise kernel matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
+ the output of the layer (its 'activation')..
depthwise_constraint: Constraint function applied to
- the depthwise kernel matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the depthwise kernel matrix.
+ bias_constraint: Constraint function applied to the bias vector.
Input shape:
4D tensor with shape:
@@ -216,6 +212,7 @@ class DepthwiseConv2D(Conv2D):
self.depthwise_constraint = constraints.get(depthwise_constraint)
self.bias_initializer = initializers.get(bias_initializer)
+ @shape_type_conversion
def build(self, input_shape):
if len(input_shape) < 4:
raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
@@ -269,6 +266,7 @@ class DepthwiseConv2D(Conv2D):
return outputs
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
rows = input_shape[2]
@@ -305,7 +303,7 @@ class DepthwiseConv2D(Conv2D):
return config
-def MobileNet(input_shape=None, # pylint: disable=invalid-name
+def MobileNet(input_shape=None,
alpha=1.0,
depth_multiplier=1,
dropout=1e-3,
@@ -334,7 +332,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or (3, 224, 224) (with `channels_first` data format).
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 32.
E.g. `(200, 200, 3)` would be one valid value.
alpha: controls the width of the network.
@@ -350,8 +348,8 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of
`layers.Input()`)
to use as image input for the model.
@@ -380,6 +378,12 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
RuntimeError: If attempting to run this model with a
backend that does not support separable convolutions.
"""
+
+ if K.backend() != 'tensorflow':
+ raise RuntimeError('Only TensorFlow backend is currently supported, '
+ 'as other backends do not support '
+ 'depthwise convolution.')
+
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `imagenet` '
@@ -390,7 +394,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
raise ValueError('If using `weights` as ImageNet with `include_top` '
'as true, `classes` should be 1000')
- # Determine proper input shape.
+ # Determine proper input shape and default size.
if input_shape is None:
default_size = 224
else:
@@ -400,10 +404,12 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
else:
rows = input_shape[0]
cols = input_shape[1]
+
if rows == cols and rows in [128, 160, 192, 224]:
default_size = rows
else:
default_size = 224
+
input_shape = _obtain_input_shape(
input_shape,
default_size=default_size,
@@ -411,6 +417,7 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
data_format=K.image_data_format(),
require_flatten=include_top,
weights=weights)
+
if K.image_data_format() == 'channels_last':
row_axis, col_axis = (0, 1)
else:
@@ -536,8 +543,6 @@ def MobileNet(input_shape=None, # pylint: disable=invalid-name
if old_data_format:
K.set_image_data_format(old_data_format)
- elif weights is not None:
- model.load_weights(weights)
return model
@@ -595,7 +600,8 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
padding='same',
use_bias=False,
strides=strides,
- name='conv1')(inputs)
+ name='conv1')(
+ inputs)
x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
return Activation(relu6, name='conv1_relu')(x)
@@ -662,7 +668,8 @@ def _depthwise_conv_block(inputs,
depth_multiplier=depth_multiplier,
strides=strides,
use_bias=False,
- name='conv_dw_%d' % block_id)(inputs)
+ name='conv_dw_%d' % block_id)(
+ inputs)
x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
@@ -671,6 +678,7 @@ def _depthwise_conv_block(inputs,
padding='same',
use_bias=False,
strides=(1, 1),
- name='conv_pw_%d' % block_id)(x)
+ name='conv_pw_%d' % block_id)(
+ x)
x = BatchNormalization(axis=channel_axis, name='conv_pw_%d_bn' % block_id)(x)
return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet.py b/tensorflow/python/keras/_impl/keras/applications/nasnet.py
new file mode 100644
index 0000000000..5dd038c096
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/applications/nasnet.py
@@ -0,0 +1,783 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=line-too-long
+# pylint: disable=invalid-name
+# pylint: disable=unused-import
+"""NASNet-A models for Keras.
+
+NASNet refers to Neural Architecture Search Network, a family of models
+that were designed automatically by learning the model architectures
+directly on the dataset of interest.
+
+Here we consider NASNet-A, the highest performance model that was found
+for the CIFAR-10 dataset, and then extended to ImageNet 2012 dataset,
+obtaining state of the art performance on CIFAR-10 and ImageNet 2012.
+Only the NASNet-A models, and their respective weights, which are suited
+for ImageNet 2012 are provided.
+
+The below table describes the performance on ImageNet 2012:
+--------------------------------------------------------------------------------
+ Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M)
+--------------------------------------------------------------------------------
+| NASNet-A (4 @ 1056) | 74.0 % | 91.6 % | 564 M | 5.3 |
+| NASNet-A (6 @ 4032) | 82.7 % | 96.2 % | 23.8 B | 88.9 |
+--------------------------------------------------------------------------------
+
+References:
+ - [Learning Transferable Architectures for Scalable Image Recognition]
+ (https://arxiv.org/abs/1707.07012)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Activation
+from tensorflow.python.keras._impl.keras.layers import add
+from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import BatchNormalization
+from tensorflow.python.keras._impl.keras.layers import concatenate
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Cropping2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import SeparableConv2D
+from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
+
+
+NASNET_MOBILE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile.h5'
+NASNET_MOBILE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-mobile-no-top.h5'
+NASNET_LARGE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large.h5'
+NASNET_LARGE_WEIGHT_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.8/NASNet-large-no-top.h5'
+
+
+def NASNet(input_shape=None,
+ penultimate_filters=4032,
+ num_blocks=6,
+ stem_block_filters=96,
+ skip_reduction=True,
+ filter_multiplier=2,
+ include_top=True,
+ weights=None,
+ input_tensor=None,
+ pooling=None,
+ classes=1000,
+ default_size=None):
+ """Instantiates a NASNet model.
+
+ Note that only TensorFlow is supported for now,
+ therefore it only works with the data format
+ `image_data_format='channels_last'` in your Keras config
+ at `~/.keras/keras.json`.
+
+ Arguments:
+ input_shape: Optional shape tuple, only to be specified
+ if `include_top` is False (otherwise the input shape
+ has to be `(331, 331, 3)` for NASNetLarge or
+ `(224, 224, 3)` for NASNetMobile
+ It should have exactly 3 inputs channels,
+ and width and height should be no smaller than 32.
+ E.g. `(224, 224, 3)` would be one valid value.
+ penultimate_filters: Number of filters in the penultimate layer.
+ NASNet models use the notation `NASNet (N @ P)`, where:
+ - N is the number of blocks
+ - P is the number of penultimate filters
+ num_blocks: Number of repeated blocks of the NASNet model.
+ NASNet models use the notation `NASNet (N @ P)`, where:
+ - N is the number of blocks
+ - P is the number of penultimate filters
+ stem_block_filters: Number of filters in the initial stem block
+ skip_reduction: Whether to skip the reduction step at the tail
+ end of the network. Set to `False` for CIFAR models.
+ filter_multiplier: Controls the width of the network.
+ - If `filter_multiplier` < 1.0, proportionally decreases the number
+ of filters in each layer.
+ - If `filter_multiplier` > 1.0, proportionally increases the number
+ of filters in each layer.
+ - If `filter_multiplier` = 1, default number of filters from the
+ paper are used at each layer.
+ include_top: Whether to include the fully-connected
+ layer at the top of the network.
+ weights: `None` (random initialization) or
+ `imagenet` (ImageNet weights)
+ input_tensor: Optional Keras tensor (i.e. output of
+ `layers.Input()`)
+ to use as image input for the model.
+ pooling: Optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model
+ will be the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a
+ 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: Optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+ default_size: Specifies the default image size of the model
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: In case of invalid argument for `weights`,
+ invalid input shape or invalid `penultimate_filters` value.
+ RuntimeError: If attempting to run this model with a
+ backend that does not support separable convolutions.
+ """
+ if K.backend() != 'tensorflow':
+ raise RuntimeError('Only Tensorflow backend is currently supported, '
+ 'as other backends do not support '
+ 'separable convolution.')
+
+ if not (weights in {'imagenet', None} or os.path.exists(weights)):
+ raise ValueError('The `weights` argument should be either '
+ '`None` (random initialization), `imagenet` '
+ '(pre-training on ImageNet), '
+ 'or the path to the weights file to be loaded.')
+
+ if weights == 'imagenet' and include_top and classes != 1000:
+ raise ValueError('If using `weights` as ImageNet with `include_top` '
+ 'as true, `classes` should be 1000')
+
+ if default_size is None:
+ default_size = 331
+
+ # Determine proper input shape and default size.
+ input_shape = _obtain_input_shape(
+ input_shape,
+ default_size=default_size,
+ min_size=32,
+ data_format=K.image_data_format(),
+ require_flatten=include_top or weights,
+ weights=weights)
+
+ if K.image_data_format() != 'channels_last':
+ logging.warning('The NASNet family of models is only available '
+ 'for the input data format "channels_last" '
+ '(width, height, channels). '
+ 'However your settings specify the default '
+ 'data format "channels_first" (channels, width, height).'
+ ' You should set `image_data_format="channels_last"` '
+ 'in your Keras config located at ~/.keras/keras.json. '
+ 'The model being returned right now will expect inputs '
+ 'to follow the "channels_last" data format.')
+ K.set_image_data_format('channels_last')
+ old_data_format = 'channels_first'
+ else:
+ old_data_format = None
+
+ if input_tensor is None:
+ img_input = Input(shape=input_shape)
+ else:
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
+
+ if penultimate_filters % 24 != 0:
+ raise ValueError(
+ 'For NASNet-A models, the value of `penultimate_filters` '
+ 'needs to be divisible by 24. Current value: %d' % penultimate_filters)
+
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+ filters = penultimate_filters // 24
+
+ if not skip_reduction:
+ x = Conv2D(
+ stem_block_filters, (3, 3),
+ strides=(2, 2),
+ padding='valid',
+ use_bias=False,
+ name='stem_conv1',
+ kernel_initializer='he_normal')(
+ img_input)
+ else:
+ x = Conv2D(
+ stem_block_filters, (3, 3),
+ strides=(1, 1),
+ padding='same',
+ use_bias=False,
+ name='stem_conv1',
+ kernel_initializer='he_normal')(
+ img_input)
+
+ x = BatchNormalization(
+ axis=channel_dim, momentum=0.9997, epsilon=1e-3, name='stem_bn1')(
+ x)
+
+ p = None
+ if not skip_reduction: # imagenet / mobile mode
+ x, p = _reduction_a_cell(
+ x, p, filters // (filter_multiplier**2), block_id='stem_1')
+ x, p = _reduction_a_cell(
+ x, p, filters // filter_multiplier, block_id='stem_2')
+
+ for i in range(num_blocks):
+ x, p = _normal_a_cell(x, p, filters, block_id='%d' % (i))
+
+ x, p0 = _reduction_a_cell(
+ x, p, filters * filter_multiplier, block_id='reduce_%d' % (num_blocks))
+
+ p = p0 if not skip_reduction else p
+
+ for i in range(num_blocks):
+ x, p = _normal_a_cell(
+ x, p, filters * filter_multiplier, block_id='%d' % (num_blocks + i + 1))
+
+ x, p0 = _reduction_a_cell(
+ x,
+ p,
+ filters * filter_multiplier**2,
+ block_id='reduce_%d' % (2 * num_blocks))
+
+ p = p0 if not skip_reduction else p
+
+ for i in range(num_blocks):
+ x, p = _normal_a_cell(
+ x,
+ p,
+ filters * filter_multiplier**2,
+ block_id='%d' % (2 * num_blocks + i + 1))
+
+ x = Activation('relu')(x)
+
+ if include_top:
+ x = GlobalAveragePooling2D()(x)
+ x = Dense(classes, activation='softmax', name='predictions')(x)
+ else:
+ if pooling == 'avg':
+ x = GlobalAveragePooling2D()(x)
+ elif pooling == 'max':
+ x = GlobalMaxPooling2D()(x)
+
+ # Ensure that the model takes into account
+ # any potential predecessors of `input_tensor`.
+ if input_tensor is not None:
+ inputs = get_source_inputs(input_tensor)
+ else:
+ inputs = img_input
+
+ model = Model(inputs, x, name='NASNet')
+
+ # load weights
+ if weights == 'imagenet':
+ if default_size == 224: # mobile version
+ if include_top:
+ weight_path = NASNET_MOBILE_WEIGHT_PATH
+ model_name = 'nasnet_mobile.h5'
+ else:
+ weight_path = NASNET_MOBILE_WEIGHT_PATH_NO_TOP
+ model_name = 'nasnet_mobile_no_top.h5'
+
+ weights_file = get_file(model_name, weight_path, cache_subdir='models')
+ model.load_weights(weights_file)
+
+ elif default_size == 331: # large version
+ if include_top:
+ weight_path = NASNET_LARGE_WEIGHT_PATH
+ model_name = 'nasnet_large.h5'
+ else:
+ weight_path = NASNET_LARGE_WEIGHT_PATH_NO_TOP
+ model_name = 'nasnet_large_no_top.h5'
+
+ weights_file = get_file(model_name, weight_path, cache_subdir='models')
+ model.load_weights(weights_file)
+ else:
+ raise ValueError('ImageNet weights can only be loaded with NASNetLarge'
+ ' or NASNetMobile')
+ elif weights is not None:
+ model.load_weights(weights)
+
+ if old_data_format:
+ K.set_image_data_format(old_data_format)
+
+ return model
+
+
+def NASNetLarge(input_shape=None,
+ include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ pooling=None,
+ classes=1000):
+ """Instantiates a NASNet model in ImageNet mode.
+
+ Note that only TensorFlow is supported for now,
+ therefore it only works with the data format
+ `image_data_format='channels_last'` in your Keras config
+ at `~/.keras/keras.json`.
+
+ Arguments:
+ input_shape: Optional shape tuple, only to be specified
+ if `include_top` is False (otherwise the input shape
+ has to be `(331, 331, 3)` for NASNetLarge.
+ It should have exactly 3 inputs channels,
+ and width and height should be no smaller than 32.
+ E.g. `(224, 224, 3)` would be one valid value.
+ include_top: Whether to include the fully-connected
+ layer at the top of the network.
+ weights: `None` (random initialization) or
+ `imagenet` (ImageNet weights)
+ input_tensor: Optional Keras tensor (i.e. output of
+ `layers.Input()`)
+ to use as image input for the model.
+ pooling: Optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model
+ will be the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a
+ 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: Optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: in case of invalid argument for `weights`,
+ or invalid input shape.
+ RuntimeError: If attempting to run this model with a
+ backend that does not support separable convolutions.
+ """
+ return NASNet(
+ input_shape,
+ penultimate_filters=4032,
+ num_blocks=6,
+ stem_block_filters=96,
+ skip_reduction=False,
+ filter_multiplier=2,
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ pooling=pooling,
+ classes=classes,
+ default_size=331)
+
+
+def NASNetMobile(input_shape=None,
+ include_top=True,
+ weights='imagenet',
+ input_tensor=None,
+ pooling=None,
+ classes=1000):
+ """Instantiates a Mobile NASNet model in ImageNet mode.
+
+ Note that only TensorFlow is supported for now,
+ therefore it only works with the data format
+ `image_data_format='channels_last'` in your Keras config
+ at `~/.keras/keras.json`.
+
+ Arguments:
+ input_shape: Optional shape tuple, only to be specified
+ if `include_top` is False (otherwise the input shape
+ has to be `(224, 224, 3)` for NASNetMobile
+ It should have exactly 3 inputs channels,
+ and width and height should be no smaller than 32.
+ E.g. `(224, 224, 3)` would be one valid value.
+ include_top: Whether to include the fully-connected
+ layer at the top of the network.
+ weights: `None` (random initialization) or
+ `imagenet` (ImageNet weights)
+ input_tensor: Optional Keras tensor (i.e. output of
+ `layers.Input()`)
+ to use as image input for the model.
+ pooling: Optional pooling mode for feature extraction
+ when `include_top` is `False`.
+ - `None` means that the output of the model
+ will be the 4D tensor output of the
+ last convolutional layer.
+ - `avg` means that global average pooling
+ will be applied to the output of the
+ last convolutional layer, and thus
+ the output of the model will be a
+ 2D tensor.
+ - `max` means that global max pooling will
+ be applied.
+ classes: Optional number of classes to classify images
+ into, only to be specified if `include_top` is True, and
+ if no `weights` argument is specified.
+
+ Returns:
+ A Keras model instance.
+
+ Raises:
+ ValueError: In case of invalid argument for `weights`,
+ or invalid input shape.
+ RuntimeError: If attempting to run this model with a
+ backend that does not support separable convolutions.
+ """
+ return NASNet(
+ input_shape,
+ penultimate_filters=1056,
+ num_blocks=4,
+ stem_block_filters=32,
+ skip_reduction=False,
+ filter_multiplier=2,
+ include_top=include_top,
+ weights=weights,
+ input_tensor=input_tensor,
+ pooling=pooling,
+ classes=classes,
+ default_size=224)
+
+
+def _separable_conv_block(ip,
+ filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ block_id=None):
+ """Adds 2 blocks of [relu-separable conv-batchnorm].
+
+ Arguments:
+ ip: Input tensor
+ filters: Number of output filters per layer
+ kernel_size: Kernel size of separable convolutions
+ strides: Strided convolution for downsampling
+ block_id: String block_id
+
+ Returns:
+ A Keras tensor
+ """
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+
+ with K.name_scope('separable_conv_block_%s' % block_id):
+ x = Activation('relu')(ip)
+ x = SeparableConv2D(
+ filters,
+ kernel_size,
+ strides=strides,
+ name='separable_conv_1_%s' % block_id,
+ padding='same',
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ x)
+ x = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='separable_conv_1_bn_%s' % (block_id))(
+ x)
+ x = Activation('relu')(x)
+ x = SeparableConv2D(
+ filters,
+ kernel_size,
+ name='separable_conv_2_%s' % block_id,
+ padding='same',
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ x)
+ x = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='separable_conv_2_bn_%s' % (block_id))(
+ x)
+ return x
+
+
+def _adjust_block(p, ip, filters, block_id=None):
+ """Adjusts the input `previous path` to match the shape of the `input`.
+
+ Used in situations where the output number of filters needs to be changed.
+
+ Arguments:
+ p: Input tensor which needs to be modified
+ ip: Input tensor whose shape needs to be matched
+ filters: Number of output filters to be matched
+ block_id: String block_id
+
+ Returns:
+ Adjusted Keras tensor
+ """
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+ img_dim = 2 if K.image_data_format() == 'channels_first' else -2
+
+ ip_shape = K.int_shape(ip)
+
+ if p is not None:
+ p_shape = K.int_shape(p)
+
+ with K.name_scope('adjust_block'):
+ if p is None:
+ p = ip
+
+ elif p_shape[img_dim] != ip_shape[img_dim]:
+ with K.name_scope('adjust_reduction_block_%s' % block_id):
+ p = Activation('relu', name='adjust_relu_1_%s' % block_id)(p)
+
+ p1 = AveragePooling2D(
+ (1, 1),
+ strides=(2, 2),
+ padding='valid',
+ name='adjust_avg_pool_1_%s' % block_id)(
+ p)
+ p1 = Conv2D(
+ filters // 2, (1, 1),
+ padding='same',
+ use_bias=False,
+ name='adjust_conv_1_%s' % block_id,
+ kernel_initializer='he_normal')(
+ p1)
+
+ p2 = ZeroPadding2D(padding=((0, 1), (0, 1)))(p)
+ p2 = Cropping2D(cropping=((1, 0), (1, 0)))(p2)
+ p2 = AveragePooling2D(
+ (1, 1),
+ strides=(2, 2),
+ padding='valid',
+ name='adjust_avg_pool_2_%s' % block_id)(
+ p2)
+ p2 = Conv2D(
+ filters // 2, (1, 1),
+ padding='same',
+ use_bias=False,
+ name='adjust_conv_2_%s' % block_id,
+ kernel_initializer='he_normal')(
+ p2)
+
+ p = concatenate([p1, p2], axis=channel_dim)
+ p = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='adjust_bn_%s' % block_id)(
+ p)
+
+ elif p_shape[channel_dim] != filters:
+ with K.name_scope('adjust_projection_block_%s' % block_id):
+ p = Activation('relu')(p)
+ p = Conv2D(
+ filters, (1, 1),
+ strides=(1, 1),
+ padding='same',
+ name='adjust_conv_projection_%s' % block_id,
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ p)
+ p = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='adjust_bn_%s' % block_id)(
+ p)
+ return p
+
+
+def _normal_a_cell(ip, p, filters, block_id=None):
+ """Adds a Normal cell for NASNet-A (Fig. 4 in the paper).
+
+ Arguments:
+ ip: Input tensor `x`
+ p: Input tensor `p`
+ filters: Number of output filters
+ block_id: String block_id
+
+ Returns:
+ A Keras tensor
+ """
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+
+ with K.name_scope('normal_A_block_%s' % block_id):
+ p = _adjust_block(p, ip, filters, block_id)
+
+ h = Activation('relu')(ip)
+ h = Conv2D(
+ filters, (1, 1),
+ strides=(1, 1),
+ padding='same',
+ name='normal_conv_1_%s' % block_id,
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ h)
+ h = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='normal_bn_1_%s' % block_id)(
+ h)
+
+ with K.name_scope('block_1'):
+ x1_1 = _separable_conv_block(
+ h, filters, kernel_size=(5, 5), block_id='normal_left1_%s' % block_id)
+ x1_2 = _separable_conv_block(
+ p, filters, block_id='normal_right1_%s' % block_id)
+ x1 = add([x1_1, x1_2], name='normal_add_1_%s' % block_id)
+
+ with K.name_scope('block_2'):
+ x2_1 = _separable_conv_block(
+ p, filters, (5, 5), block_id='normal_left2_%s' % block_id)
+ x2_2 = _separable_conv_block(
+ p, filters, (3, 3), block_id='normal_right2_%s' % block_id)
+ x2 = add([x2_1, x2_2], name='normal_add_2_%s' % block_id)
+
+ with K.name_scope('block_3'):
+ x3 = AveragePooling2D(
+ (3, 3),
+ strides=(1, 1),
+ padding='same',
+ name='normal_left3_%s' % (block_id))(
+ h)
+ x3 = add([x3, p], name='normal_add_3_%s' % block_id)
+
+ with K.name_scope('block_4'):
+ x4_1 = AveragePooling2D(
+ (3, 3),
+ strides=(1, 1),
+ padding='same',
+ name='normal_left4_%s' % (block_id))(
+ p)
+ x4_2 = AveragePooling2D(
+ (3, 3),
+ strides=(1, 1),
+ padding='same',
+ name='normal_right4_%s' % (block_id))(
+ p)
+ x4 = add([x4_1, x4_2], name='normal_add_4_%s' % block_id)
+
+ with K.name_scope('block_5'):
+ x5 = _separable_conv_block(
+ h, filters, block_id='normal_left5_%s' % block_id)
+ x5 = add([x5, h], name='normal_add_5_%s' % block_id)
+
+ x = concatenate(
+ [p, x1, x2, x3, x4, x5],
+ axis=channel_dim,
+ name='normal_concat_%s' % block_id)
+ return x, ip
+
+
+def _reduction_a_cell(ip, p, filters, block_id=None):
+ """Adds a Reduction cell for NASNet-A (Fig. 4 in the paper).
+
+ Arguments:
+ ip: Input tensor `x`
+ p: Input tensor `p`
+ filters: Number of output filters
+ block_id: String block_id
+
+ Returns:
+ A Keras tensor
+ """
+ channel_dim = 1 if K.image_data_format() == 'channels_first' else -1
+
+ with K.name_scope('reduction_A_block_%s' % block_id):
+ p = _adjust_block(p, ip, filters, block_id)
+
+ h = Activation('relu')(ip)
+ h = Conv2D(
+ filters, (1, 1),
+ strides=(1, 1),
+ padding='same',
+ name='reduction_conv_1_%s' % block_id,
+ use_bias=False,
+ kernel_initializer='he_normal')(
+ h)
+ h = BatchNormalization(
+ axis=channel_dim,
+ momentum=0.9997,
+ epsilon=1e-3,
+ name='reduction_bn_1_%s' % block_id)(
+ h)
+
+ with K.name_scope('block_1'):
+ x1_1 = _separable_conv_block(
+ h,
+ filters, (5, 5),
+ strides=(2, 2),
+ block_id='reduction_left1_%s' % block_id)
+ x1_2 = _separable_conv_block(
+ p,
+ filters, (7, 7),
+ strides=(2, 2),
+ block_id='reduction_1_%s' % block_id)
+ x1 = add([x1_1, x1_2], name='reduction_add_1_%s' % block_id)
+
+ with K.name_scope('block_2'):
+ x2_1 = MaxPooling2D(
+ (3, 3),
+ strides=(2, 2),
+ padding='same',
+ name='reduction_left2_%s' % block_id)(
+ h)
+ x2_2 = _separable_conv_block(
+ p,
+ filters, (7, 7),
+ strides=(2, 2),
+ block_id='reduction_right2_%s' % block_id)
+ x2 = add([x2_1, x2_2], name='reduction_add_2_%s' % block_id)
+
+ with K.name_scope('block_3'):
+ x3_1 = AveragePooling2D(
+ (3, 3),
+ strides=(2, 2),
+ padding='same',
+ name='reduction_left3_%s' % block_id)(
+ h)
+ x3_2 = _separable_conv_block(
+ p,
+ filters, (5, 5),
+ strides=(2, 2),
+ block_id='reduction_right3_%s' % block_id)
+ x3 = add([x3_1, x3_2], name='reduction_add3_%s' % block_id)
+
+ with K.name_scope('block_4'):
+ x4 = AveragePooling2D(
+ (3, 3),
+ strides=(1, 1),
+ padding='same',
+ name='reduction_left4_%s' % block_id)(
+ x1)
+ x4 = add([x2, x4])
+
+ with K.name_scope('block_5'):
+ x5_1 = _separable_conv_block(
+ x1, filters, (3, 3), block_id='reduction_left4_%s' % block_id)
+ x5_2 = MaxPooling2D(
+ (3, 3),
+ strides=(2, 2),
+ padding='same',
+ name='reduction_right5_%s' % block_id)(
+ h)
+ x5 = add([x5_1, x5_2], name='reduction_add4_%s' % block_id)
+
+ x = concatenate(
+ [x2, x3, x4, x5],
+ axis=channel_dim,
+ name='reduction_concat_%s' % block_id)
+ return x, ip
diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py b/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py
new file mode 100644
index 0000000000..aa1dec670c
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/applications/nasnet_test.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Nasnet application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.platform import test
+
+
+class NASNetMobileTest(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.NASNetMobile(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.NASNetMobile(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 1056))
+
+ def test_with_pooling(self):
+ model = keras.applications.NASNetMobile(weights=None,
+ include_top=False,
+ pooling='avg')
+ self.assertEqual(model.output_shape, (None, 1056))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.NASNetMobile(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.NASNetMobile(weights='imagenet',
+ classes=2000)
+
+
+class NASNetLargeTest(test.TestCase):
+
+ def test_with_top(self):
+ model = keras.applications.NASNetLarge(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+
+ def test_no_top(self):
+ model = keras.applications.NASNetLarge(weights=None, include_top=False)
+ self.assertEqual(model.output_shape, (None, None, None, 4032))
+
+ def test_with_pooling(self):
+ model = keras.applications.NASNetLarge(weights=None,
+ include_top=False,
+ pooling='avg')
+ self.assertEqual(model.output_shape, (None, 4032))
+
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.NASNetLarge(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.NASNetLarge(weights='imagenet',
+ classes=2000)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
index 8ab46693aa..5705b3481a 100644
--- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""ResNet50 model for Keras.
# Reference:
@@ -31,8 +32,8 @@ import os
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
@@ -45,7 +46,9 @@ from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils import layer_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'
@@ -78,7 +81,8 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
x = Activation('relu')(x)
x = Conv2D(
- filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
+ filters2, kernel_size, padding='same', name=conv_name_base + '2b')(
+ x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)
@@ -92,7 +96,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
2)):
- """conv_block is the block that has a conv layer at shortcut.
+ """A block that has a conv layer at shortcut.
Arguments:
input_tensor: input tensor
@@ -100,14 +104,14 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
- strides: Tuple of integers.
+ strides: Strides for the first conv layer in the block.
Returns:
Output tensor for the block.
- Note that from stage 3, the first conv layer at main path is with
- strides=(2,2)
- And the shortcut should have strides=(2,2) as well
+ Note that from stage 3,
+ the first conv layer at main path is with strides=(2, 2)
+ And the shortcut should have strides=(2, 2) as well
"""
filters1, filters2, filters3 = filters
if K.image_data_format() == 'channels_last':
@@ -118,13 +122,14 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = Conv2D(
- filters1, (1, 1), strides=strides,
- name=conv_name_base + '2a')(input_tensor)
+ filters1, (1, 1), strides=strides, name=conv_name_base + '2a')(
+ input_tensor)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = Activation('relu')(x)
x = Conv2D(
- filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
+ filters2, kernel_size, padding='same', name=conv_name_base + '2b')(
+ x)
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = Activation('relu')(x)
@@ -132,8 +137,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2,
x = BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
shortcut = Conv2D(
- filters3, (1, 1), strides=strides,
- name=conv_name_base + '1')(input_tensor)
+ filters3, (1, 1), strides=strides, name=conv_name_base + '1')(
+ input_tensor)
shortcut = BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
x = layers.add([x, shortcut])
@@ -152,7 +157,7 @@ def ResNet50(include_top=True,
Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
- `image_data_format="channels_last"` in your Keras config
+ `image_data_format='channels_last'` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
@@ -164,15 +169,15 @@ def ResNet50(include_top=True,
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or `(3, 224, 224)` (with `channels_first` data format).
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 197.
E.g. `(200, 200, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
@@ -219,15 +224,18 @@ def ResNet50(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
-
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
if K.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1
- x = Conv2D(64, (7, 7),
- strides=(2, 2), padding='same', name='conv1')(img_input)
+ x = Conv2D(
+ 64, (7, 7), strides=(2, 2), padding='same', name='conv1')(
+ img_input)
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
@@ -289,4 +297,5 @@ def ResNet50(include_top=True,
model.load_weights(weights_path)
elif weights is not None:
model.load_weights(weights)
+
return model
diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16.py b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
index 38dbbdc809..c91c24e6fb 100644
--- a/tensorflow/python/keras/_impl/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""VGG16 model for Keras.
# Reference
@@ -29,8 +30,8 @@ import os
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Conv2D
from tensorflow.python.keras._impl.keras.layers import Dense
@@ -42,6 +43,7 @@ from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils import layer_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
@@ -59,7 +61,7 @@ def VGG16(include_top=True,
Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
- `image_data_format="channels_last"` in your Keras config
+ `image_data_format='channels_last'` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
@@ -71,8 +73,8 @@ def VGG16(include_top=True,
include_top: whether to include the 3 fully-connected
layers at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
@@ -125,48 +127,62 @@ def VGG16(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
-
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
# Block 1
x = Conv2D(
- 64, (3, 3), activation='relu', padding='same',
- name='block1_conv1')(img_input)
+ 64, (3, 3), activation='relu', padding='same', name='block1_conv1')(
+ img_input)
x = Conv2D(
- 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
+ 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
x = Conv2D(
- 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
+ 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(
+ x)
x = Conv2D(
- 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
+ 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
if include_top:
@@ -215,6 +231,8 @@ def VGG16(include_top=True,
dense = model.get_layer(name='fc1')
layer_utils.convert_dense_weights_data_format(dense, shape,
'channels_first')
+
elif weights is not None:
model.load_weights(weights)
+
return model
diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19.py b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
index 126c64260b..223cd79d7b 100644
--- a/tensorflow/python/keras/_impl/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""VGG19 model for Keras.
# Reference
@@ -29,8 +30,8 @@ import os
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Conv2D
from tensorflow.python.keras._impl.keras.layers import Dense
@@ -42,6 +43,7 @@ from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils import layer_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5'
@@ -59,7 +61,7 @@ def VGG19(include_top=True,
Optionally loads weights pre-trained
on ImageNet. Note that when using TensorFlow,
for best performance you should set
- `image_data_format="channels_last"` in your Keras config
+ `image_data_format='channels_last'` in your Keras config
at ~/.keras/keras.json.
The model and the weights are compatible with both
@@ -71,15 +73,15 @@ def VGG19(include_top=True,
include_top: whether to include the 3 fully-connected
layers at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(224, 224, 3)` (with `channels_last` data format)
or `(3, 224, 224)` (with `channels_first` data format).
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 48.
E.g. `(200, 200, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
@@ -125,54 +127,71 @@ def VGG19(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
-
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
# Block 1
x = Conv2D(
- 64, (3, 3), activation='relu', padding='same',
- name='block1_conv1')(img_input)
+ 64, (3, 3), activation='relu', padding='same', name='block1_conv1')(
+ img_input)
x = Conv2D(
- 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
+ 64, (3, 3), activation='relu', padding='same', name='block1_conv2')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
# Block 2
x = Conv2D(
- 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
+ 128, (3, 3), activation='relu', padding='same', name='block2_conv1')(
+ x)
x = Conv2D(
- 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
+ 128, (3, 3), activation='relu', padding='same', name='block2_conv2')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
# Block 3
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv1')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv2')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv3')(
+ x)
x = Conv2D(
- 256, (3, 3), activation='relu', padding='same', name='block3_conv4')(x)
+ 256, (3, 3), activation='relu', padding='same', name='block3_conv4')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
# Block 4
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv1')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv2')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv3')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block4_conv4')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block4_conv4')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
# Block 5
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv1')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv2')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv3')(
+ x)
x = Conv2D(
- 512, (3, 3), activation='relu', padding='same', name='block5_conv4')(x)
+ 512, (3, 3), activation='relu', padding='same', name='block5_conv4')(
+ x)
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
if include_top:
@@ -211,6 +230,8 @@ def VGG19(include_top=True,
cache_subdir='models',
file_hash='253f8cb515780f3b799900260a226db6')
model.load_weights(weights_path)
+ if K.backend() == 'theano':
+ layer_utils.convert_all_kernels_in_model(model)
if K.image_data_format() == 'channels_first':
if include_top:
@@ -219,6 +240,8 @@ def VGG19(include_top=True,
dense = model.get_layer(name='fc1')
layer_utils.convert_dense_weights_data_format(dense, shape,
'channels_first')
+
elif weights is not None:
model.load_weights(weights)
+
return model
diff --git a/tensorflow/python/keras/_impl/keras/applications/xception.py b/tensorflow/python/keras/_impl/keras/applications/xception.py
index 8219831408..0a6eb4953a 100644
--- a/tensorflow/python/keras/_impl/keras/applications/xception.py
+++ b/tensorflow/python/keras/_impl/keras/applications/xception.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
+# pylint: disable=unused-import
"""Xception V1 model for Keras.
On ImageNet, this model gets to a top-1 validation accuracy of 0.790
@@ -42,7 +43,7 @@ from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import layers
from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
@@ -74,7 +75,7 @@ def Xception(include_top=True,
on ImageNet. This model is available for TensorFlow only,
and can only be used with inputs following the TensorFlow
data format `(width, height, channels)`.
- You should set `image_data_format="channels_last"` in your Keras config
+ You should set `image_data_format='channels_last'` in your Keras config
located at ~/.keras/keras.json.
Note that the default input image size for this model is 299x299.
@@ -83,14 +84,14 @@ def Xception(include_top=True,
include_top: whether to include the fully-connected
layer at the top of the network.
weights: one of `None` (random initialization),
- 'imagenet' (pre-training on ImageNet),
- or the path to the weights file to be loaded.
+ 'imagenet' (pre-training on ImageNet),
+ or the path to the weights file to be loaded.
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
input_shape: optional shape tuple, only to be specified
if `include_top` is False (otherwise the input shape
has to be `(299, 299, 3)`.
- It should have exactly 3 input channels,
+ It should have exactly 3 inputs channels,
and width and height should be no smaller than 71.
E.g. `(150, 150, 3)` would be one valid value.
pooling: Optional pooling mode for feature extraction
@@ -155,11 +156,14 @@ def Xception(include_top=True,
if input_tensor is None:
img_input = Input(shape=input_shape)
else:
- img_input = Input(tensor=input_tensor, shape=input_shape)
+ if not K.is_keras_tensor(input_tensor):
+ img_input = Input(tensor=input_tensor, shape=input_shape)
+ else:
+ img_input = input_tensor
x = Conv2D(
- 32, (3, 3), strides=(2, 2), use_bias=False,
- name='block1_conv1')(img_input)
+ 32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(
+ img_input)
x = BatchNormalization(name='block1_conv1_bn')(x)
x = Activation('relu', name='block1_conv1_act')(x)
x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
@@ -167,53 +171,65 @@ def Xception(include_top=True,
x = Activation('relu', name='block1_conv2_act')(x)
residual = Conv2D(
- 128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
+ 128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(
+ x)
residual = BatchNormalization()(residual)
x = SeparableConv2D(
- 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
+ 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(
+ x)
x = BatchNormalization(name='block2_sepconv1_bn')(x)
x = Activation('relu', name='block2_sepconv2_act')(x)
x = SeparableConv2D(
- 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
+ 128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(
+ x)
x = BatchNormalization(name='block2_sepconv2_bn')(x)
x = MaxPooling2D(
- (3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
+ (3, 3), strides=(2, 2), padding='same', name='block2_pool')(
+ x)
x = layers.add([x, residual])
residual = Conv2D(
- 256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
+ 256, (1, 1), strides=(2, 2), padding='same', use_bias=False)(
+ x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block3_sepconv1_act')(x)
x = SeparableConv2D(
- 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
+ 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(
+ x)
x = BatchNormalization(name='block3_sepconv1_bn')(x)
x = Activation('relu', name='block3_sepconv2_act')(x)
x = SeparableConv2D(
- 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
+ 256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(
+ x)
x = BatchNormalization(name='block3_sepconv2_bn')(x)
x = MaxPooling2D(
- (3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
+ (3, 3), strides=(2, 2), padding='same', name='block3_pool')(
+ x)
x = layers.add([x, residual])
residual = Conv2D(
- 728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
+ 728, (1, 1), strides=(2, 2), padding='same', use_bias=False)(
+ x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block4_sepconv1_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
+ 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(
+ x)
x = BatchNormalization(name='block4_sepconv1_bn')(x)
x = Activation('relu', name='block4_sepconv2_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
+ 728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(
+ x)
x = BatchNormalization(name='block4_sepconv2_bn')(x)
x = MaxPooling2D(
- (3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
+ (3, 3), strides=(2, 2), padding='same', name='block4_pool')(
+ x)
x = layers.add([x, residual])
for i in range(8):
@@ -222,46 +238,52 @@ def Xception(include_top=True,
x = Activation('relu', name=prefix + '_sepconv1_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False,
- name=prefix + '_sepconv1')(x)
+ 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(
+ x)
x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
x = Activation('relu', name=prefix + '_sepconv2_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False,
- name=prefix + '_sepconv2')(x)
+ 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(
+ x)
x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
x = Activation('relu', name=prefix + '_sepconv3_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False,
- name=prefix + '_sepconv3')(x)
+ 728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(
+ x)
x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
x = layers.add([x, residual])
residual = Conv2D(
- 1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
+ 1024, (1, 1), strides=(2, 2), padding='same', use_bias=False)(
+ x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block13_sepconv1_act')(x)
x = SeparableConv2D(
- 728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
+ 728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(
+ x)
x = BatchNormalization(name='block13_sepconv1_bn')(x)
x = Activation('relu', name='block13_sepconv2_act')(x)
x = SeparableConv2D(
- 1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
+ 1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(
+ x)
x = BatchNormalization(name='block13_sepconv2_bn')(x)
x = MaxPooling2D(
- (3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
+ (3, 3), strides=(2, 2), padding='same', name='block13_pool')(
+ x)
x = layers.add([x, residual])
x = SeparableConv2D(
- 1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
+ 1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(
+ x)
x = BatchNormalization(name='block14_sepconv1_bn')(x)
x = Activation('relu', name='block14_sepconv1_act')(x)
x = SeparableConv2D(
- 2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
+ 2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(
+ x)
x = BatchNormalization(name='block14_sepconv2_bn')(x)
x = Activation('relu', name='block14_sepconv2_act')(x)
@@ -303,8 +325,6 @@ def Xception(include_top=True,
if old_data_format:
K.set_image_data_format(old_data_format)
- elif weights is not None:
- model.load_weights(weights)
return model
diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
index 9476085bd8..460c0dc5f3 100644
--- a/tensorflow/python/keras/_impl/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -85,7 +85,7 @@ _MANUAL_VAR_INIT = False
_FLOATX = 'float32'
# Epsilon fuzz factor used throughout the codebase.
-_EPSILON = 10e-8
+_EPSILON = 1e-7
# Default image data format, one of "channels_last", "channels_first".
_IMAGE_DATA_FORMAT = 'channels_last'
@@ -116,7 +116,7 @@ def epsilon():
Example:
```python
>>> keras.backend.epsilon()
- 1e-08
+ 1e-07
```
"""
return _EPSILON
@@ -132,7 +132,7 @@ def set_epsilon(value):
```python
>>> from keras import backend as K
>>> K.epsilon()
- 1e-08
+ 1e-07
>>> K.set_epsilon(1e-05)
>>> K.epsilon()
1e-05
@@ -295,7 +295,8 @@ def clear_session():
ops.reset_default_graph()
reset_uids()
_SESSION = None
- phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
+ phase = array_ops.placeholder_with_default(
+ False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
@@ -328,7 +329,8 @@ def learning_phase():
"""
graph = ops.get_default_graph()
if graph not in _GRAPH_LEARNING_PHASES:
- phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
+ phase = array_ops.placeholder_with_default(
+ False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES[graph] = phase
return _GRAPH_LEARNING_PHASES[graph]
@@ -876,6 +878,8 @@ def zeros(shape, dtype=None, name=None):
Returns:
A variable (including Keras metadata), filled with `0.0`.
+ Note that if `shape` was symbolic, we cannot return a variable,
+ and will return a dynamically-shaped tensor instead.
Example:
```python
@@ -890,12 +894,14 @@ def zeros(shape, dtype=None, name=None):
if dtype is None:
dtype = floatx()
tf_dtype = dtypes_module.as_dtype(dtype)
- return variable(
- init_ops.constant_initializer(0., dtype=tf_dtype)(shape), dtype, name)
+ v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
def ones(shape, dtype=None, name=None):
- """Instantiates an all-ones tensor variable and returns it.
+ """Instantiates an all-ones variable and returns it.
Arguments:
shape: Tuple of integers, shape of returned Keras variable.
@@ -904,6 +910,8 @@ def ones(shape, dtype=None, name=None):
Returns:
A Keras variable, filled with `1.0`.
+ Note that if `shape` was symbolic, we cannot return a variable,
+ and will return a dynamically-shaped tensor instead.
Example:
```python
@@ -918,8 +926,10 @@ def ones(shape, dtype=None, name=None):
if dtype is None:
dtype = floatx()
tf_dtype = dtypes_module.as_dtype(dtype)
- return variable(
- init_ops.constant_initializer(1., dtype=tf_dtype)(shape), dtype, name)
+ v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
def eye(size, dtype=None, name=None):
@@ -1185,7 +1195,7 @@ def moving_average_update(x, value, momentum):
An Operation to update the variable.
"""
return moving_averages.assign_moving_average(
- x, value, momentum, zero_debias=False)
+ x, value, momentum, zero_debias=True)
# LINEAR ALGEBRA
@@ -1419,7 +1429,7 @@ def max(x, axis=None, keepdims=False):
Returns:
A tensor with maximum values of `x`.
"""
- return math_ops.reduce_max(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_max(x, axis, keepdims)
def min(x, axis=None, keepdims=False):
@@ -1436,7 +1446,7 @@ def min(x, axis=None, keepdims=False):
Returns:
A tensor with miminum values of `x`.
"""
- return math_ops.reduce_min(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_min(x, axis, keepdims)
def sum(x, axis=None, keepdims=False):
@@ -1453,7 +1463,7 @@ def sum(x, axis=None, keepdims=False):
Returns:
A tensor with sum of `x`.
"""
- return math_ops.reduce_sum(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_sum(x, axis, keepdims)
def prod(x, axis=None, keepdims=False):
@@ -1470,7 +1480,7 @@ def prod(x, axis=None, keepdims=False):
Returns:
A tensor with the product of elements of `x`.
"""
- return math_ops.reduce_prod(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_prod(x, axis, keepdims)
def cumsum(x, axis=0):
@@ -1515,10 +1525,10 @@ def var(x, axis=None, keepdims=False):
"""
if x.dtype.base_dtype == dtypes_module.bool:
x = math_ops.cast(x, floatx())
- m = math_ops.reduce_mean(x, axis=axis, keep_dims=True)
+ m = math_ops.reduce_mean(x, axis, True)
devs_squared = math_ops.square(x - m)
return math_ops.reduce_mean(
- devs_squared, axis=axis, keep_dims=keepdims)
+ devs_squared, axis, keepdims)
def std(x, axis=None, keepdims=False):
@@ -1546,7 +1556,7 @@ def mean(x, axis=None, keepdims=False):
axis: A list of integer. Axes to compute the mean.
keepdims: A boolean, whether to keep the dimensions or not.
If `keepdims` is `False`, the rank of the tensor is reduced
- by 1 for each entry in `axis`. If `keep_dims` is `True`,
+ by 1 for each entry in `axis`. If `keepdims` is `True`,
the reduced dimensions are retained with length 1.
Returns:
@@ -1554,7 +1564,7 @@ def mean(x, axis=None, keepdims=False):
"""
if x.dtype.base_dtype == dtypes_module.bool:
x = math_ops.cast(x, floatx())
- return math_ops.reduce_mean(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_mean(x, axis, keepdims)
def any(x, axis=None, keepdims=False):
@@ -1569,7 +1579,7 @@ def any(x, axis=None, keepdims=False):
A uint8 tensor (0s and 1s).
"""
x = math_ops.cast(x, dtypes_module.bool)
- return math_ops.reduce_any(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_any(x, axis, keepdims)
def all(x, axis=None, keepdims=False):
@@ -1584,7 +1594,7 @@ def all(x, axis=None, keepdims=False):
A uint8 tensor (0s and 1s).
"""
x = math_ops.cast(x, dtypes_module.bool)
- return math_ops.reduce_all(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_all(x, axis, keepdims)
def argmax(x, axis=-1):
@@ -1694,7 +1704,7 @@ def logsumexp(x, axis=None, keepdims=False):
Returns:
The reduced tensor.
"""
- return math_ops.reduce_logsumexp(x, axis=axis, keep_dims=keepdims)
+ return math_ops.reduce_logsumexp(x, axis, keepdims)
def round(x):
@@ -1884,6 +1894,108 @@ def cos(x):
return math_ops.cos(x)
+def _regular_normalize_batch_in_training(x,
+ gamma,
+ beta,
+ reduction_axes,
+ epsilon=1e-3):
+ """Non-fused version of `normalize_batch_in_training`.
+
+ Arguments:
+ x: Input tensor or variable.
+ gamma: Tensor by which to scale the input.
+ beta: Tensor with which to center the input.
+ reduction_axes: iterable of integers,
+ axes over which to normalize.
+ epsilon: Fuzz factor.
+
+ Returns:
+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
+ """
+ mean, var = nn.moments(x, reduction_axes, None, None, False)
+ normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
+ return normed, mean, var
+
+
+def _broadcast_normalize_batch_in_training(x,
+ gamma,
+ beta,
+ reduction_axes,
+ epsilon=1e-3):
+ """Non-fused, broadcast version of `normalize_batch_in_training`.
+
+ Arguments:
+ x: Input tensor or variable.
+ gamma: Tensor by which to scale the input.
+ beta: Tensor with which to center the input.
+ reduction_axes: iterable of integers,
+ axes over which to normalize.
+ epsilon: Fuzz factor.
+
+ Returns:
+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
+ """
+ mean, var = nn.moments(x, reduction_axes, None, None, False)
+ target_shape = []
+ for axis in range(ndim(x)):
+ if axis in reduction_axes:
+ target_shape.append(1)
+ else:
+ target_shape.append(array_ops.shape(x)[axis])
+ target_shape = array_ops.stack(target_shape)
+
+ broadcast_mean = array_ops.reshape(mean, target_shape)
+ broadcast_var = array_ops.reshape(var, target_shape)
+ if gamma is None:
+ broadcast_gamma = None
+ else:
+ broadcast_gamma = array_ops.reshape(gamma, target_shape)
+ if beta is None:
+ broadcast_beta = None
+ else:
+ broadcast_beta = array_ops.reshape(beta, target_shape)
+
+ normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
+ broadcast_beta, broadcast_gamma, epsilon)
+ return normed, mean, var
+
+
+def _fused_normalize_batch_in_training(x,
+ gamma,
+ beta,
+ reduction_axes,
+ epsilon=1e-3):
+ """Fused version of `normalize_batch_in_training`.
+
+ Arguments:
+ x: Input tensor or variable.
+ gamma: Tensor by which to scale the input.
+ beta: Tensor with which to center the input.
+ reduction_axes: iterable of integers,
+ axes over which to normalize.
+ epsilon: Fuzz factor.
+
+ Returns:
+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
+ """
+ if list(reduction_axes) == [0, 1, 2]:
+ normalization_axis = 3
+ tf_data_format = 'NHWC'
+ else:
+ normalization_axis = 1
+ tf_data_format = 'NCHW'
+
+ if gamma is None:
+ gamma = constant_op.constant(
+ 1.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]])
+ if beta is None:
+ beta = constant_op.constant(
+ 0.0, dtype=x.dtype, shape=[x.get_shape()[normalization_axis]])
+
+ return nn.fused_batch_norm(
+ x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
+
+
def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
"""Computes mean and std for batch then apply batch_normalization on batch.
@@ -1898,33 +2010,19 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
Returns:
A tuple length of 3, `(normalized_tensor, mean, variance)`.
"""
- mean, var = nn.moments(
- x, reduction_axes, shift=None, name=None, keep_dims=False)
- if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
- normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
+ if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
+ if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
+ return _broadcast_normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=epsilon)
+ return _fused_normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=epsilon)
else:
- # need broadcasting
- target_shape = []
- for axis in range(ndim(x)):
- if axis in reduction_axes:
- target_shape.append(1)
- else:
- target_shape.append(array_ops.shape(x)[axis])
- target_shape = array_ops.stack(target_shape)
-
- broadcast_mean = array_ops.reshape(mean, target_shape)
- broadcast_var = array_ops.reshape(var, target_shape)
- if gamma is None:
- broadcast_gamma = None
+ if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
+ return _regular_normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=epsilon)
else:
- broadcast_gamma = array_ops.reshape(gamma, target_shape)
- if beta is None:
- broadcast_beta = None
- else:
- broadcast_beta = array_ops.reshape(beta, target_shape)
- normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
- broadcast_beta, broadcast_gamma, epsilon)
- return normed, mean, var
+ return _broadcast_normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=epsilon)
def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
@@ -2619,7 +2717,8 @@ def rnn(step_function,
go_backwards=False,
mask=None,
constants=None,
- unroll=False):
+ unroll=False,
+ input_length=None):
"""Iterates over the time dimension of a tensor.
Arguments:
@@ -2648,6 +2747,7 @@ def rnn(step_function,
constants: a list of constant values passed at each step.
unroll: whether to unroll the RNN or to use a symbolic loop
(`while_loop` or `scan` depending on backend).
+ input_length: Unused; exists for API compatibility.
Returns:
A tuple, `(last_output, outputs, new_states)`.
@@ -2665,6 +2765,7 @@ def rnn(step_function,
ValueError: if `mask` is provided (not `None`) but states is not provided
(`len(states)` == 0).
"""
+ del input_length
ndim = len(inputs.get_shape())
if ndim < 3:
raise ValueError('Input should be at least 3D.')
@@ -3016,7 +3117,7 @@ def elu(x, alpha=1.):
Arguments:
x: A tensor or variable to compute the activation function for.
- alpha: A scalar, slope of positive section.
+ alpha: A scalar, slope of negative section.
Returns:
A tensor.
@@ -3083,7 +3184,7 @@ def categorical_crossentropy(target, output, from_logits=False):
if not from_logits:
# scale preds so that the class probas of each sample sum to 1
output /= math_ops.reduce_sum(
- output, axis=len(output.get_shape()) - 1, keep_dims=True)
+ output, len(output.get_shape()) - 1, True)
# manual computation of crossentropy
epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
@@ -3248,6 +3349,25 @@ def in_top_k(predictions, targets, k):
# CONVOLUTIONS
+def _preprocess_conv1d_input(x, data_format):
+ """Transpose and cast the input before the conv1d.
+
+ Arguments:
+ x: input tensor.
+ data_format: string, `"channels_last"` or `"channels_first"`.
+
+ Returns:
+ A tensor.
+ """
+ tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
+ if data_format == 'channels_first':
+ if not _has_nchw_support():
+ x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC
+ else:
+ tf_data_format = 'NCHW'
+ return x, tf_data_format
+
+
def _preprocess_conv2d_input(x, data_format):
"""Transpose and cast the input before the conv2d.
@@ -3461,6 +3581,66 @@ def conv2d_transpose(x,
return x
+def separable_conv1d(x,
+ depthwise_kernel,
+ pointwise_kernel,
+ strides=1,
+ padding='valid',
+ data_format=None,
+ dilation_rate=1):
+ """1D convolution with separable filters.
+
+ Arguments:
+ x: input tensor
+ depthwise_kernel: convolution kernel for the depthwise convolution.
+ pointwise_kernel: kernel for the 1x1 convolution.
+ strides: stride integer.
+ padding: string, `"same"` or `"valid"`.
+ data_format: string, `"channels_last"` or `"channels_first"`.
+ dilation_rate: integer dilation rate.
+
+ Returns:
+ Output tensor.
+
+ Raises:
+ ValueError: if `data_format` is neither `channels_last` or
+ `channels_first`.
+ """
+ if data_format is None:
+ data_format = image_data_format()
+ if data_format not in {'channels_first', 'channels_last'}:
+ raise ValueError('Unknown data_format ' + str(data_format))
+
+ x, tf_data_format = _preprocess_conv1d_input(x, data_format)
+ padding = _preprocess_padding(padding)
+ if tf_data_format == 'NHWC':
+ spatial_start_dim = 1
+ strides = (1, 1) + strides + (1,)
+ else:
+ spatial_start_dim = 2
+ strides = (1, 1, 1) + strides
+ x = array_ops.expand_dims(x, spatial_start_dim)
+ depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
+ pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
+ dilation_rate = (1,) + dilation_rate
+
+ x = nn.separable_conv2d(
+ x,
+ depthwise_kernel,
+ pointwise_kernel,
+ strides=strides,
+ padding=padding,
+ rate=dilation_rate,
+ data_format=tf_data_format)
+
+ x = array_ops.squeeze(x, [spatial_start_dim])
+
+ if data_format == 'channels_first' and tf_data_format == 'NHWC':
+ x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
+
+ return x
+
+
def separable_conv2d(x,
depthwise_kernel,
pointwise_kernel,
@@ -3921,7 +4101,10 @@ def bias_add(x, bias, data_format=None):
elif ndim(x) == 4:
if data_format == 'channels_first':
if len(bias_shape) == 1:
- x += reshape(bias, (1, bias_shape[0], 1, 1))
+ if _has_nchw_support():
+ x = nn.bias_add(x, bias, data_format='NCHW')
+ else:
+ x += reshape(bias, (1, bias_shape[0], 1, 1))
else:
x += reshape(bias, (1, bias_shape[2]) + bias_shape[:2])
elif data_format == 'channels_last':
@@ -4113,7 +4296,7 @@ def ctc_batch_cost(y_true, y_pred, input_length, label_length):
sparse_labels = math_ops.to_int32(
ctc_label_dense_to_sparse(y_true, label_length))
- y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
+ y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
return array_ops.expand_dims(
ctc.ctc_loss(
@@ -4148,7 +4331,7 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
Tensor `(top_paths, )` that contains
the log probability of each decoded sequence.
"""
- y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
+ y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
input_length = math_ops.to_int32(input_length)
if greedy:
diff --git a/tensorflow/python/keras/_impl/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py
index e34f1b6926..27833e368d 100644
--- a/tensorflow/python/keras/_impl/keras/backend_test.py
+++ b/tensorflow/python/keras/_impl/keras/backend_test.py
@@ -954,7 +954,6 @@ class BackendNNOpsTest(test.TestCase):
x = keras.backend.variable(val)
reduction_axes = (0, 2, 3)
- # case: need broadcasting
g_val = np.random.random((3,))
b_val = np.random.random((3,))
gamma = keras.backend.variable(g_val)
@@ -965,17 +964,6 @@ class BackendNNOpsTest(test.TestCase):
self.assertEqual(mean.get_shape().as_list(), [3,])
self.assertEqual(var.get_shape().as_list(), [3,])
- # case: doesn't need broadcasting
- g_val = np.random.random((1, 3, 1, 1))
- b_val = np.random.random((1, 3, 1, 1))
- gamma = keras.backend.variable(g_val)
- beta = keras.backend.variable(b_val)
- normed, mean, var = keras.backend.normalize_batch_in_training(
- x, gamma, beta, reduction_axes, epsilon=1e-3)
- self.assertEqual(normed.get_shape().as_list(), [10, 3, 10, 10])
- self.assertEqual(mean.get_shape().as_list(), [3,])
- self.assertEqual(var.get_shape().as_list(), [3,])
-
# case: gamma=None
gamma = None
normed, mean, var = keras.backend.normalize_batch_in_training(
diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py
index 8da3b85718..f0d9e0b0f5 100644
--- a/tensorflow/python/keras/_impl/keras/callbacks.py
+++ b/tensorflow/python/keras/_impl/keras/callbacks.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras callbacks: utilities called at certain points during model training.
+# pylint: disable=g-import-not-at-top
+"""Callbacks: utilities called at certain points during model training.
"""
from __future__ import absolute_import
from __future__ import division
@@ -36,12 +37,10 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
-# pylint: disable=g-import-not-at-top
try:
import requests
except ImportError:
requests = None
-# pylint: enable=g-import-not-at-top
class CallbackList(object):
@@ -109,9 +108,9 @@ class CallbackList(object):
delta_t_median = np.median(self._delta_ts_batch_begin)
if (self._delta_t_batch > 0. and
delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1):
- logging.warning(
- 'Method on_batch_begin() is slow compared '
- 'to the batch update (%f). Check your callbacks.' % delta_t_median)
+ logging.warning('Method on_batch_begin() is slow compared '
+ 'to the batch update (%f). Check your callbacks.',
+ delta_t_median)
self._t_enter_batch = time.time()
def on_batch_end(self, batch, logs=None):
@@ -132,9 +131,9 @@ class CallbackList(object):
delta_t_median = np.median(self._delta_ts_batch_end)
if (self._delta_t_batch > 0. and
(delta_t_median > 0.95 * self._delta_t_batch and delta_t_median > 0.1)):
- logging.warning(
- 'Method on_batch_end() is slow compared '
- 'to the batch update (%f). Check your callbacks.' % delta_t_median)
+ logging.warning('Method on_batch_end() is slow compared '
+ 'to the batch update (%f). Check your callbacks.',
+ delta_t_median)
def on_train_begin(self, logs=None):
"""Called at the beginning of training.
@@ -246,7 +245,8 @@ class BaseLogger(Callback):
class TerminateOnNaN(Callback):
- """Callback that terminates training when a NaN loss is encountered."""
+ """Callback that terminates training when a NaN loss is encountered.
+ """
def __init__(self):
super(TerminateOnNaN, self).__init__()
@@ -396,7 +396,7 @@ class ModelCheckpoint(Callback):
if mode not in ['auto', 'min', 'max']:
logging.warning('ModelCheckpoint mode %s is unknown, '
- 'fallback to auto mode.' % mode)
+ 'fallback to auto mode.', (mode), RuntimeWarning)
mode = 'auto'
if mode == 'min':
@@ -423,11 +423,11 @@ class ModelCheckpoint(Callback):
current = logs.get(self.monitor)
if current is None:
logging.warning('Can save best model only with %s available, '
- 'skipping.' % (self.monitor))
+ 'skipping.', self.monitor, RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
- print('Epoch %05d: %s improved from %0.5f to %0.5f,'
+ print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s' % (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
@@ -437,11 +437,11 @@ class ModelCheckpoint(Callback):
self.model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
- print('Epoch %05d: %s did not improve' % (epoch + 1,
- self.monitor))
+ print('\nEpoch %05d: %s did not improve' % (epoch + 1,
+ self.monitor))
else:
if self.verbose > 0:
- print('Epoch %05d: saving model to %s' % (epoch + 1, filepath))
+ print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
@@ -486,7 +486,7 @@ class EarlyStopping(Callback):
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
- 'fallback to auto mode.' % mode)
+ 'fallback to auto mode.', mode, RuntimeWarning)
mode = 'auto'
if mode == 'min':
@@ -514,8 +514,8 @@ class EarlyStopping(Callback):
current = logs.get(self.monitor)
if current is None:
logging.warning('Early stopping conditioned on metric `%s` '
- 'which is not available. Available metrics are: %s' %
- (self.monitor, ','.join(list(logs.keys()))))
+ 'which is not available. Available metrics are: %s',
+ self.monitor, ','.join(list(logs.keys())), RuntimeWarning)
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
@@ -544,8 +544,6 @@ class RemoteMonitor(Callback):
path: String; path relative to `root` to which the events will be sent.
field: String; JSON field under which the data will be stored.
headers: Dictionary; optional custom HTTP headers.
- Defaults to:
- `{'Accept': 'application/json', 'Content-Type': 'application/json'}`
"""
def __init__(self,
@@ -554,11 +552,7 @@ class RemoteMonitor(Callback):
field='data',
headers=None):
super(RemoteMonitor, self).__init__()
- if headers is None:
- headers = {
- 'Accept': 'application/json',
- 'Content-Type': 'application/json'
- }
+
self.root = root
self.path = path
self.field = field
@@ -588,11 +582,13 @@ class LearningRateScheduler(Callback):
schedule: a function that takes an epoch index as input
(integer, indexed from 0) and returns a new
learning rate as output (float).
+ verbose: int. 0: quiet, 1: update messages.
"""
- def __init__(self, schedule):
+ def __init__(self, schedule, verbose=0):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
+ self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
@@ -602,6 +598,9 @@ class LearningRateScheduler(Callback):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.lr, lr)
+ if self.verbose > 0:
+ print('\nEpoch %05d: LearningRateScheduler reducing learning '
+ 'rate to %s.' % (epoch + 1, lr))
class TensorBoard(Callback):
@@ -842,7 +841,7 @@ class ReduceLROnPlateau(Callback):
"""
if self.mode not in ['auto', 'min', 'max']:
logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
- 'fallback to auto mode.' % (self.mode))
+ 'fallback to auto mode.', self.mode, RuntimeWarning)
self.mode = 'auto'
if (self.mode == 'min' or
(self.mode == 'auto' and 'acc' not in self.monitor)):
@@ -853,7 +852,6 @@ class ReduceLROnPlateau(Callback):
self.best = -np.Inf
self.cooldown_counter = 0
self.wait = 0
- self.lr_epsilon = self.min_lr * 1e-4
def on_train_begin(self, logs=None):
self._reset()
@@ -864,8 +862,9 @@ class ReduceLROnPlateau(Callback):
current = logs.get(self.monitor)
if current is None:
logging.warning('Reduce LR on plateau conditioned on metric `%s` '
- 'which is not available. Available metrics are: %s' %
- (self.monitor, ','.join(list(logs.keys()))))
+ 'which is not available. Available metrics are: %s',
+ self.monitor, ','.join(list(logs.keys())), RuntimeWarning)
+
else:
if self.in_cooldown():
self.cooldown_counter -= 1
@@ -877,13 +876,13 @@ class ReduceLROnPlateau(Callback):
elif not self.in_cooldown():
if self.wait >= self.patience:
old_lr = float(K.get_value(self.model.optimizer.lr))
- if old_lr > self.min_lr + self.lr_epsilon:
+ if old_lr > self.min_lr:
new_lr = old_lr * self.factor
new_lr = max(new_lr, self.min_lr)
K.set_value(self.model.optimizer.lr, new_lr)
if self.verbose > 0:
- print('\nEpoch %05d: reducing learning rate to %s.' % (epoch,
- new_lr))
+ print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
+ 'rate to %s.' % (epoch + 1, new_lr))
self.cooldown_counter = self.cooldown
self.wait = 0
self.wait += 1
@@ -899,10 +898,11 @@ class CSVLogger(Callback):
including 1D iterables such as np.ndarray.
Example:
- ```python
- csv_logger = CSVLogger('training.log')
- model.fit(X_train, Y_train, callbacks=[csv_logger])
- ```
+
+ ```python
+ csv_logger = CSVLogger('training.log')
+ model.fit(X_train, Y_train, callbacks=[csv_logger])
+ ```
Arguments:
filename: filename of the csv file, e.g. 'run/log.csv'.
@@ -942,12 +942,14 @@ class CSVLogger(Callback):
else:
return k
+ if self.keys is None:
+ self.keys = sorted(logs.keys())
+
if self.model.stop_training:
# We set NA so that csv parsers do not fail for this last epoch.
logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
if not self.writer:
- self.keys = sorted(logs.keys())
class CustomDialect(csv.excel):
delimiter = self.sep
@@ -993,32 +995,32 @@ class LambdaCallback(Callback):
Example:
- ```python
- # Print the batch number at the beginning of every batch.
- batch_print_callback = LambdaCallback(
- on_batch_begin=lambda batch,logs: print(batch))
-
- # Stream the epoch loss to a file in JSON format. The file content
- # is not well-formed JSON but rather has a JSON object per line.
- import json
- json_log = open('loss_log.json', mode='wt', buffering=1)
- json_logging_callback = LambdaCallback(
- on_epoch_end=lambda epoch, logs: json_log.write(
- json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
- on_train_end=lambda logs: json_log.close()
- )
-
- # Terminate some processes after having finished model training.
- processes = ...
- cleanup_callback = LambdaCallback(
- on_train_end=lambda logs: [
- p.terminate() for p in processes if p.is_alive()])
-
- model.fit(...,
- callbacks=[batch_print_callback,
- json_logging_callback,
- cleanup_callback])
- ```
+ ```python
+ # Print the batch number at the beginning of every batch.
+ batch_print_callback = LambdaCallback(
+ on_batch_begin=lambda batch,logs: print(batch))
+
+ # Stream the epoch loss to a file in JSON format. The file content
+ # is not well-formed JSON but rather has a JSON object per line.
+ import json
+ json_log = open('loss_log.json', mode='wt', buffering=1)
+ json_logging_callback = LambdaCallback(
+ on_epoch_end=lambda epoch, logs: json_log.write(
+ json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
+ on_train_end=lambda logs: json_log.close()
+ )
+
+ # Terminate some processes after having finished model training.
+ processes = ...
+ cleanup_callback = LambdaCallback(
+ on_train_end=lambda logs: [
+ p.terminate() for p in processes if p.is_alive()])
+
+ model.fit(...,
+ callbacks=[batch_print_callback,
+ json_logging_callback,
+ cleanup_callback])
+ ```
"""
def __init__(self,
diff --git a/tensorflow/python/keras/_impl/keras/constraints.py b/tensorflow/python/keras/_impl/keras/constraints.py
index e58e3b0377..4b051c93f3 100644
--- a/tensorflow/python/keras/_impl/keras/constraints.py
+++ b/tensorflow/python/keras/_impl/keras/constraints.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Constraints: functions that impose constraints on weights values.
+# pylint: disable=invalid-name
+"""Constraints: functions that impose constraints on weight values.
"""
from __future__ import absolute_import
from __future__ import division
@@ -54,10 +55,6 @@ class MaxNorm(Constraint):
to constrain the weights of each filter tensor of size
`(rows, cols, input_depth)`.
- References:
- - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
- Srivastava, Hinton, et al.
- 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
"""
def __init__(self, max_value=2, axis=0):
@@ -79,7 +76,7 @@ class NonNeg(Constraint):
"""
def __call__(self, w):
- w *= K.cast(w >= 0., K.floatx())
+ w *= K.cast(K.greater_equal(w, 0.), K.floatx())
return w
@@ -132,7 +129,7 @@ class MinMaxNorm(Constraint):
has shape `(input_dim, output_dim)`,
set `axis` to `0` to constrain each weight vector
of length `(input_dim,)`.
- In a `Conv2D` layer with `dim_ordering="channels_last"`,
+ In a `Conv2D` layer with `data_format="channels_last"`,
the weight tensor has shape
`(rows, cols, input_depth, output_depth)`,
set `axis` to `[0, 1, 2]`
@@ -148,8 +145,9 @@ class MinMaxNorm(Constraint):
def __call__(self, w):
norms = K.sqrt(K.sum(K.square(w), axis=self.axis, keepdims=True))
- desired = (self.rate * K.clip(norms, self.min_value, self.max_value) +
- (1 - self.rate) * norms)
+ desired = (
+ self.rate * K.clip(norms, self.min_value, self.max_value) +
+ (1 - self.rate) * norms)
w *= (desired / (K.epsilon() + norms))
return w
@@ -164,13 +162,15 @@ class MinMaxNorm(Constraint):
# Aliases.
-# pylint: disable=invalid-name
max_norm = MaxNorm
non_neg = NonNeg
unit_norm = UnitNorm
min_max_norm = MinMaxNorm
-# pylint: enable=invalid-name
+# Legacy aliases.
+maxnorm = max_norm
+nonneg = non_neg
+unitnorm = unit_norm
def serialize(constraint):
diff --git a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
index 0570e9bc0c..cfd7df61d5 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
@@ -21,29 +21,27 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
-from tensorflow.python.util.tf_export import tf_export
-@tf_export('keras.datasets.boston_housing.load_data')
-def load_data(path='boston_housing.npz', seed=113, test_split=0.2):
+def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
"""Loads the Boston Housing dataset.
Arguments:
path: path where to cache the dataset locally
(relative to ~/.keras/datasets).
+ test_split: fraction of the data to reserve as test set.
seed: Random seed for shuffling the data
before computing the test split.
- test_split: fraction of the data to reserve as test set.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
assert 0 <= test_split < 1
- fh = 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5'
path = get_file(
path,
origin='https://s3.amazonaws.com/keras-datasets/boston_housing.npz',
- file_hash=fh)
+ file_hash=
+ 'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5')
f = np.load(path)
x = f['x']
y = f['y']
diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar.py b/tensorflow/python/keras/_impl/keras/datasets/cifar.py
index 564709c0ee..7ada3340a5 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/cifar.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/cifar.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Utilities used by the CIFAR10 and CIFAR100 datasets.
+"""Utilities common to CIFAR10 and CIFAR100 datasets.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py
index 1971f434b9..fb9d98d42c 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/cifar10.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""CIFAR10 small image classification dataset.
+"""CIFAR10 small images classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
@@ -25,10 +25,8 @@ import numpy as np
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.datasets.cifar import load_batch
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
-from tensorflow.python.util.tf_export import tf_export
-@tf_export('keras.datasets.cifar10.load_data')
def load_data():
"""Loads CIFAR10 dataset.
diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py
index f4039e9350..95aace599a 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/cifar100.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""CIFAR100 small image classification dataset.
+"""CIFAR100 small images classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
@@ -25,10 +25,8 @@ import numpy as np
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.datasets.cifar import load_batch
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
-from tensorflow.python.util.tf_export import tf_export
-@tf_export('keras.datasets.cifar100.load_data')
def load_data(label_mode='fine'):
"""Loads CIFAR100 dataset.
@@ -42,7 +40,7 @@ def load_data(label_mode='fine'):
ValueError: in case of invalid `label_mode`.
"""
if label_mode not in ['fine', 'coarse']:
- raise ValueError('label_mode must be one of "fine" "coarse".')
+ raise ValueError('`label_mode` must be one of `"fine"`, `"coarse"`.')
dirname = 'cifar-100-python'
origin = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
diff --git a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py
index 17be684e4f..b9ae41a0d4 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/fashion_mnist.py
@@ -20,7 +20,9 @@ from __future__ import print_function
import gzip
import os
+
import numpy as np
+
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
@@ -38,9 +40,8 @@ def load_data():
]
paths = []
- for given_file in files:
- paths.append(
- get_file(given_file, origin=base + given_file, cache_subdir=dirname))
+ for fname in files:
+ paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))
with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
diff --git a/tensorflow/python/keras/_impl/keras/datasets/imdb.py b/tensorflow/python/keras/_impl/keras/datasets/imdb.py
index 7946c46960..880c9c821b 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/imdb.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/imdb.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""IMDB movie review sentiment classification dataset.
+"""IMDB sentiment classification dataset.
"""
from __future__ import absolute_import
from __future__ import division
@@ -21,13 +21,12 @@ from __future__ import print_function
import json
import numpy as np
-from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python.keras._impl.keras.preprocessing.sequence import _remove_long_seq
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.platform import tf_logging as logging
-@tf_export('keras.datasets.imdb.load_data')
def load_data(path='imdb.npz',
num_words=None,
skip_top=0,
@@ -35,7 +34,8 @@ def load_data(path='imdb.npz',
seed=113,
start_char=1,
oov_char=2,
- index_from=3):
+ index_from=3,
+ **kwargs):
"""Loads the IMDB dataset.
Arguments:
@@ -52,6 +52,7 @@ def load_data(path='imdb.npz',
oov_char: words that were cut out because of the `num_words`
or `skip_top` limit will be replaced with this character.
index_from: index actual words with this index and higher.
+ **kwargs: Used for backwards compatibility.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -66,14 +67,21 @@ def load_data(path='imdb.npz',
Words that were not seen in the training set but are in the test set
have simply been skipped.
"""
+ # Legacy support
+ if 'nb_words' in kwargs:
+ logging.warning('The `nb_words` argument in `load_data` '
+ 'has been renamed `num_words`.')
+ num_words = kwargs.pop('nb_words')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
path = get_file(
path,
origin='https://s3.amazonaws.com/text-datasets/imdb.npz',
file_hash='599dadb1135973df5b59232a0e9a887c')
- f = np.load(path)
- x_train, labels_train = f['x_train'], f['y_train']
- x_test, labels_test = f['x_test'], f['y_test']
- f.close()
+ with np.load(path) as f:
+ x_train, labels_train = f['x_train'], f['y_train']
+ x_test, labels_test = f['x_test'], f['y_test']
np.random.seed(seed)
indices = np.arange(len(x_train))
@@ -95,14 +103,7 @@ def load_data(path='imdb.npz',
xs = [[w + index_from for w in x] for x in xs]
if maxlen:
- new_xs = []
- new_labels = []
- for x, y in zip(xs, labels):
- if len(x) < maxlen:
- new_xs.append(x)
- new_labels.append(y)
- xs = new_xs
- labels = new_labels
+ xs, labels = _remove_long_seq(maxlen, xs, labels)
if not xs:
raise ValueError('After filtering for sequences shorter than maxlen=' +
str(maxlen) + ', no sequence was kept. '
@@ -114,28 +115,19 @@ def load_data(path='imdb.npz',
# reserve 'index_from' (=3 by default) characters:
# 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
- xs = [[oov_char if (w >= num_words or w < skip_top) else w for w in x]
- for x in xs]
+ xs = [
+ [w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs
+ ]
else:
- new_xs = []
- for x in xs:
- nx = []
- for w in x:
- if skip_top <= w < num_words:
- nx.append(w)
- new_xs.append(nx)
- xs = new_xs
-
- x_train = np.array(xs[:len(x_train)])
- y_train = np.array(labels[:len(x_train)])
+ xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
- x_test = np.array(xs[len(x_train):])
- y_test = np.array(labels[len(x_train):])
+ idx = len(x_train)
+ x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
+ x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
return (x_train, y_train), (x_test, y_test)
-@tf_export('keras.datasets.imdb.get_word_index')
def get_word_index(path='imdb_word_index.json'):
"""Retrieves the dictionary mapping word indices back to words.
@@ -147,7 +139,8 @@ def get_word_index(path='imdb_word_index.json'):
"""
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json')
+ origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json',
+ file_hash='bfafd718b763782e994055a2d397834f')
f = open(path)
data = json.load(f)
f.close()
diff --git a/tensorflow/python/keras/_impl/keras/datasets/mnist.py b/tensorflow/python/keras/_impl/keras/datasets/mnist.py
index e9f5348015..ec12a31dcf 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/mnist.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/mnist.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""MNIST handwritten digits classification dataset.
+"""MNIST handwritten digits dataset.
"""
from __future__ import absolute_import
from __future__ import division
@@ -21,10 +21,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
-from tensorflow.python.util.tf_export import tf_export
-@tf_export('keras.datasets.mnist.load_data')
def load_data(path='mnist.npz'):
"""Loads the MNIST dataset.
@@ -40,9 +38,7 @@ def load_data(path='mnist.npz'):
origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
file_hash='8a61469f7ea1b51cbae51d4f78837e45')
f = np.load(path)
- x_train = f['x_train']
- y_train = f['y_train']
- x_test = f['x_test']
- y_test = f['y_test']
+ x_train, y_train = f['x_train'], f['y_train']
+ x_test, y_test = f['x_test'], f['y_test']
f.close()
return (x_train, y_train), (x_test, y_test)
diff --git a/tensorflow/python/keras/_impl/keras/datasets/reuters.py b/tensorflow/python/keras/_impl/keras/datasets/reuters.py
index 6da5aa4b5e..95cf8852a9 100644
--- a/tensorflow/python/keras/_impl/keras/datasets/reuters.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/reuters.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Reuters newswire topic classification dataset.
+"""Reuters topic classification dataset.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -22,13 +21,12 @@ from __future__ import print_function
import json
import numpy as np
-from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python.keras._impl.keras.preprocessing.sequence import _remove_long_seq
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.platform import tf_logging as logging
-@tf_export('keras.datasets.reuters.load_data')
def load_data(path='reuters.npz',
num_words=None,
skip_top=0,
@@ -37,7 +35,8 @@ def load_data(path='reuters.npz',
seed=113,
start_char=1,
oov_char=2,
- index_from=3):
+ index_from=3,
+ **kwargs):
"""Loads the Reuters newswire classification dataset.
Arguments:
@@ -55,6 +54,7 @@ def load_data(path='reuters.npz',
oov_char: words that were cut out because of the `num_words`
or `skip_top` limit will be replaced with this character.
index_from: index actual words with this index and higher.
+ **kwargs: Used for backwards compatibility.
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
@@ -65,14 +65,20 @@ def load_data(path='reuters.npz',
Words that were not seen in the training set but are in the test set
have simply been skipped.
"""
+ # Legacy support
+ if 'nb_words' in kwargs:
+ logging.warning('The `nb_words` argument in `load_data` '
+ 'has been renamed `num_words`.')
+ num_words = kwargs.pop('nb_words')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
path = get_file(
path,
origin='https://s3.amazonaws.com/text-datasets/reuters.npz',
file_hash='87aedbeb0cb229e378797a632c1997b6')
- npzfile = np.load(path)
- xs = npzfile['x']
- labels = npzfile['y']
- npzfile.close()
+ with np.load(path) as f:
+ xs, labels = f['x'], f['y']
np.random.seed(seed)
indices = np.arange(len(xs))
@@ -80,22 +86,13 @@ def load_data(path='reuters.npz',
xs = xs[indices]
labels = labels[indices]
- np.random.shuffle(labels)
-
if start_char is not None:
xs = [[start_char] + [w + index_from for w in x] for x in xs]
elif index_from:
xs = [[w + index_from for w in x] for x in xs]
if maxlen:
- new_xs = []
- new_labels = []
- for x, y in zip(xs, labels):
- if len(x) < maxlen:
- new_xs.append(x)
- new_labels.append(y)
- xs = new_xs
- labels = new_labels
+ xs, labels = _remove_long_seq(maxlen, xs, labels)
if not num_words:
num_words = max([max(x) for x in xs])
@@ -104,28 +101,17 @@ def load_data(path='reuters.npz',
# reserve 'index_from' (=3 by default) characters:
# 0 (padding), 1 (start), 2 (OOV)
if oov_char is not None:
- xs = [[oov_char if (w >= num_words or w < skip_top) else w for w in x]
- for x in xs]
+ xs = [[w if skip_top <= w < num_words else oov_char for w in x] for x in xs]
else:
- new_xs = []
- for x in xs:
- nx = []
- for w in x:
- if skip_top <= w < num_words:
- nx.append(w)
- new_xs.append(nx)
- xs = new_xs
-
- x_train = np.array(xs[:int(len(xs) * (1 - test_split))])
- y_train = np.array(labels[:int(len(xs) * (1 - test_split))])
+ xs = [[w for w in x if skip_top <= w < num_words] for x in xs]
- x_test = np.array(xs[int(len(xs) * (1 - test_split)):])
- y_test = np.array(labels[int(len(xs) * (1 - test_split)):])
+ idx = int(len(xs) * (1 - test_split))
+ x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
+ x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])
return (x_train, y_train), (x_test, y_test)
-@tf_export('keras.datasets.reuters.get_word_index')
def get_word_index(path='reuters_word_index.json'):
"""Retrieves the dictionary mapping word indices back to words.
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py
index d6e0be8e43..64aa868f38 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology.py
@@ -27,6 +27,7 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
@@ -712,8 +713,8 @@ class Network(tf_network.GraphNetwork, Layer):
for layer in self._output_layers:
self.output_names.append(layer.name)
- self.internal_input_shapes = [K.int_shape(x) for x in self.inputs]
- self.internal_output_shapes = [K.int_shape(x) for x in self.outputs]
+ self._internal_input_shapes = [K.int_shape(x) for x in self.inputs]
+ self._internal_output_shapes = [K.int_shape(x) for x in self.outputs]
@property
def uses_learning_phase(self):
@@ -1303,18 +1304,17 @@ def preprocess_weights_for_loading(layer,
Returns:
A list of weights values (Numpy arrays).
"""
- if original_keras_version == '1':
- if layer.__class__.__name__ == 'Bidirectional':
- num_weights_per_layer = len(weights) // 2
-
- forward_weights = preprocess_weights_for_loading(
- layer.forward_layer, weights[:num_weights_per_layer],
- original_keras_version, original_backend)
- backward_weights = preprocess_weights_for_loading(
- layer.backward_layer, weights[num_weights_per_layer:],
- original_keras_version, original_backend)
- weights = forward_weights + backward_weights
+ if layer.__class__.__name__ == 'Bidirectional':
+ num_weights_per_layer = len(weights) // 2
+ forward_weights = preprocess_weights_for_loading(
+ layer.forward_layer, weights[:num_weights_per_layer],
+ original_keras_version, original_backend)
+ backward_weights = preprocess_weights_for_loading(
+ layer.backward_layer, weights[num_weights_per_layer:],
+ original_keras_version, original_backend)
+ weights = forward_weights + backward_weights
+ if original_keras_version == '1':
if layer.__class__.__name__ == 'TimeDistributed':
weights = preprocess_weights_for_loading(
layer.layer, weights, original_keras_version, original_backend)
@@ -1418,7 +1418,7 @@ def preprocess_weights_for_loading(layer,
conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
if layer.__class__.__name__ in conv_layers:
- if original_backend and K.backend() != original_backend:
+ if original_backend == 'theano':
weights[0] = conv_utils.convert_kernel(weights[0])
if layer.__class__.__name__ == 'ConvLSTM2D':
weights[1] = conv_utils.convert_kernel(weights[1])
@@ -1427,10 +1427,9 @@ def preprocess_weights_for_loading(layer,
if layer.__class__.__name__ == 'ConvLSTM2D':
weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
- # convert the weights of CuDNNLSTM so that they could be loaded into LSTM
+ # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM
if layer.__class__.__name__ == 'LSTM' and len(weights) == 3:
- # determine if we're loading a CuDNNLSTM layer from the number of bias
- # weights:
+ # Determine if loading a CuDNNLSTM layer from the number of bias weights:
# CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
# if there's no bias weight in the file, skip this conversion
units = weights[1].shape[0]
@@ -1572,3 +1571,31 @@ def load_weights_from_hdf5_group_by_name(f, layers):
for i in range(len(weight_values)):
weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
K.batch_set_value(weight_value_tuples)
+
+
+def shape_type_conversion(fn):
+ """Decorator that handles tuple/TensorShape conversion.
+
+ Used in `compute_output_shape` and `build`.
+
+ Arguments:
+ fn: function to wrap.
+
+ Returns:
+ Wrapped function.
+ """
+
+ def wrapper(instance, input_shape):
+ if input_shape is not None:
+ if isinstance(input_shape, list):
+ input_shape = [
+ tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape]
+ else:
+ input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+ output_shape = fn(instance, input_shape)
+ if output_shape is not None:
+ if isinstance(output_shape, list):
+ return [tensor_shape.TensorShape(x) for x in output_shape]
+ return tensor_shape.TensorShape(output_shape)
+
+ return wrapper
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index debea2503e..699ae2edf0 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras training and evaluation routines.
+"""Training-related part of the Keras engine.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -35,6 +34,11 @@ from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
from tensorflow.python.platform import tf_logging as logging
+try:
+ from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
+except ImportError:
+ issparse = None
+
def _standardize_input_data(data,
names,
@@ -70,89 +74,72 @@ def _standardize_input_data(data,
return []
if data is None:
return [None for _ in range(len(names))]
+
if isinstance(data, dict):
- for key, value in data.items():
- if value.__class__.__name__ == 'DataFrame':
- data[key] = value.values
- arrays = []
- for name in names:
- if name not in data:
- raise ValueError('No data provided for "' + name +
- '". Need data for each key in: ' + str(names))
- arrays.append(data[name])
+ try:
+ data = [
+ data[x].values
+ if data[x].__class__.__name__ == 'DataFrame' else data[x]
+ for x in names
+ ]
+ data = [np.expand_dims(x, 1) if x.ndim == 1 else x for x in data]
+ except KeyError as e:
+ raise ValueError('No data provided for "' + e.args[0] + '". Need data '
+ 'for each key in: ' + str(names))
elif isinstance(data, list):
- for key, value in enumerate(data):
- if value.__class__.__name__ == 'DataFrame':
- data[key] = value.values
- if len(data) != len(names):
- if data and hasattr(data[0], 'shape'):
- raise ValueError(
- 'Error when checking model ' + exception_prefix +
- ': the list of Numpy arrays '
- 'that you are passing to your model '
- 'is not the size the model expected. '
- 'Expected to see ' + str(len(names)) + ' array(s), but instead got '
- 'the following list of ' + str(len(data)) + ' arrays: ' +
- str(data)[:200] + '...')
- else:
- if len(names) == 1:
- data = [np.asarray(data)]
- else:
- raise ValueError('Error when checking model ' + exception_prefix +
- ': you are passing a list as '
- 'input to your model, '
- 'but the model expects '
- 'a list of ' + str(len(names)) +
- ' Numpy arrays instead. '
- 'The list you passed was: ' + str(data)[:200])
- arrays = data
- elif data.__class__.__name__ == 'DataFrame':
- # test if data is a DataFrame, without pandas installed
- arrays = data.values
+ data = [
+ x.values if x.__class__.__name__ == 'DataFrame' else x for x in data
+ ]
+ data = [
+ np.expand_dims(x, 1) if x is not None and x.ndim == 1 else x
+ for x in data
+ ]
else:
- if not hasattr(data, 'shape'):
+ data = data.values if data.__class__.__name__ == 'DataFrame' else data
+ data = [np.expand_dims(data, 1)] if data.ndim == 1 else [data]
+
+ if len(data) != len(names):
+ if data and hasattr(data[0], 'shape'):
+ raise ValueError('Error when checking model ' + exception_prefix +
+ ': the list of Numpy arrays that you are passing to '
+ 'your model is not the size the model expected. '
+ 'Expected to see ' + str(len(names)) + ' array(s), '
+ 'but instead got the following list of ' +
+ str(len(data)) + ' arrays: ' + str(data)[:200] + '...')
+ elif len(names) > 1:
+ raise ValueError(
+ 'Error when checking model ' + exception_prefix +
+ ': you are passing a list as input to your model, '
+ 'but the model expects a list of ' + str(len(names)) +
+ ' Numpy arrays instead. The list you passed was: ' + str(data)[:200])
+ elif len(data) == 1 and not hasattr(data[0], 'shape'):
raise TypeError('Error when checking model ' + exception_prefix +
- ': data should be a Numpy array, '
- 'or list/dict of Numpy arrays. '
- 'Found: ' + str(data)[:200] + '...')
- if len(names) > 1:
- # Case: model expects multiple inputs but only received
- # a single Numpy array.
- raise ValueError('The model expects ' + str(len(names)) + ' ' +
- exception_prefix +
- ' arrays, but only received one array. '
- 'Found: array with shape ' + str(data.shape))
- arrays = [data]
-
- # Make arrays at least 2D.
- for i in range(len(names)):
- array = arrays[i]
- if len(array.shape) == 1:
- array = np.expand_dims(array, 1)
- arrays[i] = array
+ ': data should be a Numpy array, or list/dict of '
+ 'Numpy arrays. Found: ' + str(data)[:200] + '...')
+ elif len(names) == 1:
+ data = [np.asarray(data)]
# Check shapes compatibility.
if shapes:
for i in range(len(names)):
- if shapes[i] is None:
- continue
- array = arrays[i]
- if len(array.shape) != len(shapes[i]):
- raise ValueError(
- 'Error when checking ' + exception_prefix + ': expected ' + names[i]
- + ' to have ' + str(len(shapes[i])) +
- ' dimensions, but got array with shape ' + str(array.shape))
- for j, (dim, ref_dim) in enumerate(zip(array.shape, shapes[i])):
- if not j and not check_batch_axis:
- # skip the first axis
- continue
- if ref_dim:
- if ref_dim != dim:
- raise ValueError('Error when checking ' + exception_prefix +
- ': expected ' + names[i] + ' to have shape ' +
- str(shapes[i]) + ' but got array with shape ' +
- str(array.shape))
- return arrays
+ if shapes[i] is not None:
+ data_shape = data[i].shape
+ shape = shapes[i]
+ if data[i].ndim != len(shape):
+ raise ValueError('Error when checking ' + exception_prefix +
+ ': expected ' + names[i] + ' to have ' +
+ str(len(shape)) + ' dimensions, but got array '
+ 'with shape ' + str(data_shape))
+ if not check_batch_axis:
+ data_shape = data_shape[1:]
+ shape = shape[1:]
+ for dim, ref_dim in zip(data_shape, shape):
+ if ref_dim != dim and ref_dim:
+ raise ValueError(
+ 'Error when checking ' + exception_prefix + ': expected ' +
+ names[i] + ' to have shape ' + str(shape) +
+ ' but got array with shape ' + str(data_shape))
+ return data
def _standardize_sample_or_class_weights(x_weight, output_names, weight_type):
@@ -193,10 +180,10 @@ def _standardize_sample_or_class_weights(x_weight, output_names, weight_type):
x_weights.append(x_weight.get(name))
return x_weights
else:
- raise TypeError('The model has multiple outputs, so `' + weight_type + '` '
- 'should be either a list or a dict. '
- 'Provided `' + weight_type + '` type not understood: ' +
- str(x_weight))
+ raise TypeError(
+ 'The model has multiple outputs, so `' + weight_type + '` '
+ 'should be either a list or a dict. '
+ 'Provided `' + weight_type + '` type not understood: ' + str(x_weight))
def _standardize_class_weights(class_weight, output_names):
@@ -234,12 +221,12 @@ def _check_array_lengths(inputs, targets, weights=None):
set_w = set_of_lengths(weights)
if len(set_x) > 1:
raise ValueError('All input arrays (x) should have '
- 'the same number of samples. Got array shapes: ' + str(
- [x.shape for x in inputs]))
+ 'the same number of samples. Got array shapes: ' +
+ str([x.shape for x in inputs]))
if len(set_y) > 1:
raise ValueError('All target arrays (y) should have '
- 'the same number of samples. Got array shapes: ' + str(
- [y.shape for y in targets]))
+ 'the same number of samples. Got array shapes: ' +
+ str([y.shape for y in targets]))
if set_x and set_y and list(set_x)[0] != list(set_y)[0]:
raise ValueError('Input arrays should have '
'the same number of samples as target arrays. '
@@ -247,8 +234,8 @@ def _check_array_lengths(inputs, targets, weights=None):
'and ' + str(list(set_y)[0]) + ' target samples.')
if len(set_w) > 1:
raise ValueError('All sample_weight arrays should have '
- 'the same number of samples. Got array shapes: ' + str(
- [w.shape for w in weights]))
+ 'the same number of samples. Got array shapes: ' +
+ str([w.shape for w in weights]))
if set_y and set_w and list(set_y)[0] != list(set_w)[0]:
raise ValueError('Sample_weight arrays should have '
'the same number of samples as target arrays. Got ' +
@@ -528,16 +515,16 @@ def _standardize_weights(y,
if sample_weight is not None:
if len(sample_weight.shape) > len(y.shape):
- raise ValueError('Found a sample_weight with shape' +
- str(sample_weight.shape) + '.'
- 'Expected sample_weight with rank '
- 'less than or equal to ' + str(len(y.shape)))
+ raise ValueError(
+ 'Found a sample_weight with shape' + str(sample_weight.shape) + '.'
+ 'Expected sample_weight with rank '
+ 'less than or equal to ' + str(len(y.shape)))
if y.shape[:sample_weight.ndim] != sample_weight.shape:
- raise ValueError('Found a sample_weight array with shape ' +
- str(sample_weight.shape) + ' for an input with shape ' +
- str(y.shape) + '. '
- 'sample_weight cannot be broadcast.')
+ raise ValueError(
+ 'Found a sample_weight array with shape ' + str(sample_weight.shape) +
+ ' for an input with shape ' + str(y.shape) + '. '
+ 'sample_weight cannot be broadcast.')
return sample_weight
elif isinstance(class_weight, dict):
if len(y.shape) > 2:
@@ -632,7 +619,6 @@ class Model(Network):
"""
loss = loss or {}
self.optimizer = optimizers.get(optimizer)
- self.sample_weight_mode = sample_weight_mode
self.loss = loss
self.loss_weights = loss_weights
self.sample_weight_mode = sample_weight_mode
@@ -641,10 +627,10 @@ class Model(Network):
if isinstance(loss, dict):
for name in loss:
if name not in self.output_names:
- raise ValueError('Unknown entry in loss '
- 'dictionary: "' + name + '". '
- 'Only expected the following keys: ' +
- str(self.output_names))
+ raise ValueError(
+ 'Unknown entry in loss '
+ 'dictionary: "' + name + '". '
+ 'Only expected the following keys: ' + str(self.output_names))
loss_functions = []
for name in self.output_names:
if name not in loss:
@@ -657,7 +643,7 @@ class Model(Network):
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
raise ValueError('When passing a list as loss, '
- 'it should have one entry per model output. '
+ 'it should have one entry per model outputs. '
'The model has ' + str(len(self.outputs)) +
' outputs, but you passed loss=' + str(loss))
loss_functions = [losses.get(l) for l in loss]
@@ -690,20 +676,20 @@ class Model(Network):
elif isinstance(loss_weights, dict):
for name in loss_weights:
if name not in self.output_names:
- raise ValueError('Unknown entry in loss_weights '
- 'dictionary: "' + name + '". '
- 'Only expected the following keys: ' +
- str(self.output_names))
+ raise ValueError(
+ 'Unknown entry in loss_weights '
+ 'dictionary: "' + name + '". '
+ 'Only expected the following keys: ' + str(self.output_names))
loss_weights_list = []
for name in self.output_names:
loss_weights_list.append(loss_weights.get(name, 1.))
elif isinstance(loss_weights, list):
if len(loss_weights) != len(self.outputs):
- raise ValueError('When passing a list as loss_weights, '
- 'it should have one entry per model output. '
- 'The model has ' + str(len(self.outputs)) +
- ' outputs, but you passed loss_weights=' +
- str(loss_weights))
+ raise ValueError(
+ 'When passing a list as loss_weights, '
+ 'it should have one entry per model output. '
+ 'The model has ' + str(len(self.outputs)) +
+ ' outputs, but you passed loss_weights=' + str(loss_weights))
loss_weights_list = loss_weights
else:
raise TypeError('Could not interpret loss_weights argument: ' +
@@ -715,22 +701,22 @@ class Model(Network):
if target_tensors is not None:
if isinstance(target_tensors, list):
if len(target_tensors) != len(self.outputs):
- raise ValueError('When passing a list as `target_tensors`, '
- 'it should have one entry per model output. '
- 'The model has ' + str(len(self.outputs)) +
- ' outputs, but you passed target_tensors=' +
- str(target_tensors))
+ raise ValueError(
+ 'When passing a list as `target_tensors`, '
+ 'it should have one entry per model output. '
+ 'The model has ' + str(len(self.outputs)) +
+ ' outputs, but you passed target_tensors=' + str(target_tensors))
elif isinstance(target_tensors, dict):
for name in target_tensors:
if name not in self.output_names:
- raise ValueError('Unknown entry in `target_tensors` '
- 'dictionary: "' + name + '". '
- 'Only expected the following keys: ' +
- str(self.output_names))
- target_tensors_ = []
+ raise ValueError(
+ 'Unknown entry in `target_tensors` '
+ 'dictionary: "' + name + '". '
+ 'Only expected the following keys: ' + str(self.output_names))
+ tmp_target_tensors = []
for name in self.output_names:
- target_tensors_.append(target_tensors.get(name, None))
- target_tensors = target_tensors_
+ tmp_target_tensors.append(target_tensors.get(name, None))
+ target_tensors = tmp_target_tensors
else:
raise TypeError('Expected `target_tensors` to be '
'a list or dict, but got:', target_tensors)
@@ -738,7 +724,7 @@ class Model(Network):
if i in skip_target_indices:
self.targets.append(None)
else:
- shape = self.internal_output_shapes[i]
+ shape = self._internal_output_shapes[i]
name = self.output_names[i]
if target_tensors is not None:
target = target_tensors[i]
@@ -766,19 +752,19 @@ class Model(Network):
if isinstance(sample_weight_mode, dict):
for name in sample_weight_mode:
if name not in self.output_names:
- raise ValueError('Unknown entry in '
- 'sample_weight_mode dictionary: "' + name + '". '
- 'Only expected the following keys: ' +
- str(self.output_names))
+ raise ValueError(
+ 'Unknown entry in '
+ 'sample_weight_mode dictionary: "' + name + '". '
+ 'Only expected the following keys: ' + str(self.output_names))
for i, name in enumerate(self.output_names):
if i in skip_target_weighing_indices:
weight = None
sample_weight_modes.append(None)
else:
if name not in sample_weight_mode:
- raise ValueError('Output "' + name +
- '" missing from sample_weight_modes '
- 'dictionary')
+ raise ValueError(
+ 'Output "' + name + '" missing from sample_weight_modes '
+ 'dictionary')
if sample_weight_mode.get(name) == 'temporal':
weight = K.placeholder(ndim=2, name=name + '_sample_weights')
sample_weight_modes.append('temporal')
@@ -894,23 +880,36 @@ class Model(Network):
metric_name_prefix = 'weighted_' if weights is not None else ''
for metric in metrics:
- if metric == 'accuracy' or metric == 'acc':
- # custom handling of accuracy
+ if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
+ # custom handling of accuracy/crossentropy
# (because of class mode duality)
- output_shape = self.internal_output_shapes[i]
+ output_shape = self._internal_output_shapes[i]
if (output_shape[-1] == 1 or
self.loss_functions[i] == losses.binary_crossentropy):
- # case: binary accuracy
- acc_fn = metrics_module.binary_accuracy
+ # case: binary accuracy/crossentropy
+ if metric in ('accuracy', 'acc'):
+ acc_fn = metrics_module.binary_accuracy
+ elif metric in ('crossentropy', 'ce'):
+ acc_fn = metrics_module.binary_crossentropy
elif self.loss_functions[
i] == losses.sparse_categorical_crossentropy:
- # case: categorical accuracy with sparse targets
- acc_fn = metrics_module.sparse_categorical_accuracy
+ # case: categorical accuracy/crossentropy with sparse targets
+ if metric in ('accuracy', 'acc'):
+ acc_fn = metrics_module.sparse_categorical_accuracy
+ elif metric in ('crossentropy', 'ce'):
+ acc_fn = metrics_module.sparse_categorical_crossentropy
else:
- acc_fn = metrics_module.categorical_accuracy
-
+ # case: categorical accuracy/crossentropy
+ if metric in ('accuracy', 'acc'):
+ acc_fn = metrics_module.categorical_accuracy
+ elif metric in ('crossentropy', 'ce'):
+ acc_fn = metrics_module.categorical_crossentropy
+ if metric in ('accuracy', 'acc'):
+ suffix = 'acc'
+ elif metric in ('crossentropy', 'ce'):
+ suffix = 'ce'
weighted_metric_fn = _weighted_masked_objective(acc_fn)
- metric_name = metric_name_prefix + 'acc'
+ metric_name = metric_name_prefix + suffix
else:
metric_fn = metrics_module.get(metric)
weighted_metric_fn = _weighted_masked_objective(metric_fn)
@@ -949,7 +948,7 @@ class Model(Network):
"""Check trainable weights count consistency.
This will raise a warning if `trainable_weights` and
- `_collected_trainable_weights` are consistent (i.e. have the same
+ `_collected_trainable_weights` are inconsistent (i.e. have different
number of parameters).
Inconsistency will typically arise when one modifies `model.trainable`
without calling `model.compile` again.
@@ -959,9 +958,10 @@ class Model(Network):
if len(self.trainable_weights) != len(self._collected_trainable_weights):
logging.warning(
- 'Discrepancy between trainable weights and collected trainable'
- ' weights, did you set `model.trainable` without calling'
- ' `model.compile` after ?')
+ UserWarning(
+ 'Discrepancy between trainable weights and collected trainable'
+ ' weights, did you set `model.trainable` without calling'
+ ' `model.compile` after ?'))
def _make_train_function(self):
if not hasattr(self, 'train_function'):
@@ -1050,18 +1050,21 @@ class Model(Network):
processed based on the size of the first dimension of the
first input numpy array. When steps is not `None` and
`batch_size` is `None`, returns `None`.
+
+ Raises:
+ ValueError: In case of invalid arguments.
"""
if steps is not None:
num_samples = None
if batch_size is not None:
- raise ValueError('If ' + steps_name +
- ' is set, the `batch_size` must be None.')
+ raise ValueError(
+ 'If ' + steps_name + ' is set, the `batch_size` must be None.')
elif ins and hasattr(ins[0], 'shape'):
num_samples = ins[0].shape[0]
else:
- raise ValueError('Either the input data should have '
- 'a defined shape, or ' + steps_name +
- ' should be specified.')
+ raise ValueError(
+ 'Either the input data should have '
+ 'a defined shape, or ' + steps_name + ' should be specified.')
return num_samples
def _fit_loop(self,
@@ -1104,31 +1107,33 @@ class Model(Network):
steps_per_epoch: Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. Ignored with the default value of `None`.
- validation_steps: Number of steps to run validation for (only if doing
- validation from data tensors). Ignored with default value of `None`.
+ validation_steps: Number of steps to run validation for
+ (only if doing validation from data tensors).
+ Ignored with the default value of `None`.
Returns:
`History` object.
Raises:
- ValueError: In case of invalid argument values.
+ ValueError: in case of invalid arguments.
"""
do_validation = False
if val_f and val_ins:
do_validation = True
- if (verbose and ins and
- hasattr(ins[0], 'shape') and hasattr(val_ins[0], 'shape')):
+ if verbose and ins and hasattr(ins[0], 'shape') and hasattr(
+ val_ins[0], 'shape'):
print('Train on %d samples, validate on %d samples' %
(ins[0].shape[0], val_ins[0].shape[0]))
if validation_steps:
- if steps_per_epoch is None:
- raise ValueError('Can only use `validation_steps` when doing step-wise '
- 'training, i.e. `steps_per_epoch` must be set.')
do_validation = True
+ if steps_per_epoch is None:
+ raise ValueError('Can only use `validation_steps` '
+ 'when doing step-wise '
+ 'training, i.e. `steps_per_epoch` '
+ 'must be set.')
num_train_samples = self._check_num_samples(
ins, batch_size, steps_per_epoch, 'steps_per_epoch')
-
if num_train_samples is not None:
index_array = np.arange(num_train_samples)
@@ -1165,6 +1170,13 @@ class Model(Network):
for cbk in callbacks:
cbk.validation_data = val_ins
+ # To prevent a slowdown, we find beforehand the arrays that need conversion.
+ feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
+ indices_for_conversion_to_dense = []
+ for i in range(len(feed)):
+ if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
+ indices_for_conversion_to_dense.append(i)
+
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
@@ -1220,6 +1232,9 @@ class Model(Network):
batch_logs['batch'] = batch_index
batch_logs['size'] = len(batch_ids)
callbacks.on_batch_begin(batch_index, batch_logs)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
+
outs = f(ins_batch)
if not isinstance(outs, list):
outs = [outs]
@@ -1268,6 +1283,13 @@ class Model(Network):
progbar = Progbar(target=steps)
else:
progbar = Progbar(target=num_samples)
+
+ indices_for_conversion_to_dense = []
+ for i in range(len(self._feed_inputs)):
+ if (issparse is not None and issparse(ins[i]) and
+ not K.is_sparse(self._feed_inputs[i])):
+ indices_for_conversion_to_dense.append(i)
+
if steps is not None:
# Step-based predictions.
# Since we do not know how many samples
@@ -1305,6 +1327,9 @@ class Model(Network):
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
+
batch_outs = f(ins_batch)
if not isinstance(batch_outs, list):
batch_outs = [batch_outs]
@@ -1341,12 +1366,19 @@ class Model(Network):
"""
num_samples = self._check_num_samples(ins, batch_size, steps, 'steps')
outs = []
-
if verbose == 1:
if steps is not None:
progbar = Progbar(target=steps)
else:
progbar = Progbar(target=num_samples)
+
+ # To prevent a slowdown, we find beforehand the arrays that need conversion.
+ feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
+ indices_for_conversion_to_dense = []
+ for i in range(len(feed)):
+ if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
+ indices_for_conversion_to_dense.append(i)
+
if steps is not None:
for step in range(steps):
batch_outs = f(ins)
@@ -1365,8 +1397,6 @@ class Model(Network):
for i in range(len(outs)):
outs[i] /= steps
else:
- if verbose == 1:
- progbar = Progbar(target=num_samples)
batches = _make_batches(num_samples, batch_size)
index_array = np.arange(num_samples)
for batch_index, (batch_start, batch_end) in enumerate(batches):
@@ -1376,6 +1406,8 @@ class Model(Network):
ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
else:
ins_batch = _slice_arrays(ins, batch_ids)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
batch_outs = f(ins_batch)
if isinstance(batch_outs, list):
@@ -1484,7 +1516,8 @@ class Model(Network):
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
- validation_steps=None):
+ validation_steps=None,
+ **kwargs):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
Arguments:
@@ -1501,10 +1534,9 @@ class Model(Network):
dictionary mapping output names to Numpy arrays.
`y` can be `None` (default) if feeding from
TensorFlow data tensors.
- Can be `None` (default) if feeding from framework-native tensors.
batch_size: Integer or `None`.
Number of samples per gradient update.
- If unspecified, it will default to 32.
+ If unspecified, `batch_size` will default to 32.
epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y`
data provided.
@@ -1513,7 +1545,7 @@ class Model(Network):
The model is not trained for a number of iterations
given by `epochs`, but merely until the epoch
of index `epochs` is reached.
- verbose: 0, 1, or 2. Verbosity mode.
+ verbose: Integer. 0, 1, or 2. Verbosity mode.
0 = silent, 1 = progress bar, 2 = one line per epoch.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during training.
@@ -1530,7 +1562,7 @@ class Model(Network):
`(x_val, y_val, val_sample_weights)` on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data.
- This will override `validation_split`.
+ `validation_data` will override `validation_split`.
shuffle: Boolean (whether to shuffle the training data
before each epoch) or str (for 'batch').
'batch' is a special option for dealing with the
@@ -1553,17 +1585,20 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`.
- initial_epoch: Epoch at which to start training
+ initial_epoch: Integer.
+ Epoch at which to start training
(useful for resuming a previous training run).
- steps_per_epoch: Total number of steps (batches of samples)
+ steps_per_epoch: Integer or `None`.
+ Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. When training with input tensors such as
TensorFlow data tensors, the default `None` is equal to
- the number of unique samples in your dataset divided by
+ the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined.
validation_steps: Only relevant if `steps_per_epoch`
is specified. Total number of steps (batches of samples)
to validate before stopping.
+ **kwargs: Used for backwards compatibility.
Returns:
A `History` object. Its `History.history` attribute is
@@ -1572,12 +1607,21 @@ class Model(Network):
and validation metrics values (if applicable).
Raises:
+ RuntimeError: If the model was never compiled.
ValueError: In case of mismatch between the provided input data
and what the model expects.
"""
# Backwards compatibility
if batch_size is None and steps_per_epoch is None:
batch_size = 32
+ # Legacy support
+ if 'nb_epoch' in kwargs:
+ logging.warning(
+ 'The `nb_epoch` argument in `fit` '
+ 'has been renamed `epochs`.')
+ epochs = kwargs.pop('nb_epoch')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
if x is None and y is None and steps_per_epoch is None:
raise ValueError('If fitting from data tensors, '
'you should specify the `steps_per_epoch` '
@@ -1590,10 +1634,8 @@ class Model(Network):
class_weight=class_weight,
check_batch_axis=False,
batch_size=batch_size)
-
# Prepare validation data.
do_validation = False
- val_ins = []
if validation_data:
do_validation = True
if len(validation_data) == 2:
@@ -1657,8 +1699,9 @@ class Model(Network):
'val_' + n for n in out_labels
]
else:
- val_f = None
callback_metrics = copy.copy(out_labels)
+ val_f = None
+ val_ins = []
# Delegate logic to `_fit_loop`.
return self._fit_loop(
@@ -1694,14 +1737,14 @@ class Model(Network):
If input layers in the model are named, you can also pass a
dictionary mapping input names to Numpy arrays.
`x` can be `None` (default) if feeding from
- framework-native tensors (e.g. TensorFlow data tensors).
+ TensorFlow data tensors.
y: Numpy array of target (label) data
(if the model has a single output),
or list of Numpy arrays (if the model has multiple outputs).
If output layers in the model are named, you can also pass a
dictionary mapping output names to Numpy arrays.
`y` can be `None` (default) if feeding from
- framework-native tensors (e.g. TensorFlow data tensors).
+ TensorFlow data tensors.
batch_size: Integer or `None`.
Number of samples per evaluation step.
If unspecified, `batch_size` will default to 32.
@@ -1721,8 +1764,7 @@ class Model(Network):
steps: Integer or `None`.
Total number of steps (batches of samples)
before declaring the evaluation round finished.
- The default `None` is equal to the number of unique samples in
- your dataset divided by the batch size.
+ Ignored with the default value of `None`.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1731,7 +1773,7 @@ class Model(Network):
the display labels for the scalar outputs.
Raises:
- ValueError: In case of invalid arguments.
+ ValueError: in case of invalid arguments.
"""
# Backwards compatibility.
if batch_size is None and steps is None:
@@ -1890,6 +1932,9 @@ class Model(Network):
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
+
+ Raises:
+ ValueError: in case of invalid arguments.
"""
x, y, sample_weights = self._standardize_user_data(
x, y, sample_weight=sample_weight, check_batch_axis=True)
@@ -1937,8 +1982,7 @@ class Model(Network):
workers=1,
use_multiprocessing=False,
shuffle=True,
- initial_epoch=0,
- **kwargs):
+ initial_epoch=0):
"""Fits the model on data yielded batch-by-batch by a Python generator.
The generator is run in parallel to the model, for efficiency.
@@ -1950,22 +1994,31 @@ class Model(Network):
using `use_multiprocessing=True`.
Arguments:
- generator: A generator or an instance of Sequence (keras.utils.Sequence)
- object in order to avoid duplicate data when using multiprocessing.
+ generator: A generator or an instance of `Sequence`
+ (`keras.utils.Sequence`)
+ object in order to avoid duplicate data
+ when using multiprocessing.
The output of the generator must be either
- - a tuple (inputs, targets)
- - a tuple (inputs, targets, sample_weights).
- All arrays should contain the same number of samples.
+ - a tuple `(inputs, targets)`
+ - a tuple `(inputs, targets, sample_weights)`.
+ This tuple (a single output of the generator) makes a single batch.
+ Therefore, all arrays in this tuple must have the same length (equal
+ to the size of this batch). Different batches may have different
+ sizes.
+ For example, the last batch of the epoch is commonly smaller than
+ the
+ others, if the size of the dataset is not divisible by the batch
+ size.
The generator is expected to loop over its data
indefinitely. An epoch finishes when `steps_per_epoch`
batches have been seen by the model.
steps_per_epoch: Total number of steps (batches of samples)
to yield from `generator` before declaring one epoch
finished and starting the next epoch. It should typically
- be equal to the number of unique samples of your dataset
+ be equal to the number of samples of your dataset
divided by the batch size.
Optional for `Sequence`: if unspecified, will use
- `len(generator)` as a number of steps.
+ the `len(generator)` as a number of steps.
epochs: Integer, total number of iterations on the data.
verbose: Verbosity mode, 0, 1, or 2.
callbacks: List of callbacks to be called during training.
@@ -1977,27 +2030,28 @@ class Model(Network):
is a generator. Total number of steps (batches of samples)
to yield from `generator` before stopping.
Optional for `Sequence`: if unspecified, will use
- `len(generator)` as a number of steps.
+ the `len(validation_data)` as a number of steps.
class_weight: Dictionary mapping class indices to a weight
for the class.
- max_queue_size: Maximum size for the generator queue.
+ max_queue_size: Integer. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
workers: Integer. Maximum number of processes to spin up
when using process based threading.
If unspecified, `workers` will default to 1. If 0, will
execute the generator on the main thread.
- use_multiprocessing: If True, use process based threading.
+ use_multiprocessing: Boolean. If True, use process based threading.
+ If unspecified, `workers` will default to False.
Note that because
this implementation relies on multiprocessing,
you should not pass
non picklable arguments to the generator
as they can't be passed
easily to children processes.
- shuffle: Whether to shuffle the data at the beginning of each
- epoch. Only used with instances of `Sequence`
- (`keras.utils.Sequence`).
+ shuffle: Whether to shuffle the order of the batches at
+ the beginning of each epoch. Only used with instances
+ of `Sequence` (keras.utils.Sequence).
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
- **kwargs: support for legacy arguments.
Returns:
A `History` object.
@@ -2023,19 +2077,6 @@ class Model(Network):
ValueError: In case the generator yields
data in an invalid format.
"""
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
- if kwargs:
- raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
-
wait_time = 0.01 # in seconds
epoch = initial_epoch
@@ -2046,10 +2087,11 @@ class Model(Network):
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
- logging.warning('Using a generator with `use_multiprocessing=True`'
+ logging.warning(
+ UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
- ' class.')
+ ' class.'))
if steps_per_epoch is None:
if is_sequence:
steps_per_epoch = len(generator)
@@ -2098,26 +2140,47 @@ class Model(Network):
})
callbacks.on_train_begin()
- if do_validation and not val_gen:
- if len(validation_data) == 2:
- val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
- val_sample_weight = None
- elif len(validation_data) == 3:
- val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
- else:
- raise ValueError('`validation_data` should be a tuple '
- '`(val_x, val_y, val_sample_weight)` '
- 'or `(val_x, val_y)`. Found: ' + str(validation_data))
- val_x, val_y, val_sample_weights = self._standardize_user_data(
- val_x, val_y, val_sample_weight)
- val_data = val_x + val_y + val_sample_weights
- if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
- val_data += [0.]
- for cbk in callbacks:
- cbk.validation_data = val_data
enqueuer = None
+ val_enqueuer = None
try:
+ if do_validation:
+ if val_gen:
+ if workers > 0:
+ if isinstance(validation_data, Sequence):
+ val_enqueuer = OrderedEnqueuer(
+ validation_data, use_multiprocessing=use_multiprocessing)
+ if validation_steps is None:
+ validation_steps = len(validation_data)
+ else:
+ val_enqueuer = GeneratorEnqueuer(
+ validation_data,
+ use_multiprocessing=use_multiprocessing,
+ wait_time=wait_time)
+ val_enqueuer.start(workers=workers, max_queue_size=max_queue_size)
+ validation_generator = val_enqueuer.get()
+ else:
+ validation_generator = validation_data
+ else:
+ if len(validation_data) == 2:
+ val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
+ val_sample_weight = None
+ elif len(validation_data) == 3:
+ val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
+ else:
+ raise ValueError(
+ '`validation_data` should be a tuple '
+ '`(val_x, val_y, val_sample_weight)` '
+ 'or `(val_x, val_y)`. Found: ' + str(validation_data))
+ val_x, val_y, val_sample_weights = self._standardize_user_data(
+ val_x, val_y, val_sample_weight)
+ val_data = val_x + val_y + val_sample_weights
+ if self.uses_learning_phase and not isinstance(
+ K.learning_phase(), int):
+ val_data += [0.]
+ for cbk in callbacks:
+ cbk.validation_data = val_data
+
if workers > 0:
if is_sequence:
enqueuer = OrderedEnqueuer(
@@ -2135,6 +2198,8 @@ class Model(Network):
output_generator = generator
callback_model.stop_training = False
+ # Construct epoch logs.
+ epoch_logs = {}
while epoch < epochs:
callbacks.on_epoch_begin(epoch)
steps_done = 0
@@ -2178,8 +2243,6 @@ class Model(Network):
callbacks.on_batch_end(batch_index, batch_logs)
- # Construct epoch logs.
- epoch_logs = {}
batch_index += 1
steps_done += 1
@@ -2187,11 +2250,7 @@ class Model(Network):
if steps_done >= steps_per_epoch and do_validation:
if val_gen:
val_outs = self.evaluate_generator(
- validation_data,
- validation_steps,
- max_queue_size=max_queue_size,
- workers=workers,
- use_multiprocessing=use_multiprocessing)
+ validation_generator, validation_steps, workers=0)
else:
# No need for try/except because
# data has already been validated.
@@ -2216,8 +2275,12 @@ class Model(Network):
break
finally:
- if enqueuer is not None:
- enqueuer.stop()
+ try:
+ if enqueuer is not None:
+ enqueuer.stop()
+ finally:
+ if val_enqueuer is not None:
+ val_enqueuer.stop()
callbacks.on_train_end()
return self.history
@@ -2227,8 +2290,7 @@ class Model(Network):
steps=None,
max_queue_size=10,
workers=1,
- use_multiprocessing=False,
- **kwargs):
+ use_multiprocessing=False):
"""Evaluates the model on a data generator.
The generator should return the same kind of data
@@ -2256,7 +2318,6 @@ class Model(Network):
non picklable arguments to the generator
as they can't be passed
easily to children processes.
- **kwargs: support for legacy arguments.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -2265,22 +2326,12 @@ class Model(Network):
the display labels for the scalar outputs.
Raises:
+ ValueError: in case of invalid arguments.
+
+ Raises:
ValueError: In case the generator yields
data in an invalid format.
"""
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
- if kwargs:
- raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
-
self._make_test_function()
steps_done = 0
@@ -2289,10 +2340,11 @@ class Model(Network):
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
- logging.warning('Using a generator with `use_multiprocessing=True`'
+ logging.warning(
+ UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
- ' class.')
+ ' class.'))
if steps is None:
if is_sequence:
steps = len(generator)
@@ -2368,8 +2420,7 @@ class Model(Network):
max_queue_size=10,
workers=1,
use_multiprocessing=False,
- verbose=0,
- **kwargs):
+ verbose=0):
"""Generates predictions for the input samples from a data generator.
The generator should return the same kind of data as accepted by
@@ -2377,9 +2428,9 @@ class Model(Network):
Arguments:
generator: Generator yielding batches of input samples
- or an instance of Sequence (keras.utils.Sequence)
- object in order to avoid duplicate data
- when using multiprocessing.
+ or an instance of Sequence (keras.utils.Sequence)
+ object in order to avoid duplicate data
+ when using multiprocessing.
steps: Total number of steps (batches of samples)
to yield from `generator` before stopping.
Optional for `Sequence`: if unspecified, will use
@@ -2397,7 +2448,6 @@ class Model(Network):
as they can't be passed
easily to children processes.
verbose: verbosity mode, 0 or 1.
- **kwargs: support for legacy arguments.
Returns:
Numpy array(s) of predictions.
@@ -2406,17 +2456,6 @@ class Model(Network):
ValueError: In case the generator yields
data in an invalid format.
"""
- # Legacy support
- if 'max_q_size' in kwargs:
- max_queue_size = kwargs.pop('max_q_size')
- logging.warning('The argument `max_q_size` has been renamed '
- '`max_queue_size`. Update your method calls accordingly.')
- if 'pickle_safe' in kwargs:
- use_multiprocessing = kwargs.pop('pickle_safe')
- logging.warning('The argument `pickle_safe` has been renamed '
- '`use_multiprocessing`. '
- 'Update your method calls accordingly.')
-
self._make_predict_function()
steps_done = 0
@@ -2424,10 +2463,11 @@ class Model(Network):
all_outs = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
- logging.warn('Using a generator with `use_multiprocessing=True`'
- ' and multiple workers may duplicate your data.'
- ' Please consider using the`keras.utils.Sequence'
- ' class.')
+ logging.warning(
+ UserWarning('Using a generator with `use_multiprocessing=True`'
+ ' and multiple workers may duplicate your data.'
+ ' Please consider using the`keras.utils.Sequence'
+ ' class.'))
if steps is None:
if is_sequence:
steps = len(generator)
@@ -2498,6 +2538,6 @@ class Model(Network):
else:
return np.concatenate(all_outs[0])
if steps_done == 1:
- return [out for out in all_outs]
+ return [out[0] for out in all_outs]
else:
return [np.concatenate(out) for out in all_outs]
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index 7650bfb6e8..5a033a04ad 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -28,6 +28,11 @@ from tensorflow.python.keras._impl.keras import testing_utils
from tensorflow.python.keras._impl.keras.engine.training import _weighted_masked_objective
from tensorflow.python.platform import test
+try:
+ import scipy.sparse as scipy_sparse # pylint: disable=g-import-not-at-top
+except ImportError:
+ scipy_sparse = None
+
class TrainingTest(test.TestCase):
@@ -169,7 +174,7 @@ class TrainingTest(test.TestCase):
with self.assertRaises(ValueError):
model.train_on_batch({'input_a': input_a_np},
[output_d_np, output_e_np])
- with self.assertRaises(TypeError):
+ with self.assertRaises(AttributeError):
model.fit(
[input_a_np, input_b_np], [output_d_np, output_e_np],
epochs=1,
@@ -177,7 +182,7 @@ class TrainingTest(test.TestCase):
verbose=0)
with self.assertRaises(ValueError):
model.train_on_batch([input_a_np], [output_d_np, output_e_np])
- with self.assertRaises(TypeError):
+ with self.assertRaises(AttributeError):
model.train_on_batch(1, [output_d_np, output_e_np])
with self.assertRaises(ValueError):
model.train_on_batch(input_a_np, [output_d_np, output_e_np])
@@ -312,6 +317,63 @@ class TrainingTest(test.TestCase):
model.compile(loss=None,
optimizer='rmsprop')
+ def test_training_on_sparse_data_with_dense_placeholders(self):
+ if scipy_sparse is None:
+ return
+
+ test_inputs = [
+ scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
+ test_outputs = [
+ scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
+ in1 = keras.layers.Input(shape=(3,))
+ in2 = keras.layers.Input(shape=(3,))
+ out1 = keras.layers.Dropout(0.5, name='dropout')(in1)
+ out2 = keras.layers.Dense(4, name='dense_1')(in2)
+ model = keras.Model([in1, in2], [out1, out2])
+ model.predict(test_inputs, batch_size=2)
+ model.compile('rmsprop', 'mse')
+ model.fit(test_inputs, test_outputs,
+ epochs=1, batch_size=2, validation_split=0.5)
+ model.evaluate(test_inputs, test_outputs, batch_size=2)
+
+ def test_that_trainable_disables_updates(self):
+ val_a = np.random.random((10, 4))
+ val_out = np.random.random((10, 4))
+
+ with self.test_session():
+ a = keras.layers.Input(shape=(4,))
+ layer = keras.layers.BatchNormalization(input_shape=(4,))
+ b = layer(a)
+ model = keras.Model(a, b)
+
+ model.trainable = False
+ assert not model.updates
+
+ model.compile('sgd', 'mse')
+ assert not model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
+ model.trainable = True
+ model.compile('sgd', 'mse')
+ assert model.updates
+
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ assert np.abs(np.sum(x1 - x2)) > 1e-5
+
+ layer.trainable = False
+ model.compile('sgd', 'mse')
+ assert not model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
class LossWeightingTest(test.TestCase):
@@ -869,25 +931,6 @@ class TestGeneratorMethods(test.TestCase):
use_multiprocessing=False,
workers=0)
- # Test legacy API
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
- max_q_size=10,
- workers=4,
- pickle_safe=True)
- model.predict_generator(custom_generator(),
- steps=5,
- max_q_size=10,
- workers=2,
- pickle_safe=True)
- model.evaluate_generator(custom_generator(),
- steps=5,
- max_q_size=10,
- workers=2,
- pickle_safe=True)
-
def test_generator_methods_with_sample_weights(self):
arr_data = np.random.random((50, 2))
arr_labels = np.random.random((50,))
@@ -960,7 +1003,7 @@ class TestGeneratorMethods(test.TestCase):
use_multiprocessing=False,
validation_data=custom_generator(),
validation_steps=10)
- with self.assertRaises(TypeError):
+ with self.assertRaises(AttributeError):
model.predict_generator(custom_generator(),
steps=5,
max_queue_size=10,
diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
index e4b9afd38a..ffbf77c4b8 100644
--- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
@@ -14,18 +14,18 @@
# ==============================================================================
"""Layers that act as activation functions.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import activations
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
class LeakyReLU(Layer):
@@ -61,6 +61,7 @@ class LeakyReLU(Layer):
base_config = super(LeakyReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
@@ -114,9 +115,9 @@ class PReLU(Layer):
else:
self.shared_axes = list(shared_axes)
+ @shape_type_conversion
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- param_shape = input_shape[1:]
+ param_shape = list(input_shape[1:])
self.param_broadcast = [False] * len(param_shape)
if self.shared_axes is not None:
for i in self.shared_axes:
@@ -140,15 +141,13 @@ class PReLU(Layer):
def call(self, inputs, mask=None):
pos = K.relu(inputs)
if K.backend() == 'theano':
- neg = (K.pattern_broadcast(self.alpha, self.param_broadcast) *
- (inputs - K.abs(inputs)) * 0.5)
+ neg = (
+ K.pattern_broadcast(self.alpha, self.param_broadcast) *
+ (inputs - K.abs(inputs)) * 0.5)
else:
neg = -self.alpha * K.relu(-inputs)
return pos + neg
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {
'alpha_initializer': initializers.serialize(self.alpha_initializer),
@@ -159,6 +158,10 @@ class PReLU(Layer):
base_config = super(PReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
class ELU(Layer):
"""Exponential Linear Unit.
@@ -188,14 +191,15 @@ class ELU(Layer):
def call(self, inputs):
return K.elu(inputs, self.alpha)
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {'alpha': float(self.alpha)}
base_config = super(ELU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
class ThresholdedReLU(Layer):
"""Thresholded Rectified Linear Unit.
@@ -223,12 +227,46 @@ class ThresholdedReLU(Layer):
self.theta = K.cast_to_floatx(theta)
def call(self, inputs, mask=None):
- return inputs * K.cast(inputs > self.theta, K.floatx())
+ return inputs * K.cast(K.greater(inputs, self.theta), K.floatx())
+
+ def get_config(self):
+ config = {'theta': float(self.theta)}
+ base_config = super(ThresholdedReLU, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
+
+class Softmax(Layer):
+ """Softmax activation function.
+
+ Input shape:
+ Arbitrary. Use the keyword argument `input_shape`
+ (tuple of integers, does not include the samples axis)
+ when using this layer as the first layer in a model.
+
+ Output shape:
+ Same shape as the input.
+
+ Arguments:
+ axis: Integer, axis along which the softmax normalization is applied.
+ """
+
+ def __init__(self, axis=-1, **kwargs):
+ super(Softmax, self).__init__(**kwargs)
+ self.supports_masking = True
+ self.axis = axis
+
+ def call(self, inputs):
+ return activations.softmax(inputs, axis=self.axis)
+
def get_config(self):
- config = {'theta': float(self.theta)}
- base_config = super(ThresholdedReLU, self).get_config()
+ config = {'axis': self.axis}
+ base_config = super(Softmax, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
index 91efab30ed..343b7949ac 100644
--- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
@@ -56,6 +56,12 @@ class AdvancedActivationsTest(test.TestCase):
kwargs={'theta': 0.5},
input_shape=(2, 3, 4))
+ def test_softmax(self):
+ with self.test_session():
+ testing_utils.layer_test(keras.layers.Softmax,
+ kwargs={'axis': 1},
+ input_shape=(2, 3, 4))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
index f0f5e1fb46..2ee0732775 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
@@ -711,6 +711,144 @@ class Conv3DTranspose(tf_convolutional_layers.Conv3DTranspose, Layer):
return dict(list(base_config.items()) + list(config.items()))
+class SeparableConv1D(tf_convolutional_layers.SeparableConv1D, Layer):
+ """Depthwise separable 1D convolution.
+
+ This layer performs a depthwise convolution that acts separately on
+ channels, followed by a pointwise convolution that mixes channels.
+ If `use_bias` is True and a bias initializer is provided,
+ it adds a bias vector to the output.
+ It then optionally applies an activation function to produce the final output.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: A single integer specifying the spatial
+ dimensions of the filters.
+ strides: A single integer specifying the strides
+ of the convolution.
+ Specifying any `stride` value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, length, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, length)`.
+ dilation_rate: A single integer, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any stride value != 1.
+ depth_multiplier: The number of depthwise convolution output channels for
+ each input channel. The total number of depthwise convolution output
+ channels will be equal to `num_filters_in * depth_multiplier`.
+ activation: Activation function. Set it to None to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ depthwise_initializer: An initializer for the depthwise convolution kernel.
+ pointwise_initializer: An initializer for the pointwise convolution kernel.
+ bias_initializer: An initializer for the bias vector. If None, the default
+ initializer will be used.
+ depthwise_regularizer: Optional regularizer for the depthwise
+ convolution kernel.
+ pointwise_regularizer: Optional regularizer for the pointwise
+ convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Optional regularizer function for the output.
+ depthwise_constraint: Optional projection function to be applied to the
+ depthwise kernel after being updated by an `Optimizer` (e.g. used for
+ norm constraints or value constraints for layer weights). The function
+ must take as input the unprojected variable and must return the
+ projected variable (which must have the same shape). Constraints are
+ not safe to use when doing asynchronous distributed training.
+ pointwise_constraint: Optional projection function to be applied to the
+ pointwise kernel after being updated by an `Optimizer`.
+ bias_constraint: Optional projection function to be applied to the
+ bias after being updated by an `Optimizer`.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ """
+
+ def __init__(self,
+ filters,
+ kernel_size,
+ strides=1,
+ padding='valid',
+ data_format=None,
+ dilation_rate=1,
+ depth_multiplier=1,
+ activation=None,
+ use_bias=True,
+ depthwise_initializer='glorot_uniform',
+ pointwise_initializer='glorot_uniform',
+ bias_initializer='zeros',
+ depthwise_regularizer=None,
+ pointwise_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ depthwise_constraint=None,
+ pointwise_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ if data_format is None:
+ data_format = K.image_data_format()
+ super(SeparableConv1D, self).__init__(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activations.get(activation),
+ use_bias=use_bias,
+ depthwise_initializer=initializers.get(depthwise_initializer),
+ pointwise_initializer=initializers.get(pointwise_initializer),
+ bias_initializer=initializers.get(bias_initializer),
+ depthwise_regularizer=regularizers.get(depthwise_regularizer),
+ pointwise_regularizer=regularizers.get(pointwise_regularizer),
+ bias_regularizer=regularizers.get(bias_regularizer),
+ activity_regularizer=regularizers.get(activity_regularizer),
+ depthwise_constraint=constraints.get(depthwise_constraint),
+ pointwise_constraint=constraints.get(pointwise_constraint),
+ bias_constraint=constraints.get(bias_constraint),
+ **kwargs)
+
+ def get_config(self):
+ config = {
+ 'filters': self.filters,
+ 'kernel_size': self.kernel_size,
+ 'strides': self.strides,
+ 'padding': self.padding,
+ 'data_format': self.data_format,
+ 'dilation_rate': self.dilation_rate,
+ 'activation': activations.serialize(self.activation),
+ 'use_bias': self.use_bias,
+ 'depthwise_initializer':
+ initializers.serialize(self.depthwise_initializer),
+ 'pointwise_initializer':
+ initializers.serialize(self.pointwise_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'depthwise_regularizer':
+ regularizers.serialize(self.depthwise_regularizer),
+ 'pointwise_regularizer':
+ regularizers.serialize(self.pointwise_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'activity_regularizer':
+ regularizers.serialize(self.activity_regularizer),
+ 'depthwise_constraint':
+ constraints.serialize(self.depthwise_constraint),
+ 'pointwise_constraint':
+ constraints.serialize(self.pointwise_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint)
+ }
+ base_config = super(SeparableConv1D, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer):
"""Depthwise separable 2D convolution.
@@ -1663,6 +1801,7 @@ class Cropping3D(Layer):
Convolution1D = Conv1D
Convolution2D = Conv2D
Convolution3D = Conv3D
+SeparableConvolution1D = SeparableConv1D
SeparableConvolution2D = SeparableConv2D
Convolution2DTranspose = Conv2DTranspose
Convolution3DTranspose = Conv3DTranspose
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
index 4f0e9fc691..565db19e41 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import activations
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent
from tensorflow.python.keras._impl.keras.utils import conv_utils
@@ -127,10 +127,10 @@ class ConvRecurrent2D(Recurrent):
self.input_spec = [InputSpec(ndim=5)]
self.state_spec = None
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
rows = input_shape[3]
cols = input_shape[4]
@@ -151,30 +151,28 @@ class ConvRecurrent2D(Recurrent):
dilation=self.dilation_rate[1])
if self.return_sequences:
if self.data_format == 'channels_first':
- output_shape = [input_shape[0], input_shape[1],
- self.filters, rows, cols]
+ output_shape = (input_shape[0], input_shape[1], self.filters, rows,
+ cols)
elif self.data_format == 'channels_last':
- output_shape = [input_shape[0], input_shape[1],
- rows, cols, self.filters]
+ output_shape = (input_shape[0], input_shape[1], rows, cols,
+ self.filters)
else:
if self.data_format == 'channels_first':
- output_shape = [input_shape[0], self.filters, rows, cols]
+ output_shape = (input_shape[0], self.filters, rows, cols)
elif self.data_format == 'channels_last':
- output_shape = [input_shape[0], rows, cols, self.filters]
+ output_shape = (input_shape[0], rows, cols, self.filters)
if self.return_state:
if self.data_format == 'channels_first':
- output_shapes = [output_shape] + [(input_shape[0],
- self.filters,
- rows,
- cols) for _ in range(2)]
+ output_shape = [output_shape] + [
+ (input_shape[0], self.filters, rows, cols) for _ in range(2)
+ ]
elif self.data_format == 'channels_last':
- output_shapes = [output_shape] + [(input_shape[0],
- rows,
- cols,
- self.filters) for _ in range(2)]
- return [tensor_shape.TensorShape(shape) for shape in output_shapes]
- return tensor_shape.TensorShape(output_shape)
+ output_shape = [output_shape] + [
+ (input_shape[0], rows, cols, self.filters) for _ in range(2)
+ ]
+
+ return output_shape
def get_config(self):
config = {
@@ -294,11 +292,6 @@ class ConvLSTM2D(ConvRecurrent2D):
Raises:
ValueError: in case of invalid constructor arguments.
- References:
- - [Convolutional LSTM Network: A Machine Learning Approach for
- Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1)
- The current implementation does not include the feedback loop on the
- cells output
"""
def __init__(self,
@@ -338,7 +331,6 @@ class ConvLSTM2D(ConvRecurrent2D):
return_sequences=return_sequences,
go_backwards=go_backwards,
stateful=stateful,
- activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
@@ -352,6 +344,7 @@ class ConvLSTM2D(ConvRecurrent2D):
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
@@ -361,13 +354,12 @@ class ConvLSTM2D(ConvRecurrent2D):
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
+ @shape_type_conversion
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
batch_size = input_shape[0] if self.stateful else None
self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:])
-
if self.stateful:
self.reset_states()
else:
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
index be7da6f2b4..39c9d4f0fb 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
@@ -311,6 +311,72 @@ class Conv3DTransposeTest(test.TestCase):
self.assertEqual(layer.bias.constraint, b_constraint)
+class SeparableConv1DTest(test.TestCase):
+
+ def test_separable_conv_1d(self):
+ num_samples = 2
+ filters = 6
+ stack_size = 3
+ length = 7
+ strides = 1
+
+ for padding in ['valid', 'same']:
+ for multiplier in [1, 2]:
+ if padding == 'same' and strides != 1:
+ continue
+
+ with self.test_session(use_gpu=True):
+ testing_utils.layer_test(
+ keras.layers.SeparableConv1D,
+ kwargs={
+ 'filters': filters,
+ 'kernel_size': 3,
+ 'padding': padding,
+ 'strides': strides,
+ 'depth_multiplier': multiplier
+ },
+ input_shape=(num_samples, length, stack_size))
+
+ def test_separable_conv1d_regularizers(self):
+ kwargs = {
+ 'filters': 3,
+ 'kernel_size': 3,
+ 'padding': 'valid',
+ 'depthwise_regularizer': 'l2',
+ 'pointwise_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'activity_regularizer': 'l2',
+ 'strides': 1
+ }
+ with self.test_session(use_gpu=True):
+ layer = keras.layers.SeparableConv1D(**kwargs)
+ layer.build((None, 5, 2))
+ self.assertEqual(len(layer.losses), 3)
+ layer(keras.backend.variable(np.ones((1, 5, 2))))
+ self.assertEqual(len(layer.losses), 4)
+
+ def test_separable_conv1d_constraints(self):
+ d_constraint = lambda x: x
+ p_constraint = lambda x: x
+ b_constraint = lambda x: x
+
+ kwargs = {
+ 'filters': 3,
+ 'kernel_size': 3,
+ 'padding': 'valid',
+ 'pointwise_constraint': p_constraint,
+ 'depthwise_constraint': d_constraint,
+ 'bias_constraint': b_constraint,
+ 'strides': 1
+ }
+ with self.test_session(use_gpu=True):
+ layer = keras.layers.SeparableConv1D(**kwargs)
+ layer.build((None, 5, 2))
+ self.assertEqual(layer.depthwise_kernel.constraint, d_constraint)
+ self.assertEqual(layer.pointwise_kernel.constraint, p_constraint)
+ self.assertEqual(layer.bias.constraint, b_constraint)
+
+
class SeparableConv2DTest(test.TestCase):
def test_separable_conv_2d(self):
diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
index 51c520be38..f8e31068f8 100644
--- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
class Embedding(Layer):
@@ -58,13 +58,13 @@ class Embedding(Layer):
output_dim: int >= 0. Dimension of the dense embedding.
embeddings_initializer: Initializer for the `embeddings` matrix.
embeddings_regularizer: Regularizer function applied to
- the `embeddings` matrix.
+ the `embeddings` matrix.
embeddings_constraint: Constraint function applied to
- the `embeddings` matrix.
+ the `embeddings` matrix.
mask_zero: Whether or not the input value 0 is a special "padding"
value that should be masked out.
- This is useful when using recurrent layers,
- which may take variable length inputs.
+ This is useful when using recurrent layers
+ which may take variable length input.
If this is `True` then all subsequent layers
in the model need to support masking or an exception will be raised.
If mask_zero is set to True, as a consequence, index 0 cannot be
@@ -81,9 +81,6 @@ class Embedding(Layer):
Output shape:
3D tensor with shape: `(batch_size, sequence_length, output_dim)`.
- References:
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural
- Networks](http://arxiv.org/abs/1512.05287)
"""
def __init__(self,
@@ -101,19 +98,19 @@ class Embedding(Layer):
kwargs['input_shape'] = (input_length,)
else:
kwargs['input_shape'] = (None,)
- super(Embedding, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(Embedding, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.embeddings_initializer = initializers.get(embeddings_initializer)
self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
self.embeddings_constraint = constraints.get(embeddings_constraint)
self.mask_zero = mask_zero
self.input_length = input_length
+ @shape_type_conversion
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer,
@@ -129,10 +126,10 @@ class Embedding(Layer):
else:
return K.not_equal(inputs, 0)
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.input_length is None:
- return tensor_shape.TensorShape(input_shape + [self.output_dim])
+ return input_shape + (self.output_dim,)
else:
# input_length can be tuple if input is 3D or higher
if isinstance(self.input_length, (list, tuple)):
@@ -149,8 +146,7 @@ class Embedding(Layer):
(str(self.input_length), str(input_shape)))
elif s1 is None:
in_lens[i] = s2
- return tensor_shape.TensorShape(
- (input_shape[0],) + tuple(in_lens) + (self.output_dim,))
+ return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
def call(self, inputs):
if K.dtype(inputs) != 'int32':
diff --git a/tensorflow/python/keras/_impl/keras/layers/local.py b/tensorflow/python/keras/_impl/keras/layers/local.py
index 0a31b87fb5..b844b071e0 100644
--- a/tensorflow/python/keras/_impl/keras/layers/local.py
+++ b/tensorflow/python/keras/_impl/keras/layers/local.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import activations
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras import constraints
@@ -26,6 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils import conv_utils
@@ -98,8 +98,7 @@ class LocallyConnected1D(Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
- super(LocallyConnected1D, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(LocallyConnected1D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 1, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
@@ -114,12 +113,13 @@ class LocallyConnected1D(Layer):
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=3)
+ @shape_type_conversion
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
input_dim = input_shape[2]
if input_dim is None:
raise ValueError('Axis 2 of input should be fully-defined. '
@@ -146,15 +146,14 @@ class LocallyConnected1D(Layer):
self.input_spec = InputSpec(ndim=3, axes={2: input_dim})
self.built = True
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
length = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0],
self.padding, self.strides[0])
- return tensor_shape.TensorShape([input_shape[0], length, self.filters])
+ return (input_shape[0], length, self.filters)
def call(self, inputs):
output = K.local_conv1d(inputs, self.kernel, self.kernel_size, self.strides)
-
if self.use_bias:
output = K.bias_add(output, self.bias)
if self.activation is not None:
@@ -163,20 +162,32 @@ class LocallyConnected1D(Layer):
def get_config(self):
config = {
- 'filters': self.filters,
- 'kernel_size': self.kernel_size,
- 'strides': self.strides,
- 'padding': self.padding,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'filters':
+ self.filters,
+ 'kernel_size':
+ self.kernel_size,
+ 'strides':
+ self.strides,
+ 'padding':
+ self.padding,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint)
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint)
}
base_config = super(LocallyConnected1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -273,8 +284,7 @@ class LocallyConnected2D(Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
- super(LocallyConnected2D, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(LocallyConnected2D, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
@@ -289,12 +299,13 @@ class LocallyConnected2D(Layer):
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=4)
+ @shape_type_conversion
def build(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_last':
input_row, input_col = input_shape[1:-1]
input_filter = input_shape[3]
@@ -306,7 +317,6 @@ class LocallyConnected2D(Layer):
' a LocallyConnected2D layer '
'should be fully-defined, but layer received '
'the inputs shape ' + str(input_shape))
-
output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0],
self.padding, self.strides[0])
output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1],
@@ -337,33 +347,30 @@ class LocallyConnected2D(Layer):
self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
self.built = True
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
rows = input_shape[2]
cols = input_shape[3]
elif self.data_format == 'channels_last':
rows = input_shape[1]
cols = input_shape[2]
+
rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
self.padding, self.strides[0])
cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
self.padding, self.strides[1])
if self.data_format == 'channels_first':
- return tensor_shape.TensorShape(
- [input_shape[0], self.filters, rows, cols])
+ return (input_shape[0], self.filters, rows, cols)
elif self.data_format == 'channels_last':
- return tensor_shape.TensorShape(
- [input_shape[0], rows, cols, self.filters])
+ return (input_shape[0], rows, cols, self.filters)
def call(self, inputs):
- output = K.local_conv2d(inputs,
- self.kernel,
- self.kernel_size,
- self.strides,
+ output = K.local_conv2d(inputs, self.kernel, self.kernel_size, self.strides,
(self.output_row, self.output_col),
self.data_format)
+
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
@@ -372,21 +379,34 @@ class LocallyConnected2D(Layer):
def get_config(self):
config = {
- 'filters': self.filters,
- 'kernel_size': self.kernel_size,
- 'strides': self.strides,
- 'padding': self.padding,
- 'data_format': self.data_format,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'filters':
+ self.filters,
+ 'kernel_size':
+ self.kernel_size,
+ 'strides':
+ self.strides,
+ 'padding':
+ self.padding,
+ 'data_format':
+ self.data_format,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint)
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint)
}
base_config = super(LocallyConnected2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py
index 76eb03cf27..38b0b30297 100644
--- a/tensorflow/python/keras/_impl/keras/layers/merge.py
+++ b/tensorflow/python/keras/_impl/keras/layers/merge.py
@@ -14,15 +14,15 @@
# ==============================================================================
# pylint: disable=not-callable
# pylint: disable=redefined-builtin
-"""Layers can merge several input tensors into a single output tensor.
+"""Layers that can merge several inputs into one.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine.topology import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
class _Merge(Layer):
@@ -73,12 +73,13 @@ class _Merge(Layer):
output_shape.append(i)
else:
if i != j:
- raise ValueError('Operands could not be broadcast '
- 'together with shapes ' + str(shape1) + ' ' +
- str(shape2))
+ raise ValueError(
+ 'Operands could not be broadcast '
+ 'together with shapes ' + str(shape1) + ' ' + str(shape2))
output_shape.append(i)
return tuple(output_shape)
+ @shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list):
@@ -87,14 +88,13 @@ class _Merge(Layer):
raise ValueError('A merge layer should be called '
'on a list of at least 2 inputs. '
'Got ' + str(len(input_shape)) + ' inputs.')
- input_shape = [tensor_shape.TensorShape(s).as_list() for s in input_shape]
batch_sizes = [s[0] for s in input_shape if s is not None]
batch_sizes = set(batch_sizes)
batch_sizes -= set([None])
if len(batch_sizes) > 1:
- raise ValueError('Can not merge tensors with different '
- 'batch sizes. Got tensors with shapes : ' +
- str(input_shape))
+ raise ValueError(
+ 'Can not merge tensors with different '
+ 'batch sizes. Got tensors with shapes : ' + str(input_shape))
if input_shape[0] is None:
output_shape = None
else:
@@ -111,9 +111,10 @@ class _Merge(Layer):
self._reshape_required = False
else:
self._reshape_required = True
- self.built = True
def call(self, inputs):
+ if not isinstance(inputs, list):
+ raise ValueError('A merge layer should be called ' 'on a list of inputs.')
if self._reshape_required:
reshaped_inputs = []
input_ndims = list(map(K.ndim, inputs))
@@ -172,6 +173,7 @@ class _Merge(Layer):
else:
return self._merge_function(inputs)
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if input_shape[0] is None:
output_shape = None
@@ -214,6 +216,22 @@ class Add(_Merge):
It takes as input a list of tensors,
all of the same shape, and returns
a single tensor (also of the same shape).
+
+ Examples:
+
+ ```python
+ import keras
+
+ input1 = keras.layers.Input(shape=(16,))
+ x1 = keras.layers.Dense(8, activation='relu')(input1)
+ input2 = keras.layers.Input(shape=(32,))
+ x2 = keras.layers.Dense(8, activation='relu')(input2)
+ added = keras.layers.Add()([x1, x2]) # equivalent to added =
+ keras.layers.add([x1, x2])
+
+ out = keras.layers.Dense(4)(added)
+ model = keras.models.Model(inputs=[input1, input2], outputs=out)
+ ```
"""
def _merge_function(self, inputs):
@@ -247,10 +265,17 @@ class Subtract(_Merge):
```
"""
+ @shape_type_conversion
+ def build(self, input_shape):
+ super(Subtract, self).build(input_shape)
+ if len(input_shape) != 2:
+ raise ValueError('A `Subtract` layer should be called '
+ 'on exactly 2 inputs')
+
def _merge_function(self, inputs):
if len(inputs) != 2:
- raise ValueError('`Subtract` layer should be called '
- 'on exactly 2 inputs. Received: %s' % inputs)
+ raise ValueError('A `Subtract` layer should be called '
+ 'on exactly 2 inputs')
return inputs[0] - inputs[1]
@@ -330,47 +355,43 @@ class Concatenate(_Merge):
super(Concatenate, self).__init__(**kwargs)
self.axis = axis
self.supports_masking = True
+ self._reshape_required = False
+ @shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
- if not (isinstance(input_shape, list) and len(input_shape) > 1):
- raise ValueError('`Concatenate` layer should be called '
- 'on a list containing at least two inputs')
+ if not isinstance(input_shape, list) or len(input_shape) < 2:
+ raise ValueError('A `Concatenate` layer should be called '
+ 'on a list of at least 2 inputs')
if all([shape is None for shape in input_shape]):
return
- reduced_inputs_shapes = [
- tensor_shape.TensorShape(shape).as_list() for shape in input_shape
- ]
+ reduced_inputs_shapes = [list(shape) for shape in input_shape]
shape_set = set()
for i in range(len(reduced_inputs_shapes)):
del reduced_inputs_shapes[i][self.axis]
shape_set.add(tuple(reduced_inputs_shapes[i]))
if len(shape_set) > 1:
- raise ValueError('`Concatenate` layer requires '
+ raise ValueError('A `Concatenate` layer requires '
'inputs with matching shapes '
'except for the concat axis. '
'Got inputs shapes: %s' % (input_shape))
- self.built = True
- def call(self, inputs):
- if not isinstance(inputs, list):
- raise ValueError('A `Concatenate` layer should be called '
- 'on a list of inputs.')
+ def _merge_function(self, inputs):
return K.concatenate(inputs, axis=self.axis)
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if not isinstance(input_shape, list):
raise ValueError('A `Concatenate` layer should be called '
'on a list of inputs.')
input_shapes = input_shape
- output_shape = tensor_shape.TensorShape(input_shapes[0]).as_list()
+ output_shape = list(input_shapes[0])
for shape in input_shapes[1:]:
- shape = tensor_shape.TensorShape(shape).as_list()
if output_shape[self.axis] is None or shape[self.axis] is None:
output_shape[self.axis] = None
break
output_shape[self.axis] += shape[self.axis]
- return tensor_shape.TensorShape(output_shape)
+ return tuple(output_shape)
def compute_mask(self, inputs, mask=None):
if mask is None:
@@ -390,7 +411,7 @@ class Concatenate(_Merge):
masks = []
for input_i, mask_i in zip(inputs, mask):
if mask_i is None:
- # Input is unmasked. Append all 1s to masks
+ # Input is unmasked. Append all 1s to masks,
masks.append(K.ones_like(input_i, dtype='bool'))
elif K.ndim(mask_i) < K.ndim(input_i):
# Mask is smaller than the input, expand it
@@ -441,14 +462,16 @@ class Dot(_Merge):
self.axes = axes
self.normalize = normalize
self.supports_masking = True
+ self._reshape_required = False
+ @shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list) or len(input_shape) != 2:
raise ValueError('A `Dot` layer should be called '
'on a list of 2 inputs.')
- shape1 = tensor_shape.TensorShape(input_shape[0]).as_list()
- shape2 = tensor_shape.TensorShape(input_shape[1]).as_list()
+ shape1 = input_shape[0]
+ shape2 = input_shape[1]
if shape1 is None or shape2 is None:
return
if isinstance(self.axes, int):
@@ -462,9 +485,10 @@ class Dot(_Merge):
raise ValueError('Dimension incompatibility '
'%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
'Layer shapes: %s, %s' % (shape1, shape2))
- self.built = True
- def call(self, inputs):
+ def _merge_function(self, inputs):
+ if len(inputs) != 2:
+ raise ValueError('A `Dot` layer should be called ' 'on exactly 2 inputs')
x1 = inputs[0]
x2 = inputs[1]
if isinstance(self.axes, int):
@@ -485,12 +509,13 @@ class Dot(_Merge):
output = K.batch_dot(x1, x2, axes)
return output
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if not isinstance(input_shape, list) or len(input_shape) != 2:
raise ValueError('A `Dot` layer should be called '
'on a list of 2 inputs.')
- shape1 = tensor_shape.TensorShape(input_shape[0]).as_list()
- shape2 = tensor_shape.TensorShape(input_shape[1]).as_list()
+ shape1 = list(input_shape[0])
+ shape2 = list(input_shape[1])
if isinstance(self.axes, int):
if self.axes < 0:
axes = [self.axes % len(shape1), self.axes % len(shape2)]
@@ -504,7 +529,7 @@ class Dot(_Merge):
output_shape = shape1 + shape2
if len(output_shape) == 1:
output_shape += [1]
- return tensor_shape.TensorShape(output_shape)
+ return tuple(output_shape)
def compute_mask(self, inputs, mask=None):
return None
@@ -527,6 +552,21 @@ def add(inputs, **kwargs):
Returns:
A tensor, the sum of the inputs.
+
+ Examples:
+
+ ```python
+ import keras
+
+ input1 = keras.layers.Input(shape=(16,))
+ x1 = keras.layers.Dense(8, activation='relu')(input1)
+ input2 = keras.layers.Input(shape=(32,))
+ x2 = keras.layers.Dense(8, activation='relu')(input2)
+ added = keras.layers.add([x1, x2])
+
+ out = keras.layers.Dense(4)(added)
+ model = keras.models.Model(inputs=[input1, input2], outputs=out)
+ ```
"""
return Add(**kwargs)(inputs)
diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py
index 459f13145f..04fffcc384 100644
--- a/tensorflow/python/keras/_impl/keras/layers/noise.py
+++ b/tensorflow/python/keras/_impl/keras/layers/noise.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Layers for regularization models via the addition of noise.
+"""Layers that operate regularization via the addition of noise.
"""
from __future__ import absolute_import
from __future__ import division
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
class GaussianNoise(Layer):
@@ -59,14 +60,15 @@ class GaussianNoise(Layer):
return K.in_train_phase(noised, inputs, training=training)
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {'stddev': self.stddev}
base_config = super(GaussianNoise, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
class GaussianDropout(Layer):
"""Apply multiplicative 1-centered Gaussian noise.
@@ -86,10 +88,6 @@ class GaussianDropout(Layer):
Output shape:
Same shape as input.
- References:
- - [Dropout: A Simple Way to Prevent Neural Networks from Overfitting
- Srivastava, Hinton, et al.
- 2014](http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf)
"""
def __init__(self, rate, **kwargs):
@@ -108,14 +106,15 @@ class GaussianDropout(Layer):
return K.in_train_phase(noised, inputs, training=training)
return inputs
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {'rate': self.rate}
base_config = super(GaussianDropout, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
class AlphaDropout(Layer):
"""Applies Alpha Dropout to the input.
@@ -140,8 +139,6 @@ class AlphaDropout(Layer):
Output shape:
Same shape as input.
- References:
- - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
"""
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
@@ -157,26 +154,34 @@ class AlphaDropout(Layer):
def call(self, inputs, training=None):
if 0. < self.rate < 1.:
noise_shape = self._get_noise_shape(inputs)
- alpha = 1.6732632423543772848170429916717
- scale = 1.0507009873554804934193349852946
- def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed):
+ def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed): # pylint: disable=missing-docstring
+ alpha = 1.6732632423543772848170429916717
+ scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale
- kept_idx = K.greater_equal(K.random_uniform(noise_shape, seed=seed),
- rate)
+
+ kept_idx = K.greater_equal(
+ K.random_uniform(noise_shape, seed=seed), rate)
kept_idx = K.cast(kept_idx, K.floatx())
- a = ((1 - rate) * (1 + rate * alpha_p ** 2)) ** -0.5
+
+ # Get affine transformation params
+ a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5
b = -a * alpha_p * rate
+
+ # Apply mask
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
+
+ # Do affine transformation
return a * x + b
return K.in_train_phase(dropped_inputs, inputs, training=training)
return inputs
- def compute_output_shape(self, input_shape):
- return input_shape
-
def get_config(self):
config = {'rate': self.rate}
base_config = super(AlphaDropout, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 9ea21c9c36..1b0f6cb6cf 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
-"""Recurrent layers.
+"""Recurrent layers and their base classes.
"""
from __future__ import absolute_import
from __future__ import division
@@ -29,6 +29,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.platform import tf_logging as logging
@@ -109,6 +110,7 @@ class StackedRNNCells(Layer):
states += cell_states
return inputs, states
+ @shape_type_conversion
def build(self, input_shape):
for cell in self.cells:
if isinstance(cell, Layer):
@@ -117,7 +119,7 @@ class StackedRNNCells(Layer):
output_dim = cell.state_size[0]
else:
output_dim = cell.state_size
- input_shape = (input_shape[0], input_shape[1], output_dim)
+ input_shape = (input_shape[0], output_dim)
self.built = True
def get_config(self):
@@ -262,8 +264,7 @@ class RNN(Layer):
(e.g. via the `input_shape` argument)
Input shape:
- 3D tensor with shape `(batch_size, timesteps, input_dim)`,
- (Optional) 2D tensors with shape `(batch_size, output_dim)`.
+ 3D tensor with shape `(batch_size, timesteps, input_dim)`.
Output shape:
- if `return_state`: a list of tensors. The first tensor is
@@ -370,7 +371,6 @@ class RNN(Layer):
go_backwards=False,
stateful=False,
unroll=False,
- activity_regularizer=None,
**kwargs):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
@@ -382,8 +382,7 @@ class RNN(Layer):
'an attribute `state_size` '
'(tuple of integers, '
'one integer per RNN state).')
- super(RNN, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+ super(RNN, self).__init__(**kwargs)
self.cell = cell
self.return_sequences = return_sequences
self.return_state = return_state
@@ -412,15 +411,16 @@ class RNN(Layer):
def states(self, states):
self._states = states
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
if hasattr(self.cell.state_size, '__len__'):
- output_dim = self.cell.state_size[0]
+ state_size = self.cell.state_size
else:
- output_dim = self.cell.state_size
+ state_size = [self.cell.state_size]
+ output_dim = state_size[0]
if self.return_sequences:
output_shape = (input_shape[0], input_shape[1], output_dim)
@@ -428,11 +428,10 @@ class RNN(Layer):
output_shape = (input_shape[0], output_dim)
if self.return_state:
- state_shape = [(input_shape[0], output_dim) for _ in self.states]
- output_shape = [output_shape] + state_shape
+ state_shape = [(input_shape[0], dim) for dim in state_size]
+ return [output_shape] + state_shape
else:
- output_shape = output_shape
- return tensor_shape.TensorShape(output_shape)
+ return output_shape
def compute_mask(self, inputs, mask):
if isinstance(mask, list):
@@ -444,6 +443,7 @@ class RNN(Layer):
else:
return output_mask
+ @shape_type_conversion
def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
@@ -454,7 +454,6 @@ class RNN(Layer):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
batch_size = input_shape[0] if self.stateful else None
input_dim = input_shape[-1]
@@ -478,9 +477,9 @@ class RNN(Layer):
# initial_state was passed in call, check compatibility
if [spec.shape[-1] for spec in self.state_spec] != state_size:
raise ValueError(
- 'An initial_state was passed that is not compatible with '
+ 'An `initial_state` was passed that is not compatible with '
'`cell.state_size`. Received `state_spec`={}; '
- 'However `cell.state_size` is '
+ 'however `cell.state_size` is '
'{}'.format(self.state_spec, self.cell.state_size))
else:
self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
@@ -610,7 +609,8 @@ class RNN(Layer):
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
- unroll=self.unroll)
+ unroll=self.unroll,
+ input_length=timesteps)
if self.stateful:
updates = []
for i in range(len(states)):
@@ -625,6 +625,8 @@ class RNN(Layer):
# Properly set learning phase
if getattr(last_output, '_uses_learning_phase', False):
output._uses_learning_phase = True
+ for state in states:
+ state._uses_learning_phase = True
if self.return_state:
if not isinstance(states, (list, tuple)):
@@ -636,7 +638,7 @@ class RNN(Layer):
return output
def _standardize_args(self, inputs, initial_state, constants):
- """Standardize `__call__` arguments to a single list of tensor inputs.
+ """Standardize `__call__` to a single list of tensor inputs.
When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
@@ -688,7 +690,7 @@ class RNN(Layer):
'a `batch_input_shape` '
'argument to your first layer.\n'
'- If using the functional API, specify '
- 'the time dimension by passing a '
+ 'the batch size by passing a '
'`batch_shape` argument to your Input layer.')
# initialize state if None
if self.states[0] is None:
@@ -788,37 +790,26 @@ class SimpleRNNCell(Layer):
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- Default: hyperbolic tangent (`tanh`).
- If you pass `None`, no activation is applied
+ activation: Activation function to use.
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -866,6 +857,7 @@ class SimpleRNNCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
+ @shape_type_conversion
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
@@ -890,33 +882,21 @@ class SimpleRNNCell(Layer):
self.bias = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._dropout_mask = K.in_train_phase(
- dropped_inputs, ones, training=training)
- else:
- self._dropout_mask = None
-
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
- if 0 < self.recurrent_dropout < 1:
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, self.units))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._recurrent_dropout_mask = K.in_train_phase(
- dropped_inputs, ones, training=training)
- else:
- self._recurrent_dropout_mask = None
-
def call(self, inputs, states, training=None):
prev_output = states[0]
+ if 0 < self.dropout < 1 and self._dropout_mask is None:
+ self._dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs,
+ K.shape(inputs)[-1]),
+ self.dropout,
+ training=training)
+ if (0 < self.recurrent_dropout < 1 and
+ self._recurrent_dropout_mask is None):
+ self._recurrent_dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs, self.units),
+ self.recurrent_dropout,
+ training=training)
+
dp_mask = self._dropout_mask
rec_dp_mask = self._recurrent_dropout_mask
@@ -939,46 +919,68 @@ class SimpleRNNCell(Layer):
output._uses_learning_phase = True
return output, [output]
+ def get_config(self):
+ config = {
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'recurrent_initializer':
+ initializers.serialize(self.recurrent_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'recurrent_regularizer':
+ regularizers.serialize(self.recurrent_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'recurrent_constraint':
+ constraints.serialize(self.recurrent_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout
+ }
+ base_config = super(SimpleRNNCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
class SimpleRNN(RNN):
"""Fully-connected RNN where the output is to be fed back to input.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- Default: hyperbolic tangent (`tanh`).
- If you pass `None`, no activation is applied
+ activation: Activation function to use.
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1052,12 +1054,12 @@ class SimpleRNN(RNN):
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
- activity_regularizer=regularizers.get(activity_regularizer),
**kwargs)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ self.cell._dropout_mask = None
+ self.cell._recurrent_dropout_mask = None
return super(SimpleRNN, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@@ -1119,25 +1121,36 @@ class SimpleRNN(RNN):
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout
}
base_config = super(SimpleRNN, self).get_config()
del base_config['cell']
@@ -1155,43 +1168,28 @@ class GRUCell(Layer):
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- Default: hyperbolic tangent (`tanh`).
- If you pass `None`, no activation is applied
+ activation: Activation function to use.
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
- Default: hard sigmoid (`hard_sigmoid`).
- If you pass `None`, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1249,6 +1247,7 @@ class GRUCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
+ @shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
@@ -1292,38 +1291,24 @@ class GRUCell(Layer):
self.bias_h = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._dropout_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(3)
- ]
- else:
- self._dropout_mask = None
-
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
- if 0 < self.recurrent_dropout < 1:
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, self.units))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._recurrent_dropout_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(3)
- ]
- else:
- self._recurrent_dropout_mask = None
-
def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory
+ if 0 < self.dropout < 1 and self._dropout_mask is None:
+ self._dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs,
+ K.shape(inputs)[-1]),
+ self.dropout,
+ training=training,
+ count=3)
+ if (0 < self.recurrent_dropout < 1 and
+ self._recurrent_dropout_mask is None):
+ self._recurrent_dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs, self.units),
+ self.recurrent_dropout,
+ training=training,
+ count=3)
+
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
@@ -1387,55 +1372,76 @@ class GRUCell(Layer):
h._uses_learning_phase = True
return h, [h]
+ def get_config(self):
+ config = {
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'recurrent_activation':
+ activations.serialize(self.recurrent_activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'recurrent_initializer':
+ initializers.serialize(self.recurrent_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'recurrent_regularizer':
+ regularizers.serialize(self.recurrent_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'recurrent_constraint':
+ constraints.serialize(self.recurrent_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
+ }
+ base_config = super(GRUCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
class GRU(RNN):
- # pylint: disable=line-too-long
"""Gated Recurrent Unit - Cho et al.
2014.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- Default: hyperbolic tangent (`tanh`).
- If you pass `None`, no activation is applied
+ activation: Activation function to use.
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
- Default: hard sigmoid (`hard_sigmoid`).
- If you pass `None`, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1465,12 +1471,7 @@ class GRU(RNN):
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
- References:
- - [On the Properties of Neural Machine Translation: Encoder-Decoder Approaches](https://arxiv.org/abs/1409.1259)
- - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555v1)
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
"""
- # pylint: enable=line-too-long
def __init__(self,
units,
@@ -1528,8 +1529,8 @@ class GRU(RNN):
self.activity_regularizer = regularizers.get(activity_regularizer)
def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ self.cell._dropout_mask = None
+ self.cell._recurrent_dropout_mask = None
return super(GRU, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@@ -1599,28 +1600,40 @@ class GRU(RNN):
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout,
- 'implementation': self.implementation
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
}
base_config = super(GRU, self).get_config()
del base_config['cell']
@@ -1638,48 +1651,33 @@ class LSTMCell(Layer):
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- Default: hyperbolic tangent (`tanh`).
- If you pass `None`, no activation is applied
+ activation: Activation function to use.
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
- Default: hard sigmoid (`hard_sigmoid`).
- If you pass `None`, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et
al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1739,6 +1737,7 @@ class LSTMCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
+ @shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
@@ -1798,36 +1797,22 @@ class LSTMCell(Layer):
self.bias_o = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._dropout_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(4)
- ]
- else:
- self._dropout_mask = None
-
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
- if 0 < self.recurrent_dropout < 1:
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, self.units))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._recurrent_dropout_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(4)
- ]
- else:
- self._recurrent_dropout_mask = None
-
def call(self, inputs, states, training=None):
+ if 0 < self.dropout < 1 and self._dropout_mask is None:
+ self._dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs,
+ K.shape(inputs)[-1]),
+ self.dropout,
+ training=training,
+ count=4)
+ if (0 < self.recurrent_dropout < 1 and
+ self._recurrent_dropout_mask is None):
+ self._recurrent_dropout_mask = _generate_dropout_mask(
+ _generate_dropout_ones(inputs, self.units),
+ self.recurrent_dropout,
+ training=training,
+ count=4)
+
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
@@ -1901,59 +1886,81 @@ class LSTMCell(Layer):
h._uses_learning_phase = True
return h, [h, c]
+ def get_config(self):
+ config = {
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
+ 'recurrent_activation':
+ activations.serialize(self.recurrent_activation),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
+ 'recurrent_initializer':
+ initializers.serialize(self.recurrent_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'unit_forget_bias':
+ self.unit_forget_bias,
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
+ 'recurrent_regularizer':
+ regularizers.serialize(self.recurrent_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
+ 'recurrent_constraint':
+ constraints.serialize(self.recurrent_constraint),
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
+ }
+ base_config = super(LSTMCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
class LSTM(RNN):
- # pylint: disable=line-too-long
"""Long-Short Term Memory layer - Hochreiter 1997.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- Default: hyperbolic tangent (`tanh`).
- If you pass `None`, no activation is applied
+ activation: Activation function to use.
+ If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
- Default: hyperbolic tangent (`tanh`).
- Default: hard sigmoid (`hard_sigmoid`).
- If you pass `None`, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the inputs..
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the recurrent state..
+ bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et
al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
@@ -1983,13 +1990,7 @@ class LSTM(RNN):
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
- References:
- - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf)
- - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
- - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
"""
- # pylint: enable=line-too-long
def __init__(self,
units,
@@ -2049,8 +2050,8 @@ class LSTM(RNN):
self.activity_regularizer = regularizers.get(activity_regularizer)
def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+ self.cell._dropout_mask = None
+ self.cell._recurrent_dropout_mask = None
return super(LSTM, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@@ -2124,29 +2125,42 @@ class LSTM(RNN):
def get_config(self):
config = {
- 'units': self.units,
- 'activation': activations.serialize(self.activation),
+ 'units':
+ self.units,
+ 'activation':
+ activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
- 'use_bias': self.use_bias,
- 'kernel_initializer': initializers.serialize(self.kernel_initializer),
+ 'use_bias':
+ self.use_bias,
+ 'kernel_initializer':
+ initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
- 'unit_forget_bias': self.unit_forget_bias,
- 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
+ 'unit_forget_bias':
+ self.unit_forget_bias,
+ 'kernel_regularizer':
+ regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint': constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint':
+ constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint),
- 'dropout': self.dropout,
- 'recurrent_dropout': self.recurrent_dropout,
- 'implementation': self.implementation
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint),
+ 'dropout':
+ self.dropout,
+ 'recurrent_dropout':
+ self.recurrent_dropout,
+ 'implementation':
+ self.implementation
}
base_config = super(LSTM, self).get_config()
del base_config['cell']
@@ -2159,6 +2173,23 @@ class LSTM(RNN):
return cls(**config)
+def _generate_dropout_ones(inputs, dims):
+ return K.ones((K.shape(inputs)[0], dims))
+
+
+def _generate_dropout_mask(ones, rate, training=None, count=1):
+
+ def dropped_inputs():
+ return K.dropout(ones, rate)
+
+ if count > 1:
+ return [
+ K.in_train_phase(dropped_inputs, ones, training=training)
+ for _ in range(count)
+ ]
+ return K.in_train_phase(dropped_inputs, ones, training=training)
+
+
class Recurrent(Layer):
"""Deprecated abstract base class for recurrent layers.
@@ -2285,6 +2316,7 @@ class Recurrent(Layer):
self.dropout = 0
self.recurrent_dropout = 0
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
index 7dc4c1db9b..a1407a24ea 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
@@ -392,6 +392,105 @@ class RNNTest(test.TestCase):
self.assertEqual(len(layer.trainable_weights), 3)
self.assertEqual(len(layer.non_trainable_weights), 0)
+ def test_state_reuse_with_dropout(self):
+ layer_class = keras.layers.SimpleRNN
+ embedding_dim = 4
+ units = 3
+ timesteps = 2
+ num_samples = 2
+
+ with self.test_session():
+ input1 = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
+ layer = layer_class(units,
+ return_state=True,
+ return_sequences=True,
+ dropout=0.2)
+ state = layer(input1)[1:]
+
+ input2 = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
+ output = layer_class(units)(input2, initial_state=state)
+ model = keras.Model([input1, input2], output)
+
+ inputs = [np.random.random((num_samples, timesteps, embedding_dim)),
+ np.random.random((num_samples, timesteps, embedding_dim))]
+ model.predict(inputs)
+
+ def test_builtin_rnn_cell_serialization(self):
+ for cell_class in [keras.layers.SimpleRNNCell,
+ keras.layers.GRUCell,
+ keras.layers.LSTMCell]:
+ with self.test_session():
+ # Test basic case.
+ x = keras.Input((None, 5))
+ cell = cell_class(32)
+ layer = keras.layers.RNN(cell)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+
+ # Test basic case serialization.
+ x_np = np.random.random((6, 5, 5))
+ y_np = model.predict(x_np)
+ weights = model.get_weights()
+ config = layer.get_config()
+ layer = keras.layers.RNN.from_config(config)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.set_weights(weights)
+ y_np_2 = model.predict(x_np)
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ # Test stacking.
+ cells = [cell_class(8),
+ cell_class(12),
+ cell_class(32)]
+ layer = keras.layers.RNN(cells)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+
+ # Test stacked RNN serialization.
+ x_np = np.random.random((6, 5, 5))
+ y_np = model.predict(x_np)
+ weights = model.get_weights()
+ config = layer.get_config()
+ layer = keras.layers.RNN.from_config(config)
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.set_weights(weights)
+ y_np_2 = model.predict(x_np)
+ self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+ def test_stacked_rnn_dropout(self):
+ cells = [keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
+ keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
+ layer = keras.layers.RNN(cells)
+
+ with self.test_session():
+ x = keras.Input((None, 5))
+ y = layer(x)
+ model = keras.models.Model(x, y)
+ model.compile('sgd', 'mse')
+ x_np = np.random.random((6, 5, 5))
+ y_np = np.random.random((6, 3))
+ model.train_on_batch(x_np, y_np)
+
+ def test_stacked_rnn_compute_output_shape(self):
+ cells = [keras.layers.LSTMCell(3),
+ keras.layers.LSTMCell(6)]
+ embedding_dim = 4
+ timesteps = 2
+ layer = keras.layers.RNN(cells, return_state=True, return_sequences=True)
+ output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
+ expected_output_shape = [(None, timesteps, 6),
+ (None, 6),
+ (None, 6),
+ (None, 3),
+ (None, 3)]
+ self.assertEqual(
+ [tuple(o.as_list()) for o in output_shape],
+ expected_output_shape)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
index 452801b656..3667956f80 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.layers import utils as tf_layers_util
@@ -291,6 +292,7 @@ class Bidirectional(Wrapper):
self.backward_layer.initial_weights = weights[nw // 2:]
self.stateful = layer.stateful
self.return_sequences = layer.return_sequences
+ self.return_state = layer.return_state
self.supports_masking = True
def get_weights(self):
@@ -301,27 +303,54 @@ class Bidirectional(Wrapper):
self.forward_layer.set_weights(weights[:nw // 2])
self.backward_layer.set_weights(weights[nw // 2:])
+ @shape_type_conversion
def compute_output_shape(self, input_shape):
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
- if self.merge_mode in ['sum', 'ave', 'mul']:
- return self.forward_layer.compute_output_shape(input_shape)
- elif self.merge_mode == 'concat':
- shape = self.forward_layer.compute_output_shape(input_shape).as_list()
- shape[-1] *= 2
- return tensor_shape.TensorShape(shape)
+ output_shape = tuple(self.forward_layer.compute_output_shape(
+ input_shape).as_list())
+ if self.return_state:
+ state_shape = output_shape[1:]
+ output_shape = output_shape[0]
+
+ if self.merge_mode == 'concat':
+ output_shape = list(output_shape)
+ output_shape[-1] *= 2
+ output_shape = tuple(output_shape)
elif self.merge_mode is None:
- shape = self.forward_layer.compute_output_shape(input_shape)
- return [shape, copy.copy(shape)]
+ output_shape = [output_shape, copy.copy(output_shape)]
- def call(self, inputs, training=None, mask=None):
+ if self.return_state:
+ if self.merge_mode is None:
+ return output_shape + state_shape + copy.copy(state_shape)
+ return [output_shape] + state_shape + copy.copy(state_shape)
+ return output_shape
+
+ def call(self, inputs, training=None, mask=None, initial_state=None):
kwargs = {}
if has_arg(self.layer.call, 'training'):
kwargs['training'] = training
if has_arg(self.layer.call, 'mask'):
kwargs['mask'] = mask
- y = self.forward_layer.call(inputs, **kwargs)
- y_rev = self.backward_layer.call(inputs, **kwargs)
+ if initial_state is not None and has_arg(self.layer.call, 'initial_state'):
+ if not isinstance(initial_state, list):
+ raise ValueError(
+ 'When passing `initial_state` to a Bidirectional RNN, the state '
+ 'should be a list containing the states of the underlying RNNs. '
+ 'Found: ' + str(initial_state))
+ forward_state = initial_state[:len(initial_state) // 2]
+ backward_state = initial_state[len(initial_state) // 2:]
+ y = self.forward_layer.call(inputs, initial_state=forward_state, **kwargs)
+ y_rev = self.backward_layer.call(
+ inputs, initial_state=backward_state, **kwargs)
+ else:
+ y = self.forward_layer.call(inputs, **kwargs)
+ y_rev = self.backward_layer.call(inputs, **kwargs)
+
+ if self.return_state:
+ states = y[1:] + y_rev[1:]
+ y = y[0]
+ y_rev = y_rev[0]
+
if self.return_sequences:
y_rev = K.reverse(y_rev, 1)
if self.merge_mode == 'concat':
@@ -343,6 +372,11 @@ class Bidirectional(Wrapper):
out._uses_learning_phase = True
else:
output._uses_learning_phase = True
+
+ if self.return_state:
+ if self.merge_mode is None:
+ return output + states
+ return [output] + states
return output
def reset_states(self):
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
index 0866c4b0ae..f48c8919a1 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
@@ -238,6 +238,131 @@ class BidirectionalTest(test.TestCase):
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
+ def test_Bidirectional_merged_value(self):
+ rnn = keras.layers.LSTM
+ samples = 2
+ dim = 5
+ timesteps = 3
+ units = 3
+ x = [np.random.rand(samples, timesteps, dim)]
+
+ with self.test_session():
+ for merge_mode in ['sum', 'mul', 'ave', 'concat', None]:
+ if merge_mode == 'sum':
+ merge_func = lambda y, y_rev: y + y_rev
+ elif merge_mode == 'mul':
+ merge_func = lambda y, y_rev: y * y_rev
+ elif merge_mode == 'ave':
+ merge_func = lambda y, y_rev: (y + y_rev) / 2
+ elif merge_mode == 'concat':
+ merge_func = lambda y, y_rev: np.concatenate((y, y_rev), axis=-1)
+ else:
+ merge_func = lambda y, y_rev: [y, y_rev]
+
+ # basic case
+ inputs = keras.Input((timesteps, dim))
+ layer = keras.layers.Bidirectional(
+ rnn(units, return_sequences=True), merge_mode=merge_mode)
+ f_merged = keras.backend.function([inputs], _to_list(layer(inputs)))
+ f_forward = keras.backend.function([inputs],
+ [layer.forward_layer.call(inputs)])
+ f_backward = keras.backend.function(
+ [inputs],
+ [keras.backend.reverse(layer.backward_layer.call(inputs), 1)])
+
+ y_merged = f_merged(x)
+ y_expected = _to_list(merge_func(f_forward(x)[0], f_backward(x)[0]))
+ assert len(y_merged) == len(y_expected)
+ for x1, x2 in zip(y_merged, y_expected):
+ self.assertAllClose(x1, x2, atol=1e-5)
+
+ # test return_state
+ inputs = keras.Input((timesteps, dim))
+ layer = keras.layers.Bidirectional(
+ rnn(units, return_state=True), merge_mode=merge_mode)
+ f_merged = keras.backend.function([inputs], layer(inputs))
+ f_forward = keras.backend.function([inputs],
+ layer.forward_layer.call(inputs))
+ f_backward = keras.backend.function([inputs],
+ layer.backward_layer.call(inputs))
+ n_states = len(layer.layer.states)
+
+ y_merged = f_merged(x)
+ y_forward = f_forward(x)
+ y_backward = f_backward(x)
+ y_expected = _to_list(merge_func(y_forward[0], y_backward[0]))
+ assert len(y_merged) == len(y_expected) + n_states * 2
+ for x1, x2 in zip(y_merged, y_expected):
+ self.assertAllClose(x1, x2, atol=1e-5)
+
+ y_merged = y_merged[-n_states * 2:]
+ y_forward = y_forward[-n_states:]
+ y_backward = y_backward[-n_states:]
+ for state_birnn, state_inner in zip(y_merged, y_forward + y_backward):
+ self.assertAllClose(state_birnn, state_inner, atol=1e-5)
+
+ def test_Bidirectional_dropout(self):
+ rnn = keras.layers.LSTM
+ samples = 2
+ dim = 5
+ timesteps = 3
+ units = 3
+ merge_mode = 'sum'
+ x = [np.random.rand(samples, timesteps, dim)]
+
+ with self.test_session():
+ inputs = keras.Input((timesteps, dim))
+ wrapped = keras.layers.Bidirectional(
+ rnn(units, dropout=0.2, recurrent_dropout=0.2), merge_mode=merge_mode)
+ outputs = _to_list(wrapped(inputs, training=True))
+ assert all(not getattr(x, '_uses_learning_phase') for x in outputs)
+
+ inputs = keras.Input((timesteps, dim))
+ wrapped = keras.layers.Bidirectional(
+ rnn(units, dropout=0.2, return_state=True), merge_mode=merge_mode)
+ outputs = _to_list(wrapped(inputs))
+ assert all(x._uses_learning_phase for x in outputs)
+
+ model = keras.Model(inputs, outputs)
+ assert model.uses_learning_phase
+ y1 = _to_list(model.predict(x))
+ y2 = _to_list(model.predict(x))
+ for x1, x2 in zip(y1, y2):
+ self.assertAllClose(x1, x2, atol=1e-5)
+
+ def test_Bidirectional_state_reuse(self):
+ rnn = keras.layers.LSTM
+ samples = 2
+ dim = 5
+ timesteps = 3
+ units = 3
+
+ with self.test_session():
+ inputs = keras.Input((timesteps, dim))
+ layer = keras.layers.Bidirectional(
+ rnn(units, return_state=True, return_sequences=True))
+ outputs = layer(inputs)
+ output, state = outputs[0], outputs[1:]
+
+ # test passing invalid initial_state: passing a tensor
+ with self.assertRaises(ValueError):
+ output = keras.layers.Bidirectional(
+ rnn(units))(output, initial_state=state[0])
+
+ # test valid usage: passing a list
+ output = keras.layers.Bidirectional(
+ rnn(units))(output, initial_state=state)
+ model = keras.Model(inputs, output)
+ inputs = np.random.rand(samples, timesteps, dim)
+ outputs = model.predict(inputs)
+
+
+def _to_list(ls):
+ if isinstance(ls, list):
+ return ls
+ else:
+ return [ls]
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py
index 1d6319abb1..fe0ef54360 100644
--- a/tensorflow/python/keras/_impl/keras/losses.py
+++ b/tensorflow/python/keras/_impl/keras/losses.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Built-in Keras loss functions.
+# pylint: disable=unused-import
+"""Built-in loss functions.
"""
from __future__ import absolute_import
from __future__ import division
@@ -34,7 +35,6 @@ def mean_absolute_error(y_true, y_pred):
def mean_absolute_percentage_error(y_true, y_pred):
- # Equivalent to MAE, but sometimes easier to interpret.
diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true), K.epsilon(), None))
return 100. * K.mean(diff, axis=-1)
@@ -56,10 +56,24 @@ def hinge(y_true, y_pred):
def categorical_hinge(y_true, y_pred):
pos = K.sum(y_true * y_pred, axis=-1)
neg = K.max((1. - y_true) * y_pred, axis=-1)
- return K.maximum(neg - pos + 1., 0.)
+ return K.maximum(0., neg - pos + 1.)
def logcosh(y_true, y_pred):
+ """Logarithm of the hyperbolic cosine of the prediction error.
+
+ `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
+ to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
+ like the mean squared error, but will not be so strongly affected by the
+ occasional wildly incorrect prediction.
+
+ Arguments:
+ y_true: tensor of true targets.
+ y_pred: tensor of predicted targets.
+
+ Returns:
+ Tensor with one scalar loss entry per sample.
+ """
def _logcosh(x):
return x + K.softplus(-2. * x) - K.log(2.)
diff --git a/tensorflow/python/keras/_impl/keras/metrics.py b/tensorflow/python/keras/_impl/keras/metrics.py
index 202048f26d..3c18e68260 100644
--- a/tensorflow/python/keras/_impl/keras/metrics.py
+++ b/tensorflow/python/keras/_impl/keras/metrics.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Built-in Keras metrics functions.
+# pylint: disable=unused-import
+"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
@@ -21,7 +22,6 @@ from __future__ import print_function
import six
from tensorflow.python.keras._impl.keras import backend as K
-# pylint: disable=unused-import
from tensorflow.python.keras._impl.keras.losses import binary_crossentropy
from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy
from tensorflow.python.keras._impl.keras.losses import cosine_proximity
@@ -35,7 +35,6 @@ from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_
from tensorflow.python.keras._impl.keras.losses import poisson
from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy
from tensorflow.python.keras._impl.keras.losses import squared_hinge
-# pylint: disable=unused-import
from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
@@ -60,8 +59,8 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
- return K.mean(K.in_top_k(y_pred,
- K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)
+ return K.mean(
+ K.in_top_k(y_pred, K.cast(K.max(y_true, axis=-1), 'int32'), k), axis=-1)
# Aliases
diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py
index e262cc8c8e..9cd547200d 100644
--- a/tensorflow/python/keras/_impl/keras/models.py
+++ b/tensorflow/python/keras/_impl/keras/models.py
@@ -492,13 +492,13 @@ class Sequential(Model):
# to the input layer we just created.
layer(x)
- if len(layer.inbound_nodes[-1].output_tensors) != 1:
+ if len(layer._inbound_nodes[-1].output_tensors) != 1:
raise ValueError('All layers in a Sequential model '
'should have a single output tensor. '
'For multi-output layers, '
'use the functional API.')
- self.outputs = [layer.inbound_nodes[-1].output_tensors[0]]
+ self.outputs = [layer._inbound_nodes[-1].output_tensors[0]]
self.inputs = topology.get_source_inputs(self.outputs[0])
# We create an input node, which we will keep updated
diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py
index edfc0ce0eb..04017e4b28 100644
--- a/tensorflow/python/keras/_impl/keras/models_test.py
+++ b/tensorflow/python/keras/_impl/keras/models_test.py
@@ -340,6 +340,35 @@ class TestSequential(test.TestCase):
inner_model.trainable = True
self.assertEqual(len(model.trainable_weights), 4)
+ def test_sequential_update_disabling(self):
+ val_a = np.random.random((10, 4))
+ val_out = np.random.random((10, 4))
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.BatchNormalization(input_shape=(4,)))
+
+ model.trainable = False
+ assert not model.updates
+
+ model.compile('sgd', 'mse')
+ assert not model.updates
+ assert not model.model.updates
+
+ x1 = model.predict(val_a)
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ self.assertAllClose(x1, x2, atol=1e-7)
+
+ model.trainable = True
+ model.compile('sgd', 'mse')
+ assert model.updates
+ assert model.model.updates
+
+ model.train_on_batch(val_a, val_out)
+ x2 = model.predict(val_a)
+ assert np.abs(np.sum(x1 - x2)) > 1e-5
+
class TestModelCloning(test.TestCase):
diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py
index a08073fa86..e47987aadc 100644
--- a/tensorflow/python/keras/_impl/keras/optimizers.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras optimizer classes (will eventually be replaced with core optimizers).
+# pylint: disable=invalid-name
+"""Built-in optimizer classes.
"""
from __future__ import absolute_import
from __future__ import division
@@ -121,9 +122,9 @@ class Optimizer(object):
param_values = K.batch_get_value(params)
for pv, p, w in zip(param_values, params, weights):
if pv.shape != w.shape:
- raise ValueError('Optimizer weight shape ' + str(pv.shape) +
- ' not compatible with '
- 'provided weight shape ' + str(w.shape))
+ raise ValueError(
+ 'Optimizer weight shape ' + str(pv.shape) + ' not compatible with '
+ 'provided weight shape ' + str(w.shape))
weight_value_tuples.append((p, w))
K.batch_set_value(weight_value_tuples)
@@ -156,7 +157,8 @@ class SGD(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
- momentum: float >= 0. Parameter updates momentum.
+ momentum: float >= 0. Parameter that accelerates SGD
+ in the relevant direction and dampens oscillations.
decay: float >= 0. Learning rate decay over each update.
nesterov: boolean. Whether to apply Nesterov momentum.
"""
@@ -177,9 +179,8 @@ class SGD(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
-
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
# momentum
shapes = [K.int_shape(p) for p in params]
moments = [K.zeros(shape) for shape in shapes]
@@ -224,32 +225,33 @@ class RMSprop(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
rho: float >= 0.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
+
"""
- def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, decay=0., **kwargs):
+ def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs):
super(RMSprop, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.lr = K.variable(lr, name='lr')
self.rho = K.variable(rho, name='rho')
self.decay = K.variable(decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
- accumulators = [
- K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params
- ]
+ accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
self.weights = accumulators
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
for p, g, a in zip(params, grads, accumulators):
# update accumulator
@@ -283,20 +285,19 @@ class Adagrad(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
- epsilon: float >= 0.
+ epsilon: float >= 0. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
- References:
- - [Adaptive Subgradient Methods for Online Learning and Stochastic
- Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
"""
- def __init__(self, lr=0.01, epsilon=1e-8, decay=0., **kwargs):
+ def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs):
super(Adagrad, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.lr = K.variable(lr, name='lr')
self.decay = K.variable(decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
@@ -309,8 +310,8 @@ class Adagrad(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
for p, g, a in zip(params, grads, accumulators):
new_a = a + K.square(g) # update accumulator
@@ -344,20 +345,19 @@ class Adadelta(Optimizer):
lr: float >= 0. Learning rate.
It is recommended to leave it at the default value.
rho: float >= 0.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
- References:
- - [Adadelta - an adaptive learning rate
- method](http://arxiv.org/abs/1212.5701)
"""
- def __init__(self, lr=1.0, rho=0.95, epsilon=1e-8, decay=0., **kwargs):
+ def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs):
super(Adadelta, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.lr = K.variable(lr, name='lr')
self.decay = K.variable(decay, name='decay')
self.iterations = K.variable(0, dtype='int64', name='iterations')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.rho = rho
self.epsilon = epsilon
self.initial_decay = decay
@@ -372,8 +372,8 @@ class Adadelta(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
# update accumulator
@@ -415,20 +415,21 @@ class Adam(Optimizer):
lr: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
+ amsgrad: boolean. Whether to apply the AMSGrad variant of this
+ algorithm from the paper "On the Convergence of Adam and
+ Beyond".
- References:
- - [Adam - A Method for Stochastic
- Optimization](http://arxiv.org/abs/1412.6980v8)
"""
def __init__(self,
lr=0.001,
beta_1=0.9,
beta_2=0.999,
- epsilon=1e-8,
+ epsilon=None,
decay=0.,
+ amsgrad=False,
**kwargs):
super(Adam, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
@@ -437,8 +438,11 @@ class Adam(Optimizer):
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
+ self.amsgrad = amsgrad
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
@@ -446,21 +450,30 @@ class Adam(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
t = K.cast(self.iterations, K.floatx()) + 1
- lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
- (1. - K.pow(self.beta_1, t)))
+ lr_t = lr * (
+ K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
- self.weights = [self.iterations] + ms + vs
+ if self.amsgrad:
+ vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
+ else:
+ vhats = [K.zeros(1) for _ in params]
+ self.weights = [self.iterations] + ms + vs + vhats
- for p, g, m, v in zip(params, grads, ms, vs):
+ for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
- p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
+ if self.amsgrad:
+ vhat_t = K.maximum(vhat, v_t)
+ p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
+ self.updates.append(K.update(vhat, vhat_t))
+ else:
+ p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
@@ -479,7 +492,8 @@ class Adam(Optimizer):
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
- 'epsilon': self.epsilon
+ 'epsilon': self.epsilon,
+ 'amsgrad': self.amsgrad
}
base_config = super(Adam, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -494,19 +508,16 @@ class Adamax(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
- References:
- - [Adam - A Method for Stochastic
- Optimization](http://arxiv.org/abs/1412.6980v8)
"""
def __init__(self,
lr=0.002,
beta_1=0.9,
beta_2=0.999,
- epsilon=1e-8,
+ epsilon=None,
decay=0.,
**kwargs):
super(Adamax, self).__init__(**kwargs)
@@ -516,6 +527,8 @@ class Adamax(Optimizer):
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
@@ -525,8 +538,8 @@ class Adamax(Optimizer):
lr = self.lr
if self.initial_decay > 0:
- lr *= (1. / (1. + self.decay * K.cast(self.iterations,
- K.dtype(self.decay))))
+ lr *= (1. /
+ (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
t = K.cast(self.iterations, K.floatx()) + 1
lr_t = lr / (1. - K.pow(self.beta_1, t))
@@ -580,19 +593,15 @@ class Nadam(Optimizer):
Arguments:
lr: float >= 0. Learning rate.
beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
- epsilon: float >= 0. Fuzz factor.
+ epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
- References:
- - [Nadam report](http://cs229.stanford.edu/proj2015/054_report.pdf)
- - [On the importance of initialization and momentum in deep
- learning](http://www.cs.toronto.edu/~fritz/absps/momentum.pdf)
"""
def __init__(self,
lr=0.002,
beta_1=0.9,
beta_2=0.999,
- epsilon=1e-8,
+ epsilon=None,
schedule_decay=0.004,
**kwargs):
super(Nadam, self).__init__(**kwargs)
@@ -602,12 +611,15 @@ class Nadam(Optimizer):
self.lr = K.variable(lr, name='lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
+ if epsilon is None:
+ epsilon = K.epsilon()
self.epsilon = epsilon
self.schedule_decay = schedule_decay
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
+
t = K.cast(self.iterations, K.floatx()) + 1
# Due to the recommendations in [2], i.e. warming momentum schedule
@@ -691,7 +703,6 @@ class TFOptimizer(Optimizer):
# Aliases.
-# pylint: disable=invalid-name
sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
@@ -700,8 +711,6 @@ adam = Adam
adamax = Adamax
nadam = Nadam
-# pylint: enable=invalid-name
-
def serialize(optimizer):
return serialize_keras_object(optimizer)
diff --git a/tensorflow/python/keras/_impl/keras/optimizers_test.py b/tensorflow/python/keras/_impl/keras/optimizers_test.py
index 6e9e4e6c99..57636afbf0 100644
--- a/tensorflow/python/keras/_impl/keras/optimizers_test.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers_test.py
@@ -102,6 +102,7 @@ class KerasOptimizersTest(test.TestCase):
with self.test_session():
_test_optimizer(keras.optimizers.Adam())
_test_optimizer(keras.optimizers.Adam(decay=1e-3))
+ _test_optimizer(keras.optimizers.Adam(amsgrad=True))
def test_adamax(self):
with self.test_session():
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py
index 82441de592..db1fdd4e6b 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=g-import-not-at-top
"""Fairly basic set of tools for real-time data augmentation on image data.
Can easily be extended to include new transformations,
@@ -28,25 +29,22 @@ import re
import threading
import numpy as np
-from six.moves import range # pylint: disable=redefined-builtin
-
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
from tensorflow.python.platform import tf_logging as logging
-
-# pylint: disable=g-import-not-at-top
-try:
- from PIL import Image as pil_image
-except ImportError:
- pil_image = None
try:
from scipy import linalg
import scipy.ndimage as ndi
except ImportError:
linalg = None
ndi = None
-# pylint: enable=g-import-not-at-top
+
+
+try:
+ from PIL import Image as pil_image
+except ImportError:
+ pil_image = None
if pil_image is not None:
_PIL_INTERPOLATION_METHODS = {
@@ -88,7 +86,7 @@ def random_rotation(x,
Returns:
Rotated Numpy image tensor.
"""
- theta = np.pi / 180 * np.random.uniform(-rg, rg)
+ theta = np.deg2rad(np.random.uniform(-rg, rg))
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
[np.sin(theta), np.cos(theta), 0], [0, 0, 1]])
@@ -145,7 +143,7 @@ def random_shear(x,
Arguments:
x: Input tensor. Must be 3D.
- intensity: Transformation intensity.
+ intensity: Transformation intensity in degrees.
row_axis: Index of axis for rows in the input tensor.
col_axis: Index of axis for columns in the input tensor.
channel_axis: Index of axis for channels in the input tensor.
@@ -158,7 +156,7 @@ def random_shear(x,
Returns:
Sheared Numpy image tensor.
"""
- shear = np.random.uniform(-intensity, intensity)
+ shear = np.deg2rad(np.random.uniform(-intensity, intensity))
shear_matrix = np.array([[1, -np.sin(shear), 0], [0, np.cos(shear), 0],
[0, 0, 1]])
@@ -188,8 +186,10 @@ def random_zoom(x,
(one of `{'constant', 'nearest', 'reflect', 'wrap'}`).
cval: Value used for points outside the boundaries
of the input if `mode='constant'`.
+
Returns:
Zoomed Numpy image tensor.
+
Raises:
ValueError: if `zoom_range` isn't a tuple.
"""
@@ -366,7 +366,7 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'):
grayscale: Boolean, whether to load the image as grayscale.
target_size: Either `None` (default to original size)
or tuple of ints `(img_height, img_width)`.
- interpolation: Interpolation method used to resample the image if the
+ interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image.
Supported methods are "nearest", "bilinear", and "bicubic".
If PIL version 1.1.3 or newer is installed, "lanczos" is also
@@ -394,11 +394,9 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'):
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
- raise ValueError(
- 'Invalid interpolation method {} specified. Supported '
- 'methods are {}'.format(
- interpolation,
- ', '.join(_PIL_INTERPOLATION_METHODS.keys())))
+ raise ValueError('Invalid interpolation method {} specified. Supported '
+ 'methods are {}'.format(interpolation, ', '.join(
+ _PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]
img = img.resize(width_height_tuple, resample)
return img
@@ -407,7 +405,8 @@ def load_img(path, grayscale=False, target_size=None, interpolation='nearest'):
def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm'):
return [
os.path.join(root, f)
- for root, _, files in os.walk(directory) for f in files
+ for root, _, files in os.walk(directory)
+ for f in files
if re.match(r'([\w]+\.(?:' + ext + '))', f)
]
@@ -423,9 +422,9 @@ class ImageDataGenerator(object):
zca_whitening: apply ZCA whitening.
zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
rotation_range: degrees (0 to 180).
- width_shift_range: fraction of total width.
- height_shift_range: fraction of total height.
- shear_range: shear intensity (shear angle in radians).
+ width_shift_range: fraction of total width, if < 1, or pixels if >= 1.
+ height_shift_range: fraction of total height, if < 1, or pixels if >= 1.
+ shear_range: shear intensity (shear angle in degrees).
zoom_range: amount of zoom. if scalar z, zoom will be randomly picked
in the range [1-z, 1+z]. A sequence of two can be passed instead
to select this range.
@@ -433,6 +432,12 @@ class ImageDataGenerator(object):
fill_mode: points outside the boundaries are filled according to the
given mode ('constant', 'nearest', 'reflect' or 'wrap'). Default
is 'nearest'.
+ Points outside the boundaries of the input are filled according to the
+ given mode:
+ 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
+ 'nearest': aaaaaaaa|abcd|dddddddd
+ 'reflect': abcddcba|abcd|dcbaabcd
+ 'wrap': abcdabcd|abcd|abcdabcd
cval: value used for points outside the boundaries when fill_mode is
'constant'. Default is 0.
horizontal_flip: whether to randomly flip images horizontally.
@@ -522,6 +527,32 @@ class ImageDataGenerator(object):
raise ValueError('`zoom_range` should be a float or '
'a tuple or list of two floats. '
'Received arg: ', zoom_range)
+ if zca_whitening:
+ if not featurewise_center:
+ self.featurewise_center = True
+ logging.warning('This ImageDataGenerator specifies '
+ '`zca_whitening`, which overrides '
+ 'setting of `featurewise_center`.')
+ if featurewise_std_normalization:
+ self.featurewise_std_normalization = False
+ logging.warning('This ImageDataGenerator specifies '
+ '`zca_whitening` '
+ 'which overrides setting of'
+ '`featurewise_std_normalization`.')
+ if featurewise_std_normalization:
+ if not featurewise_center:
+ self.featurewise_center = True
+ logging.warning('This ImageDataGenerator specifies '
+ '`featurewise_std_normalization`, '
+ 'which overrides setting of '
+ '`featurewise_center`.')
+ if samplewise_std_normalization:
+ if not samplewise_center:
+ self.samplewise_center = True
+ logging.warning('This ImageDataGenerator specifies '
+ '`samplewise_std_normalization`, '
+ 'which overrides setting of '
+ '`samplewise_center`.')
def flow(self,
x,
@@ -591,7 +622,7 @@ class ImageDataGenerator(object):
if self.samplewise_center:
x -= np.mean(x, keepdims=True)
if self.samplewise_std_normalization:
- x /= np.std(x, keepdims=True) + 1e-7
+ x /= (np.std(x, keepdims=True) + K.epsilon())
if self.featurewise_center:
if self.mean is not None:
@@ -603,7 +634,7 @@ class ImageDataGenerator(object):
'first by calling `.fit(numpy_data)`.')
if self.featurewise_std_normalization:
if self.std is not None:
- x /= (self.std + 1e-7)
+ x /= (self.std + K.epsilon())
else:
logging.warning('This ImageDataGenerator specifies '
'`featurewise_std_normalization`, but it hasn\'t '
@@ -636,7 +667,6 @@ class ImageDataGenerator(object):
"""
if ndi is None:
raise ImportError('Scipy is required for image transformations.')
-
# x is a single image, so it doesn't have image number at index 0
img_row_axis = self.row_axis - 1
img_col_axis = self.col_axis - 1
@@ -648,25 +678,27 @@ class ImageDataGenerator(object):
# use composition of homographies
# to generate final transform that needs to be applied
if self.rotation_range:
- theta = np.pi / 180 * np.random.uniform(-self.rotation_range,
- self.rotation_range)
+ theta = np.deg2rad(
+ np.random.uniform(-self.rotation_range, self.rotation_range))
else:
theta = 0
if self.height_shift_range:
- tx = np.random.uniform(-self.height_shift_range,
- self.height_shift_range) * x.shape[img_row_axis]
+ tx = np.random.uniform(-self.height_shift_range, self.height_shift_range)
+ if self.height_shift_range < 1:
+ tx *= x.shape[img_row_axis]
else:
tx = 0
if self.width_shift_range:
- ty = np.random.uniform(-self.width_shift_range,
- self.width_shift_range) * x.shape[img_col_axis]
+ ty = np.random.uniform(-self.width_shift_range, self.width_shift_range)
+ if self.width_shift_range < 1:
+ ty *= x.shape[img_col_axis]
else:
ty = 0
if self.shear_range:
- shear = np.random.uniform(-self.shear_range, self.shear_range)
+ shear = np.deg2rad(np.random.uniform(-self.shear_range, self.shear_range))
else:
shear = 0
@@ -744,7 +776,7 @@ class ImageDataGenerator(object):
if x.ndim != 4:
raise ValueError('Input to `.fit()` should have rank 4. '
'Got array with shape: ' + str(x.shape))
- if x.shape[self.channel_axis] not in {3, 4}:
+ if x.shape[self.channel_axis] not in {1, 3, 4}:
logging.warning(
'Expected input to be images (as Numpy array) '
'following the data format convention "' + self.data_format + '" '
@@ -784,10 +816,12 @@ class ImageDataGenerator(object):
raise ImportError('Scipy is required for zca_whitening.')
flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]))
- sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0]
- u, s, _ = linalg.svd(sigma)
- self.principal_components = np.dot(
- np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T)
+ num_examples = flat_x.shape[0]
+ _, s, vt = linalg.svd(flat_x / np.sqrt(num_examples))
+ s_expand = np.hstack(
+ (s, np.zeros(vt.shape[0] - num_examples, dtype=flat_x.dtype)))
+ self.principal_components = (
+ vt.T / np.sqrt(s_expand**2 + self.zca_epsilon)).dot(vt)
class Iterator(Sequence):
@@ -797,10 +831,10 @@ class Iterator(Sequence):
method.
Arguments:
- n: Integer, total number of samples in the dataset to loop over.
- batch_size: Integer, size of a batch.
- shuffle: Boolean, whether to shuffle the data between epochs.
- seed: Random seeding for data shuffling.
+ n: Integer, total number of samples in the dataset to loop over.
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ seed: Random seeding for data shuffling.
"""
def __init__(self, n, batch_size, shuffle, seed):
@@ -823,15 +857,14 @@ class Iterator(Sequence):
if idx >= len(self):
raise ValueError('Asked to retrieve element {idx}, '
'but the Sequence '
- 'has length {length}'.format(idx=idx,
- length=len(self)))
+ 'has length {length}'.format(idx=idx, length=len(self)))
if self.seed is not None:
np.random.seed(self.seed + self.total_batches_seen)
self.total_batches_seen += 1
if self.index_array is None:
self._set_index_array()
- index_array = self.index_array[self.batch_size * idx:self.batch_size *
- (idx + 1)]
+ index_array = self.index_array[self.batch_size * idx:self.batch_size * (
+ idx + 1)]
return self._get_batches_of_transformed_samples(index_array)
def __len__(self):
@@ -873,6 +906,7 @@ class Iterator(Sequence):
Arguments:
index_array: array of sample indices to include in batch.
+
Returns:
A batch of transformed samples.
"""
@@ -948,8 +982,8 @@ class NumpyArrayIterator(Iterator):
seed)
def _get_batches_of_transformed_samples(self, index_array):
- batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]),
- dtype=K.floatx())
+ batch_x = np.zeros(
+ tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=K.floatx())
for i, j in enumerate(index_array):
x = self.x[j]
x = self.image_data_generator.random_transform(x.astype(K.floatx()))
@@ -959,7 +993,9 @@ class NumpyArrayIterator(Iterator):
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(
- prefix=self.save_prefix, index=j, hash=np.random.randint(1e4),
+ prefix=self.save_prefix,
+ index=j,
+ hash=np.random.randint(1e4),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
if self.y is None:
@@ -984,10 +1020,11 @@ class NumpyArrayIterator(Iterator):
def _count_valid_files_in_directory(directory, white_list_formats,
follow_links):
- """Count files with extension in `white_list_formats` in a directory.
+ """Count files with extension in `white_list_formats` contained in directory.
Arguments:
- directory: absolute path to the directory containing files to be counted
+ directory: absolute path to the directory
+ containing files to be counted
white_list_formats: set of strings containing allowed extensions for
the files to be counted.
follow_links: boolean.
@@ -1003,7 +1040,7 @@ def _count_valid_files_in_directory(directory, white_list_formats,
samples = 0
for _, _, files in _recursive_list(directory):
- for fname in sorted(files):
+ for fname in files:
is_valid = False
for extension in white_list_formats:
if fname.lower().endswith('.' + extension):
@@ -1043,7 +1080,7 @@ def _list_valid_filenames_in_directory(directory, white_list_formats,
subdir = os.path.basename(directory)
basedir = os.path.dirname(directory)
for root, _, files in _recursive_list(directory):
- for fname in files:
+ for fname in sorted(files):
is_valid = False
for extension in white_list_formats:
if fname.lower().endswith('.' + extension):
@@ -1167,8 +1204,8 @@ class DirectoryIterator(Iterator):
white_list_formats=white_list_formats,
follow_links=follow_links)
self.samples = sum(
- pool.map(function_partial, (os.path.join(directory, subdir)
- for subdir in classes)))
+ pool.map(function_partial,
+ (os.path.join(directory, subdir) for subdir in classes)))
print('Found %d images belonging to %d classes.' % (self.samples,
self.num_classes))
@@ -1181,8 +1218,9 @@ class DirectoryIterator(Iterator):
i = 0
for dirpath in (os.path.join(directory, subdir) for subdir in classes):
results.append(
- pool.apply_async(_list_valid_filenames_in_directory, (
- dirpath, white_list_formats, self.class_indices, follow_links)))
+ pool.apply_async(
+ _list_valid_filenames_in_directory,
+ (dirpath, white_list_formats, self.class_indices, follow_links)))
for res in results:
classes, filenames = res.get()
self.classes[i:i + len(classes)] = classes
@@ -1199,10 +1237,11 @@ class DirectoryIterator(Iterator):
# build batch of image data
for i, j in enumerate(index_array):
fname = self.filenames[j]
- img = load_img(os.path.join(self.directory, fname),
- grayscale=grayscale,
- target_size=self.target_size,
- interpolation=self.interpolation)
+ img = load_img(
+ os.path.join(self.directory, fname),
+ grayscale=grayscale,
+ target_size=self.target_size,
+ interpolation=self.interpolation)
x = img_to_array(img, data_format=self.data_format)
x = self.image_data_generator.random_transform(x)
x = self.image_data_generator.standardize(x)
@@ -1212,7 +1251,9 @@ class DirectoryIterator(Iterator):
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(
- prefix=self.save_prefix, index=j, hash=np.random.randint(1e7),
+ prefix=self.save_prefix,
+ index=j,
+ hash=np.random.randint(1e7),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
# build batch of labels
@@ -1241,4 +1282,3 @@ class DirectoryIterator(Iterator):
# The transformation of images is not under thread lock
# so it can be done in parallel
return self._get_batches_of_transformed_samples(index_array)
-
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
index 642f4f2fac..4d59250af0 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Preprocessing utilities for sequence data.
+"""Utilities for preprocessing sequence data.
"""
from __future__ import absolute_import
from __future__ import division
@@ -129,7 +129,7 @@ def make_sampling_table(size, sampling_factor=1e-5):
is the probability that a word of rank i should be sampled.
"""
gamma = 0.577
- rank = np.array(list(range(size)))
+ rank = np.arange(size)
rank[0] = 1
inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1. / (12. * rank)
f = sampling_factor * inv_fq
@@ -170,7 +170,7 @@ def skipgrams(sequence,
if True labels will be categorical eg. [[1,0],[0,1],[0,1] .. ]
sampling_table: 1D array of size `vocabulary_size` where the entry i
encodes the probability to sample a word of rank i.
- seed: Random seed.
+ seed: random seed.
Returns:
couples, labels: where `couples` are int pairs and
@@ -224,3 +224,22 @@ def skipgrams(sequence,
random.shuffle(labels)
return couples, labels
+
+
+def _remove_long_seq(maxlen, seq, label):
+ """Removes sequences that exceed the maximum length.
+
+ Arguments:
+ maxlen: int, maximum length
+ seq: list of lists where each sublist is a sequence
+ label: list where each element is an integer
+
+ Returns:
+ new_seq, new_label: shortened lists for `seq` and `label`.
+ """
+ new_seq, new_label = [], []
+ for x, y in zip(seq, label):
+ if len(x) < maxlen:
+ new_seq.append(x)
+ new_label.append(y)
+ return new_seq, new_label
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text.py b/tensorflow/python/keras/_impl/keras/preprocessing/text.py
index 47e5aa064f..8f7f25dc0a 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/text.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/text.py
@@ -13,8 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Utilities for text input preprocessing.
-
-May benefit from a fast Cython rewrite.
"""
from __future__ import absolute_import
from __future__ import division
@@ -29,6 +27,9 @@ import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python.platform import tf_logging as logging
+
+
if sys.version_info < (3,):
maketrans = string.maketrans
else:
@@ -68,6 +69,21 @@ def one_hot(text,
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' '):
+ """One-hot encodes a text into a list of word indexes of size n.
+
+ This is a wrapper to the `hashing_trick` function using `hash` as the
+ hashing function; unicity of word to index mapping non-guaranteed.
+
+ Arguments:
+ text: Input text (string).
+ n: Dimension of the hashing space.
+ filters: Sequence of characters to filter out.
+ lower: Whether to convert the input to lowercase.
+ split: Sentence split marker (string).
+
+ Returns:
+ A list of integer word indices (unicity non-guaranteed).
+ """
return hashing_trick(
text, n, hash_function=hash, filters=filters, lower=lower, split=split)
@@ -99,6 +115,10 @@ def hashing_trick(text,
Two or more words may be assigned to the same index, due to possible
collisions by the hashing function.
+ The
+ probability
+ of a collision is in relation to the dimension of the hashing space and
+ the number of distinct objects.
"""
if hash_function is None:
hash_function = hash
@@ -127,6 +147,8 @@ class Tokenizer(object):
lower: boolean. Whether to convert the texts to lowercase.
split: character or string to use for token splitting.
char_level: if True, every character will be treated as a token.
+ oov_token: if given, it will be added to word_index and used to
+ replace out-of-vocabulary words during text_to_sequence calls
By default, all punctuation is removed, turning the texts into
space-separated sequences of words
@@ -141,7 +163,17 @@ class Tokenizer(object):
filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
lower=True,
split=' ',
- char_level=False):
+ char_level=False,
+ oov_token=None,
+ **kwargs):
+ # Legacy support
+ if 'nb_words' in kwargs:
+ logging.warning('The `nb_words` argument in `Tokenizer` '
+ 'has been renamed `num_words`.')
+ num_words = kwargs.pop('nb_words')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
self.word_counts = OrderedDict()
self.word_docs = {}
self.filters = filters
@@ -150,6 +182,7 @@ class Tokenizer(object):
self.num_words = num_words
self.document_count = 0
self.char_level = char_level
+ self.oov_token = oov_token
def fit_on_texts(self, texts):
"""Updates internal vocabulary based on a list of texts.
@@ -181,7 +214,13 @@ class Tokenizer(object):
sorted_voc = [wc[0] for wc in wcounts]
# note that index 0 is reserved, never assigned to an existing word
self.word_index = dict(
- list(zip(sorted_voc, list(range(1, len(sorted_voc) + 1)))))
+ list(zip(sorted_voc, list(range(1,
+ len(sorted_voc) + 1)))))
+
+ if self.oov_token is not None:
+ i = self.word_index.get(self.oov_token)
+ if i is None:
+ self.word_index[self.oov_token] = len(self.word_index) + 1
self.index_docs = {}
for w, c in list(self.word_docs.items()):
@@ -248,6 +287,10 @@ class Tokenizer(object):
continue
else:
vect.append(i)
+ elif self.oov_token is not None:
+ i = self.word_index.get(self.oov_token)
+ if i is not None:
+ vect.append(i)
yield vect
def texts_to_matrix(self, texts, mode='binary'):
diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
index 17ab48ba3f..a934e331c4 100644
--- a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
@@ -76,6 +76,22 @@ class TestText(test.TestCase):
self.assertLessEqual(np.max(encoded), 4)
self.assertGreaterEqual(np.min(encoded), 1)
+ def test_tokenizer_oov_flag(self):
+ x_train = ['This text has only known words']
+ x_test = ['This text has some unknown words'] # 2 OOVs: some, unknown
+
+ # Defalut, without OOV flag
+ tokenizer = keras.preprocessing.text.Tokenizer()
+ tokenizer.fit_on_texts(x_train)
+ x_test_seq = tokenizer.texts_to_sequences(x_test)
+ assert len(x_test_seq[0]) == 4 # discards 2 OOVs
+
+ # With OOV feature
+ tokenizer = keras.preprocessing.text.Tokenizer(oov_token='<unk>')
+ tokenizer.fit_on_texts(x_train)
+ x_test_seq = tokenizer.texts_to_sequences(x_test)
+ assert len(x_test_seq[0]) == 6 # OOVs marked in place
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/regularizers.py b/tensorflow/python/keras/_impl/keras/regularizers.py
index 161ff9bf5b..c53ee8a1ae 100644
--- a/tensorflow/python/keras/_impl/keras/regularizers.py
+++ b/tensorflow/python/keras/_impl/keras/regularizers.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Keras built-in regularizers.
+"""Built-in regularizers.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/keras/_impl/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py
index d9e8f37e36..fcee9fbcc3 100644
--- a/tensorflow/python/keras/_impl/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=g-import-not-at-top
"""Utilities for file download and caching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import abstractmethod
+from contextlib import closing
import hashlib
import multiprocessing
from multiprocessing.pool import ThreadPool
@@ -38,12 +40,12 @@ from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlopen
from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
-from tensorflow.python.util.tf_export import tf_export
+
try:
- import queue # pylint:disable=g-import-not-at-top
+ import queue
except ImportError:
- import Queue as queue # pylint:disable=g-import-not-at-top
+ import Queue as queue
if sys.version_info[0] == 2:
@@ -87,7 +89,7 @@ if sys.version_info[0] == 2:
for chunk in chunk_read(response, reporthook=reporthook):
fd.write(chunk)
else:
- from six.moves.urllib.request import urlretrieve # pylint: disable=g-import-not-at-top
+ from six.moves.urllib.request import urlretrieve
def _extract_archive(file_path, path='.', archive_format='auto'):
@@ -136,7 +138,6 @@ def _extract_archive(file_path, path='.', archive_format='auto'):
return False
-@tf_export('keras.utils.get_file')
def get_file(fname,
origin,
untar=False,
@@ -188,7 +189,7 @@ def get_file(fname,
Path to the downloaded file
"""
if cache_dir is None:
- cache_dir = os.path.expanduser(os.path.join('~', '.keras'))
+ cache_dir = os.path.join(os.path.expanduser('~'), '.keras')
if md5_hash is not None and file_hash is None:
file_hash = md5_hash
hash_algorithm = 'md5'
@@ -317,37 +318,46 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
return False
-@tf_export('keras.utils.Sequence')
class Sequence(object):
"""Base object for fitting to a sequence of data, such as a dataset.
Every `Sequence` must implements the `__getitem__` and the `__len__` methods.
If you want to modify your dataset between epochs you may implement
- `on_epoch_end`. The method `__getitem__` should return a complete batch.
+ `on_epoch_end`.
+ The method `__getitem__` should return a complete batch.
+
+ # Notes
- Notes:
`Sequence` are a safer way to do multiprocessing. This structure guarantees
- that the network will only train once on each sample per epoch which is not
- the case with generators.
+ that the network will only train once
+ on each sample per epoch which is not the case with generators.
+
Examples:
+
```python
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
+
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
+
class CIFAR10Sequence(Sequence):
+
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
+
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
+
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
- self.batch_size]
+ self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
- self.batch_size]
+ self.batch_size]
+
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
@@ -375,7 +385,6 @@ class Sequence(object):
"""
raise NotImplementedError
- @abstractmethod
def on_epoch_end(self):
"""Method called at the end of every epoch.
"""
@@ -405,7 +414,6 @@ def get_index(uid, i):
return _SHARED_SEQUENCES[uid][i]
-@tf_export('keras.utils.SequenceEnqueuer')
class SequenceEnqueuer(object):
"""Base class to enqueue inputs.
@@ -474,35 +482,36 @@ class OrderedEnqueuer(SequenceEnqueuer):
Arguments:
sequence: A `keras.utils.data_utils.Sequence` object.
- use_multiprocessing: Use multiprocessing if True, otherwise threading
- shuffle: Whether to shuffle the data at the beginning of each epoch
+ use_multiprocessing: use multiprocessing if True, otherwise threading
+ shuffle: whether to shuffle the data at the beginning of each epoch
"""
def __init__(self, sequence, use_multiprocessing=False, shuffle=False):
self.sequence = sequence
self.use_multiprocessing = use_multiprocessing
- # Doing Multiprocessing.Value += x is not process-safe.
global _SEQUENCE_COUNTER
if _SEQUENCE_COUNTER is None:
- if self.use_multiprocessing:
+ try:
_SEQUENCE_COUNTER = multiprocessing.Value('i', 0)
- else:
+ except OSError:
+ # In this case the OS does not allow us to use
+ # multiprocessing. We resort to an int
+ # for enqueuer indexing.
_SEQUENCE_COUNTER = 0
- if self.use_multiprocessing:
+ if isinstance(_SEQUENCE_COUNTER, int):
+ self.uid = _SEQUENCE_COUNTER
+ _SEQUENCE_COUNTER += 1
+ else:
+ # Doing Multiprocessing.Value += x is not process-safe.
with _SEQUENCE_COUNTER.get_lock():
self.uid = _SEQUENCE_COUNTER.value
_SEQUENCE_COUNTER.value += 1
- else:
- self.uid = _SEQUENCE_COUNTER
- if isinstance(_SEQUENCE_COUNTER, int):
- _SEQUENCE_COUNTER += 1
- else:
- _SEQUENCE_COUNTER.value += 1
+
self.shuffle = shuffle
self.workers = 0
- self.executor = None
+ self.executor_fn = None
self.queue = None
self.run_thread = None
self.stop_signal = None
@@ -519,9 +528,9 @@ class OrderedEnqueuer(SequenceEnqueuer):
(when full, workers could block on `put()`)
"""
if self.use_multiprocessing:
- self.executor = multiprocessing.Pool(workers)
+ self.executor_fn = lambda: multiprocessing.Pool(workers)
else:
- self.executor = ThreadPool(workers)
+ self.executor_fn = lambda: ThreadPool(workers)
self.workers = workers
self.queue = queue.Queue(max_queue_size)
self.stop_signal = threading.Event()
@@ -537,24 +546,26 @@ class OrderedEnqueuer(SequenceEnqueuer):
return
def _run(self):
- """Function to submit request to the executor & queue `Future` objects."""
+ """Submits request to the executor and queue the `Future` objects."""
sequence = list(range(len(self.sequence)))
self._send_sequence() # Share the initial sequence
while True:
if self.shuffle:
random.shuffle(sequence)
- for i in sequence:
- if self.stop_signal.is_set():
- return
- self.queue.put(
- self.executor.apply_async(get_index, (self.uid, i)), block=True)
- # Done with the current epoch, waiting for the final batches
- self._wait_queue()
+ with closing(self.executor_fn()) as executor:
+ for i in sequence:
+ if self.stop_signal.is_set():
+ return
+ self.queue.put(
+ executor.apply_async(get_index, (self.uid, i)), block=True)
- if self.stop_signal.is_set():
- # We're done
- return
+ # Done with the current epoch, waiting for the final batches
+ self._wait_queue()
+
+ if self.stop_signal.is_set():
+ # We're done
+ return
# Call the internal on epoch end.
self.sequence.on_epoch_end()
@@ -566,8 +577,9 @@ class OrderedEnqueuer(SequenceEnqueuer):
Skip the data if it is `None`.
Yields:
- Tuples (inputs, targets)
- or (inputs, targets, sample_weights)
+ The next element in the queue, i.e. a tuple
+ `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
"""
try:
while self.is_running():
@@ -581,14 +593,8 @@ class OrderedEnqueuer(SequenceEnqueuer):
def _send_sequence(self):
"""Send current Sequence to all workers."""
- _SHARED_SEQUENCES[
- self.uid] = self.sequence # For new processes that may spawn
-
- self._close_pool()
- if self.use_multiprocessing:
- self.executor = multiprocessing.Pool(self.workers)
- else:
- self.executor = ThreadPool(self.workers)
+ # For new processes that may spawn
+ _SHARED_SEQUENCES[self.uid] = self.sequence
def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.
@@ -603,16 +609,10 @@ class OrderedEnqueuer(SequenceEnqueuer):
self.queue.queue.clear()
self.queue.unfinished_tasks = 0
self.queue.not_full.notify()
- self._close_pool()
self.run_thread.join(timeout)
_SHARED_SEQUENCES[self.uid] = None
- def _close_pool(self):
- self.executor.close()
- self.executor.join()
-
-@tf_export('keras.utils.GeneratorEnqueuer')
class GeneratorEnqueuer(SequenceEnqueuer):
"""Builds a queue out of a data generator.
@@ -636,26 +636,53 @@ class GeneratorEnqueuer(SequenceEnqueuer):
seed=None):
self.wait_time = wait_time
self._generator = generator
- self._use_multiprocessing = use_multiprocessing
+ if os.name is 'nt' and use_multiprocessing is True:
+ # On Windows, avoid **SYSTEMATIC** error in `multiprocessing`:
+ # `TypeError: can't pickle generator objects`
+ # => Suggest multithreading instead of multiprocessing on Windows
+ raise ValueError('Using a generator with `use_multiprocessing=True`'
+ ' is not supported on Windows (no marshalling of'
+ ' generators across process boundaries). Instead,'
+ ' use single thread/process or multithreading.')
+ else:
+ self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self._manager = None
self.queue = None
self.seed = seed
- def start(self, workers=1, max_queue_size=10):
- """Kicks off threads which add data from the generator into the queue.
-
- Arguments:
- workers: number of worker threads
- max_queue_size: queue size
- (when full, threads could block on `put()`)
- """
-
- def data_generator_task():
+ def _data_generator_task(self):
+ if self._use_multiprocessing is False:
+ while not self._stop_event.is_set():
+ with self.genlock:
+ try:
+ if (self.queue is not None and
+ self.queue.qsize() < self.max_queue_size):
+ # On all OSes, avoid **SYSTEMATIC** error
+ # in multithreading mode:
+ # `ValueError: generator already executing`
+ # => Serialize calls to
+ # infinite iterator/generator's next() function
+ generator_output = next(self._generator)
+ self.queue.put((True, generator_output))
+ else:
+ time.sleep(self.wait_time)
+ except StopIteration:
+ break
+ except Exception as e: # pylint: disable=broad-except
+ # Can't pickle tracebacks.
+ # As a compromise, print the traceback and pickle None instead.
+ if not hasattr(e, '__traceback__'):
+ setattr(e, '__traceback__', sys.exc_info()[2])
+ self.queue.put((False, e))
+ self._stop_event.set()
+ break
+ else:
while not self._stop_event.is_set():
try:
- if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
+ if (self.queue is not None and
+ self.queue.qsize() < self.max_queue_size):
generator_output = next(self._generator)
self.queue.put((True, generator_output))
else:
@@ -663,24 +690,34 @@ class GeneratorEnqueuer(SequenceEnqueuer):
except StopIteration:
break
except Exception as e: # pylint: disable=broad-except
- # Can't pick tracebacks.
+ # Can't pickle tracebacks.
# As a compromise, print the traceback and pickle None instead.
- if self._use_multiprocessing:
- traceback.print_exc()
- setattr(e, '__traceback__', None)
- elif not hasattr(e, '__traceback__'):
- setattr(e, '__traceback__', sys.exc_info()[2])
+ traceback.print_exc()
+ setattr(e, '__traceback__', None)
self.queue.put((False, e))
self._stop_event.set()
break
+ def start(self, workers=1, max_queue_size=10):
+ """Kicks off threads which add data from the generator into the queue.
+
+ Arguments:
+ workers: number of worker threads
+ max_queue_size: queue size
+ (when full, threads could block on `put()`)
+ """
try:
+ self.max_queue_size = max_queue_size
if self._use_multiprocessing:
self._manager = multiprocessing.Manager()
self.queue = self._manager.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
- self.queue = queue.Queue()
+ # On all OSes, avoid **SYSTEMATIC** error in multithreading mode:
+ # `ValueError: generator already executing`
+ # => Serialize calls to infinite iterator/generator's next() function
+ self.genlock = threading.Lock()
+ self.queue = queue.Queue(maxsize=max_queue_size)
self._stop_event = threading.Event()
for _ in range(workers):
@@ -688,12 +725,12 @@ class GeneratorEnqueuer(SequenceEnqueuer):
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.seed)
- thread = multiprocessing.Process(target=data_generator_task)
+ thread = multiprocessing.Process(target=self._data_generator_task)
thread.daemon = True
if self.seed is not None:
self.seed += 1
else:
- thread = threading.Thread(target=data_generator_task)
+ thread = threading.Thread(target=self._data_generator_task)
self._threads.append(thread)
thread.start()
except:
@@ -715,11 +752,15 @@ class GeneratorEnqueuer(SequenceEnqueuer):
self._stop_event.set()
for thread in self._threads:
- if thread.is_alive():
- if self._use_multiprocessing:
+ if self._use_multiprocessing:
+ if thread.is_alive():
thread.terminate()
- else:
- thread.join(timeout)
+ else:
+ # The thread.is_alive() test is subject to a race condition:
+ # the thread could terminate right after the test and before the
+ # join, rendering this test meaningless -> Call thread.join()
+ # always, which is ok no matter what the status of the thread.
+ thread.join(timeout)
if self._manager:
self._manager.shutdown()
@@ -734,7 +775,9 @@ class GeneratorEnqueuer(SequenceEnqueuer):
Skip the data if it is `None`.
Yields:
- Data arrays.
+ The next element in the queue, i.e. a tuple
+ `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
"""
while self.is_running():
if not self.queue.empty():
@@ -752,7 +795,7 @@ class GeneratorEnqueuer(SequenceEnqueuer):
else:
time.sleep(self.wait_time)
- # Make sure to rethrow the first exception in the queue, if any
+ # Make sure to rethrow the first exception in the queue, if any
while not self.queue.empty():
success, value = self.queue.get()
if not success:
diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
index a805315c94..adbe6c3288 100644
--- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import binascii
import codecs
import marshal
import os
@@ -255,7 +256,10 @@ def func_load(code, defaults=None, closure=None, globs=None):
if closure is not None:
closure = tuple(ensure_value_to_cell(_) for _ in closure)
- raw_code = codecs.decode(code.encode('ascii'), 'base64')
+ try:
+ raw_code = codecs.decode(code.encode('ascii'), 'base64')
+ except (UnicodeEncodeError, binascii.Error):
+ raw_code = code.encode('raw_unicode_escape')
code = marshal.loads(raw_code)
if globs is None:
globs = globals()
diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py
index e123339f5a..b36c769843 100644
--- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=g-import-not-at-top
"""Utilities related to disk I/O."""
from __future__ import absolute_import
from __future__ import division
@@ -21,16 +22,14 @@ from collections import defaultdict
import sys
import numpy as np
-from tensorflow.python.util.tf_export import tf_export
try:
- import h5py # pylint:disable=g-import-not-at-top
+ import h5py
except ImportError:
h5py = None
-@tf_export('keras.utils.HDF5Matrix')
class HDF5Matrix(object):
"""Representation of HDF5 dataset to be used instead of a Numpy array.
@@ -65,11 +64,11 @@ class HDF5Matrix(object):
'HDF5 and h5py installed.')
if datapath not in list(self.refs.keys()):
- self._f = h5py.File(datapath)
- self.refs[datapath] = self._f
+ f = h5py.File(datapath)
+ self.refs[datapath] = f
else:
- self._f = self.refs[datapath]
- self.data = self._f[dataset]
+ f = self.refs[datapath]
+ self.data = f[dataset]
self.start = start
if end is None:
self.end = self.data.shape[0]
@@ -80,9 +79,6 @@ class HDF5Matrix(object):
def __len__(self):
return self.end - self.start
- def __del__(self):
- self._f.close()
-
def __getitem__(self, key):
if isinstance(key, slice):
start, stop = key.start, key.stop
diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
index 30af285cbf..a2d32424b5 100644
--- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Utilities related to Keras layers.
+# pylint: disable=protected-access
+"""Utilities related to layer/model functionality.
"""
from __future__ import absolute_import
from __future__ import division
@@ -22,17 +23,16 @@ import numpy as np
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.utils.conv_utils import convert_kernel
-from tensorflow.python.util.tf_export import tf_export
def count_params(weights):
"""Count the total number of scalars composing the weights.
Arguments:
- weights: An iterable containing the weights on which to compute params
+ weights: An iterable containing the weights on which to compute params
Returns:
- The total number of scalars composing the weights
+ The total number of scalars composing the weights
"""
return int(np.sum([K.count_params(p) for p in set(weights)]))
@@ -47,10 +47,11 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
terminal window sizes).
positions: Relative or absolute positions of log elements in each line.
If not provided, defaults to `[.33, .55, .67, 1.]`.
- print_fn: Print function to use (defaults to `print`).
+ print_fn: Print function to use.
It will be called on each line of the summary.
You can set it to a custom function
in order to capture the string summary.
+ It defaults to `print` (prints to stdout).
"""
if print_fn is None:
print_fn = print
@@ -59,12 +60,13 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
sequential_like = True
else:
sequential_like = True
- nodes_by_depth = model._nodes_by_depth.values() # pylint: disable=protected-access
+ nodes_by_depth = model._nodes_by_depth.values()
nodes = []
for v in nodes_by_depth:
if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1):
- # If the model has multiple nodes or if the nodes have
- # multiple inbound_layers, the model is no longer sequential.
+ # if the model has multiple nodes
+ # or if the nodes have multiple inbound_layers
+ # the model is no longer sequential
sequential_like = False
break
nodes += v
@@ -72,7 +74,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
# search for shared layers
for layer in model.layers:
flag = False
- for node in layer.inbound_nodes:
+ for node in layer._inbound_nodes:
if node in nodes:
if flag:
sequential_like = False
@@ -97,7 +99,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
# header names for the different log elements
to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
relevant_nodes = []
- for v in model._nodes_by_depth.values(): # pylint: disable=protected-access
+ for v in model._nodes_by_depth.values():
relevant_nodes += v
def print_row(fields, positions):
@@ -135,7 +137,7 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
except AttributeError:
output_shape = 'multiple'
connections = []
- for node in layer._inbound_nodes: # pylint: disable=protected-access
+ for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
@@ -143,8 +145,8 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
inbound_layer = node.inbound_layers[i].name
inbound_node_index = node.node_indices[i]
inbound_tensor_index = node.tensor_indices[i]
- connections.append(inbound_layer + '[' + str(inbound_node_index) + ']['
- + str(inbound_tensor_index) + ']')
+ connections.append(inbound_layer + '[' + str(inbound_node_index) +
+ '][' + str(inbound_tensor_index) + ']')
name = layer.name
cls_name = layer.__class__.__name__
@@ -173,9 +175,9 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
else:
print_fn('_' * line_length)
- model._check_trainable_weights_consistency() # pylint: disable=protected-access
+ model._check_trainable_weights_consistency()
if hasattr(model, '_collected_trainable_weights'):
- trainable_count = count_params(model._collected_trainable_weights) # pylint: disable=protected-access
+ trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model.trainable_weights)
@@ -188,7 +190,6 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
print_fn('_' * line_length)
-@tf_export('keras.utils.convert_all_kernels_in_model')
def convert_all_kernels_in_model(model):
"""Converts all convolution kernels in a model from Theano to TensorFlow.
diff --git a/tensorflow/python/keras/_impl/keras/utils/np_utils.py b/tensorflow/python/keras/_impl/keras/utils/np_utils.py
index 3dddb99191..231833e776 100644
--- a/tensorflow/python/keras/_impl/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/np_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,10 +18,8 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-from tensorflow.python.util.tf_export import tf_export
-@tf_export('keras.utils.to_categorical')
def to_categorical(y, num_classes=None):
"""Converts a class vector (integers) to binary class matrix.
@@ -50,7 +48,6 @@ def to_categorical(y, num_classes=None):
return categorical
-@tf_export('keras.utils.normalize')
def normalize(x, axis=-1, order=2):
"""Normalizes a Numpy array.
diff --git a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
index 1ec8e3a2bf..0c5f2c19c7 100644
--- a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,31 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=protected-access
+# pylint: disable=g-import-not-at-top
"""Utilities related to model visualization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
-import sys
-from tensorflow.python.util.tf_export import tf_export
+
try:
# pydot-ng is a fork of pydot that is better maintained.
- import pydot_ng as pydot # pylint: disable=g-import-not-at-top
+ import pydot_ng as pydot
except ImportError:
- # Fall back on pydot if necessary.
- # Silence a `print` statement that occurs in case of import error,
- # by temporarily replacing sys.stdout.
- _stdout = sys.stdout
- sys.stdout = sys.stderr
+ # pydotplus is an improved version of pydot
try:
- import pydot # pylint: disable=g-import-not-at-top
+ import pydotplus as pydot
except ImportError:
- pydot = None
- finally:
- # Restore sys.stdout.
- sys.stdout = _stdout
+ # Fall back on pydot if necessary.
+ try:
+ import pydot
+ except ImportError:
+ pydot = None
def _check_pydot():
@@ -66,8 +64,8 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
Returns:
A `pydot.Dot` instance representing the Keras model.
"""
- from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper # pylint: disable=g-import-not-at-top
- from tensorflow.python.keras._impl.keras.models import Sequential # pylint: disable=g-import-not-at-top
+ from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper
+ from tensorflow.python.keras._impl.keras.models import Sequential
_check_pydot()
dot = pydot.Dot()
@@ -119,9 +117,9 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
# Connect nodes with edges.
for layer in layers:
layer_id = str(id(layer))
- for i, node in enumerate(layer._inbound_nodes): # pylint: disable=protected-access
+ for i, node in enumerate(layer._inbound_nodes):
node_key = layer.name + '_ib-' + str(i)
- if node_key in model._network_nodes: # pylint: disable=protected-access
+ if node_key in model._container_nodes:
for inbound_layer in node.inbound_layers:
inbound_layer_id = str(id(inbound_layer))
layer_id = str(id(layer))
@@ -129,7 +127,6 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
return dot
-@tf_export('keras.utils.plot_model')
def plot_model(model,
to_file='model.png',
show_shapes=False,
diff --git a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
index bc788d874f..223ceac3de 100644
--- a/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
+++ b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""API wrapper allowing to use certain Keras models with the Scikit-Learn API.
+"""Wrapper for using the Scikit-Learn API with Keras models.
"""
from __future__ import absolute_import
from __future__ import division
@@ -24,8 +24,8 @@ import types
import numpy as np
from tensorflow.python.keras._impl.keras.models import Sequential
+from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical
-from tensorflow.python.util import tf_inspect
class BaseWrapper(object):
@@ -75,7 +75,7 @@ class BaseWrapper(object):
self.check_params(sk_params)
def check_params(self, params):
- """Checks for user typos in "params".
+ """Checks for user typos in `params`.
Arguments:
params: dictionary; the parameters to be checked
@@ -95,13 +95,11 @@ class BaseWrapper(object):
else:
legal_params_fns.append(self.build_fn)
- legal_params = []
- for fn in legal_params_fns:
- legal_params += tf_inspect.getargspec(fn)[0]
- legal_params = set(legal_params)
-
for params_name in params:
- if params_name not in legal_params:
+ for fn in legal_params_fns:
+ if has_arg(fn, params_name):
+ break
+ else:
if params_name != 'nb_epoch':
raise ValueError('{} is not a legal parameter'.format(params_name))
@@ -136,10 +134,10 @@ class BaseWrapper(object):
Arguments:
x : array-like, shape `(n_samples, n_features)`
- Training samples where n_samples in the number of samples
- and n_features is the number of features.
+ Training samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
- True labels for X.
+ True labels for `x`.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.fit`
@@ -170,21 +168,20 @@ class BaseWrapper(object):
return history
def filter_sk_params(self, fn, override=None):
- """Filters `sk_params` and return those in `fn`'s arguments.
+ """Filters `sk_params` and returns those in `fn`'s arguments.
Arguments:
fn : arbitrary function
- override: dictionary, values to override sk_params
+ override: dictionary, values to override `sk_params`
Returns:
- res : dictionary dictionary containing variables
- in both sk_params and fn's arguments.
+ res : dictionary containing variables
+ in both `sk_params` and `fn`'s arguments.
"""
override = override or {}
res = {}
- fn_args = tf_inspect.getargspec(fn)[0]
for name, value in self.sk_params.items():
- if name in fn_args:
+ if has_arg(fn, name):
res.update({name: value})
res.update(override)
return res
@@ -199,10 +196,10 @@ class KerasClassifier(BaseWrapper):
Arguments:
x : array-like, shape `(n_samples, n_features)`
- Training samples where n_samples in the number of samples
- and n_features is the number of features.
+ Training samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
- True labels for X.
+ True labels for `x`.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.fit`
@@ -229,8 +226,8 @@ class KerasClassifier(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments
of `Sequential.predict_classes`.
@@ -248,8 +245,8 @@ class KerasClassifier(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments
of `Sequential.predict_classes`.
@@ -258,8 +255,8 @@ class KerasClassifier(BaseWrapper):
proba: array-like, shape `(n_samples, n_outputs)`
Class probability estimates.
In the case of binary classification,
- tp match the scikit-learn API,
- will return an array of shape '(n_samples, 2)'
+ to match the scikit-learn API,
+ will return an array of shape `(n_samples, 2)`
(instead of `(n_sample, 1)` as in Keras).
"""
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
@@ -276,16 +273,16 @@ class KerasClassifier(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
- True labels for x.
+ True labels for `x`.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.evaluate`.
Returns:
score: float
- Mean accuracy of predictions on X wrt. y.
+ Mean accuracy of predictions on `x` wrt. `y`.
Raises:
ValueError: If the underlying model isn't configured to
@@ -321,8 +318,8 @@ class KerasRegressor(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.predict`.
@@ -338,16 +335,16 @@ class KerasRegressor(BaseWrapper):
Arguments:
x: array-like, shape `(n_samples, n_features)`
- Test samples where n_samples in the number of samples
- and n_features is the number of features.
+ Test samples where `n_samples` is the number of samples
+ and `n_features` is the number of features.
y: array-like, shape `(n_samples,)`
- True labels for X.
+ True labels for `x`.
**kwargs: dictionary arguments
Legal arguments are the arguments of `Sequential.evaluate`.
Returns:
score: float
- Mean accuracy of predictions on X wrt. y.
+ Mean accuracy of predictions on `x` wrt. `y`.
"""
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
loss = self.model.evaluate(x, y, **kwargs)
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index 34f1435ffb..fccedf919a 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -18,16 +18,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.keras.applications import densenet
from tensorflow.python.keras.applications import inception_resnet_v2
from tensorflow.python.keras.applications import inception_v3
from tensorflow.python.keras.applications import mobilenet
+from tensorflow.python.keras.applications import nasnet
from tensorflow.python.keras.applications import resnet50
from tensorflow.python.keras.applications import vgg16
from tensorflow.python.keras.applications import vgg19
from tensorflow.python.keras.applications import xception
+from tensorflow.python.keras.applications.densenet import DenseNet121
+from tensorflow.python.keras.applications.densenet import DenseNet169
+from tensorflow.python.keras.applications.densenet import DenseNet201
from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.python.keras.applications.inception_v3 import InceptionV3
from tensorflow.python.keras.applications.mobilenet import MobileNet
+from tensorflow.python.keras.applications.nasnet import NASNetLarge
+from tensorflow.python.keras.applications.nasnet import NASNetMobile
from tensorflow.python.keras.applications.resnet50 import ResNet50
from tensorflow.python.keras.applications.vgg16 import VGG16
from tensorflow.python.keras.applications.vgg19 import VGG19
diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/python/keras/applications/densenet/__init__.py
new file mode 100644
index 0000000000..6b8ea83920
--- /dev/null
+++ b/tensorflow/python/keras/applications/densenet/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""DenseNet Keras applications."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169
+from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201
+from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/nasnet/__init__.py b/tensorflow/python/keras/applications/nasnet/__init__.py
new file mode 100644
index 0000000000..94eb145b85
--- /dev/null
+++ b/tensorflow/python/keras/applications/nasnet/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""NASNet Keras applications."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.nasnet import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetLarge
+from tensorflow.python.keras._impl.keras.applications.nasnet import NASNetMobile
+from tensorflow.python.keras._impl.keras.applications.nasnet import preprocess_input
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index b94bf8f0f6..84ee5040dc 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -30,6 +30,7 @@ from tensorflow.python.keras._impl.keras.layers.advanced_activations import Leak
from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU
from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU
from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU
+from tensorflow.python.keras._impl.keras.layers.advanced_activations import Softmax
# Convolution layers.
from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D
@@ -37,6 +38,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D
from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D
from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose
from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose
+from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv1D
from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D
# Convolution layer aliases.
@@ -45,6 +47,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution
from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D
from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose
from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose
+from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution1D
from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D
# Image processing layers.
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index ec6184aacd..a96b88d96f 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -82,7 +82,9 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
matrix_ph = array_ops.placeholder(dtypes.int32)
transposed = array_ops.matrix_transpose(matrix_ph)
self.assertAllEqual(
- expected_transposed, transposed.eval(feed_dict={matrix_ph: matrix}))
+ expected_transposed, transposed.eval(feed_dict={
+ matrix_ph: matrix
+ }))
def testBatchMatrixDynamicallyDefined(self):
matrix_0 = [[1, 2, 3], [4, 5, 6]]
@@ -96,7 +98,9 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
transposed = array_ops.matrix_transpose(batch_matrix_ph)
self.assertAllEqual(
expected_transposed,
- transposed.eval(feed_dict={batch_matrix_ph: batch_matrix}))
+ transposed.eval(feed_dict={
+ batch_matrix_ph: batch_matrix
+ }))
def testTensorWithStaticRankLessThanTwoRaisesBecauseNotAMatrix(self):
vector = [1, 2, 3]
@@ -203,8 +207,10 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
masked_tensor = sess.run(
array_ops.boolean_mask(ph_tensor, ph_mask),
- feed_dict={ph_tensor: arr,
- ph_mask: mask})
+ feed_dict={
+ ph_tensor: arr,
+ ph_mask: mask
+ })
np.testing.assert_allclose(masked_tensor, arr[mask])
def testMaskDimensionsSetToNoneRaises(self):
@@ -280,7 +286,8 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
for axis_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(use_gpu=use_gpu):
x_tf = array_ops.reverse_v2(x_np,
- constant_op.constant([0], dtype=axis_dtype)).eval()
+ constant_op.constant(
+ [0], dtype=axis_dtype)).eval()
self.assertAllEqual(x_tf, np.asarray(x_np)[::-1])
def _reverse2DimAuto(self, np_dtype):
@@ -290,16 +297,17 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
for use_gpu in [False, True]:
for axis_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(use_gpu=use_gpu):
- x_tf_1 = reverse_f(x_np,
- constant_op.constant([0], dtype=axis_dtype)).eval()
- x_tf_2 = reverse_f(x_np,
- constant_op.constant([-2], dtype=axis_dtype)).eval()
- x_tf_3 = reverse_f(x_np,
- constant_op.constant([1], dtype=axis_dtype)).eval()
- x_tf_4 = reverse_f(x_np,
- constant_op.constant([-1], dtype=axis_dtype)).eval()
+ x_tf_1 = reverse_f(x_np, constant_op.constant(
+ [0], dtype=axis_dtype)).eval()
+ x_tf_2 = reverse_f(x_np, constant_op.constant(
+ [-2], dtype=axis_dtype)).eval()
+ x_tf_3 = reverse_f(x_np, constant_op.constant(
+ [1], dtype=axis_dtype)).eval()
+ x_tf_4 = reverse_f(x_np, constant_op.constant(
+ [-1], dtype=axis_dtype)).eval()
x_tf_5 = reverse_f(x_np,
- constant_op.constant([1, 0], dtype=axis_dtype)).eval()
+ constant_op.constant([1, 0],
+ dtype=axis_dtype)).eval()
self.assertAllEqual(x_tf_1, np.asarray(x_np)[::-1, :])
self.assertAllEqual(x_tf_2, np.asarray(x_np)[::-1, :])
self.assertAllEqual(x_tf_3, np.asarray(x_np)[:, ::-1])
@@ -324,18 +332,16 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
def testReverse1DimAuto(self):
for dtype in [
- np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64,
- np.bool, np.float16, np.float32,
- np.float64, np.complex64, np.complex128,
+ np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool,
+ np.float16, np.float32, np.float64, np.complex64, np.complex128,
np.array(b"").dtype.type
]:
self._reverse1DimAuto(dtype)
def testReverse2DimAuto(self):
for dtype in [
- np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64,
- np.bool, np.float16, np.float32,
- np.float64, np.complex64, np.complex128,
+ np.uint8, np.int8, np.uint16, np.int16, np.int32, np.int64, np.bool,
+ np.float16, np.float32, np.float64, np.complex64, np.complex128,
np.array(b"").dtype.type
]:
self._reverse2DimAuto(dtype)
@@ -711,8 +717,8 @@ class GradSliceChecker(object):
slice_val_grad2, = gradients_impl.gradients(
slice_val_grad, dy, grad_ys=self.var)
self.sess.run(assign)
- slice_val_grad_evaled, slice_val_grad2_evaled = (self.sess.run(
- [slice_val_grad, slice_val_grad2]))
+ slice_val_grad_evaled, slice_val_grad2_evaled = (
+ self.sess.run([slice_val_grad, slice_val_grad2]))
analytic_grad2_evaled = analytic_grad2.eval()
self.test.assertAllEqual(slice_val_grad2_evaled, analytic_grad2_evaled)
@@ -987,9 +993,10 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
with self.test_session():
res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5)
self.assertAllEqual(res.get_shape(), [3, 5])
- self.assertAllEqual(res.eval(), [[True, False, False, False, False],
- [True, True, True, False, False],
- [True, True, False, False, False]])
+ self.assertAllEqual(
+ res.eval(),
+ [[True, False, False, False, False], [True, True, True, False, False],
+ [True, True, False, False, False]])
# test dtype and default maxlen:
res = array_ops.sequence_mask(
@@ -998,17 +1005,17 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
self.assertAllEqual(res.get_shape().as_list(), [3, 4])
else:
self.assertAllEqual(res.get_shape().as_list(), [3, None])
- self.assertAllEqual(res.eval(), [[0.0, 0.0, 0.0,
- 0.0], [1.0, 0.0, 0.0, 0.0],
- [1.0, 1.0, 1.0, 1.0]])
+ self.assertAllEqual(
+ res.eval(),
+ [[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]])
def testTwoDimensional(self):
with self.test_session():
res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5)
self.assertAllEqual(res.get_shape(), [1, 3, 5])
- self.assertAllEqual(res.eval(), [[[True, False, False, False, False],
- [True, True, True, False, False],
- [True, True, False, False, False]]])
+ self.assertAllEqual(res.eval(), [[[True, False, False, False, False], [
+ True, True, True, False, False
+ ], [True, True, False, False, False]]])
# test dtype and default maxlen:
res = array_ops.sequence_mask(
@@ -1017,12 +1024,10 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
self.assertAllEqual(res.get_shape().as_list(), [2, 3, 4])
else:
self.assertAllEqual(res.get_shape().as_list(), [2, 3, None])
- self.assertAllEqual(res.eval(), [[[0.0, 0.0, 0.0, 0.0],
- [1.0, 0.0, 0.0, 0.0],
- [1.0, 1.0, 1.0, 1.0]],
- [[1.0, 0.0, 0.0, 0.0],
- [1.0, 1.0, 0.0, 0.0],
- [1.0, 1.0, 1.0, 0.0]]])
+ self.assertAllEqual(
+ res.eval(),
+ [[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]],
+ [[1.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 1.0, 0.0]]])
def testDtypes(self):
@@ -1031,9 +1036,10 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
constant_op.constant([1, 3, 2], dtype=lengths_dtype),
constant_op.constant(5, dtype=maxlen_dtype))
self.assertAllEqual(res.get_shape(), [3, 5])
- self.assertAllEqual(res.eval(), [[True, False, False, False, False],
- [True, True, True, False, False],
- [True, True, False, False, False]])
+ self.assertAllEqual(
+ res.eval(),
+ [[True, False, False, False, False], [True, True, True, False, False],
+ [True, True, False, False, False]])
with self.test_session():
check_dtypes(dtypes.int32, dtypes.int32)
@@ -1088,13 +1094,14 @@ class PadTest(test_util.TensorFlowTestCase):
def testEager(self):
with context.eager_mode():
t = constant_op.constant([[1, 2, 3], [4, 5, 6]])
- paddings = constant_op.constant([[1, 1,], [2, 2]])
+ paddings = constant_op.constant([[
+ 1,
+ 1,
+ ], [2, 2]])
padded = array_ops.pad(t, paddings, "CONSTANT")
self.assertAllEqual(padded.numpy(),
- [[0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 2, 3, 0, 0],
- [0, 0, 4, 5, 6, 0, 0],
- [0, 0, 0, 0, 0, 0, 0]])
+ [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0, 0],
+ [0, 0, 4, 5, 6, 0, 0], [0, 0, 0, 0, 0, 0, 0]])
class InvertPermutationTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
index 6cfa9b37fe..0825d8fc6b 100644
--- a/tensorflow/python/kernel_tests/diag_op_test.py
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -84,11 +84,8 @@ class MatrixSetDiagTest(test.TestCase):
def testSquare(self):
with self.test_session(use_gpu=True):
v = np.array([1.0, 2.0, 3.0])
- mat = np.array([[0.0, 1.0, 0.0],
- [1.0, 0.0, 1.0],
- [1.0, 1.0, 1.0]])
- mat_set_diag = np.array([[1.0, 1.0, 0.0],
- [1.0, 2.0, 1.0],
+ mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
+ mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0],
[1.0, 1.0, 3.0]])
output = array_ops.matrix_set_diag(mat, v)
self.assertEqual((3, 3), output.get_shape())
@@ -135,19 +132,12 @@ class MatrixSetDiagTest(test.TestCase):
def testRectangularBatch(self):
with self.test_session(use_gpu=True):
- v_batch = np.array([[-1.0, -2.0],
- [-4.0, -5.0]])
- mat_batch = np.array(
- [[[1.0, 0.0, 3.0],
- [0.0, 2.0, 0.0]],
- [[4.0, 0.0, 4.0],
- [0.0, 5.0, 0.0]]])
-
- mat_set_diag_batch = np.array(
- [[[-1.0, 0.0, 3.0],
- [0.0, -2.0, 0.0]],
- [[-4.0, 0.0, 4.0],
- [0.0, -5.0, 0.0]]])
+ v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]])
+ mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
+ [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]])
+
+ mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
+ [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]])
output = array_ops.matrix_set_diag(mat_batch, v_batch)
self.assertEqual((2, 2, 3), output.get_shape())
self.assertAllEqual(mat_set_diag_batch, output.eval())
@@ -178,10 +168,14 @@ class MatrixSetDiagTest(test.TestCase):
np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
y = array_ops.matrix_set_diag(x, x_diag)
error_x = gradient_checker.compute_gradient_error(
- x, x.get_shape().as_list(), y, y.get_shape().as_list())
+ x,
+ x.get_shape().as_list(), y,
+ y.get_shape().as_list())
self.assertLess(error_x, 1e-4)
error_x_diag = gradient_checker.compute_gradient_error(
- x_diag, x_diag.get_shape().as_list(), y, y.get_shape().as_list())
+ x_diag,
+ x_diag.get_shape().as_list(), y,
+ y.get_shape().as_list())
self.assertLess(error_x_diag, 1e-4)
def testGradWithNoShapeInformation(self):
@@ -192,12 +186,13 @@ class MatrixSetDiagTest(test.TestCase):
output = array_ops.matrix_set_diag(mat, v)
grads = gradients_impl.gradients(output, [mat, v], grad_ys=grad_input)
grad_input_val = np.random.rand(3, 3).astype(np.float32)
- grad_vals = sess.run(grads,
- feed_dict={
- v: 2 * np.ones(3),
- mat: np.ones((3, 3)),
- grad_input: grad_input_val
- })
+ grad_vals = sess.run(
+ grads,
+ feed_dict={
+ v: 2 * np.ones(3),
+ mat: np.ones((3, 3)),
+ grad_input: grad_input_val
+ })
self.assertAllEqual(np.diag(grad_input_val), grad_vals[1])
self.assertAllEqual(grad_input_val - np.diag(np.diag(grad_input_val)),
grad_vals[0])
@@ -242,13 +237,9 @@ class MatrixDiagPartTest(test.TestCase):
def testRectangularBatch(self):
with self.test_session(use_gpu=True):
- v_batch = np.array([[1.0, 2.0],
- [4.0, 5.0]])
- mat_batch = np.array(
- [[[1.0, 0.0, 0.0],
- [0.0, 2.0, 0.0]],
- [[4.0, 0.0, 0.0],
- [0.0, 5.0, 0.0]]])
+ v_batch = np.array([[1.0, 2.0], [4.0, 5.0]])
+ mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]],
+ [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]])
self.assertEqual(mat_batch.shape, (2, 2, 3))
mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
self.assertEqual((2, 2), mat_batch_diag.get_shape())
@@ -301,19 +292,13 @@ class DiagTest(test.TestCase):
def testRankOneIntTensor(self):
x = np.array([1, 2, 3])
- expected_ans = np.array(
- [[1, 0, 0],
- [0, 2, 0],
- [0, 0, 3]])
+ expected_ans = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
self.diagOp(x, np.int32, expected_ans)
self.diagOp(x, np.int64, expected_ans)
def testRankOneFloatTensor(self):
x = np.array([1.1, 2.2, 3.3])
- expected_ans = np.array(
- [[1.1, 0, 0],
- [0, 2.2, 0],
- [0, 0, 3.3]])
+ expected_ans = np.array([[1.1, 0, 0], [0, 2.2, 0], [0, 0, 3.3]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
@@ -321,123 +306,105 @@ class DiagTest(test.TestCase):
for dtype in [np.complex64, np.complex128]:
x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype)
expected_ans = np.array(
- [[1.1 + 1.1j, 0 + 0j, 0 + 0j],
- [0 + 0j, 2.2 + 2.2j, 0 + 0j],
- [0 + 0j, 0 + 0j, 3.3 + 3.3j]], dtype=dtype)
+ [[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 2.2 + 2.2j, 0 + 0j],
+ [0 + 0j, 0 + 0j, 3.3 + 3.3j]],
+ dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankTwoIntTensor(self):
x = np.array([[1, 2, 3], [4, 5, 6]])
- expected_ans = np.array(
- [[[[1, 0, 0], [0, 0, 0]],
- [[0, 2, 0], [0, 0, 0]],
- [[0, 0, 3], [0, 0, 0]]],
- [[[0, 0, 0], [4, 0, 0]],
- [[0, 0, 0], [0, 5, 0]],
- [[0, 0, 0], [0, 0, 6]]]])
+ expected_ans = np.array([[[[1, 0, 0], [0, 0, 0]], [[0, 2, 0], [0, 0, 0]],
+ [[0, 0, 3], [0, 0, 0]]],
+ [[[0, 0, 0], [4, 0, 0]], [[0, 0, 0], [0, 5, 0]],
+ [[0, 0, 0], [0, 0, 6]]]])
self.diagOp(x, np.int32, expected_ans)
self.diagOp(x, np.int64, expected_ans)
def testRankTwoFloatTensor(self):
x = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
expected_ans = np.array(
- [[[[1.1, 0, 0], [0, 0, 0]],
- [[0, 2.2, 0], [0, 0, 0]],
- [[0, 0, 3.3], [0, 0, 0]]],
- [[[0, 0, 0], [4.4, 0, 0]],
- [[0, 0, 0], [0, 5.5, 0]],
- [[0, 0, 0], [0, 0, 6.6]]]])
+ [[[[1.1, 0, 0], [0, 0, 0]], [[0, 2.2, 0], [0, 0, 0]],
+ [[0, 0, 3.3], [0, 0, 0]]], [[[0, 0, 0], [4.4, 0, 0]],
+ [[0, 0, 0], [0, 5.5, 0]], [[0, 0, 0],
+ [0, 0, 6.6]]]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
def testRankTwoComplexTensor(self):
for dtype in [np.complex64, np.complex128]:
- x = np.array([[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
- [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], dtype=dtype)
+ x = np.array(
+ [[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
+ [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]],
+ dtype=dtype)
expected_ans = np.array(
- [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
- [[0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]],
- [[[0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]],
- [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
- dtype=dtype)
+ [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], [
+ [0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]
+ ], [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], [[
+ [0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]
+ ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]
+ ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
+ dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankThreeFloatTensor(self):
- x = np.array([[[1.1, 2.2], [3.3, 4.4]],
- [[5.5, 6.6], [7.7, 8.8]]])
- expected_ans = np.array(
- [[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
- [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
- [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
- [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
- [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
- [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
- [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
- [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
+ x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]])
+ expected_ans = np.array([[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
+ [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
+ [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
+ [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
+ [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
+ [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
+ [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
+ [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
self.diagOp(x, np.float32, expected_ans)
self.diagOp(x, np.float64, expected_ans)
def testRankThreeComplexTensor(self):
for dtype in [np.complex64, np.complex128]:
- x = np.array([[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
- [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
- dtype=dtype)
+ x = np.array(
+ [[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
+ [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
+ dtype=dtype)
expected_ans = np.array(
- [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
- [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]],
- [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]],
- [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]]]]],
- [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]]],
- [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 6.6 + 6.6j], [0 + 0j, 0 + 0j]]]],
- [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [7.7 + 7.7j, 0 + 0j]]],
- [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
- [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
+ [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
+ 0 + 0j, 0 + 0j
+ ]]], [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
+ 0 + 0j, 0 + 0j
+ ]]]], [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
+ 0 + 0j, 0 + 0j
+ ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], [[0 + 0j, 0 + 0j], [
+ 0 + 0j, 0 + 0j
+ ]]]]], [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [
+ [5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]
+ ]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 6.6 + 6.6j], [
+ 0 + 0j, 0 + 0j
+ ]]]], [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
+ 7.7 + 7.7j, 0 + 0j
+ ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
+ [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testRankFourNumberTensor(self):
for dtype in [np.float32, np.float64, np.int64, np.int32]:
# Input with shape [2, 1, 2, 3]
- x = np.array([[[[ 1, 2, 3],
- [ 4, 5, 6]]],
- [[[ 7, 8, 9],
- [10, 11, 12]]]], dtype=dtype)
+ x = np.array(
+ [[[[1, 2, 3], [4, 5, 6]]], [[[7, 8, 9], [10, 11, 12]]]], dtype=dtype)
# Output with shape [2, 1, 2, 3, 2, 1, 2, 3]
expected_ans = np.array(
- [[[[[[[[1, 0, 0], [0, 0, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]],
- [[[[0, 2, 0], [0, 0, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]],
- [[[[0, 0, 3], [0, 0, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]]],
- [[[[[0, 0, 0], [4, 0, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]],
- [[[[0, 0, 0], [0, 5, 0]]],
- [[[0, 0, 0], [0, 0, 0]]]],
- [[[[0, 0, 0], [0, 0, 6]]],
- [[[0, 0, 0], [0, 0, 0]]]]]]],
-
- [[[[[[[0, 0, 0], [0, 0, 0]]],
- [[[7, 0, 0], [0, 0, 0]]]],
- [[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 8, 0], [0, 0, 0]]]],
- [[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 0, 9], [0, 0, 0]]]]],
- [[[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 0, 0], [10, 0, 0]]]],
- [[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 0, 0], [0, 11, 0]]]],
- [[[[0, 0, 0], [0, 0, 0]]],
- [[[0, 0, 0], [0, 0, 12]]]]]]]], dtype=dtype)
+ [[[[[[[[1, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
+ [[[0, 2, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
+ ], [[[[0, 0, 3], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]]], [[
+ [[[0, 0, 0], [4, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
+ ], [[[[0, 0, 0], [0, 5, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
+ [[[0, 0, 0], [0, 0, 6]]], [[[0, 0, 0], [0, 0, 0]]]
+ ]]]], [[[[[[[0, 0, 0], [0, 0, 0]]], [[[7, 0, 0], [0, 0, 0]]]], [
+ [[[0, 0, 0], [0, 0, 0]]], [[[0, 8, 0], [0, 0, 0]]]
+ ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 9], [0, 0, 0]]]]], [[
+ [[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [10, 0, 0]]]
+ ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 11, 0]]]
+ ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 12]]]]]]]],
+ dtype=dtype)
self.diagOp(x, dtype, expected_ans)
def testInvalidRank(self):
@@ -537,7 +504,9 @@ class DiagGradOpTest(test.TestCase):
x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
y = array_ops.diag(x1)
error = gradient_checker.compute_gradient_error(
- x1, x1.get_shape().as_list(), y, y.get_shape().as_list())
+ x1,
+ x1.get_shape().as_list(), y,
+ y.get_shape().as_list())
tf_logging.info("error = %f", error)
self.assertLess(error, 1e-4)
@@ -555,7 +524,9 @@ class DiagGradPartOpTest(test.TestCase):
x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
y = array_ops.diag_part(x1)
error = gradient_checker.compute_gradient_error(
- x1, x1.get_shape().as_list(), y, y.get_shape().as_list())
+ x1,
+ x1.get_shape().as_list(), y,
+ y.get_shape().as_list())
tf_logging.info("error = %f", error)
self.assertLess(error, 1e-4)
diff --git a/tensorflow/python/kernel_tests/map_stage_op_test.py b/tensorflow/python/kernel_tests/map_stage_op_test.py
index 8b66945059..acfafde9e0 100644
--- a/tensorflow/python/kernel_tests/map_stage_op_test.py
+++ b/tensorflow/python/kernel_tests/map_stage_op_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.platform import test
TIMEOUT = 1
+
class MapStageTest(test.TestCase):
def testSimple(self):
@@ -83,7 +84,7 @@ class MapStageTest(test.TestCase):
[dtypes.float32, dtypes.float32],
shapes=[[], [128, 128]],
names=['x', 'v'])
- stage = stager.put(pi,{'x': x, 'v': v})
+ stage = stager.put(pi, {'x': x, 'v': v})
key, ret = stager.get(gi)
z = ret['x']
y = ret['v']
@@ -128,8 +129,11 @@ class MapStageTest(test.TestCase):
gi = array_ops.placeholder(dtypes.int64)
p = array_ops.placeholder(dtypes.int32, name='p')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.MapStagingArea([dtypes.int32, ], shapes=[[]])
- stage = stager.put(pi,[x], [0])
+ stager = data_flow_ops.MapStagingArea(
+ [
+ dtypes.int32,
+ ], shapes=[[]])
+ stage = stager.put(pi, [x], [0])
peek = stager.peek(gi)
size = stager.size()
@@ -158,7 +162,7 @@ class MapStageTest(test.TestCase):
[dtypes.float32, dtypes.float32],
shapes=[[], [128, 128]],
names=['x', 'v'])
- stage = stager.put(pi,{'x': x, 'v': v})
+ stage = stager.put(pi, {'x': x, 'v': v})
size = stager.size()
clear = stager.clear()
@@ -172,7 +176,6 @@ class MapStageTest(test.TestCase):
sess.run(clear)
self.assertEqual(sess.run(size), 0)
-
def testCapacity(self):
capacity = 3
@@ -182,8 +185,10 @@ class MapStageTest(test.TestCase):
pi = array_ops.placeholder(dtypes.int64, name='pi')
gi = array_ops.placeholder(dtypes.int64, name='gi')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.MapStagingArea([dtypes.int32, ],
- capacity=capacity, shapes=[[]])
+ stager = data_flow_ops.MapStagingArea(
+ [
+ dtypes.int32,
+ ], capacity=capacity, shapes=[[]])
stage = stager.put(pi, [x], [0])
get = stager.get()
@@ -222,9 +227,8 @@ class MapStageTest(test.TestCase):
self.fail("Expected to timeout on iteration '{}' "
"but instead timed out on iteration '{}' "
"Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
+ "capacity is '{}'.".format(capacity, i, sess.run(size),
+ capacity))
# Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
@@ -236,8 +240,8 @@ class MapStageTest(test.TestCase):
self.assertTrue(sess.run(size) == 0)
def testMemoryLimit(self):
- memory_limit = 512*1024 # 512K
- chunk = 200*1024 # 256K
+ memory_limit = 512 * 1024 # 512K
+ chunk = 200 * 1024 # 256K
capacity = memory_limit // chunk
with ops.Graph().as_default() as G:
@@ -246,8 +250,8 @@ class MapStageTest(test.TestCase):
pi = array_ops.placeholder(dtypes.int64, name='pi')
gi = array_ops.placeholder(dtypes.int64, name='gi')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.MapStagingArea([dtypes.uint8],
- memory_limit=memory_limit, shapes=[[]])
+ stager = data_flow_ops.MapStagingArea(
+ [dtypes.uint8], memory_limit=memory_limit, shapes=[[]])
stage = stager.put(pi, [x], [0])
get = stager.get()
size = stager.size()
@@ -287,9 +291,8 @@ class MapStageTest(test.TestCase):
self.fail("Expected to timeout on iteration '{}' "
"but instead timed out on iteration '{}' "
"Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
+ "capacity is '{}'.".format(capacity, i, sess.run(size),
+ capacity))
# Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
@@ -310,8 +313,10 @@ class MapStageTest(test.TestCase):
pi = array_ops.placeholder(dtypes.int64, name='pi')
gi = array_ops.placeholder(dtypes.int64, name='gi')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.MapStagingArea([dtypes.int32, ],
- shapes=[[]], ordered=True)
+ stager = data_flow_ops.MapStagingArea(
+ [
+ dtypes.int32,
+ ], shapes=[[]], ordered=True)
stage = stager.put(pi, [x], [0])
get = stager.get()
size = stager.size()
@@ -349,7 +354,7 @@ class MapStageTest(test.TestCase):
stager = data_flow_ops.MapStagingArea(
[dtypes.float32, dtypes.float32, dtypes.float32],
names=['x', 'v', 'f'])
- stage_xf = stager.put(pi,{'x': x, 'f': f})
+ stage_xf = stager.put(pi, {'x': x, 'f': f})
stage_v = stager.put(pi, {'v': v})
key, ret = stager.get(gi)
size = stager.size()
@@ -373,12 +378,13 @@ class MapStageTest(test.TestCase):
self.assertTrue(sess.run([size, isize]) == [1, 1])
# We can now obtain tuple associated with key 0
self.assertTrue(
- sess.run([key, ret],
- feed_dict={gi: 0}) == [0, {
- 'x': 1,
- 'f': 2,
- 'v': 1
- }])
+ sess.run([key, ret], feed_dict={
+ gi: 0
+ }) == [0, {
+ 'x': 1,
+ 'f': 2,
+ 'v': 1
+ }])
# 0 complete and 1 incomplete entry
self.assertTrue(sess.run([size, isize]) == [0, 1])
@@ -386,12 +392,13 @@ class MapStageTest(test.TestCase):
sess.run(stage_v, feed_dict={pi: 1, v: 3})
# We can now obtain tuple associated with key 1
self.assertTrue(
- sess.run([key, ret],
- feed_dict={gi: 1}) == [1, {
- 'x': 1,
- 'f': 2,
- 'v': 3
- }])
+ sess.run([key, ret], feed_dict={
+ gi: 1
+ }) == [1, {
+ 'x': 1,
+ 'f': 2,
+ 'v': 3
+ }])
def testPartialIndexInsert(self):
with ops.Graph().as_default() as G:
@@ -450,7 +457,7 @@ class MapStageTest(test.TestCase):
stager = data_flow_ops.MapStagingArea(
[dtypes.float32, dtypes.float32, dtypes.float32],
names=['x', 'v', 'f'])
- stage_xf = stager.put(pi,{'x': x, 'f': f})
+ stage_xf = stager.put(pi, {'x': x, 'f': f})
stage_v = stager.put(pi, {'v': v})
peek_xf = stager.peek(pei, ['x', 'f'])
peek_v = stager.peek(pei, ['v'])
@@ -487,11 +494,12 @@ class MapStageTest(test.TestCase):
# We can now obtain 'x' and 'f' values associated with key 0
self.assertTrue(
- sess.run([key_xf, get_xf],
- feed_dict={gi: 0}) == [0, {
- 'x': 1,
- 'f': 2
- }])
+ sess.run([key_xf, get_xf], feed_dict={
+ gi: 0
+ }) == [0, {
+ 'x': 1,
+ 'f': 2
+ }])
# Still have 1 complete and 1 incomplete entry
self.assertTrue(sess.run([size, isize]) == [1, 1])
@@ -499,14 +507,15 @@ class MapStageTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError) as cm:
sess.run([key_xf, get_xf], feed_dict={gi: 0})
- exc_str = ("Tensor at index '0' for key '0' "
- "has already been removed.")
+ exc_str = ("Tensor at index '0' for key '0' " 'has already been removed.')
self.assertTrue(exc_str in cm.exception.message)
# Obtain 'v' value associated with key 0
self.assertTrue(
- sess.run([key_v, get_v], feed_dict={gi: 0}) == [0, {
+ sess.run([key_v, get_v], feed_dict={
+ gi: 0
+ }) == [0, {
'v': 1
}])
# 0 complete and 1 incomplete entry
@@ -523,7 +532,9 @@ class MapStageTest(test.TestCase):
self.assertTrue(sess.run([size, isize]) == [1, 0])
# We can now obtain 'x' and 'f' values associated with key 1
self.assertTrue(
- sess.run([pop_key_v, pop_v], feed_dict={pi: 1}) == [1, {
+ sess.run([pop_key_v, pop_v], feed_dict={
+ pi: 1
+ }) == [1, {
'v': 1
}])
# Nothing is left
@@ -557,18 +568,20 @@ class MapStageTest(test.TestCase):
self.assertTrue(sess.run([size, isize]) == [1, 0])
# Partial get using indices
- self.assertTrue(sess.run([key_xf, get_xf],
- feed_dict={gi: 0}) == [0, [1, 2]])
+ self.assertTrue(
+ sess.run([key_xf, get_xf], feed_dict={
+ gi: 0
+ }) == [0, [1, 2]])
# Still some of key 0 left
self.assertTrue(sess.run([size, isize]) == [1, 0])
# Partial get of remaining index
- self.assertTrue(sess.run([key_v, get_v],
- feed_dict={gi: 0}) == [0, [3]])
+ self.assertTrue(sess.run([key_v, get_v], feed_dict={gi: 0}) == [0, [3]])
# All gone
self.assertTrue(sess.run([size, isize]) == [0, 0])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 5c0ea8ec8e..3263ed1a60 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -159,8 +159,10 @@ class PoolingTest(test.TestCase):
elif data_format == "NCHW":
t = test_util.NCHWToNHWC(t)
if v2:
- actual = t.eval(feed_dict={ksize_placeholder: ksize,
- strides_placeholder: strides})
+ actual = t.eval(feed_dict={
+ ksize_placeholder: ksize,
+ strides_placeholder: strides
+ })
else:
actual = t.eval()
self.assertShapeEqual(actual, t)
@@ -195,8 +197,15 @@ class PoolingTest(test.TestCase):
self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
data_format, dtypes.float16, expected, use_gpu, v2)
- def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
- expected, use_gpu, v2=False):
+ def _VerifyValues(self,
+ pool_func,
+ input_sizes,
+ ksize,
+ strides,
+ padding,
+ expected,
+ use_gpu,
+ v2=False):
"""Verifies the output values of the pooling function.
Args:
@@ -1148,16 +1157,16 @@ class PoolingTest(test.TestCase):
def _testMaxPoolGradSamePadding3_1(self, data_format, use_gpu):
for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
self._ConstructAndTestGradient(
- pool_func,
- input_sizes=[1, 7, 7, 1],
- output_sizes=[1, 7, 7, 1],
- window_rows=3,
- window_cols=3,
- row_stride=1,
- col_stride=1,
- padding="SAME",
- data_format=data_format,
- use_gpu=use_gpu)
+ pool_func,
+ input_sizes=[1, 7, 7, 1],
+ output_sizes=[1, 7, 7, 1],
+ window_rows=3,
+ window_cols=3,
+ row_stride=1,
+ col_stride=1,
+ padding="SAME",
+ data_format=data_format,
+ use_gpu=use_gpu)
def testMaxPoolGrad(self):
for (data_format, use_gpu) in GetTestConfigs():
@@ -1202,17 +1211,14 @@ class PoolingTest(test.TestCase):
pool_func = gen_nn_ops._max_pool_v2 if v2 else nn_ops.max_pool
with self.test_session(use_gpu=use_gpu):
input_tensor = constant_op.constant(input_data, shape=input_sizes)
- output_tensor = pool_func(input_tensor,
- [1, window_rows, window_cols, 1],
+ output_tensor = pool_func(input_tensor, [1, window_rows, window_cols, 1],
[1, row_stride, col_stride, 1], padding)
output_backprop_tensor = constant_op.constant(
output_backprop, shape=output_sizes)
- input_backprop_tensor = self._MaxPoolGrad(input_tensor, output_tensor,
- output_backprop_tensor,
- window_rows, window_cols,
- row_stride, col_stride,
- padding, v2)
+ input_backprop_tensor = self._MaxPoolGrad(
+ input_tensor, output_tensor, output_backprop_tensor, window_rows,
+ window_cols, row_stride, col_stride, padding, v2)
actual_input_backprop = input_backprop_tensor.eval()
self.assertShapeEqual(actual_input_backprop, input_backprop_tensor)
@@ -1414,13 +1420,15 @@ class PoolingTest(test.TestCase):
def _testMaxPoolGradDirectWithNans2_2(self):
input_data = [float("nan")] * 16
output_backprop = [
- float("nan"), 12.0, 13.0, 15.0, float("nan"), 17.0, 19.0, 20.0,
+ float("nan"), 12.0, 13.0, 15.0,
+ float("nan"), 17.0, 19.0, 20.0,
float("nan")
]
# Test the CPU implementation, which propagates diffs in case of NaN
expected_input_backprop_tf_cpu = [
- float("nan"), 12.0, 13.0, 0.0, 15.0, float("nan"), 17.0, 0.0, 19.0,
- 20.0, float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0
+ float("nan"), 12.0, 13.0, 0.0, 15.0,
+ float("nan"), 17.0, 0.0, 19.0, 20.0,
+ float("nan"), 0.0, 0.0, 0.0, 0.0, 0.0
]
for v2 in [True, False]:
self._testMaxPoolGradDirect(
@@ -1636,10 +1644,9 @@ class PoolingTest(test.TestCase):
Returns:
A Tensor.
"""
- return gen_nn_ops._max_pool_grad_grad(orig_input, orig_output, grad,
- [1, window_rows, window_cols,
- 1], [1, row_stride, col_stride,
- 1], padding)
+ return gen_nn_ops._max_pool_grad_grad(
+ orig_input, orig_output, grad, [1, window_rows, window_cols, 1],
+ [1, row_stride, col_stride, 1], padding)
def testAvgPoolGrad(self):
for (data_format, use_gpu) in GetTestConfigs():
@@ -1793,8 +1800,7 @@ class PoolingTest(test.TestCase):
]:
with self.assertRaises(ValueError):
pool_func(
- array_ops.placeholder(
- dtypes.float32, shape=[1, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[1, 3]),
ksize=[1, 1, 1, 1],
strides=[1, 1, 1, 1],
padding="SAME")
@@ -1820,15 +1826,13 @@ class PoolingTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
sess.run(
pool_func(
- array_ops.placeholder(
- dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
ksize=[1, 20, 21, 1],
strides=[1, 1, 1, 1],
padding="VALID"))
with self.assertRaisesRegexp(ValueError, "Negative dimension size"):
pool_func(
- array_ops.placeholder(
- dtypes.float32, shape=[32, 20, 20, 3]),
+ array_ops.placeholder(dtypes.float32, shape=[32, 20, 20, 3]),
ksize=[1, 21, 20, 1],
strides=[1, 1, 1, 1],
padding="VALID")
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 223a4b2c87..82a27eebee 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -428,7 +428,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
filenames.append(fn)
- with open(fn+".tmp", "wb") as f:
+ with open(fn + ".tmp", "wb") as f:
f.write(b"H" * self._header_bytes)
if num_records > 0:
f.write(self._Record(i, 0))
@@ -437,7 +437,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
f.write(b"G" * gap_bytes)
f.write(self._Record(i, j))
f.write(b"F" * self._footer_bytes)
- with open(fn+".tmp", "rb") as f:
+ with open(fn + ".tmp", "rb") as f:
cdata = zlib.compress(f.read())
with open(fn, "wb") as zf:
zf.write(cdata)
@@ -455,7 +455,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
all_records_str = "".join([
str(i)[0]
for i in range(self._record_bytes + self._hop_bytes *
- (num_overlapped_records - 1))
+ (num_overlapped_records - 1))
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
@@ -467,7 +467,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
fn = os.path.join(self.get_temp_dir(),
"fixed_length_overlapped_record.%d.txt" % i)
filenames.append(fn)
- with open(fn+".tmp", "wb") as f:
+ with open(fn + ".tmp", "wb") as f:
f.write(b"H" * self._header_bytes)
if num_overlapped_records > 0:
all_records_str = "".join([
@@ -477,7 +477,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
- with open(fn+".tmp", "rb") as f:
+ with open(fn + ".tmp", "rb") as f:
cdata = zlib.compress(f.read())
with open(fn, "wb") as zf:
zf.write(cdata)
@@ -509,7 +509,10 @@ class FixedLengthRecordReaderTest(test.TestCase):
"\\(requested 1, current size 0\\)"):
k, v = sess.run([key, value])
- def _TestOneEpochWithHopBytes(self, files, num_overlapped_records, encoding=None):
+ def _TestOneEpochWithHopBytes(self,
+ files,
+ num_overlapped_records,
+ encoding=None):
with self.test_session() as sess:
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
@@ -565,13 +568,15 @@ class FixedLengthRecordReaderTest(test.TestCase):
def testGzipOneEpochWithHopBytes(self):
for num_overlapped_records in [0, 2]:
- files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records, )
- self._TestOneEpochWithHopBytes(files, num_overlapped_records, encoding="GZIP")
+ files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records,)
+ self._TestOneEpochWithHopBytes(
+ files, num_overlapped_records, encoding="GZIP")
def testZlibOneEpochWithHopBytes(self):
for num_overlapped_records in [0, 2]:
files = self._CreateZlibOverlappedRecordFiles(num_overlapped_records)
- self._TestOneEpochWithHopBytes(files, num_overlapped_records, encoding="ZLIB")
+ self._TestOneEpochWithHopBytes(
+ files, num_overlapped_records, encoding="ZLIB")
class TFRecordReaderTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index dd11ba700d..6b4091ae5d 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -48,8 +48,8 @@ class ReluTest(test.TestCase):
self.assertAllClose(
np.array([[0.0, 0.7, 0.0, 0.3, 0.0], [0.1, 0.0, 0.5, 0.0, 0.9]]),
self._npRelu(
- np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
- ])))
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]])))
def _testRelu(self, np_features, use_gpu=False):
np_relu = self._npRelu(np_features)
@@ -163,8 +163,8 @@ class Relu6Test(test.TestCase):
self.assertAllClose(
np.array([[0.0, 0.7, 0.0, 0.3, 6.0], [0.1, 0.0, 6.0, 0.0, 0.9]]),
self._npRelu6(
- np.array([[-0.9, 0.7, -0.5, 0.3, 6.0], [0.1, -0.3, 6.5, -0.7, 0.9]
- ])))
+ np.array([[-0.9, 0.7, -0.5, 0.3, 6.0], [0.1, -0.3, 6.5, -0.7,
+ 0.9]])))
def _testRelu6(self, np_features, use_gpu=False):
np_relu6 = self._npRelu6(np_features)
@@ -231,8 +231,8 @@ class EluTest(test.TestCase):
np.array([[-0.59343034025, 0.7, -0.39346934028, 0.3, -0.09516258196],
[0.1, -0.25918177931, 0.5, -0.5034146962, 0.9]]),
self._npElu(
- np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
- ])))
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]])))
def _testElu(self, np_features, use_gpu=False):
np_elu = self._npElu(np_features)
@@ -330,11 +330,11 @@ class SeluTest(test.TestCase):
def testNpSelu(self):
self.assertAllClose(
- np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103 , -0.16730527],
- [0.1050701 , -0.45566732, 0.5253505, -0.88505305, 0.9456309]]),
+ np.array([[-1.0433095, 0.73549069, -0.6917582, 0.3152103, -0.16730527],
+ [0.1050701, -0.45566732, 0.5253505, -0.88505305, 0.9456309]]),
self._npSelu(
- np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
- ])))
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]])))
def _testSelu(self, np_features, use_gpu=False):
np_selu = self._npSelu(np_features)
diff --git a/tensorflow/python/kernel_tests/scalar_test.py b/tensorflow/python/kernel_tests/scalar_test.py
index b34426cc21..e65241981e 100644
--- a/tensorflow/python/kernel_tests/scalar_test.py
+++ b/tensorflow/python/kernel_tests/scalar_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
@@ -30,6 +31,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
+@test_util.with_c_api
class ScalarTest(test.TestCase):
def check(self, op, args, error, correct=None):
@@ -51,7 +53,7 @@ class ScalarTest(test.TestCase):
# Test various GraphDef versions
for version in strict + lenient:
with ops.Graph().as_default() as g:
- g.graph_def_versions.producer = version
+ test_util.set_producer_version(g, version)
with self.test_session(graph=g) as sess:
feed = {}
xs = placeholders(args, feed)
diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
index 762e400447..da116601f8 100644
--- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
@@ -32,11 +32,12 @@ class SparseSliceOpTest(test.TestCase):
# [ |11| |13|14| ]
# [20| | |23| |25]
# [30| |32|33| |35]
- ind = np.array([[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4],
- [2, 0], [2, 3], [2, 5], [3, 0], [3, 2], [3, 3],
- [3, 5]]).astype(np.int64)
- val = np.array(
- [0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(np.int64)
+ ind = np.array([[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1,
+ 4], [2, 0],
+ [2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype(
+ np.int64)
+ val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(
+ np.int64)
shape = np.array([4, 6]).astype(np.int64)
return sparse_tensor.SparseTensor(ind, val, shape)
@@ -65,50 +66,49 @@ class SparseSliceOpTest(test.TestCase):
# [ |'c1'| |'d1']
# [ | |'e1'| ]
ind = np.array([[0, 0, 0], [0, 0, 1], [0, 2, 0], [0, 2, 1], [1, 1, 0],
- [1, 1, 1], [1, 3, 0], [1, 3, 1], [2, 2, 0],
- [2, 2, 1]]).astype(np.int64)
+ [1, 1, 1], [1, 3, 0], [1, 3, 1], [2, 2, 0], [2, 2,
+ 1]]).astype(
+ np.int64)
val = np.array(['a0', 'a1', 'b0', 'b1', 'c0', 'c1', 'd0', 'd1', 'e0', 'e1'])
shape = np.array([3, 4, 2]).astype(np.int64)
return sparse_tensor.SparseTensorValue(ind, val, shape)
def _SparseTensor_3x4x2(self):
- return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x4x2(
- ))
+ return sparse_tensor.SparseTensor.from_value(
+ self._SparseTensorValue_3x4x2())
def testSliceMatrixRows(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_4x6()
+ sp_input = self._SparseTensor_4x6()
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [2, 6])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [2, 0], [3, 7])
- self.assertAllEqual(sp_tensor0.indices.eval(), [[0, 0], [0, 2], [0, 4],
- [0, 5], [1, 1], [1, 3],
- [1, 4]])
+ self.assertAllEqual(
+ sp_tensor0.indices.eval(),
+ [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4]])
self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 4, 5, 11, 13, 14])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [2, 6])
- self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 0], [0, 3], [0, 5],
- [1, 0], [1, 2], [1, 3],
- [1, 5]])
+ self.assertAllEqual(
+ sp_tensor1.indices.eval(),
+ [[0, 0], [0, 3], [0, 5], [1, 0], [1, 2], [1, 3], [1, 5]])
self.assertAllEqual(sp_tensor1.values.eval(),
[20, 23, 25, 30, 32, 33, 35])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 6])
def testSliceMatrixUnevenCols(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_5x7()
+ sp_input = self._SparseTensor_5x7()
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [5, 3])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 3], [5, 2])
sp_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 5], [5, 2])
- self.assertAllEqual(sp_tensor0.indices.eval(),
- [[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2],
- [4, 1]])
- self.assertAllEqual(sp_tensor0.values.eval(),
- [0, 2, 11, 20, 30, 32, 41])
+ self.assertAllEqual(
+ sp_tensor0.indices.eval(),
+ [[0, 0], [0, 2], [1, 1], [2, 0], [3, 0], [3, 2], [4, 1]])
+ self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 11, 20, 30, 32, 41])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [5, 3])
self.assertAllEqual(sp_tensor1.indices.eval(),
[[0, 1], [1, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
- self.assertAllEqual(sp_tensor1.values.eval(),
- [4, 13, 14, 23, 33, 44])
+ self.assertAllEqual(sp_tensor1.values.eval(), [4, 13, 14, 23, 33, 44])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [5, 2])
self.assertAllEqual(sp_tensor2.indices.eval(),
[[0, 0], [1, 1], [2, 0], [3, 0], [4, 1]])
@@ -137,7 +137,7 @@ class SparseSliceOpTest(test.TestCase):
def testSliceMatrixUnevenRows(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_5x7()
+ sp_input = self._SparseTensor_5x7()
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [3, 7])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [3, 0], [3, 7])
self.assertAllEqual(sp_tensor0.indices.eval(),
@@ -146,9 +146,9 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sp_tensor0.values.eval(),
[0, 2, 4, 5, 11, 13, 14, 16, 20, 23, 25])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [3, 7])
- self.assertAllEqual(sp_tensor1.indices.eval(),
- [[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4],
- [1, 6]])
+ self.assertAllEqual(
+ sp_tensor1.indices.eval(),
+ [[0, 0], [0, 2], [0, 3], [0, 5], [1, 1], [1, 4], [1, 6]])
self.assertAllEqual(sp_tensor1.values.eval(),
[30, 32, 33, 35, 41, 44, 46])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 7])
@@ -156,9 +156,9 @@ class SparseSliceOpTest(test.TestCase):
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [2, 7])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [2, 0], [2, 7])
sp_tensor2 = sparse_ops.sparse_slice(sp_input, [4, 0], [2, 7])
- self.assertAllEqual(sp_tensor0.indices.eval(),
- [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3],
- [1, 4], [1, 6]])
+ self.assertAllEqual(
+ sp_tensor0.indices.eval(),
+ [[0, 0], [0, 2], [0, 4], [0, 5], [1, 1], [1, 3], [1, 4], [1, 6]])
self.assertAllEqual(sp_tensor0.values.eval(),
[0, 2, 4, 5, 11, 13, 14, 16])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [2, 7])
@@ -166,45 +166,42 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sp_tensor1.values.eval(),
[20, 23, 25, 30, 32, 33, 35])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [2, 7])
- self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 1], [0, 4],
- [0, 6]])
+ self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 1], [0, 4], [0, 6]])
self.assertAllEqual(sp_tensor2.values.eval(), [41, 44, 46])
self.assertAllEqual(sp_tensor2.dense_shape.eval(), [1, 7])
return
def testSliceAllRows(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_4x6()
+ sp_input = self._SparseTensor_4x6()
sp_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [1, 6])
sp_tensor1 = sparse_ops.sparse_slice(sp_input, [1, 0], [1, 6])
sp_tensor2 = sparse_ops.sparse_slice(sp_input, [2, 0], [1, 7])
sp_tensor3 = sparse_ops.sparse_slice(sp_input, [3, 0], [2, 7])
- self.assertAllEqual(sp_tensor0.indices.eval(), [[0, 0], [0, 2], [0, 4],
- [0, 5]])
+ self.assertAllEqual(sp_tensor0.indices.eval(),
+ [[0, 0], [0, 2], [0, 4], [0, 5]])
self.assertAllEqual(sp_tensor0.values.eval(), [0, 2, 4, 5])
self.assertAllEqual(sp_tensor0.dense_shape.eval(), [1, 6])
- self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 1], [0, 3], [0,
- 4]])
+ self.assertAllEqual(sp_tensor1.indices.eval(), [[0, 1], [0, 3], [0, 4]])
self.assertAllEqual(sp_tensor1.values.eval(), [11, 13, 14])
self.assertAllEqual(sp_tensor1.dense_shape.eval(), [1, 6])
- self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 0], [0, 3], [0,
- 5]])
+ self.assertAllEqual(sp_tensor2.indices.eval(), [[0, 0], [0, 3], [0, 5]])
self.assertAllEqual(sp_tensor2.values.eval(), [20, 23, 25])
self.assertAllEqual(sp_tensor2.dense_shape.eval(), [1, 6])
- self.assertAllEqual(sp_tensor3.indices.eval(), [[0, 0], [0, 2], [0, 3],
- [0, 5]])
+ self.assertAllEqual(sp_tensor3.indices.eval(),
+ [[0, 0], [0, 2], [0, 3], [0, 5]])
self.assertAllEqual(sp_tensor3.values.eval(), [30, 32, 33, 35])
self.assertAllEqual(sp_tensor3.dense_shape.eval(), [1, 6])
def testSliceColumns(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_4x6()
+ sp_input = self._SparseTensor_4x6()
sparse_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [4, 2])
sparse_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 2], [5, 2])
sparse_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 4], [5, 3])
- self.assertAllEqual(sparse_tensor0.indices.eval(), [[0, 0], [1, 1],
- [2, 0], [3, 0]])
+ self.assertAllEqual(sparse_tensor0.indices.eval(),
+ [[0, 0], [1, 1], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensor0.values.eval(), [0, 11, 20, 30])
self.assertAllEqual(sparse_tensor0.dense_shape.eval(), [4, 2])
self.assertAllEqual(sparse_tensor1.indices.eval(),
@@ -218,15 +215,15 @@ class SparseSliceOpTest(test.TestCase):
def testSliceAllColumns(self):
with self.test_session(use_gpu=False):
- sp_input=self._SparseTensor_4x6()
+ sp_input = self._SparseTensor_4x6()
sparse_tensor0 = sparse_ops.sparse_slice(sp_input, [0, 0], [4, 1])
sparse_tensor1 = sparse_ops.sparse_slice(sp_input, [0, 1], [4, 1])
sparse_tensor2 = sparse_ops.sparse_slice(sp_input, [0, 2], [4, 1])
sparse_tensor3 = sparse_ops.sparse_slice(sp_input, [0, 3], [4, 1])
sparse_tensor4 = sparse_ops.sparse_slice(sp_input, [0, 4], [5, 1])
sparse_tensor5 = sparse_ops.sparse_slice(sp_input, [0, 5], [6, 3])
- self.assertAllEqual(sparse_tensor0.indices.eval(), [[0, 0], [2, 0],
- [3, 0]])
+ self.assertAllEqual(sparse_tensor0.indices.eval(),
+ [[0, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensor0.values.eval(), [0, 20, 30])
self.assertAllEqual(sparse_tensor0.dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensor1.indices.eval(), [[1, 0]])
@@ -235,17 +232,18 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sparse_tensor2.indices.eval(), [[0, 0], [3, 0]])
self.assertAllEqual(sparse_tensor2.values.eval(), [2, 32])
self.assertAllEqual(sparse_tensor2.dense_shape.eval(), [4, 1])
- self.assertAllEqual(sparse_tensor3.indices.eval(), [[1, 0], [2, 0],
- [3, 0]])
+ self.assertAllEqual(sparse_tensor3.indices.eval(),
+ [[1, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensor3.dense_shape.eval(), [4, 1])
self.assertAllEqual(sparse_tensor3.values.eval(), [13, 23, 33])
self.assertAllEqual(sparse_tensor4.indices.eval(), [[0, 0], [1, 0]])
self.assertAllEqual(sparse_tensor4.values.eval(), [4, 14])
self.assertAllEqual(sparse_tensor4.dense_shape.eval(), [4, 1])
- self.assertAllEqual(sparse_tensor5.indices.eval(), [[0, 0], [2, 0],
- [3, 0]])
+ self.assertAllEqual(sparse_tensor5.indices.eval(),
+ [[0, 0], [2, 0], [3, 0]])
self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35])
self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py
index 64b3388c5c..dd06d30391 100644
--- a/tensorflow/python/kernel_tests/stage_op_test.py
+++ b/tensorflow/python/kernel_tests/stage_op_test.py
@@ -25,8 +25,8 @@ from tensorflow.python.platform import test
TIMEOUT = 1
-class StageTest(test.TestCase):
+class StageTest(test.TestCase):
def testSimple(self):
with ops.Graph().as_default() as G:
@@ -116,7 +116,10 @@ class StageTest(test.TestCase):
x = array_ops.placeholder(dtypes.int32, name='x')
p = array_ops.placeholder(dtypes.int32, name='p')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.StagingArea([dtypes.int32, ], shapes=[[]])
+ stager = data_flow_ops.StagingArea(
+ [
+ dtypes.int32,
+ ], shapes=[[]])
stage = stager.put([x])
peek = stager.peek(p)
ret = stager.get()
@@ -162,8 +165,10 @@ class StageTest(test.TestCase):
with ops.device('/cpu:0'):
x = array_ops.placeholder(dtypes.int32, name='x')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.StagingArea([dtypes.int32, ],
- capacity=capacity, shapes=[[]])
+ stager = data_flow_ops.StagingArea(
+ [
+ dtypes.int32,
+ ], capacity=capacity, shapes=[[]])
stage = stager.put([x])
ret = stager.get()
size = stager.size()
@@ -201,9 +206,8 @@ class StageTest(test.TestCase):
self.fail("Expected to timeout on iteration '{}' "
"but instead timed out on iteration '{}' "
"Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
+ "capacity is '{}'.".format(capacity, i, sess.run(size),
+ capacity))
# Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
@@ -216,16 +220,18 @@ class StageTest(test.TestCase):
self.assertTrue(sess.run(size) == 0)
def testMemoryLimit(self):
- memory_limit = 512*1024 # 512K
- chunk = 200*1024 # 256K
+ memory_limit = 512 * 1024 # 512K
+ chunk = 200 * 1024 # 256K
capacity = memory_limit // chunk
with ops.Graph().as_default() as G:
with ops.device('/cpu:0'):
x = array_ops.placeholder(dtypes.uint8, name='x')
with ops.device(test.gpu_device_name()):
- stager = data_flow_ops.StagingArea([dtypes.uint8, ],
- memory_limit=memory_limit, shapes=[[]])
+ stager = data_flow_ops.StagingArea(
+ [
+ dtypes.uint8,
+ ], memory_limit=memory_limit, shapes=[[]])
stage = stager.put([x])
ret = stager.get()
size = stager.size()
@@ -264,9 +270,8 @@ class StageTest(test.TestCase):
self.fail("Expected to timeout on iteration '{}' "
"but instead timed out on iteration '{}' "
"Staging Area size is '{}' and configured "
- "capacity is '{}'.".format(capacity, i,
- sess.run(size),
- capacity))
+ "capacity is '{}'.".format(capacity, i, sess.run(size),
+ capacity))
# Should have capacity elements in the staging area
self.assertTrue(sess.run(size) == capacity)
@@ -277,5 +282,6 @@ class StageTest(test.TestCase):
self.assertTrue(sess.run(size) == 0)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 00faf3faa1..5d9feb07b4 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -99,8 +99,16 @@ class Layer(object):
raise TypeError('Keyword argument not understood:', kwarg)
# Mutable properties
+ # Indicates whether the layer's weights are updated during training
+ # and whether the layer's updates are run during training
self.trainable = trainable
+ # A stateful layer is a layer whose updates are run during inference too,
+ # for instance stateful RNNs.
+ self.stateful = False
+ # Indicates whether `build` needs to be called upon layer call, to create
+ # the layer's weights.
self.built = False
+ # Provides information about which inputs are compatible with the layer.
self.input_spec = None
if activity_regularizer and context.in_eager_mode():
@@ -223,6 +231,8 @@ class Layer(object):
def updates(self):
if context.in_eager_mode():
raise RuntimeError('Layer.updates not supported in Eager mode.')
+ if not self.trainable and not self.stateful:
+ return []
return self._updates
def add_update(self, updates, inputs=None):
@@ -284,6 +294,8 @@ class Layer(object):
"""
if context.in_eager_mode():
raise RuntimeError('Layer.get_updates_for not supported in Eager mode.')
+ if not self.trainable and not self.stateful:
+ return []
if inputs is not None:
inputs = nest.flatten(inputs)
if not inputs:
@@ -500,13 +512,30 @@ class Layer(object):
instance is returned.
Raises:
- RuntimeError: If called in Eager mode with partioned variable
- regularization.
+ RuntimeError: If called with partioned variable regularization and
+ eager execution is enabled.
"""
- in_graph_mode = context.in_graph_mode()
- if in_graph_mode:
- existing_variables = set(tf_variables.global_variables())
+ # `init_graph` should point to the graph in which variable initialization
+ # will occur; it should be None if and only if initialization will take
+ # place in the eager context.
+ init_graph = None
+ if context.in_graph_mode():
+ default_graph = ops.get_default_graph()
+ if default_graph.building_function:
+ with ops.init_scope():
+ # Retrieve the variables from the graph into which variables
+ # will be lifted; if initialization ops will be lifted into
+ # the eager context, then there is nothing to retrieve, since variable
+ # collections are not supported when eager execution is enabled.
+ if context.in_graph_mode():
+ init_graph = ops.get_default_graph()
+ existing_variables = set(tf_variables.global_variables())
+ else:
+ # Initialization ops will not be lifted out of the default graph.
+ init_graph = default_graph
+ existing_variables = set(tf_variables.global_variables())
+
if dtype is None:
dtype = self.dtype or dtypes.float32
@@ -523,54 +552,51 @@ class Layer(object):
trainable=trainable and self.trainable,
partitioner=partitioner)
- if in_graph_mode:
- if (trainable and self.trainable
- and variable not in tf_variables.trainable_variables()):
- # A custom getter / variable scope overrode the trainable flag.
- trainable = False
+ if init_graph is not None: # pylint: disable=protected-access
+ # The variable was created and initialized in a graph.
+
if variable in existing_variables:
# To match the behavior of tf.get_variable(), we only apply
# regularization if the variable is newly created.
return variable
- if regularizer:
- def regularizer_factory():
- if context.in_graph_mode():
- with vs.variable_scope(scope, reuse=reuse,
- auxiliary_name_scope=False):
- with ops.name_scope(self._name_scope_name(scope)):
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
+ with init_graph.as_default():
+ trainable_variables = tf_variables.trainable_variables()
+ if (trainable and self.trainable and
+ variable not in trainable_variables):
+ # A custom getter / variable scope overrode the trainable flag.
+ trainable = False
+
+ if regularizer:
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ with ops.colocate_with(v.op):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ if regularization is not None:
+ self.add_loss(regularization)
else:
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request'
- 'if this is important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested.
- # This makes losses responsive to variable updates when
- # executing eagerly.
- self._losses.append(lambda: regularizer(variable))
-
- if hasattr(self, '_defer_regularizers') and self._defer_regularizers:
- # _defer_regularizers exists and is set to True if `build` was
- # invoked in `__call__`: deferring regularizer construction
- # prevents the regularizer from being created in an `init_scope`.
- self._get_regularizer_factories().append(regularizer_factory)
- else:
- regularizer_factory()
+ with ops.colocate_with(variable.op):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(variable)
+ if regularization is not None:
+ self.add_loss(regularization)
+ elif regularizer: # and initialization took place in an eager context
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ raise RuntimeError(
+ 'Partitioned variable regularization is not yet '
+ 'supported when executing eagerly. File a feature request'
+ 'if this is important to you.')
+ # Save a zero-argument lambda which runs the regularizer on the
+ # variable, to be executed when `Layer.losses` is requested.
+ # This makes losses responsive to variable updates when executing
+ # eagerly.
+ #
+ # TODO(akshayka): Do the same for graphs as well, so that losses
+ # collected in a while_loop can be run outside its control flow
+ # context and so that losses won't be swallowed up by graph functions
+ # (i.e., `.losses()` should always create regularizers).
+ self._losses.append(lambda: regularizer(variable))
if trainable:
self._trainable_weights.append(variable)
@@ -670,15 +696,7 @@ class Layer(object):
except AttributeError:
pass
input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
-
- # Signal to `add_variable` that regularizer construction should be
- # deferred.
- self._defer_regularizers = True
- with ops.init_scope():
- self.build(input_shapes)
- # Create any regularizers added by `build`.
- self._maybe_create_variable_regularizers()
- self._defer_regularizers = False
+ self.build(input_shapes)
try:
# Note: not all sub-classes of Layer call Layer.__init__ (especially
# the ones under tensorflow/python/keras). Hence we recompute this
@@ -1263,6 +1281,15 @@ class InputSpec(object):
self.min_ndim = min_ndim
self.axes = axes or {}
+ def __repr__(self):
+ spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
+ ('shape=' + str(self.shape)) if self.shape else '',
+ ('ndim=' + str(self.ndim)) if self.ndim else '',
+ ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
+ ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
+ ('axes=' + str(self.axes)) if self.axes else '']
+ return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
+
class Node(object):
"""A `Node` describes the connectivity between two layers.
diff --git a/tensorflow/python/layers/maxout.py b/tensorflow/python/layers/maxout.py
index ed048845a0..20ce6c9770 100644
--- a/tensorflow/python/layers/maxout.py
+++ b/tensorflow/python/layers/maxout.py
@@ -31,15 +31,18 @@ from tensorflow.python.layers import base
def maxout(inputs, num_units, axis=-1, name=None):
"""Adds a maxout op from https://arxiv.org/abs/1302.4389
- "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville,
+ "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron
+ Courville,
Yoshua Bengio
- Usually the operation is performed in the filter/channel dimension. This can also be
+ Usually the operation is performed in the filter/channel dimension. This can
+ also be
used after fully-connected layers to reduce number of features.
Arguments:
inputs: Tensor input
- num_units: Specifies how many features will remain after maxout in the `axis` dimension
+ num_units: Specifies how many features will remain after maxout in the `axis`
+ dimension
(usually channel). This must be multiple of number of `axis`.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
@@ -57,15 +60,18 @@ def maxout(inputs, num_units, axis=-1, name=None):
class MaxOut(base.Layer):
"""Adds a maxout op from https://arxiv.org/abs/1302.4389
- "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, Yoshua
+ "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron
+ Courville, Yoshua
Bengio
- Usually the operation is performed in the filter/channel dimension. This can also be
+ Usually the operation is performed in the filter/channel dimension. This can
+ also be
used after fully-connected layers to reduce number of features.
Arguments:
inputs: Tensor input
- num_units: Specifies how many features will remain after maxout in the `axis` dimension
+ num_units: Specifies how many features will remain after maxout in the
+ `axis` dimension
(usually channel).
This must be multiple of number of `axis`.
axis: The dimension where max pooling will be performed. Default is the
@@ -79,13 +85,8 @@ class MaxOut(base.Layer):
ValueError: if num_units is not multiple of number of features.
"""
- def __init__(self,
- num_units,
- axis=-1,
- name=None,
- **kwargs):
- super(MaxOut, self).__init__(
- name=name, trainable=False, **kwargs)
+ def __init__(self, num_units, axis=-1, name=None, **kwargs):
+ super(MaxOut, self).__init__(name=name, trainable=False, **kwargs)
self.axis = axis
self.num_units = num_units
@@ -95,8 +96,8 @@ class MaxOut(base.Layer):
num_channels = shape[self.axis]
if num_channels % self.num_units:
raise ValueError('number of features({}) is not '
- 'a multiple of num_units({})'
- .format(num_channels, self.num_units))
+ 'a multiple of num_units({})'.format(
+ num_channels, self.num_units))
shape[self.axis] = -1
shape += [num_channels // self.num_units]
@@ -104,6 +105,7 @@ class MaxOut(base.Layer):
for i in range(len(shape)):
if shape[i] is None:
shape[i] = gen_array_ops.shape(inputs)[i]
- outputs = math_ops.reduce_max(gen_array_ops.reshape(inputs, shape), -1, keep_dims=False)
+ outputs = math_ops.reduce_max(
+ gen_array_ops.reshape(inputs, shape), -1, keep_dims=False)
return outputs
diff --git a/tensorflow/python/layers/network.py b/tensorflow/python/layers/network.py
index ade57da411..0a5dd57621 100644
--- a/tensorflow/python/layers/network.py
+++ b/tensorflow/python/layers/network.py
@@ -575,6 +575,11 @@ class GraphNetwork(base.Layer):
raise ValueError('No such layer: ' + name)
@property
+ def stateful(self):
+ return any([(hasattr(layer, 'stateful') and layer.stateful)
+ for layer in self.layers])
+
+ @property
def updates(self):
"""Retrieve the network's updates.
@@ -586,6 +591,8 @@ class GraphNetwork(base.Layer):
Returns:
A list of update ops.
"""
+ if not self.trainable and not self.stateful:
+ return []
updates = []
for layer in self.layers:
if hasattr(layer, 'updates'):
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 55cae0bcbf..c9292184e6 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Gradients for operators defined in array_ops.py."""
from __future__ import absolute_import
@@ -131,8 +130,8 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
# extract the size of each input along the concat dimension
sizes = array_ops.squeeze(
array_ops.slice(
- array_ops.stack(
- sizes, axis=1), [non_neg_concat_dim, 0], [1, -1]))
+ array_ops.stack(sizes, axis=1), [non_neg_concat_dim, 0],
+ [1, -1]))
out_grads = array_ops.split(grad, sizes, non_neg_concat_dim)
else:
offset = gen_array_ops._concat_offset(non_neg_concat_dim, sizes)
@@ -167,8 +166,7 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
new_values = array_ops.slice(
grad.values, begin,
array_ops.concat([[-1], array_ops.slice(size, [1], [-1])], 0))
- out_grads.append(
- ops.IndexedSlices(new_values, grad.indices, size))
+ out_grads.append(ops.IndexedSlices(new_values, grad.indices, size))
# Lint complains begin = begin + ...
begin = math_ops.add(begin, size * mask)
else:
@@ -178,30 +176,33 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
for size in sizes:
size_concat_dim = array_ops.gather(size, non_neg_concat_dim)
if size_concat_dim.dtype != grad.indices.dtype:
- size_concat_dim = math_ops.cast(size_concat_dim,
- dtype=grad.indices.dtype)
+ size_concat_dim = math_ops.cast(
+ size_concat_dim, dtype=grad.indices.dtype)
end = start + size_concat_dim
# Compute the 1-D Tensor of indices relevant for this input.
indices_to_select = array_ops.squeeze(
- array_ops.where(math_ops.logical_and(grad.indices >= start,
- grad.indices < end)),
+ array_ops.where(
+ math_ops.logical_and(grad.indices >= start,
+ grad.indices < end)),
squeeze_dims=[1])
new_indices = array_ops.gather(grad.indices, indices_to_select) - start
new_values = array_ops.gather(grad.values, indices_to_select)
- out_grads.append(
- ops.IndexedSlices(new_values, new_indices, size))
+ out_grads.append(ops.IndexedSlices(new_values, new_indices, size))
start = end
else:
raise TypeError("Expected Tensor or IndexedSlices, got %s" % type(grad))
- return (out_grads + [None] if end_value_index <= dim_index
- else [None] + out_grads)
+ return (out_grads + [None]
+ if end_value_index <= dim_index else [None] + out_grads)
@ops.RegisterGradient("Concat")
def _ConcatGrad(op, grad):
return _ConcatGradHelper(
- op, grad, start_value_index=1, end_value_index=len(op.inputs),
+ op,
+ grad,
+ start_value_index=1,
+ end_value_index=len(op.inputs),
dim_index=0)
@@ -287,9 +288,13 @@ def _SplitGrad(op, *grads):
@ops.RegisterGradient("SplitV")
def _SplitVGrad(op, *grads):
returnval = array_ops.concat(list(grads), op.inputs[2])
- returnval = [returnval] + [None,] * (len(op.inputs) - 1)
+ returnval = [returnval] + [
+ None,
+ ] * (
+ len(op.inputs) - 1)
return returnval
+
ops.NotDifferentiable("Const")
@@ -334,9 +339,9 @@ def _MatrixSetDiagGrad(op, grad):
matrix_shape = array_ops.slice(grad_shape, [grad_rank - 2], [2])
min_dim = math_ops.reduce_min(matrix_shape)
diag_shape = array_ops.concat([batch_shape, [min_dim]], 0)
- grad_input = array_ops.matrix_set_diag(
- grad, array_ops.zeros(
- diag_shape, dtype=grad.dtype))
+ grad_input = array_ops.matrix_set_diag(grad,
+ array_ops.zeros(
+ diag_shape, dtype=grad.dtype))
grad_diag = array_ops.matrix_diag_part(grad)
return (grad_input, grad_diag)
@@ -444,8 +449,8 @@ def _GatherV2Grad(op, grad):
values_transpose = array_ops.transpose(values, transpose_dims)
num_segments = params_shape[axis]
- params_grad = math_ops.unsorted_segment_sum(
- values_transpose, indices, num_segments)
+ params_grad = math_ops.unsorted_segment_sum(values_transpose, indices,
+ num_segments)
# Inverts the above transpose by moving dimension 0 back to its original
# position.
@@ -536,13 +541,10 @@ def _ConjugateTransposeGrad(op, grad):
ops.NotDifferentiable("Shape")
-
ops.NotDifferentiable("ShapeN")
-
ops.NotDifferentiable("Rank")
-
ops.NotDifferentiable("Size")
@@ -590,6 +592,7 @@ def _PadGrad(op, grad):
else:
return x_grad, None
+
ops.RegisterGradient("Pad")(_PadGrad)
ops.RegisterGradient("PadV2")(_PadGrad)
@@ -625,30 +628,34 @@ def _ReverseV2Grad(op, grad):
def _SpaceToBatchGrad(op, grad):
# Its gradient is the opposite op: BatchToSpace.
block_size = op.get_attr("block_size")
- return [array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size),
- None]
+ return [
+ array_ops.batch_to_space(grad, op.inputs[1], block_size=block_size), None
+ ]
@ops.RegisterGradient("SpaceToBatchND")
def _SpaceToBatchNDGrad(op, grad):
# Its gradient is the opposite op: BatchToSpaceND.
- return [array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]),
- None, None]
+ return [
+ array_ops.batch_to_space_nd(grad, op.inputs[1], op.inputs[2]), None, None
+ ]
@ops.RegisterGradient("BatchToSpace")
def _BatchToSpaceGrad(op, grad):
# Its gradient is the opposite op: SpaceToBatch.
block_size = op.get_attr("block_size")
- return [array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size),
- None]
+ return [
+ array_ops.space_to_batch(grad, op.inputs[1], block_size=block_size), None
+ ]
@ops.RegisterGradient("BatchToSpaceND")
def _BatchToSpaceNDGrad(op, grad):
# Its gradient is the opposite op: SpaceToBatchND.
- return [array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]),
- None, None]
+ return [
+ array_ops.space_to_batch_nd(grad, op.inputs[1], op.inputs[2]), None, None
+ ]
@ops.RegisterGradient("SpaceToDepth")
@@ -712,30 +719,28 @@ def _QuantizeAndDequantizeV3Grad(_, grad):
def _ExtractImagePatchesGrad(op, grad):
batch_size, rows_in, cols_in, channels = [
- dim.value for dim in op.inputs[0].get_shape()
+ dim.value for dim in op.inputs[0].get_shape()
]
input_bhwc = array_ops.shape(op.inputs[0])
batch_size = input_bhwc[0]
channels = input_bhwc[3]
- _, rows_out, cols_out, _ = [
- dim.value for dim in op.outputs[0].get_shape()
- ]
- _, ksize_r, ksize_c, _ = op.get_attr('ksizes')
- _, stride_r, stride_h, _ = op.get_attr('strides')
- _, rate_r, rate_c, _ = op.get_attr('rates')
- padding = op.get_attr('padding')
+ _, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].get_shape()]
+ _, ksize_r, ksize_c, _ = op.get_attr("ksizes")
+ _, stride_r, stride_h, _ = op.get_attr("strides")
+ _, rate_r, rate_c, _ = op.get_attr("rates")
+ padding = op.get_attr("padding")
ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1)
ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1)
- if padding == b'SAME':
+ if padding == b"SAME":
rows_out = int(ceil(rows_in / stride_r))
cols_out = int(ceil(cols_in / stride_h))
pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2
pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2
- elif padding == b'VALID':
+ elif padding == b"VALID":
rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r))
cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h))
pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in
@@ -744,10 +749,9 @@ def _ExtractImagePatchesGrad(op, grad):
pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols)
grad_expanded = array_ops.transpose(
- array_ops.reshape(grad, (batch_size, rows_out,
- cols_out, ksize_r, ksize_c, channels)),
- (1, 2, 3, 4, 0, 5)
- )
+ array_ops.reshape(
+ grad, (batch_size, rows_out, cols_out, ksize_r, ksize_c, channels)),
+ (1, 2, 3, 4, 0, 5))
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
row_steps = range(0, rows_out * stride_r, stride_r)
@@ -759,29 +763,21 @@ def _ExtractImagePatchesGrad(op, grad):
r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols
r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff
- idx.extend([(r * (cols_in) + c,
- i * (cols_out * ksize_r * ksize_c) +
- j * (ksize_r * ksize_c) +
- ri * (ksize_c) + ci)
+ idx.extend([(r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j *
+ (ksize_r * ksize_c) + ri * (ksize_c) + ci)
for (ri, r) in enumerate(range(r_low, r_high, rate_r))
for (ci, c) in enumerate(range(c_low, c_high, rate_c))
- if 0 <= r and r < rows_in and 0 <= c and c < cols_in
- ])
+ if 0 <= r and r < rows_in and 0 <= c and c < cols_in])
- sp_shape = (rows_in * cols_in,
- rows_out * cols_out * ksize_r * ksize_c)
+ sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c)
sp_mat = sparse_tensor.SparseTensor(
- array_ops.constant(idx, dtype=ops.dtypes.int64),
- array_ops.ones((len(idx),), dtype=ops.dtypes.float32),
- sp_shape
- )
+ array_ops.constant(idx, dtype=ops.dtypes.int64),
+ array_ops.ones((len(idx),), dtype=ops.dtypes.float32), sp_shape)
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
- grad_out = array_ops.reshape(
- jac, (rows_in, cols_in, batch_size, channels)
- )
+ grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
grad_out = array_ops.transpose(grad_out, (2, 0, 1, 3))
return [grad_out]
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index d379eccc20..49191c647d 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Control Flow Operations.
See the @{$python/control_flow_ops} guide.
@@ -84,7 +83,6 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
-
# We override the 'tuple' for a control flow op, so we keep python's
# existing 'tuple' for later use in this module.
_basetuple = tuple
@@ -156,9 +154,10 @@ def Assert(condition, data, summarize=None, name=None):
xs = ops.convert_n_to_tensor(data)
data_str = [_summarize_eager(x, summarize) for x in xs]
raise errors.InvalidArgumentError(
- node_def=None, op=None,
- message="Expected '%s' to be true. Summarized data: %s" % (
- condition, "\n".join(data_str)))
+ node_def=None,
+ op=None,
+ message="Expected '%s' to be true. Summarized data: %s" %
+ (condition, "\n".join(data_str)))
return
with ops.name_scope(name, "Assert", [condition, data]) as name:
@@ -167,15 +166,15 @@ def Assert(condition, data, summarize=None, name=None):
# As a simple heuristic, we assume that string and int32 are
# on host to avoid the need to use cond. If it is not case,
# we will pay the price copying the tensor to host memory.
- return gen_logging_ops._assert(
- condition, data, summarize, name="Assert")
+ return gen_logging_ops._assert(condition, data, summarize, name="Assert")
else:
condition = ops.convert_to_tensor(condition, name="Condition")
+
def true_assert():
return gen_logging_ops._assert(
condition, data, summarize, name="Assert")
- guarded_assert = cond(
- condition, no_op, true_assert, name="AssertGuard")
+
+ guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
if context.in_eager_mode():
return
return guarded_assert.op
@@ -215,7 +214,7 @@ def _Identity(data, name=None):
def _NextIteration(data, name=None):
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
- if data.dtype._is_ref_dtype: # pylint: disable=protected-access
+ if data.dtype._is_ref_dtype: # pylint: disable=protected-access
return ref_next_iteration(data, name=name)
else:
return next_iteration(data, name=name)
@@ -234,8 +233,13 @@ def _NextIteration(data, name=None):
return sparse_tensor.SparseTensor(indices, values, dense_shape)
-def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
- use_ref=True, use_input_shape=True, name=None):
+def _Enter(data,
+ frame_name,
+ is_constant=False,
+ parallel_iterations=10,
+ use_ref=True,
+ use_input_shape=True,
+ name=None):
"""Creates or finds a child frame, and makes `data` available to it.
The unique `frame_name` is used by the `Executor` to identify frames. If
@@ -257,35 +261,51 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
- result = ref_enter(data, frame_name, is_constant, parallel_iterations,
- name=name)
+ result = ref_enter(
+ data, frame_name, is_constant, parallel_iterations, name=name)
else:
- result = enter(data, frame_name, is_constant, parallel_iterations,
- name=name)
+ result = enter(
+ data, frame_name, is_constant, parallel_iterations, name=name)
if use_input_shape:
result.set_shape(data.get_shape())
return result
else:
if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(data))
- values = _Enter(data.values, frame_name, is_constant,
- parallel_iterations=parallel_iterations,
- use_input_shape=use_input_shape, name=name)
- indices = enter(data.indices, frame_name, is_constant,
- parallel_iterations, name="indices")
+ values = _Enter(
+ data.values,
+ frame_name,
+ is_constant,
+ parallel_iterations=parallel_iterations,
+ use_input_shape=use_input_shape,
+ name=name)
+ indices = enter(
+ data.indices,
+ frame_name,
+ is_constant,
+ parallel_iterations,
+ name="indices")
if use_input_shape:
indices.set_shape(data.indices.get_shape())
if isinstance(data, ops.IndexedSlices):
dense_shape = data.dense_shape
if dense_shape is not None:
- dense_shape = enter(dense_shape, frame_name, is_constant,
- parallel_iterations, name="dense_shape")
+ dense_shape = enter(
+ dense_shape,
+ frame_name,
+ is_constant,
+ parallel_iterations,
+ name="dense_shape")
if use_input_shape:
dense_shape.set_shape(data.dense_shape.get_shape())
return ops.IndexedSlices(values, indices, dense_shape)
else:
- dense_shape = enter(data.dense_shape, frame_name, is_constant,
- parallel_iterations, name="dense_shape")
+ dense_shape = enter(
+ data.dense_shape,
+ frame_name,
+ is_constant,
+ parallel_iterations,
+ name="dense_shape")
if use_input_shape:
dense_shape.set_shape(data.dense_shape.get_shape())
return sparse_tensor.SparseTensor(indices, values, dense_shape)
@@ -444,8 +464,10 @@ def merge(inputs, name=None):
if any([inp is None for inp in inputs]):
raise ValueError("At least one of the merge inputs is None: %s" % inputs)
with ops.name_scope(name, "Merge", inputs) as name:
- inputs = [ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
- for inp in inputs]
+ inputs = [
+ ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
+ for inp in inputs
+ ]
if all([isinstance(v, ops.Tensor) for v in inputs]):
if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access
return gen_control_flow_ops._ref_merge(inputs, name)
@@ -475,6 +497,8 @@ def merge(inputs, name=None):
else:
dense_shape = None
return ops.IndexedSlices(values, indices, dense_shape), chosen_index
+
+
# pylint: enable=protected-access
@@ -488,7 +512,9 @@ def _convert_tensorarray_to_flow(tensor_or_tensor_array):
def _make_tensor_array(ta, t_or_flow):
# pylint: disable=protected-access
new_ta = tensor_array_ops.TensorArray(
- dtype=ta.dtype, handle=ta.handle, flow=t_or_flow,
+ dtype=ta.dtype,
+ handle=ta.handle,
+ flow=t_or_flow,
infer_shape=ta._infer_shape,
colocate_with_first_write_call=ta._colocate_with_first_write_call)
new_ta._colocate_with = ta._colocate_with
@@ -500,13 +526,13 @@ def _make_tensor_array(ta, t_or_flow):
def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
if len(tensors_or_tensorarrays) != len(tensors_or_flows):
raise ValueError(
- "Lengths of original Tensor list and new list do not match: %d vs. %d"
- % (len(tensors_or_tensorarrays), len(tensors_or_flows)))
+ "Lengths of original Tensor list and new list do not match: %d vs. %d" %
+ (len(tensors_or_tensorarrays), len(tensors_or_flows)))
return [
_make_tensor_array(ta, t_or_flow)
- if isinstance(ta, tensor_array_ops.TensorArray)
- else t_or_flow
- for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)]
+ if isinstance(ta, tensor_array_ops.TensorArray) else t_or_flow
+ for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
+ ]
def _ShapeLessThanOrEqual(shape1, shape2):
@@ -545,8 +571,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes):
raise ValueError(
"The shape invariant specified for %s is not compatible with "
"the initial shape of the loop variable. It enters the loop "
- "with shape %s, but the specified shape invariant is %s."
- % (inp.name, inp.get_shape(), shape))
+ "with shape %s, but the specified shape invariant is %s." %
+ (inp.name, inp.get_shape(), shape))
var.set_shape(shape)
else:
if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
@@ -557,8 +583,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes):
"The shape invariant specified for %s is not compatible with "
"the initial shape of the values tensor of this IndexedSlices. "
"It enters the loop with shape %s, but the specified shape "
- "invariant is %s."
- % (inp.values.name, inp.values.get_shape(), shape))
+ "invariant is %s." % (inp.values.name, inp.values.get_shape(),
+ shape))
var.values.set_shape(shape)
var.indices.set_shape(tensor_shape.TensorShape([shape[0]]))
if var.dense_shape is not None:
@@ -569,8 +595,8 @@ def _SetShapeInvariants(input_vars, enter_vars, shapes):
"The shape invariant specified for %s is not compatible with "
"the initial shape of the shape tensor of this SparseTensor. "
"It enters the loop with shape %s, but the specified shape "
- "invariant is %s."
- % (inp.dense_shape.name, inp.dense_shape.get_shape(), shape))
+ "invariant is %s." % (inp.dense_shape.name,
+ inp.dense_shape.get_shape(), shape))
var.values.set_shape(tensor_shape.TensorShape([None]))
var.indices.set_shape(tensor_shape.TensorShape([None, shape.ndims]))
var.dense_shape.set_shape(shape)
@@ -599,8 +625,8 @@ def _EnforceShapeInvariant(merge_var, next_var):
"The shape for %s is not an invariant for the loop. It enters "
"the loop with shape %s, but has shape %s after one iteration. "
"Provide shape invariants using either the `shape_invariants` "
- "argument of tf.while_loop or set_shape() on the loop variables."
- % (merge_var.name, m_shape, n_shape))
+ "argument of tf.while_loop or set_shape() on the loop variables." %
+ (merge_var.name, m_shape, n_shape))
else:
if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(var))
@@ -623,9 +649,9 @@ def _EnforceShapeInvariant(merge_var, next_var):
"the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
"after one iteration. Provide shape invariants using either the "
"`shape_invariants` argument of tf.while_loop or set_shape() "
- "on the loop variables."
- % (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
- n_values_shape, n_indices_shape, n_shape_shape))
+ "on the loop variables." %
+ (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
+ n_values_shape, n_indices_shape, n_shape_shape))
else:
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
@@ -637,12 +663,12 @@ def _EnforceShapeInvariant(merge_var, next_var):
not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape) or
not _ShapeLessThanOrEqual(n_shape_shape, m_shape_shape)):
raise ValueError(
- "The shape for %s is not an invariant for the loop. It enters "
- "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
- "after one iteration. Provide shape invariants using either "
- "the `shape_invariants` argument of tf.while_loop or set_shape() "
- "on the loop variables."
- % (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
+ "The shape for %s is not an invariant for the loop. It enters "
+ "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
+ "after one iteration. Provide shape invariants using either "
+ "the `shape_invariants` argument of tf.while_loop or set_shape() "
+ "on the loop variables." %
+ (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
n_values_shape, n_indices_shape, n_shape_shape))
@@ -657,7 +683,7 @@ def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
# the types don't match.
# TODO(skyewm): call this for other cases below (needs testing)
_EnforceShapeInvariant(m, v)
- m.op._update_input(1, v) # pylint: disable=protected-access
+ m.op._update_input(1, v) # pylint: disable=protected-access
elif isinstance(m, ops.IndexedSlices):
# pylint: disable=protected-access
v = math_ops._as_indexed_slices(v, optimize=False)
@@ -720,8 +746,7 @@ def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
raise ValueError(
"Cannot create a gradient accumulator for tensor '%s' inside "
"XLA while_loop because maximum_iterations was not passed to "
- "the tf.while_loop call ('%s')."
- % (value_name, while_ctxt.name))
+ "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
# pylint: disable=protected-access
max_iter_ctxt = max_iter.op._get_control_flow_context()
@@ -742,9 +767,9 @@ def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
"while_loop. maximum_iterations tensor '%s' for while_loop context "
"'%s' must be statically known (e.g. a constant value or known "
"shape dimension), or be defined at or outside the while loop "
- "context '%s' (currently defined in '%s')." % (
- value_name, max_iter.name, while_ctxt.name,
- curr_ctxt_name, max_iter_ctxt.name))
+ "context '%s' (currently defined in '%s')." %
+ (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
+ max_iter_ctxt.name))
max_size *= const_max_iter
# Find the next outer WhileContext (or stop if we reach the
@@ -808,9 +833,11 @@ class GradLoopState(object):
outer_forward_ctxt = forward_ctxt.outer_context
# Add the forward loop counter.
- if outer_forward_ctxt: outer_forward_ctxt.Enter()
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Enter()
cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
- if outer_forward_ctxt: outer_forward_ctxt.Exit()
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Exit()
self._forward_context = forward_ctxt
self._forward_index = forward_index
@@ -835,7 +862,8 @@ class GradLoopState(object):
real_cnt, outer_grad_state)
outer_grad_ctxt.Exit()
else:
- if outer_forward_ctxt: outer_forward_ctxt.Enter()
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Enter()
self._grad_context = WhileContext(
maximum_iterations=forward_ctxt.maximum_iterations,
parallel_iterations=forward_ctxt.parallel_iterations,
@@ -845,7 +873,8 @@ class GradLoopState(object):
grad_state=self)
self._grad_index = self._grad_context.AddBackpropLoopCounter(
cnt, outer_grad_state)
- if outer_forward_ctxt: outer_forward_ctxt.Exit()
+ if outer_forward_ctxt:
+ outer_forward_ctxt.Exit()
@property
def outer_grad_state(self):
@@ -973,7 +1002,8 @@ class GradLoopState(object):
# curr_ctxt is the context that tf.gradients was called in.
curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
with ops.control_dependencies(None):
- if curr_ctxt: curr_ctxt.Enter()
+ if curr_ctxt:
+ curr_ctxt.Enter()
with ops.colocate_with(value):
# We only need to pass maximum_iterations to the stack if
# we're inside an XLA context.
@@ -984,11 +1014,10 @@ class GradLoopState(object):
value, self.forward_context)
# pylint: disable=protected-access
acc = gen_data_flow_ops._stack_v2(
- max_size=max_size,
- elem_type=value.dtype.base_dtype,
- name="f_acc")
+ max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
# pylint: enable=protected-access
- if curr_ctxt: curr_ctxt.Exit()
+ if curr_ctxt:
+ curr_ctxt.Exit()
# Make acc available in the forward context.
enter_acc = self.forward_context.AddValue(acc)
@@ -1009,8 +1038,7 @@ class GradLoopState(object):
else:
# value is in a cond context within the forward context.
if not isinstance(value_ctxt, CondContext):
- raise TypeError(
- "value_ctxt is not a CondContext: %s" % value_ctxt)
+ raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
if dead_branch:
# The special case for creating a zero tensor for a dead
# branch of a switch. See ControlFlowState.ZerosLike().
@@ -1134,8 +1162,8 @@ class GradLoopState(object):
if real_value is None:
# Add the stack pop op in the grad context.
- real_value = cur_grad_state.AddBackpropAccumulatedValue(history_value,
- cur_value)
+ real_value = cur_grad_state.AddBackpropAccumulatedValue(
+ history_value, cur_value)
if cur_grad_state != self:
real_value = self._grad_context.AddValue(real_value)
self._history_map[value.name] = real_value
@@ -1154,7 +1182,7 @@ class ControlFlowState(object):
"""Maintain the mapping from the loops to their grad states."""
def __init__(self):
- self._map = {} # maps forward loop context to GradLoopState
+ self._map = {} # maps forward loop context to GradLoopState
def GetGradState(self, op, before):
"""Return the grad state for this op if it's in a forward loop context."""
@@ -1318,7 +1346,8 @@ class ControlFlowState(object):
Returns:
A zero tensor of the same shape of op.outputs[index].
"""
- if util.IsLoopSwitch(op): return None
+ if util.IsLoopSwitch(op):
+ return None
dead_branch = util.IsSwitch(op)
forward_ctxt = _GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)
@@ -1361,8 +1390,8 @@ class ControlFlowState(object):
grad_state.grad_context.Enter()
# Create a zero tensor with the right shape.
- shape = grad_state.AddBackpropAccumulatedValue(
- history_zeros_shape, zeros_shape, dead_branch)
+ shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
+ zeros_shape, dead_branch)
result = array_ops.zeros(shape, val.dtype)
return result
@@ -1393,12 +1422,14 @@ class ControlFlowState(object):
else:
# Create a zeros in the outer grad context.
outer_grad_ctxt = grad_state.grad_context.outer_context
- if outer_grad_ctxt: outer_grad_ctxt.Enter()
+ if outer_grad_ctxt:
+ outer_grad_ctxt.Enter()
enter_grad_op = b_merge.op.inputs[0].op
enter_grad = enter_grad_op.inputs[0]
grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
grad_val = array_ops.zeros(grad_shape)
- if outer_grad_ctxt: outer_grad_ctxt.Exit()
+ if outer_grad_ctxt:
+ outer_grad_ctxt.Exit()
# Use the zeros for iterations > 0.
grad_state.grad_context.Enter()
next_grad_val = _NextIteration(grad_val)
@@ -1470,8 +1501,7 @@ class ControlFlowContext(object):
self._outer_context = ops.get_default_graph()._get_control_flow_context()
self._context_stack = []
if values_def:
- self._init_values_from_proto(values_def,
- import_scope=import_scope)
+ self._init_values_from_proto(values_def, import_scope=import_scope)
else:
# Values that have been already seen in this context.
self._values = set()
@@ -1532,19 +1562,16 @@ class ControlFlowContext(object):
"""
values_def = control_flow_pb2.ValuesDef()
values_def.values.extend(
- [ops.strip_name_scope(v, export_scope)
- for v in sorted(self._values)])
+ [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
for k, v in self._external_values.items():
k = ops.strip_name_scope(k, export_scope)
- values_def.external_values[k] = ops.strip_name_scope(
- v.name, export_scope)
+ values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
return values_def
@staticmethod
def _from_proto(values_def, import_scope=None):
"""Returns a `ControlFlowContext` created from `values_def`."""
- return ControlFlowContext(values_def=values_def,
- import_scope=import_scope)
+ return ControlFlowContext(values_def=values_def, import_scope=import_scope)
def AddName(self, name):
self._values.add(name)
@@ -1599,6 +1626,7 @@ class ControlFlowContext(object):
op._remove_all_control_inputs()
op._add_control_inputs(internal_control_inputs)
return internal_control_inputs
+
# pylint: enable=protected-access
def AddInnerOp(self, op):
@@ -1626,8 +1654,13 @@ class ControlFlowContext(object):
class CondContext(ControlFlowContext):
"""The context for the conditional construct."""
- def __init__(self, pred=None, pivot=None, branch=None,
- name="cond_text", context_def=None, import_scope=None):
+ def __init__(self,
+ pred=None,
+ pivot=None,
+ branch=None,
+ name="cond_text",
+ context_def=None,
+ import_scope=None):
"""Creates a `CondContext`.
Args:
@@ -1647,9 +1680,9 @@ class CondContext(ControlFlowContext):
else:
# Initializes the default fields.
ControlFlowContext.__init__(self)
- self._pred = pred # The boolean tensor for the cond predicate
- self._pivot = pivot # The predicate tensor in this branch
- self._branch = branch # 0 or 1 representing this branch
+ self._pred = pred # The boolean tensor for the cond predicate
+ self._pivot = pivot # The predicate tensor in this branch
+ self._branch = branch # 0 or 1 representing this branch
# Values considered to have been already seen in this context.
self._values.add(pred.name)
@@ -1665,15 +1698,14 @@ class CondContext(ControlFlowContext):
assert isinstance(context_def, control_flow_pb2.CondContextDef)
# Create from context_def.
g = ops.get_default_graph()
- self._name = ops.prepend_name_scope(
- context_def.context_name, import_scope)
- self._pred = g.as_graph_element(ops.prepend_name_scope(
- context_def.pred_name, import_scope))
- self._pivot = g.as_graph_element(ops.prepend_name_scope(
- context_def.pivot_name, import_scope))
+ self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
+ self._pred = g.as_graph_element(
+ ops.prepend_name_scope(context_def.pred_name, import_scope))
+ self._pivot = g.as_graph_element(
+ ops.prepend_name_scope(context_def.pivot_name, import_scope))
self._branch = context_def.branch
- super(CondContext, self).__init__(values_def=context_def.values_def,
- import_scope=import_scope)
+ super(CondContext, self).__init__(
+ values_def=context_def.values_def, import_scope=import_scope)
@property
def pred(self):
@@ -1711,18 +1743,16 @@ class CondContext(ControlFlowContext):
Returns:
A `CondContextDef` protocol buffer.
"""
- if (export_scope is None or
- self.name.startswith(export_scope)):
+ if (export_scope is None or self.name.startswith(export_scope)):
context_def = control_flow_pb2.CondContextDef()
- context_def.context_name = ops.strip_name_scope(
- self.name, export_scope)
- context_def.pred_name = ops.strip_name_scope(
- self._pred.name, export_scope)
- context_def.pivot_name = ops.strip_name_scope(
- self._pivot.name, export_scope)
+ context_def.context_name = ops.strip_name_scope(self.name, export_scope)
+ context_def.pred_name = ops.strip_name_scope(self._pred.name,
+ export_scope)
+ context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
+ export_scope)
context_def.branch = self._branch
- context_def.values_def.MergeFrom(super(CondContext, self)._to_proto(
- export_scope))
+ context_def.values_def.MergeFrom(
+ super(CondContext, self)._to_proto(export_scope))
return context_def
else:
@@ -1731,8 +1761,7 @@ class CondContext(ControlFlowContext):
@staticmethod
def from_proto(context_def, import_scope=None):
"""Returns a `CondContext` object created from `context_def`."""
- return CondContext(context_def=context_def,
- import_scope=import_scope)
+ return CondContext(context_def=context_def, import_scope=import_scope)
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
@@ -1846,8 +1875,8 @@ class CondContext(ControlFlowContext):
if original_result is None:
return no_op(), None
else:
- original_result = nest.map_structure(
- array_ops.identity, original_result)
+ original_result = nest.map_structure(array_ops.identity,
+ original_result)
if original_result is None:
return None, None
@@ -1871,11 +1900,15 @@ def _UnpackIfSingleton(res):
# pylint: disable=g-doc-args
@tf_export("cond")
@deprecation.deprecated_args(
- None,
- "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
+ None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
"fn1", "fn2")
-def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
- fn1=None, fn2=None):
+def cond(pred,
+ true_fn=None,
+ false_fn=None,
+ strict=False,
+ name=None,
+ fn1=None,
+ fn2=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
@@ -2044,6 +2077,8 @@ def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
if not strict:
merges = _UnpackIfSingleton(merges)
return merges
+
+
# pylint: enable=g-doc-args
# pylint: enable=redefined-outer-name
@@ -2139,8 +2174,7 @@ class WhileContext(ControlFlowContext):
assert isinstance(context_def, control_flow_pb2.WhileContextDef)
# Create from context_def.
g = ops.get_default_graph()
- self._name = ops.prepend_name_scope(
- context_def.context_name, import_scope)
+ self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
if context_def.maximum_iterations_name:
self._maximum_iterations = g.as_graph_element(
ops.prepend_name_scope(context_def.maximum_iterations_name,
@@ -2150,25 +2184,27 @@ class WhileContext(ControlFlowContext):
self._parallel_iterations = context_def.parallel_iterations
self._back_prop = context_def.back_prop
self._swap_memory = context_def.swap_memory
- self._pivot_for_pred = g.as_graph_element(ops.prepend_name_scope(
- context_def.pivot_for_pred_name, import_scope))
+ self._pivot_for_pred = g.as_graph_element(
+ ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
# We use this node to control constants created by the body lambda.
- self._pivot_for_body = g.as_graph_element(ops.prepend_name_scope(
- context_def.pivot_for_body_name, import_scope))
+ self._pivot_for_body = g.as_graph_element(
+ ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
# The boolean tensor for loop termination condition. Used in code
# generation for gradient computation.
self._pivot = g.as_graph_element(
ops.prepend_name_scope(context_def.pivot_name, import_scope))
# The list of exit tensors for loop variables.
- self._loop_exits = [g.as_graph_element(
- ops.prepend_name_scope(exit_name, import_scope))
- for exit_name in context_def.loop_exit_names]
+ self._loop_exits = [
+ g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
+ for exit_name in context_def.loop_exit_names
+ ]
# The list of enter tensors for loop variables.
- self._loop_enters = [g.as_graph_element(
- ops.prepend_name_scope(enter_name, import_scope))
- for enter_name in context_def.loop_enter_names]
- super(WhileContext, self).__init__(values_def=context_def.values_def,
- import_scope=import_scope)
+ self._loop_enters = [
+ g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
+ for enter_name in context_def.loop_enter_names
+ ]
+ super(WhileContext, self).__init__(
+ values_def=context_def.values_def, import_scope=import_scope)
@property
def maximum_iterations(self):
@@ -2219,11 +2255,9 @@ class WhileContext(ControlFlowContext):
Returns:
A `WhileContextDef` protocol buffer.
"""
- if (export_scope is None or
- self.name.startswith(export_scope)):
+ if (export_scope is None or self.name.startswith(export_scope)):
context_def = control_flow_pb2.WhileContextDef()
- context_def.context_name = ops.strip_name_scope(
- self.name, export_scope)
+ context_def.context_name = ops.strip_name_scope(self.name, export_scope)
context_def.parallel_iterations = self._parallel_iterations
if self._maximum_iterations is not None:
context_def.maximum_iterations_name = ops.strip_name_scope(
@@ -2234,17 +2268,16 @@ class WhileContext(ControlFlowContext):
self._pivot_for_pred.name, export_scope)
context_def.pivot_for_body_name = ops.strip_name_scope(
self._pivot_for_body.name, export_scope)
- context_def.pivot_name = ops.strip_name_scope(
- self._pivot.name, export_scope)
- context_def.loop_exit_names.extend(
- [ops.strip_name_scope(l.name, export_scope)
- for l in self._loop_exits])
- context_def.loop_enter_names.extend(
- [ops.strip_name_scope(l.name, export_scope)
- for l in self._loop_enters])
+ context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
+ export_scope)
+ context_def.loop_exit_names.extend([
+ ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
+ ])
+ context_def.loop_enter_names.extend([
+ ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
+ ])
context_def.values_def.MergeFrom(
- super(WhileContext, self)._to_proto(
- export_scope=export_scope))
+ super(WhileContext, self)._to_proto(export_scope=export_scope))
return context_def
else:
@@ -2261,8 +2294,7 @@ class WhileContext(ControlFlowContext):
Returns:
A `WhileContext` Python object.
"""
- return WhileContext(context_def=context_def,
- import_scope=import_scope)
+ return WhileContext(context_def=context_def, import_scope=import_scope)
def GetWhileContext(self):
return self
@@ -2299,8 +2331,11 @@ class WhileContext(ControlFlowContext):
result = self._outer_context.AddValue(val)
# Create an Enter to make `result` known to this loop context.
with ops.control_dependencies(None):
- enter = _Enter(result, self._name, is_constant=True,
- parallel_iterations=self._parallel_iterations)
+ enter = _Enter(
+ result,
+ self._name,
+ is_constant=True,
+ parallel_iterations=self._parallel_iterations)
enter.graph.prevent_feeding(enter)
if self._outer_context:
self._outer_context.AddInnerOp(enter.op)
@@ -2378,6 +2413,7 @@ class WhileContext(ControlFlowContext):
def _MaybeAddControlDependency(self, op):
"""Add a control input to the op if it only depends on loop invariants."""
+
def _IsOpFree(op):
"""Determines if `op` needs a control dependency."""
if op.control_inputs:
@@ -2390,6 +2426,7 @@ class WhileContext(ControlFlowContext):
if not util.IsLoopConstantEnter(x.op):
return False
return True
+
if _IsOpFree(op):
# pylint: disable=protected-access
op._add_control_input(self.GetControlPivot().op)
@@ -2423,9 +2460,12 @@ class WhileContext(ControlFlowContext):
self.Enter()
self.AddName(n.name)
- enter_n = _Enter(n, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- name="f_count")
+ enter_n = _Enter(
+ n,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="f_count")
self.loop_enters.append(enter_n)
merge_n = merge([enter_n, enter_n])[0]
@@ -2465,9 +2505,12 @@ class WhileContext(ControlFlowContext):
self.Enter()
self.AddName(count.name)
- enter_count = _Enter(count, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- name="b_count")
+ enter_count = _Enter(
+ count,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_count")
self.loop_enters.append(enter_count)
merge_count = merge([enter_count, enter_count])[0]
@@ -2525,9 +2568,11 @@ class WhileContext(ControlFlowContext):
# without running any iterations.
shape = grad.get_shape()
if shape.is_fully_defined():
- if self.outer_context: self.outer_context.Enter()
+ if self.outer_context:
+ self.outer_context.Enter()
acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Exit()
else:
value = op.inputs[0]
if (isinstance(self.outer_context, WhileContext) and
@@ -2546,16 +2591,21 @@ class WhileContext(ControlFlowContext):
acc = array_ops.zeros(real_shape, grad.dtype)
self.outer_context.Exit()
else:
- if self.outer_context: self.outer_context.Enter()
+ if self.outer_context:
+ self.outer_context.Enter()
zeros_shape = array_ops.shape_internal(value, optimize=False)
acc = array_ops.zeros(zeros_shape, grad.dtype)
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Exit()
self.Enter()
self.AddName(acc.name)
- enter_acc = _Enter(acc, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- name="b_acc")
+ enter_acc = _Enter(
+ acc,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ name="b_acc")
self.loop_enters.append(enter_acc)
merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
@@ -2588,14 +2638,17 @@ class WhileContext(ControlFlowContext):
dense_shape = grad.dense_shape
self.Exit()
- if self.outer_context: self.outer_context.Enter()
+ if self.outer_context:
+ self.outer_context.Enter()
if values.get_shape().is_fully_defined():
values_shape = tensor_shape.TensorShape(
[tensor_shape.Dimension(1)] + values.get_shape().dims[1:])
- if self.outer_context: self.outer_context.Enter()
- values_acc = constant_op.constant(0, values.dtype, shape=values_shape,
- name="b_acc")
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Enter()
+ values_acc = constant_op.constant(
+ 0, values.dtype, shape=values_shape, name="b_acc")
+ if self.outer_context:
+ self.outer_context.Exit()
else:
values_shape = _resource_safe_shape(op.inputs[0])[1:]
values_shape = array_ops.concat([[1], values_shape], 0)
@@ -2604,16 +2657,19 @@ class WhileContext(ControlFlowContext):
shape_acc = None
if dense_shape is not None:
if dense_shape.get_shape().is_fully_defined():
- if self.outer_context: self.outer_context.Enter()
- shape_acc = constant_op.constant(0, dense_shape.dtype,
- shape=dense_shape.get_shape())
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Enter()
+ shape_acc = constant_op.constant(
+ 0, dense_shape.dtype, shape=dense_shape.get_shape())
+ if self.outer_context:
+ self.outer_context.Exit()
else:
shape_acc = array_ops.zeros_like(
array_ops.shape_internal(op.inputs[0], optimize=False),
optimize=False)
- if self.outer_context: self.outer_context.Exit()
+ if self.outer_context:
+ self.outer_context.Exit()
self.Enter()
self.AddName(values_acc.name)
@@ -2626,9 +2682,15 @@ class WhileContext(ControlFlowContext):
# Set use_input_shape=False since the accumulator tensors will grow in
# size. If use_input_shape=True, the _update_input call below will result in
# incompatible shapes.
- enter_acc = [_Enter(x, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- use_input_shape=False, name="b_acc") for x in init_acc]
+ enter_acc = [
+ _Enter(
+ x,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ use_input_shape=False,
+ name="b_acc") for x in init_acc
+ ]
# Manually set appropriate partial shapes.
enter_acc[0].set_shape([None])
if values_acc.shape.dims is not None:
@@ -2645,8 +2707,7 @@ class WhileContext(ControlFlowContext):
]
if shape_acc is not None:
# For the shape we just keep the maximum
- acc_indexed_slices.append(
- math_ops.maximum(dense_shape, switch_acc[2][1]))
+ acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
next_acc = [_NextIteration(x) for x in acc_indexed_slices]
for xm, xn in zip(merge_acc, next_acc):
@@ -2657,7 +2718,8 @@ class WhileContext(ControlFlowContext):
self.ExitResult(exit_acc)
return ops.IndexedSlices(
- indices=exit_acc[0], values=exit_acc[1],
+ indices=exit_acc[0],
+ values=exit_acc[1],
dense_shape=exit_acc[2] if shape_acc is not None else None)
def _InitializeValues(self, values):
@@ -2690,10 +2752,14 @@ class WhileContext(ControlFlowContext):
if self._outer_context:
real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
with ops.control_dependencies(None):
- enter_vars = [_Enter(x, self._name, is_constant=False,
- parallel_iterations=self._parallel_iterations,
- use_input_shape=(shape_invariants is None))
- for x in real_vars]
+ enter_vars = [
+ _Enter(
+ x,
+ self._name,
+ is_constant=False,
+ parallel_iterations=self._parallel_iterations,
+ use_input_shape=(shape_invariants is None)) for x in real_vars
+ ]
for x in enter_vars:
x.graph.prevent_feeding(x)
if self._outer_context:
@@ -2754,11 +2820,13 @@ class WhileContext(ControlFlowContext):
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
summary_ref[:] = pre_summaries
with ops.control_dependencies(new_summaries):
+
def map_fn(x):
# TODO(apassos) figure out how to trigger with tensor arrays as well
if isinstance(x, tensor_array_ops.TensorArray):
return x
return array_ops.identity(x)
+
body_result = nest.map_structure(map_fn, body_result)
# Compare the structure types of input and output of body.
@@ -2815,8 +2883,7 @@ class WhileContext(ControlFlowContext):
packed_exit_vars = nest.pack_sequence_as(
structure=original_body_result,
flat_sequence=exit_vars_with_tensor_arrays)
- return (packed_exit_vars[0] if len(exit_vars) == 1
- else packed_exit_vars)
+ return (packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars)
def _FixControlInputsAndContext(self, enters):
graph = ops.get_default_graph()
@@ -2834,8 +2901,9 @@ class WhileContext(ControlFlowContext):
for x in xs:
inp_op = x.op.inputs[0].op
control_inputs = graph._control_dependencies_for_inputs([inp_op])
- outer_control_inputs = [op for op in control_inputs
- if self._IsInOuterContext(op)]
+ outer_control_inputs = [
+ op for op in control_inputs if self._IsInOuterContext(op)
+ ]
x.op._set_control_flow_context(self)
x.op._add_control_inputs(outer_control_inputs)
graph._record_op_seen_by_control_dependencies(x.op)
@@ -2847,9 +2915,15 @@ class WhileContext(ControlFlowContext):
# pylint: disable=redefined-outer-name
@tf_export("while_loop")
-def while_loop(cond, body, loop_vars, shape_invariants=None,
- parallel_iterations=10, back_prop=True, swap_memory=False,
- name=None, maximum_iterations=None):
+def while_loop(cond,
+ body,
+ loop_vars,
+ shape_invariants=None,
+ parallel_iterations=10,
+ back_prop=True,
+ swap_memory=False,
+ name=None,
+ maximum_iterations=None):
"""Repeat `body` while the condition `cond` is true.
`cond` is a callable returning a boolean scalar tensor. `body` is a callable
@@ -3024,6 +3098,8 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
return result[1]
else:
return result
+
+
# pylint: enable=redefined-outer-name
@@ -3051,8 +3127,9 @@ def _AsTensorList(x, p):
if isinstance(v, ops.Tensor):
l.append(array_ops.identity(v))
else:
- l.append(ops.IndexedSlices(array_ops.identity(v.values),
- array_ops.identity(v.indices)))
+ l.append(
+ ops.IndexedSlices(
+ array_ops.identity(v.values), array_ops.identity(v.indices)))
return l
@@ -3062,8 +3139,7 @@ def _CheckResults(a, b):
for x, y in zip(a, b):
assert x.dtype == y.dtype, (
"Values returned by a() [%s] and b() [%s] must have "
- "the same type: %s, %s." %
- (x.name, y.name, x.dtype.name, y.dtype.name))
+ "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
def with_dependencies(dependencies, output_tensor, name=None):
@@ -3099,9 +3175,9 @@ def with_dependencies(dependencies, output_tensor, name=None):
if isinstance(output_tensor, ops.Tensor):
return _Identity(output_tensor, name=name)
else:
- return ops.IndexedSlices(_Identity(output_tensor.values, name=name),
- output_tensor.indices,
- output_tensor.dense_shape)
+ return ops.IndexedSlices(
+ _Identity(output_tensor.values, name=name), output_tensor.indices,
+ output_tensor.dense_shape)
def _GroupControlDeps(dev, deps, name=None):
@@ -3173,6 +3249,7 @@ def group(*inputs, **kwargs):
def device_key(dev):
"""A sort key that allows None to be compared to strings."""
return "" if dev is None else dev
+
for dev in sorted(six.iterkeys(ops_on_device), key=device_key):
deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
@@ -3463,12 +3540,14 @@ class XLAControlFlowContext(ControlFlowContext):
return x
-ops.register_proto_function(ops.GraphKeys.COND_CONTEXT,
- proto_type=control_flow_pb2.CondContextDef,
- to_proto=CondContext.to_proto,
- from_proto=CondContext.from_proto)
+ops.register_proto_function(
+ ops.GraphKeys.COND_CONTEXT,
+ proto_type=control_flow_pb2.CondContextDef,
+ to_proto=CondContext.to_proto,
+ from_proto=CondContext.from_proto)
-ops.register_proto_function(ops.GraphKeys.WHILE_CONTEXT,
- proto_type=control_flow_pb2.WhileContextDef,
- to_proto=WhileContext.to_proto,
- from_proto=WhileContext.from_proto)
+ops.register_proto_function(
+ ops.GraphKeys.WHILE_CONTEXT,
+ proto_type=control_flow_pb2.WhileContextDef,
+ to_proto=WhileContext.to_proto,
+ from_proto=WhileContext.from_proto)
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 34f0bf7b78..95e45bff06 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#==============================================================================
-
"""Data Flow Operations."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
@@ -40,6 +39,7 @@ from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
from tensorflow.python.util.tf_export import tf_export
+
# pylint: enable=wildcard-import
@@ -54,17 +54,19 @@ def _as_type_list(dtypes):
return list(dtypes)
-def _as_shape_list(shapes, dtypes, unknown_dim_allowed=False,
+def _as_shape_list(shapes,
+ dtypes,
+ unknown_dim_allowed=False,
unknown_rank_allowed=False):
"""Convert shapes to a list of tuples of int (or None)."""
del dtypes
if unknown_dim_allowed:
- if (not isinstance(shapes, collections.Sequence)
- or not shapes
- or any(shape is None or isinstance(shape, int) for shape in shapes)):
+ if (not isinstance(shapes, collections.Sequence) or not shapes or
+ any(shape is None or isinstance(shape, int) for shape in shapes)):
raise ValueError(
"When providing partial shapes, a list of shapes must be provided.")
- if shapes is None: return None
+ if shapes is None:
+ return None
if isinstance(shapes, tensor_shape.TensorShape):
shapes = [shapes]
if not isinstance(shapes, (tuple, list)):
@@ -103,7 +105,8 @@ def _shape_common(s1, s2):
return tensor_shape.unknown_shape()
d = [
d1 if d1 is not None and d1 == d2 else None
- for (d1, d2) in zip(s1.as_list(), s2.as_list())]
+ for (d1, d2) in zip(s1.as_list(), s2.as_list())
+ ]
return tensor_shape.TensorShape(d)
@@ -195,8 +198,7 @@ class QueueBase(object):
TypeError: When `queues` is not a list of `QueueBase` objects,
or when the data types of `queues` are not all the same.
"""
- if ((not queues) or
- (not isinstance(queues, list)) or
+ if ((not queues) or (not isinstance(queues, list)) or
(not all(isinstance(x, QueueBase) for x in queues))):
raise TypeError("A list of queues expected")
@@ -210,12 +212,16 @@ class QueueBase(object):
queue_shapes = [q.shapes for q in queues]
reduced_shapes = [
- six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)]
+ six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)
+ ]
queue_refs = array_ops.stack([x.queue_ref for x in queues])
selected_queue = array_ops.gather(queue_refs, index)
- return QueueBase(dtypes=dtypes, shapes=reduced_shapes, names=names,
- queue_ref=selected_queue)
+ return QueueBase(
+ dtypes=dtypes,
+ shapes=reduced_shapes,
+ names=names,
+ queue_ref=selected_queue)
@property
def queue_ref(self):
@@ -282,8 +288,8 @@ class QueueBase(object):
tensors = []
for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
- tensors.append(ops.convert_to_tensor(val, dtype=dtype,
- name="component_%d" % i))
+ tensors.append(
+ ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
return tensors
@@ -555,11 +561,13 @@ class QueueBase(object):
name = "%s_Close" % self._name
if self._queue_ref.dtype == _dtypes.resource:
return gen_data_flow_ops._queue_close_v2(
- self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
+ self._queue_ref,
+ cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
else:
return gen_data_flow_ops._queue_close(
- self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
+ self._queue_ref,
+ cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
def is_closed(self, name=None):
@@ -577,9 +585,9 @@ class QueueBase(object):
if name is None:
name = "%s_Is_Closed" % self._name
if self._queue_ref.dtype == _dtypes.resource:
- return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref,name=name)
+ return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name)
else:
- return gen_data_flow_ops.queue_is_closed_(self._queue_ref,name=name)
+ return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name)
def size(self, name=None):
"""Compute the number of elements in this queue.
@@ -611,8 +619,14 @@ class RandomShuffleQueue(QueueBase):
@end_compatibility
"""
- def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None,
- names=None, seed=None, shared_name=None,
+ def __init__(self,
+ capacity,
+ min_after_dequeue,
+ dtypes,
+ shapes=None,
+ names=None,
+ seed=None,
+ shared_name=None,
name="random_shuffle_queue"):
"""Create a queue that dequeues elements in a random order.
@@ -670,9 +684,14 @@ class RandomShuffleQueue(QueueBase):
string = (str(seed1) + shared_name).encode("utf-8")
seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
queue_ref = gen_data_flow_ops._random_shuffle_queue_v2(
- component_types=dtypes, shapes=shapes, capacity=capacity,
- min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2,
- shared_name=shared_name, name=name)
+ component_types=dtypes,
+ shapes=shapes,
+ capacity=capacity,
+ min_after_dequeue=min_after_dequeue,
+ seed=seed1,
+ seed2=seed2,
+ shared_name=shared_name,
+ name=name)
super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -690,8 +709,13 @@ class FIFOQueue(QueueBase):
@end_compatibility
"""
- def __init__(self, capacity, dtypes, shapes=None, names=None,
- shared_name=None, name="fifo_queue"):
+ def __init__(self,
+ capacity,
+ dtypes,
+ shapes=None,
+ names=None,
+ shared_name=None,
+ name="fifo_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
A `FIFOQueue` has bounded capacity; supports multiple concurrent
@@ -725,8 +749,11 @@ class FIFOQueue(QueueBase):
shapes = _as_shape_list(shapes, dtypes)
names = _as_name_list(names, dtypes)
queue_ref = gen_data_flow_ops._fifo_queue_v2(
- component_types=dtypes, shapes=shapes, capacity=capacity,
- shared_name=shared_name, name=name)
+ component_types=dtypes,
+ shapes=shapes,
+ capacity=capacity,
+ shared_name=shared_name,
+ name=name)
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -747,7 +774,12 @@ class PaddingFIFOQueue(QueueBase):
@end_compatibility
"""
- def __init__(self, capacity, dtypes, shapes, names=None, shared_name=None,
+ def __init__(self,
+ capacity,
+ dtypes,
+ shapes,
+ names=None,
+ shared_name=None,
name="padding_fifo_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
@@ -792,12 +824,15 @@ class PaddingFIFOQueue(QueueBase):
names = _as_name_list(names, dtypes)
if len(dtypes) != len(shapes):
raise ValueError("Shapes must be provided for all components, "
- "but received %d dtypes and %d shapes."
- % (len(dtypes), len(shapes)))
+ "but received %d dtypes and %d shapes." % (len(dtypes),
+ len(shapes)))
queue_ref = gen_data_flow_ops._padding_fifo_queue_v2(
- component_types=dtypes, shapes=shapes, capacity=capacity,
- shared_name=shared_name, name=name)
+ component_types=dtypes,
+ shapes=shapes,
+ capacity=capacity,
+ shared_name=shared_name,
+ name=name)
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -815,7 +850,12 @@ class PriorityQueue(QueueBase):
@end_compatibility
"""
- def __init__(self, capacity, types, shapes=None, names=None, shared_name=None,
+ def __init__(self,
+ capacity,
+ types,
+ shapes=None,
+ names=None,
+ shared_name=None,
name="priority_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
@@ -856,14 +896,17 @@ class PriorityQueue(QueueBase):
shapes = _as_shape_list(shapes, types)
queue_ref = gen_data_flow_ops._priority_queue_v2(
- component_types=types, shapes=shapes, capacity=capacity,
- shared_name=shared_name, name=name)
+ component_types=types,
+ shapes=shapes,
+ capacity=capacity,
+ shared_name=shared_name,
+ name=name)
priority_dtypes = [_dtypes.int64] + types
priority_shapes = [()] + shapes if shapes else shapes
- super(PriorityQueue, self).__init__(
- priority_dtypes, priority_shapes, names, queue_ref)
+ super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names,
+ queue_ref)
# TODO(josh11b): class BatchQueue(QueueBase):
@@ -943,8 +986,10 @@ class Barrier(object):
self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
self._barrier_ref = gen_data_flow_ops._barrier(
- component_types=self._types, shapes=self._shapes,
- shared_name=shared_name, name=name)
+ component_types=self._types,
+ shapes=self._shapes,
+ shared_name=shared_name,
+ name=name)
if context.in_graph_mode():
self._name = self._barrier_ref.op.name.split("/")[-1]
else:
@@ -1028,12 +1073,13 @@ class Barrier(object):
"""
if name is None:
name = "%s_BarrierTakeMany" % self._name
- ret = gen_data_flow_ops._barrier_take_many(self._barrier_ref,
- num_elements,
- self._types,
- allow_small_batch,
- timeout,
- name=name)
+ ret = gen_data_flow_ops._barrier_take_many(
+ self._barrier_ref,
+ num_elements,
+ self._types,
+ allow_small_batch,
+ timeout,
+ name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the Barrier object.
@@ -1048,8 +1094,7 @@ class Barrier(object):
op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys
for output, shape in zip(op.outputs[2:], self._shapes): # value_list
output.set_shape(
- tensor_shape.TensorShape([batch_dim]).concatenate(
- shape))
+ tensor_shape.TensorShape([batch_dim]).concatenate(shape))
return ret
@@ -1298,8 +1343,8 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
name="sparse_conditional_accumulator"):
accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
dtype=dtype, shape=shape, shared_name=shared_name, name=name)
- super(SparseConditionalAccumulator,
- self).__init__(dtype, shape, accumulator_ref)
+ super(SparseConditionalAccumulator, self).__init__(dtype, shape,
+ accumulator_ref)
def apply_indexed_slices_grad(self, grad, local_step=0, name=None):
"""Attempts to apply a gradient to the accumulator.
@@ -1368,8 +1413,8 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase):
local_step=local_step,
gradient_indices=math_ops.to_int64(grad_indices),
gradient_values=grad_values,
- gradient_shape=math_ops.to_int64([] if grad_shape is None else
- grad_shape),
+ gradient_shape=math_ops.to_int64([]
+ if grad_shape is None else grad_shape),
has_known_shape=(grad_shape is not None),
name=name)
@@ -1431,11 +1476,16 @@ class BaseStagingArea(object):
_identifier = 0
_lock = threading.Lock()
- def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
- capacity=0, memory_limit=0):
+ def __init__(self,
+ dtypes,
+ shapes=None,
+ names=None,
+ shared_name=None,
+ capacity=0,
+ memory_limit=0):
if shared_name is None:
- self._name = (ops.get_default_graph()
- .unique_name(self.__class__.__name__))
+ self._name = (
+ ops.get_default_graph().unique_name(self.__class__.__name__))
elif isinstance(shared_name, six.string_types):
self._name = shared_name
else:
@@ -1532,8 +1582,9 @@ class BaseStagingArea(object):
(sorted(vals.keys()), sorted(self._names)))
# The order of values in `self._names` indicates the order in which the
# tensors in the dictionary `vals` must be listed.
- vals, indices, n = zip(*[(vals[k], i, k) for i, k in enumerate(self._names)
- if k in vals])
+ vals, indices, n = zip(*[(vals[k], i, k)
+ for i, k in enumerate(self._names)
+ if k in vals])
else:
if self._names:
raise ValueError("You must enqueue a dictionary in a staging area "
@@ -1541,7 +1592,7 @@ class BaseStagingArea(object):
if indices is None:
raise ValueError("Indices must be supplied when inserting a list "
- "of tensors")
+ "of tensors")
if len(indices) != len(vals):
raise ValueError("Number of indices '%s' doesn't match "
@@ -1553,8 +1604,8 @@ class BaseStagingArea(object):
# Sanity check number of values
if not len(vals) <= len(self._dtypes):
- raise ValueError("Unexpected number of inputs '%s' vs '%s'" % (
- len(vals), len(self._dtypes)))
+ raise ValueError("Unexpected number of inputs '%s' vs '%s'" %
+ (len(vals), len(self._dtypes)))
tensors = []
@@ -1562,14 +1613,14 @@ class BaseStagingArea(object):
dtype, shape = self._dtypes[i], self._shapes[i]
# Check dtype
if not val.dtype == dtype:
- raise ValueError("Datatypes do not match. '%s' != '%s'" %(
- str(val.dtype), str(dtype)))
+ raise ValueError("Datatypes do not match. '%s' != '%s'" %
+ (str(val.dtype), str(dtype)))
# Check shape
val.get_shape().assert_is_compatible_with(shape)
- tensors.append(ops.convert_to_tensor(val, dtype=dtype,
- name="component_%d" % i))
+ tensors.append(
+ ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
return tensors, indices
@@ -1632,6 +1683,7 @@ class BaseStagingArea(object):
else:
return [vals]
+
class StagingArea(BaseStagingArea):
"""Class for staging inputs. No ordering guarantees.
@@ -1666,8 +1718,13 @@ class StagingArea(BaseStagingArea):
"""
- def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
- capacity=0, memory_limit=0):
+ def __init__(self,
+ dtypes,
+ shapes=None,
+ names=None,
+ shared_name=None,
+ capacity=0,
+ memory_limit=0):
"""Constructs a staging area object.
The two optional lists, `shapes` and `names`, must be of the same length
@@ -1702,9 +1759,8 @@ class StagingArea(BaseStagingArea):
ValueError: If one of the arguments is invalid.
"""
- super(StagingArea, self).__init__(dtypes, shapes,
- names, shared_name,
- capacity, memory_limit)
+ super(StagingArea, self).__init__(dtypes, shapes, names, shared_name,
+ capacity, memory_limit)
def put(self, values, name=None):
"""Create an op that places a value into the staging area.
@@ -1726,14 +1782,18 @@ class StagingArea(BaseStagingArea):
self._scope_vals(values)) as scope:
# Hard-code indices for this staging area
- indices = (list(six.moves.range(len(values)))
- if isinstance(values, (list, tuple)) else None)
+ indices = (
+ list(six.moves.range(len(values)))
+ if isinstance(values, (list, tuple)) else None)
vals, _ = self._check_put_dtypes(values, indices)
with ops.colocate_with(self._coloc_op):
- op = gen_data_flow_ops.stage(values=vals, shared_name=self._name,
- name=scope, capacity=self._capacity,
- memory_limit=self._memory_limit)
+ op = gen_data_flow_ops.stage(
+ values=vals,
+ shared_name=self._name,
+ name=scope,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
return op
@@ -1741,7 +1801,7 @@ class StagingArea(BaseStagingArea):
with ops.colocate_with(self._coloc_op):
ret = get_fn()
- indices = list(six.moves.range(len(self._dtypes))) # Hard coded
+ indices = list(six.moves.range(len(self._dtypes))) # Hard coded
return self._get_return_value(ret, indices)
def get(self, name=None):
@@ -1769,10 +1829,12 @@ class StagingArea(BaseStagingArea):
if name is None:
name = "%s_get" % self._name
+ # pylint: disable=bad-continuation
fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
shared_name=self._name, name=name,
capacity=self._capacity,
memory_limit=self._memory_limit)
+ # pylint: enable=bad-continuation
return self.__internal_get(fn, name)
@@ -1797,10 +1859,12 @@ class StagingArea(BaseStagingArea):
if name is None:
name = "%s_peek" % self._name
+ # pylint: disable=bad-continuation
fn = lambda: gen_data_flow_ops.stage_peek(index,
dtypes=self._dtypes, shared_name=self._name,
name=name, capacity=self._capacity,
memory_limit=self._memory_limit)
+ # pylint: enable=bad-continuation
return self.__internal_get(fn, name)
@@ -1816,9 +1880,12 @@ class StagingArea(BaseStagingArea):
if name is None:
name = "%s_size" % self._name
- return gen_data_flow_ops.stage_size(name=name, shared_name=self._name,
- dtypes=self._dtypes, capacity=self._capacity,
- memory_limit=self._memory_limit)
+ return gen_data_flow_ops.stage_size(
+ name=name,
+ shared_name=self._name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
def clear(self, name=None):
"""Clears the staging area.
@@ -1832,14 +1899,16 @@ class StagingArea(BaseStagingArea):
if name is None:
name = "%s_clear" % self._name
- return gen_data_flow_ops.stage_clear(name=name, shared_name=self._name,
- dtypes=self._dtypes, capacity=self._capacity,
- memory_limit=self._memory_limit)
+ return gen_data_flow_ops.stage_clear(
+ name=name,
+ shared_name=self._name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
+
class MapStagingArea(BaseStagingArea):
- """
- A `MapStagingArea` is a TensorFlow data structure that stores tensors across
- multiple steps, and exposes operations that can put and get tensors.
+ """A `MapStagingArea` is a TensorFlow data structure that stores tensors across multiple steps, and exposes operations that can put and get tensors.
Each `MapStagingArea` element is a (key, value) pair.
Only int64 keys are supported, other types should be
@@ -1852,7 +1921,8 @@ class MapStagingArea(BaseStagingArea):
It supports multiple concurrent producers and consumers; and
provides exactly-once delivery.
- Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors whose
+ Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors
+ whose
dtypes are described by `dtypes`, and whose shapes are optionally described
by the `shapes` argument.
@@ -1896,10 +1966,16 @@ class MapStagingArea(BaseStagingArea):
associated with it are removed.
"""
- def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
- ordered=False, capacity=0, memory_limit=0):
- """
- Args:
+ def __init__(self,
+ dtypes,
+ shapes=None,
+ names=None,
+ shared_name=None,
+ ordered=False,
+ capacity=0,
+ memory_limit=0):
+ """Args:
+
dtypes: A list of types. The length of dtypes must equal the number
of tensors in each element.
capacity: (Optional.) Maximum number of elements.
@@ -1925,9 +2001,8 @@ class MapStagingArea(BaseStagingArea):
"""
- super(MapStagingArea, self).__init__(dtypes, shapes,
- names, shared_name,
- capacity, memory_limit)
+ super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name,
+ capacity, memory_limit)
# Defer to different methods depending if the map is ordered
self._ordered = ordered
@@ -1950,8 +2025,7 @@ class MapStagingArea(BaseStagingArea):
self._clear_fn = gen_data_flow_ops.map_clear
def put(self, key, vals, indices=None, name=None):
- """
- Create an op that stores the (key, vals) pair in the staging area.
+ """Create an op that stores the (key, vals) pair in the staging area.
Incomplete puts are possible, preferably using a dictionary for vals
as the appropriate dtypes and shapes can be inferred from the value names
@@ -1973,7 +2047,8 @@ class MapStagingArea(BaseStagingArea):
The created op
Raises:
- ValueError: If the number or type of inputs don't match the staging area.
+ ValueError: If the number or type of inputs don't match the staging
+ area.
"""
with ops.name_scope(name, "%s_put" % self._name,
@@ -1982,10 +2057,15 @@ class MapStagingArea(BaseStagingArea):
vals, indices = self._check_put_dtypes(vals, indices)
with ops.colocate_with(self._coloc_op):
- op = self._put_fn(key, indices, vals, dtypes=self._dtypes,
- shared_name=self._name, name=scope,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ op = self._put_fn(
+ key,
+ indices,
+ vals,
+ dtypes=self._dtypes,
+ shared_name=self._name,
+ name=scope,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
return op
def _get_indices_and_dtypes(self, indices=None):
@@ -2001,13 +2081,13 @@ class MapStagingArea(BaseStagingArea):
if all(isinstance(i, str) for i in indices):
if self._names is None:
raise ValueError("String indices provided '%s', but this Staging Area "
- "was not created with names." % indices)
+ "was not created with names." % indices)
try:
indices = [self._names.index(n) for n in indices]
except ValueError:
raise ValueError("Named index '%s' not in "
- "Staging Area names '%s'" % (n, self._names))
+ "Staging Area names '%s'" % (n, self._names))
elif all(isinstance(i, int) for i in indices):
pass
else:
@@ -2018,10 +2098,8 @@ class MapStagingArea(BaseStagingArea):
return indices, dtypes
-
def peek(self, key, indices=None, name=None):
- """
- Peeks at staging area data associated with the key.
+ """Peeks at staging area data associated with the key.
If the key is not in the staging area, it will block
until the associated (key, value) is inserted.
@@ -2044,22 +2122,22 @@ class MapStagingArea(BaseStagingArea):
indices, dtypes = self._get_indices_and_dtypes(indices)
with ops.colocate_with(self._coloc_op):
- result = self._peek_fn(key, shared_name=self._name,
- indices=indices,
- dtypes=dtypes,
- name=name,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ result = self._peek_fn(
+ key,
+ shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
return self._get_return_value(result, indices)
def get(self, key=None, indices=None, name=None):
- """
- If the key is provided, the associated (key, value)
- is returned from the staging area. If the key is not
- in the staging area, this method will block until
- the associated (key, value) is inserted.
+ """If the key is provided, the associated (key, value) is returned from the staging area.
+ If the key is not in the staging area, this method will block until
+ the associated (key, value) is inserted.
If no key is provided and the staging area is ordered,
the (key, value) with the smallest key will be returned.
Otherwise, a random (key, value) will be returned.
@@ -2084,12 +2162,10 @@ class MapStagingArea(BaseStagingArea):
return self._pop(key, indices=indices, name=name)
def _pop(self, key, indices=None, name=None):
- """
- Remove and return the associated (key, value)
- is returned from the staging area. If the key is not
- in the staging area, this method will block until
- the associated (key, value) is inserted.
+ """Remove and return the associated (key, value) is returned from the staging area.
+ If the key is not in the staging area, this method will block until
+ the associated (key, value) is inserted.
Args:
key: Key associated with the required data
indices: Partial list of tensors to retrieve (optional).
@@ -2107,21 +2183,21 @@ class MapStagingArea(BaseStagingArea):
indices, dtypes = self._get_indices_and_dtypes(indices)
with ops.colocate_with(self._coloc_op):
- result = self._pop_fn(key, shared_name=self._name,
- indices=indices,
- dtypes=dtypes,
- name=name,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ result = self._pop_fn(
+ key,
+ shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
return key, self._get_return_value(result, indices)
def _popitem(self, indices=None, name=None):
- """
- If the staging area is ordered,
- the (key, value) with the smallest key will be returned.
- Otherwise, a random (key, value) will be returned.
+ """If the staging area is ordered, the (key, value) with the smallest key will be returned.
+ Otherwise, a random (key, value) will be returned.
If the staging area is empty when this operation executes,
it will block until there is an element to dequeue.
@@ -2142,12 +2218,13 @@ class MapStagingArea(BaseStagingArea):
indices, dtypes = self._get_indices_and_dtypes(indices)
with ops.colocate_with(self._coloc_op):
- key, result = self._popitem_fn(shared_name=self._name,
- indices=indices,
- dtypes=dtypes,
- name=name,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ key, result = self._popitem_fn(
+ shared_name=self._name,
+ indices=indices,
+ dtypes=dtypes,
+ name=name,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
# Separate keys and results out from
# underlying namedtuple
@@ -2157,8 +2234,7 @@ class MapStagingArea(BaseStagingArea):
return key, result
def size(self, name=None):
- """
- Returns the number of elements in the staging area.
+ """Returns the number of elements in the staging area.
Args:
name: A name for the operation (optional)
@@ -2169,14 +2245,15 @@ class MapStagingArea(BaseStagingArea):
if name is None:
name = "%s_size" % self._name
- return self._size_fn(shared_name=self._name,
- name=name, dtypes=self._dtypes,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ return self._size_fn(
+ shared_name=self._name,
+ name=name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
def incomplete_size(self, name=None):
- """
- Returns the number of incomplete elements in the staging area.
+ """Returns the number of incomplete elements in the staging area.
Args:
name: A name for the operation (optional)
@@ -2187,16 +2264,15 @@ class MapStagingArea(BaseStagingArea):
if name is None:
name = "%s_incomplete_size" % self._name
- return self._incomplete_size_fn(shared_name=self._name,
- name=name, dtypes=self._dtypes,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
-
-
+ return self._incomplete_size_fn(
+ shared_name=self._name,
+ name=name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
def clear(self, name=None):
- """
- Clears the staging area.
+ """Clears the staging area.
Args:
name: A name for the operation (optional)
@@ -2207,10 +2283,12 @@ class MapStagingArea(BaseStagingArea):
if name is None:
name = "%s_clear" % self._name
- return self._clear_fn(shared_name=self._name,
- name=name, dtypes=self._dtypes,
- capacity=self._capacity,
- memory_limit=self._memory_limit)
+ return self._clear_fn(
+ shared_name=self._name,
+ name=name,
+ dtypes=self._dtypes,
+ capacity=self._capacity,
+ memory_limit=self._memory_limit)
class RecordInput(object):
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 5d4b9ecd8b..314726ede6 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -52,7 +52,6 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
-
# Warn the user if we convert a sparse representation to dense with at
# least this number of elements.
_LARGE_SPARSE_NUM_ELEMENTS = 100000000
@@ -235,9 +234,10 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
raise TypeError(
"Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
y.dtype)
- new_grad_ys.append(array_ops.fill(
- array_ops.shape(y), constant_op.constant(
- 1, dtype=y.dtype, name="grad_ys_%d" % i)))
+ new_grad_ys.append(
+ array_ops.fill(
+ array_ops.shape(y),
+ constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
continue
if y.dtype.is_floating or y.dtype.is_integer:
if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
@@ -492,11 +492,12 @@ def gradients(ys,
name, "gradients",
list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
- xs = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
- else x
- for x in xs]
- xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name="x",
- as_ref=True)
+ xs = [
+ x.handle if isinstance(x, resource_variable_ops.ResourceVariable) else x
+ for x in xs
+ ]
+ xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
+ xs, name="x", as_ref=True)
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
# The approach we take here is as follows: Create a list of all ops in the
@@ -513,9 +514,8 @@ def gradients(ys,
to_ops = [t.op for t in ys]
from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients]
- pending_count, loop_state = _PendingCount(ops.get_default_graph(), to_ops,
- from_ops,
- colocate_gradients_with_ops)
+ pending_count, loop_state = _PendingCount(
+ ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)
# Iterate over the collected ops.
#
@@ -588,9 +588,8 @@ def gradients(ys,
# output, it means that the cost does not depend on output[i],
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
- if (not isinstance(out_grad, ops.Tensor) and
- not out_grad) and ((not grad_fn and is_func_call) or
- _IsTrainable(op.outputs[i])):
+ if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
+ (not grad_fn and is_func_call) or _IsTrainable(op.outputs[i])):
# Only trainable outputs or outputs for a function call that
# will use SymbolicGradient get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
@@ -607,17 +606,17 @@ def gradients(ys,
if grad_fn:
# If grad_fn was found, do not use SymbolicGradient even for
# functions.
- in_grads = _MaybeCompile(
- grad_scope, op, func_call, lambda: grad_fn(op, *out_grads))
+ in_grads = _MaybeCompile(grad_scope, op, func_call,
+ lambda: grad_fn(op, *out_grads))
else:
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
- in_grads = _MaybeCompile(
- grad_scope, op, func_call, lambda: _SymGrad(op, out_grads))
+ in_grads = _MaybeCompile(grad_scope, op, func_call,
+ lambda: _SymGrad(op, out_grads))
in_grads = _AsList(in_grads)
_VerifyGeneratedGradients(in_grads, op)
- if gate_gradients and len(
- [x for x in in_grads if x is not None]) > 1:
+ if gate_gradients and len([x for x in in_grads
+ if x is not None]) > 1:
with ops.device(None):
with ops.colocate_with(None, ignore_existing=True):
in_grads = control_flow_ops.tuple(in_grads)
@@ -637,8 +636,8 @@ def gradients(ys,
"Incompatible shapes between op input and calculated "
"input gradient. Forward operation: %s. Input index: %d. "
"Original input shape: %s. "
- "Calculated input gradient shape: %s"
- % (op.name, i, t_in.shape, in_grad.shape))
+ "Calculated input gradient shape: %s" %
+ (op.name, i, t_in.shape, in_grad.shape))
_SetGrad(grads, t_in, in_grad)
if loop_state:
loop_state.ExitGradWhileContext(op, before=False)
@@ -670,8 +669,8 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
pending_count[x.op._id] -= 1
ready = (pending_count[x.op._id] == 0)
if loop_state and not ready:
- ready = (pending_count[x.op._id] > 0 and
- control_flow_util.IsLoopSwitch(x.op))
+ ready = (
+ pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op))
# pylint: enable=protected-access
if ready:
if control_flow_util.IsLoopExit(x.op):
@@ -725,8 +724,8 @@ def _GetGrad(grads, t):
if not op_grads:
return None
t_grad = op_grads[t.value_index]
- assert not isinstance(t_grad, list), (
- "gradients list should have been aggregated by now.")
+ assert not isinstance(
+ t_grad, list), ("gradients list should have been aggregated by now.")
return t_grad
@@ -745,9 +744,8 @@ def _HandleNestedIndexedSlices(grad):
else:
assert isinstance(grad.values, ops.IndexedSlices)
g = _HandleNestedIndexedSlices(grad.values)
- return ops.IndexedSlices(g.values,
- array_ops.gather(grad.indices, g.indices),
- g.dense_shape)
+ return ops.IndexedSlices(g.values, array_ops.gather(
+ grad.indices, g.indices), g.dense_shape)
def _AccumulatorShape(inputs):
@@ -849,8 +847,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
]:
- raise ValueError("Invalid aggregation_method specified %s." %
- aggregation_method)
+ raise ValueError(
+ "Invalid aggregation_method specified %s." % aggregation_method)
out_grads = _GetGrads(grads, op)
for i, out_grad in enumerate(out_grads):
if loop_state:
@@ -859,7 +857,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
continue
# Grads have to be Tensors or IndexedSlices
if (isinstance(out_grad, collections.Sequence) and not all([
- isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in out_grad
+ isinstance(g, (ops.Tensor, ops.IndexedSlices))
+ for g in out_grad
if g is not None
])):
raise TypeError("gradients have to be either all Tensors "
@@ -903,8 +902,8 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
else:
used = "add_n"
out_grads[i] = _MultiDeviceAddN(out_grad)
- logging.vlog(2, " _AggregatedGrads %d x %s using %s",
- len(out_grad), tensor_shape, used)
+ logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
+ tensor_shape, used)
else:
out_grad = math_ops._as_indexed_slices_list(
[g for g in out_grad if g is not None])
@@ -967,7 +966,8 @@ def _hessian_vector_product(ys, xs, v):
assert len(grads) == length
elemwise_products = [
math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem))
- for grad_elem, v_elem in zip(grads, v) if grad_elem is not None
+ for grad_elem, v_elem in zip(grads, v)
+ if grad_elem is not None
]
# Second backprop
@@ -975,8 +975,12 @@ def _hessian_vector_product(ys, xs, v):
@tf_export("hessians")
-def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
- gate_gradients=False, aggregation_method=None):
+def hessians(ys,
+ xs,
+ name="hessians",
+ colocate_gradients_with_ops=False,
+ gate_gradients=False,
+ aggregation_method=None):
"""Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.
`hessians()` adds ops to the graph to output the Hessian matrix of `ys`
@@ -1004,9 +1008,9 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
"""
xs = _AsList(xs)
kwargs = {
- 'colocate_gradients_with_ops': colocate_gradients_with_ops,
- 'gate_gradients': gate_gradients,
- 'aggregation_method': aggregation_method
+ "colocate_gradients_with_ops": colocate_gradients_with_ops,
+ "gate_gradients": gate_gradients,
+ "aggregation_method": aggregation_method
}
# Compute first-order derivatives and iterate for each x in xs.
hessians = []
@@ -1031,8 +1035,7 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
)
_shape = array_ops.shape(x)
- _reshaped_hessian = array_ops.reshape(
- hessian.stack(), array_ops.concat((_shape, _shape), 0)
- )
+ _reshaped_hessian = array_ops.reshape(hessian.stack(),
+ array_ops.concat((_shape, _shape), 0))
hessians.append(_reshaped_hessian)
return hessians
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 721efcf786..cab1025df1 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -770,8 +770,9 @@ def resize_images(images,
size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
method: ResizeMethod. Defaults to `ResizeMethod.BILINEAR`.
- align_corners: bool. If true, exactly align all 4 corners of the input and
- output. Defaults to `false`.
+ align_corners: bool. If True, the centers of the 4 corner pixels of the
+ input and output tensors are aligned, preserving the values at the
+ corner pixels. Defaults to `False`.
Raises:
ValueError: if the shape of `images` is incompatible with the
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index fc013b565b..eebfb17085 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -21,10 +21,8 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -40,15 +38,6 @@ from tensorflow.python.platform import test
@test_util.with_c_api
class BatchNormalizationTest(test.TestCase):
- def SetProducerVersion(self, graph, producer_version):
- # The C API doesn't expose altering GraphDefVersions. We can indirectly set
- # it via import_graph_def though.
- graph_def = graph_pb2.GraphDef()
- graph_def.versions.producer = producer_version
- with graph.as_default():
- importer.import_graph_def(graph_def)
- assert graph.graph_def_versions.producer, producer_version
-
def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization, shift_after_normalization):
y = (x - m) / np.sqrt(v + epsilon)
@@ -65,7 +54,7 @@ class BatchNormalizationTest(test.TestCase):
def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
scale_after_normalization):
"""Original implementation."""
- self.SetProducerVersion(ops.get_default_graph(), 8)
+ test_util.set_producer_version(ops.get_default_graph(), 8)
return gen_nn_ops._batch_norm_with_global_normalization(
x, m, v, beta, gamma, epsilon, scale_after_normalization)
# pylint: enable=protected-access
@@ -233,7 +222,7 @@ class BatchNormalizationTest(test.TestCase):
epsilon = 0.001
for scale_after_normalization in [True, False]:
# _batch_norm_with_global_normalization_grad is deprecated in v9
- self.SetProducerVersion(ops.get_default_graph(), 8)
+ test_util.set_producer_version(ops.get_default_graph(), 8)
grad = gen_nn_ops._batch_norm_with_global_normalization_grad(
x, m, v, gamma, backprop, epsilon, scale_after_normalization)
dx, dm, dv, db, dg = grad
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index cfff73774b..5e6cafd6aa 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -89,52 +89,63 @@ def _Conv2DBackpropFilterGrad(op, grad):
@ops.RegisterGradient("Conv3D")
def _Conv3DGrad(op, grad):
data_format = op.get_attr("data_format")
- return [nn_ops.conv3d_backprop_input_v2(array_ops.shape(op.inputs[0]),
- op.inputs[1],
- grad,
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format),
- nn_ops.conv3d_backprop_filter_v2(op.inputs[0],
- array_ops.shape(op.inputs[1]),
- grad,
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format)]
+ return [
+ nn_ops.conv3d_backprop_input_v2(
+ array_ops.shape(op.inputs[0]),
+ op.inputs[1],
+ grad,
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format),
+ nn_ops.conv3d_backprop_filter_v2(
+ op.inputs[0],
+ array_ops.shape(op.inputs[1]),
+ grad,
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format)
+ ]
@ops.RegisterGradient("Conv3DBackpropInputV2")
def _Conv3DBackpropInputGrad(op, grad):
data_format = op.get_attr("data_format")
- return [None,
- nn_ops.conv3d_backprop_filter_v2(grad,
- array_ops.shape(op.inputs[1]),
- op.inputs[2],
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format),
- nn_ops.conv3d(grad,
- op.inputs[1],
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format)]
+ return [
+ None,
+ nn_ops.conv3d_backprop_filter_v2(
+ grad,
+ array_ops.shape(op.inputs[1]),
+ op.inputs[2],
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format),
+ nn_ops.conv3d(
+ grad,
+ op.inputs[1],
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format)
+ ]
@ops.RegisterGradient("Conv3DBackpropFilterV2")
def _Conv3DBackpropFilterGrad(op, grad):
data_format = op.get_attr("data_format")
- return [nn_ops.conv3d_backprop_input_v2(array_ops.shape(op.inputs[0]),
- grad,
- op.inputs[2],
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format),
- None,
- nn_ops.conv3d(op.inputs[0],
- grad,
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=data_format)]
+ return [
+ nn_ops.conv3d_backprop_input_v2(
+ array_ops.shape(op.inputs[0]),
+ grad,
+ op.inputs[2],
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format), None,
+ nn_ops.conv3d(
+ op.inputs[0],
+ grad,
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=data_format)
+ ]
@ops.RegisterGradient("AvgPool3D")
@@ -150,12 +161,13 @@ def _AvgPool3DGrad(op, grad):
@ops.RegisterGradient("AvgPool3DGrad")
def _AvgPool3DGradGrad(op, grad):
- return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops.avg_pool3d(
- grad,
- op.get_attr("ksize"),
- op.get_attr("strides"),
- op.get_attr("padding"),
- data_format=op.get_attr("data_format")))
+ return (array_ops.stop_gradient(op.inputs[0]),
+ gen_nn_ops.avg_pool3d(
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
@ops.RegisterGradient("MaxPool3D")
@@ -173,9 +185,9 @@ def _MaxPool3DGrad(op, grad):
@ops.RegisterGradient("MaxPool3DGrad")
def _MaxPool3DGradGrad(op, grad):
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops._max_pool3d_grad_grad(
op.inputs[0],
op.inputs[1],
@@ -189,9 +201,9 @@ def _MaxPool3DGradGrad(op, grad):
@ops.RegisterGradient("MaxPool3DGradGrad")
def _MaxPool3DGradGradGrad(op, grad):
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops._max_pool3d_grad(
op.inputs[0],
op.inputs[1],
@@ -272,8 +284,9 @@ def _BiasAddGrad(op, received_grad):
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
- return (received_grad, gen_nn_ops.bias_add_grad(out_backprop=received_grad,
- data_format=data_format))
+ return (received_grad,
+ gen_nn_ops.bias_add_grad(
+ out_backprop=received_grad, data_format=data_format))
@ops.RegisterGradient("BiasAddGrad")
@@ -346,10 +359,9 @@ def _ReluGrad(op, grad):
def _EluGradGrad(op, grad):
elu_x = op.inputs[1]
return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
- array_ops.where(elu_x < 0,
- grad * op.inputs[0],
- array_ops.zeros(shape=array_ops.shape(elu_x),
- dtype=elu_x.dtype)))
+ array_ops.where(elu_x < 0, grad * op.inputs[0],
+ array_ops.zeros(
+ shape=array_ops.shape(elu_x), dtype=elu_x.dtype)))
@ops.RegisterGradient("SeluGrad")
@@ -357,9 +369,11 @@ def _SeluGradGrad(op, grad):
x = op.inputs[1]
scale_alpha = 1.7580993408473768599402175208123
return (gen_nn_ops._elu_grad(grad, op.outputs[0]),
- array_ops.where(
- x < 0., gen_nn_ops._elu_grad(grad, op.outputs[0] + scale_alpha),
- array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype)))
+ array_ops.where(x < 0.,
+ gen_nn_ops._elu_grad(grad,
+ op.outputs[0] + scale_alpha),
+ array_ops.zeros(
+ shape=array_ops.shape(x), dtype=x.dtype)))
@ops.RegisterGradient("Relu6")
@@ -370,8 +384,8 @@ def _Relu6Grad(op, grad):
@ops.RegisterGradient("Relu6Grad")
def _Relu6GradGrad(op, grad):
x = op.inputs[1]
- return (gen_nn_ops._relu6_grad(grad, x), array_ops.zeros(
- shape=array_ops.shape(x), dtype=x.dtype))
+ return (gen_nn_ops._relu6_grad(grad, x),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
@ops.RegisterGradient("Elu")
@@ -410,8 +424,8 @@ def _SoftsignGrad(op, grad):
@ops.RegisterGradient("ReluGrad")
def _ReluGradGrad(op, grad):
x = op.inputs[1]
- return (gen_nn_ops._relu_grad(grad, x), array_ops.zeros(
- shape=array_ops.shape(x), dtype=x.dtype))
+ return (gen_nn_ops._relu_grad(grad, x),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
def _BroadcastMul(vec, mat):
@@ -455,8 +469,8 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
softmax = nn_ops.softmax(logits)
grad += ((grad_grad - array_ops.squeeze(
- math_ops.matmul(grad_grad[:, None, :],
- softmax[:, :, None]), axis=1)) * softmax)
+ math_ops.matmul(grad_grad[:, None, :], softmax[:, :, None]), axis=1)) *
+ softmax)
return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
@@ -473,7 +487,8 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
# so we make sure we prevent silently incorrect results by raising
# an error if the second derivative is requested via prevent_gradient.
sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
- op.outputs[1], message="Currently there is no way to take the second "
+ op.outputs[1],
+ message="Currently there is no way to take the second "
"derivative of sparse_softmax_cross_entropy_with_logits due to the fused "
"implementation's interaction with tf.gradients()")
return _BroadcastMul(grad_0, sparse_softmax_grad_without_gradient), None
@@ -531,14 +546,16 @@ def _DepthwiseConv2dNativeGrad(op, grad):
@ops.RegisterGradient("Dilation2D")
def _Dilation2DGrad(op, grad):
- return [nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
- op.get_attr("strides"),
- op.get_attr("rates"),
- op.get_attr("padding")),
- nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
- op.get_attr("strides"),
- op.get_attr("rates"),
- op.get_attr("padding"))]
+ return [
+ nn_ops.dilation2d_backprop_input(op.inputs[0], op.inputs[1], grad,
+ op.get_attr("strides"),
+ op.get_attr("rates"),
+ op.get_attr("padding")),
+ nn_ops.dilation2d_backprop_filter(op.inputs[0], op.inputs[1], grad,
+ op.get_attr("strides"),
+ op.get_attr("rates"),
+ op.get_attr("padding"))
+ ]
@ops.RegisterGradient("LRN")
@@ -547,8 +564,10 @@ def _LRNGrad(op, grad):
bias = op.get_attr("bias")
alpha = op.get_attr("alpha")
beta = op.get_attr("beta")
- return [gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius,
- bias, alpha, beta)]
+ return [
+ gen_nn_ops._lrn_grad(grad, op.inputs[0], op.outputs[0], depth_radius,
+ bias, alpha, beta)
+ ]
@ops.RegisterGradient("AvgPool")
@@ -564,54 +583,58 @@ def _AvgPoolGrad(op, grad):
@ops.RegisterGradient("AvgPoolGrad")
def _AvgPoolGradGrad(op, grad):
- return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops._avg_pool(
- grad,
- op.get_attr("ksize"),
- op.get_attr("strides"),
- op.get_attr("padding"),
- data_format=op.get_attr("data_format")))
+ return (array_ops.stop_gradient(op.inputs[0]),
+ gen_nn_ops._avg_pool(
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
@ops.RegisterGradient("MaxPool")
def _MaxPoolGrad(op, grad):
- return gen_nn_ops._max_pool_grad(op.inputs[0],
- op.outputs[0],
- grad,
- op.get_attr("ksize"),
- op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format"))
+ return gen_nn_ops._max_pool_grad(
+ op.inputs[0],
+ op.outputs[0],
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format"))
@ops.RegisterGradient("MaxPoolV2")
def _MaxPoolGradV2(op, grad):
ksize = op.inputs[1]
strides = op.inputs[2]
- return gen_nn_ops.max_pool_grad_v2(op.inputs[0],
- op.outputs[0],
- grad,
- ksize,
- strides,
- padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format")), None, None
+ return gen_nn_ops.max_pool_grad_v2(
+ op.inputs[0],
+ op.outputs[0],
+ grad,
+ ksize,
+ strides,
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format")), None, None
@ops.RegisterGradient("MaxPoolWithArgmax")
def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
- return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
- grad,
- op.outputs[1],
- op.get_attr("ksize"),
- op.get_attr("strides"),
- padding=op.get_attr("padding"))
+ return gen_nn_ops._max_pool_grad_with_argmax(
+ op.inputs[0],
+ grad,
+ op.outputs[1],
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"))
@ops.RegisterGradient("MaxPoolGrad")
def _MaxPoolGradGrad(op, grad):
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops._max_pool_grad_grad(
op.inputs[0],
op.inputs[1],
@@ -627,9 +650,9 @@ def _MaxPoolGradGradV2(op, grad):
ksize = op.inputs[3]
strides = op.inputs[4]
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops.max_pool_grad_grad_v2(
op.inputs[0],
op.inputs[1],
@@ -643,9 +666,9 @@ def _MaxPoolGradGradV2(op, grad):
@ops.RegisterGradient("MaxPoolGradGrad")
def _MaxPoolGradGradGrad(op, grad):
return (array_ops.zeros(
- shape=array_ops.shape(op.inputs[0]),
- dtype=op.inputs[0].dtype), array_ops.zeros(
- shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
+ array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
gen_nn_ops._max_pool_grad(
op.inputs[0],
op.inputs[1],
@@ -674,10 +697,9 @@ def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
Input backprop for FractionalMaxPool op.
"""
# pylint: disable=protected-access
- return gen_nn_ops._fractional_max_pool_grad(op.inputs[0], op.outputs[0],
- grad_0, op.outputs[1],
- op.outputs[2],
- op.get_attr("overlapping"))
+ return gen_nn_ops._fractional_max_pool_grad(
+ op.inputs[0], op.outputs[0], grad_0, op.outputs[1], op.outputs[2],
+ op.get_attr("overlapping"))
@ops.RegisterGradient("FractionalAvgPool")
@@ -761,8 +783,9 @@ def _BaseFusedBatchNormGrad(op, use_v2, *grad):
epsilon = op.get_attr("epsilon")
data_format = op.get_attr("data_format")
is_training = op.get_attr("is_training")
- grad_fun = (gen_nn_ops.fused_batch_norm_grad_v2 if use_v2
- else gen_nn_ops.fused_batch_norm_grad)
+ grad_fun = (
+ gen_nn_ops.fused_batch_norm_grad_v2
+ if use_v2 else gen_nn_ops.fused_batch_norm_grad)
if is_training:
return grad_fun(
grad_y,
@@ -786,7 +809,7 @@ def _BaseFusedBatchNormGrad(op, use_v2, *grad):
pop_mean,
pop_var,
epsilon=epsilon,
- data_format='NHWC',
+ data_format="NHWC",
is_training=is_training)
if data_format == b"NCHW":
dx = array_ops.transpose(dx, [0, 3, 1, 2])
@@ -803,18 +826,28 @@ def _FusedBatchNormV2Grad(op, *grad):
return _BaseFusedBatchNormGrad(op, True, *grad)
-def _BatchNormGrad(grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training=True):
+def _BatchNormGrad(grad_y,
+ x,
+ scale,
+ pop_mean,
+ pop_var,
+ epsilon,
+ data_format,
+ is_training=True):
"""Returns the gradients for the 3 inputs of BatchNorm.
Args:
grad_y: A `Tensor` of 4 dimensions for gradient for y.
x: A `Tensor` of 4 dimensions for x.
scale: A `Tensor` of 1 dimension for scaling.
- pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when is_training=False.
- pop_var: A `Tensor` of 1 dimension for the population variance. Only used when is_training=False.
+ pop_mean: A `Tensor` of 1 dimension for the population mean. Only used when
+ is_training=False.
+ pop_var: A `Tensor` of 1 dimension for the population variance. Only used
+ when is_training=False.
epsilon: A small float number added to the variance of x.
data_format: The data format for input. Either b"NHWC" or b"NCHW".
- is_training: A bool value to indicate the operation is for training (default)
+ is_training: A bool value to indicate the operation is for training
+ (default)
or inference.
Returns:
@@ -900,7 +933,7 @@ def _FusedBatchNormGradGrad(op, *grad):
grad_grad_scale = grad[1]
grad_grad_offset = grad[2]
grad_x, grad_scale, grad_offset = _BatchNormGrad(
- grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
+ grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
grad_grad_y, grad_x, grad_scale = gradients_impl.gradients(
[grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial)
@@ -954,14 +987,15 @@ def _TopKGrad(op, grad, _):
# Substitute grad to appropriate locations and fill the rest with zeros,
# finally reshaping it to the original input shape.
- return [array_ops.reshape(
- sparse_ops.sparse_to_dense(ind,
- array_ops.reshape(
- math_ops.reduce_prod(in_shape), [1]),
- array_ops.reshape(grad, [-1]),
- validate_indices=False),
- in_shape), array_ops.zeros(
- [], dtype=dtypes.int32)]
+ return [
+ array_ops.reshape(
+ sparse_ops.sparse_to_dense(
+ ind,
+ array_ops.reshape(math_ops.reduce_prod(in_shape), [1]),
+ array_ops.reshape(grad, [-1]),
+ validate_indices=False), in_shape),
+ array_ops.zeros([], dtype=dtypes.int32)
+ ]
@ops.RegisterGradient("NthElement")
@@ -983,11 +1017,9 @@ def _NthElementGrad(op, grad):
# dimension. If there are multiple elements then the gradient will be
# divided between them.
indicators = math_ops.cast(
- math_ops.equal(array_ops.expand_dims(output, -1), input),
- grad.dtype)
+ math_ops.equal(array_ops.expand_dims(output, -1), input), grad.dtype)
grad = array_ops.expand_dims(grad, -1)
- num_selected = array_ops.expand_dims(
- math_ops.reduce_sum(indicators, -1), -1)
+ num_selected = array_ops.expand_dims(math_ops.reduce_sum(indicators, -1), -1)
return [math_ops.div(indicators, num_selected) * grad, None]
diff --git a/tensorflow/python/ops/nn_grad_test.py b/tensorflow/python/ops/nn_grad_test.py
index f7541c0e89..aa7539ae9f 100644
--- a/tensorflow/python/ops/nn_grad_test.py
+++ b/tensorflow/python/ops/nn_grad_test.py
@@ -30,17 +30,20 @@ from tensorflow.python.platform import test
class Relu6OpTest(test.TestCase):
+
def testRelu6GradGrad(self):
- inputs = constant_op.constant([[-2, -1, 1, 3], [5, 7, 8, 9]],
- dtype=dtypes.float32)
+ inputs = constant_op.constant(
+ [[-2, -1, 1, 3], [5, 7, 8, 9]], dtype=dtypes.float32)
x_init_value = np.array([[-3.5, -1.5, 2, 4], [4.5, 7.5, 8.5, 11]])
r = nn_ops.relu6(inputs)
r_g = gradients_impl.gradients(r, inputs)[0]
with self.test_session():
error = gradient_checker.compute_gradient_error(
- inputs, inputs.get_shape().as_list(),
- r_g, r_g.get_shape().as_list(),
- x_init_value=x_init_value)
+ inputs,
+ inputs.get_shape().as_list(),
+ r_g,
+ r_g.get_shape().as_list(),
+ x_init_value=x_init_value)
self.assertLess(error, 1e-4)
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 879c206313..bdf41cd75d 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -348,11 +348,11 @@ class ResourceVariable(variables.Variable):
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
self._save_slice_info = None
- self._in_graph_mode = context.in_graph_mode()
# Save the graph's container prefix for error checking. Reading the value of
# the ResourceVariable from another Graph in Eager mode is an error.
self._container_prefix = ops.get_default_graph()._container_prefix # pylint: disable=protected-access
- with ops.control_dependencies(None):
+ with ops.init_scope():
+ self._in_graph_mode = context.in_graph_mode()
with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as name:
# pylint: disable=protected-access
@@ -835,25 +835,45 @@ class ResourceVariable(variables.Variable):
return self.value()
def __iadd__(self, unused_other):
- raise RuntimeError("Variable += value not supported.")
+ raise RuntimeError("Variable += value not supported. Use "
+ "variable.assign_add(value) to modify the variable "
+ "value and variable = variable + value to get a new "
+ "Tensor object.")
def __isub__(self, unused_other):
- raise RuntimeError("Variable -= value not supported.")
+ raise RuntimeError("Variable -= value not supported. Use "
+ "variable.assign_sub(value) to modify the variable "
+ "value and variable = variable - value to get a new "
+ "Tensor object.")
def __imul__(self, unused_other):
- raise RuntimeError("Variable *= value not supported.")
+ raise RuntimeError("Variable *= value not supported. Use "
+ "variable.assign_mul(value) to modify the variable "
+ "value and variable = variable * value to get a new "
+ "Tensor object.")
def __idiv__(self, unused_other):
- raise RuntimeError("Variable /= value not supported.")
+ raise RuntimeError("Variable /= value not supported. Use "
+ "variable.assign_div(value) to modify the variable "
+ "value and variable = variable / value to get a new "
+ "Tensor object.")
def __itruediv__(self, unused_other):
- raise RuntimeError("Variable /= value not supported.")
+ raise RuntimeError("Variable /= value not supported. Use "
+ "variable.assign_div(value) to modify the variable "
+ "value and variable = variable / value to get a new "
+ "Tensor object.")
def __irealdiv__(self, unused_other):
- raise RuntimeError("Variable /= value not supported.")
+ raise RuntimeError("Variable /= value not supported. Use "
+ "variable.assign_div(value) to modify the variable "
+ "value and variable = variable / value to get a new "
+ "Tensor object.")
def __ipow__(self, unused_other):
- raise RuntimeError("Variable **= value not supported.")
+ raise RuntimeError("Variable **= value not supported. Use "
+ "value and variable = variable ** value to get a new "
+ "Tensor object.")
def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index a1008f1c83..a10e1963d1 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -812,7 +812,10 @@ def _dynamic_rnn_loop(cell,
return (time + 1, output_ta_t, new_state)
if in_graph_mode:
- loop_bound = max_sequence_length
+ # Make sure that we run at least 1 step, if necessary, to ensure
+ # the TensorArrays pick up the dynamic shape.
+ loop_bound = math_ops.minimum(
+ time_steps, math_ops.maximum(1, max_sequence_length))
else:
# Using max_sequence_length isn't currently supported in the Eager branch.
loop_bound = time_steps
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index 1990087072..15127862a4 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -155,27 +155,24 @@ def einsum(equation, *inputs, **kwargs):
indices in its subscript, or
- the input shapes are inconsistent along a particular axis.
"""
- name = kwargs.pop("name", None)
+ name = kwargs.pop('name', None)
if kwargs:
- raise TypeError("invalid keyword arguments for this function: " +
- ", ".join([format(key)
- for key in sorted(list(kwargs.keys()))]))
- with ops.name_scope(name, "einsum", [equation, inputs]) as name:
+ raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
+ [format(key) for key in sorted(list(kwargs.keys()))]))
+ with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
if '...' in equation:
raise ValueError('Subscripts with ellipses are not yet supported.')
match = re.match('([a-z,]+)(->[a-z]*)?', equation)
if not match:
- raise ValueError(
- 'Indices have incorrect format: %s' % equation
- )
+ raise ValueError('Indices have incorrect format: %s' % equation)
inputs = list(inputs)
input_axis_labels = match.group(1).split(',')
if len(inputs) != len(input_axis_labels):
- raise ValueError('Got %d arguments for equation "%s", expecting %d' % (
- len(inputs), equation, len(input_axis_labels)))
+ raise ValueError('Got %d arguments for equation "%s", expecting %d' %
+ (len(inputs), equation, len(input_axis_labels)))
axis_labels = set(''.join(input_axis_labels))
if match.group(2):
@@ -188,10 +185,8 @@ def einsum(equation, *inputs, **kwargs):
for ax in axes_:
counts[ax] += 1
- output_axis_labels = ''.join(sorted(
- ax for ax in indices
- if counts[ax] == 1
- ))
+ output_axis_labels = ''.join(
+ sorted(ax for ax in indices if counts[ax] == 1))
for a in axis_labels:
input_count = sum(1 for s in input_axis_labels if a in s)
@@ -203,22 +198,23 @@ def einsum(equation, *inputs, **kwargs):
temp = inputs[0]
temp_axis_labels = input_axis_labels[0]
- for i in xrange(len(inputs)-1):
- axes_to_sum = (set(temp_axis_labels) & set(input_axis_labels[i+1])
- - set(output_axis_labels))
- temp, temp_axis_labels = _einsum_reduction(temp,
- temp_axis_labels,
- inputs[i+1],
- input_axis_labels[i+1],
- axes_to_sum)
+ for i in xrange(len(inputs) - 1):
+ axes_to_sum = (
+ set(temp_axis_labels) &
+ set(input_axis_labels[i + 1]) - set(output_axis_labels))
+ temp, temp_axis_labels = _einsum_reduction(
+ temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1],
+ axes_to_sum)
missing_indices = set(temp_axis_labels) - set(output_axis_labels)
if missing_indices:
- reduction_indices = [i for i, a in enumerate(temp_axis_labels)
- if a not in output_axis_labels]
+ reduction_indices = [
+ i for i, a in enumerate(temp_axis_labels)
+ if a not in output_axis_labels
+ ]
temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices)
- temp_axis_labels = ''.join(a for a in temp_axis_labels
- if a in output_axis_labels)
+ temp_axis_labels = ''.join(
+ a for a in temp_axis_labels if a in output_axis_labels)
if sorted(temp_axis_labels) != sorted(output_axis_labels):
raise ValueError('Invalid equation: %s' % equation)
@@ -296,8 +292,10 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
return (1, a)
axis_labels = [t0_axis_labels, t1_axis_labels]
- sorted_axes = [sorted(sym_list, key=lambda a: sort_key(i, a))
- for i, sym_list in enumerate(axis_labels)]
+ sorted_axes = [
+ sorted(sym_list, key=lambda a: sort_key(i, a))
+ for i, sym_list in enumerate(axis_labels)
+ ]
inputs = [t0, t1]
for i, axes_str in enumerate(axis_labels):
perm = [axes_str.find(a) for a in sorted_axes[i]]
@@ -325,30 +323,30 @@ def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
num_broadcast_elements_t0 = _total_size(
t0_shape[len(preserved_axes):-len(axes_to_sum)])
num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):])
- new_shape = (t0_shape[:len(preserved_axes)]
- + [num_broadcast_elements_t0, num_summed_elements])
+ new_shape = (
+ t0_shape[:len(preserved_axes)] +
+ [num_broadcast_elements_t0, num_summed_elements])
t0 = _reshape_if_necessary(t0, new_shape)
t1_shape = _get_shape(t1)
num_broadcast_elements_t1 = _total_size(
- t1_shape[len(preserved_axes)+len(axes_to_sum):])
- new_shape = (t1_shape[:len(preserved_axes)]
- + [num_summed_elements, num_broadcast_elements_t1])
+ t1_shape[len(preserved_axes) + len(axes_to_sum):])
+ new_shape = (
+ t1_shape[:len(preserved_axes)] +
+ [num_summed_elements, num_broadcast_elements_t1])
t1 = _reshape_if_necessary(t1, new_shape)
product = math_ops.matmul(t0, t1)
# Undo compaction of broadcast axes
uncompacted_shape = (
- t0_shape[:len(preserved_axes)+len(broadcast_axes[0])]
- + t1_shape[len(t1_shape)-len(broadcast_axes[1]):]
- )
+ t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] +
+ t1_shape[len(t1_shape) - len(broadcast_axes[1]):])
product = _reshape_if_necessary(product, uncompacted_shape)
product_axes = (
- sorted_axes[0][:len(preserved_axes)+len(broadcast_axes[0])] +
- sorted_axes[1][len(sorted_axes[1])-len(broadcast_axes[1]):]
- )
+ sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] +
+ sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):])
return product, ''.join(product_axes)
@@ -402,13 +400,11 @@ def _total_size(shape_values):
def _exponential_space_einsum(equation, *inputs):
"""Fallback implementation that supports summing an index over > 2 inputs."""
if '...' in equation:
- raise ValueError("Subscripts with ellipses are not yet supported.")
+ raise ValueError('Subscripts with ellipses are not yet supported.')
match = re.match('([a-z,]+)(->[a-z]*)?', equation)
if not match:
- raise ValueError(
- 'Indices have incorrect format: %s' % equation
- )
+ raise ValueError('Indices have incorrect format: %s' % equation)
inputs = list(inputs)
idx_in = match.group(1).split(',')
@@ -425,21 +421,15 @@ def _exponential_space_einsum(equation, *inputs):
for ax in axes_:
counts[ax] += 1
- idx_out = ''.join(sorted(
- ax for ax in indices
- if counts[ax] == 1
- ))
+ idx_out = ''.join(sorted(ax for ax in indices if counts[ax] == 1))
if len(idx_in) != len(inputs):
- raise ValueError(
- 'Expected %d inputs but got %d' % (len(idx_in), len(inputs))
- )
+ raise ValueError('Expected %d inputs but got %d' % (len(idx_in),
+ len(inputs)))
missing_idx = set(idx_out).difference(idx_all)
if missing_idx:
- raise ValueError(
- 'Unknown output axes: %s' % missing_idx
- )
+ raise ValueError('Unknown output axes: %s' % missing_idx)
axis_order = {}
for ax in indices:
@@ -452,18 +442,17 @@ def _exponential_space_einsum(equation, *inputs):
for i, (input_, axes_) in enumerate(zip(inputs, idx_in)):
if input_.get_shape().ndims != len(axes_):
raise ValueError(
- 'Input %d with axes %s has incorrect' \
- ' number of dimensions (expected %d, got %d)' % (
- i, axes_, len(axes_), input_.get_shape().ndims
- )
+ 'Input %d with axes %s has incorrect' \
+ ' number of dimensions (expected %d, got %d)' % (
+ i, axes_, len(axes_), input_.get_shape().ndims
+ )
)
sorted_idx = sorted(axes_, key=axis_order.get)
if len(set(axes_)) != len(axes_):
raise ValueError(
- 'Subscript not supported: an axis appears more than once: %s' % axes_
- )
+ 'Subscript not supported: an axis appears more than once: %s' % axes_)
if list(axes_) != sorted_idx:
permuted = [axes_.find(ax) for ax in sorted_idx]
@@ -487,16 +476,15 @@ def _exponential_space_einsum(equation, *inputs):
dims.append(dim)
if len(set(dims)) > 1:
- raise ValueError(
- 'Dimension mismatch on axis: %s' % ax
- )
+ raise ValueError('Dimension mismatch on axis: %s' % ax)
if ax not in idx_out:
reduction_idx.append(j)
# reshape, multiply
- expanded_inputs = [array_ops.reshape(input_, shape)
- for input_, shape in zip(inputs, shapes)]
+ expanded_inputs = [
+ array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes)
+ ]
expanded_output = 1
for input_ in expanded_inputs:
expanded_output *= input_
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index c1a66717d8..2c212f4548 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -39,8 +39,9 @@ class LBetaTest(test.TestCase):
x_one_half = [2, 1.]
with self.test_session(use_gpu=True):
self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_one)).eval())
- self.assertAllClose(
- 0.5, math_ops.exp(special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose(0.5,
+ math_ops.exp(
+ special_math_ops.lbeta(x_one_half)).eval())
self.assertEqual([], special_math_ops.lbeta(x_one).get_shape())
def test_one_dimensional_arg_dynamic(self):
@@ -70,8 +71,9 @@ class LBetaTest(test.TestCase):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.test_session(use_gpu=True):
- self.assertAllClose(
- [0.5, 0.5], math_ops.exp(special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose([0.5, 0.5],
+ math_ops.exp(
+ special_math_ops.lbeta(x_one_half)).eval())
self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape())
def test_two_dimensional_arg_dynamic(self):
@@ -86,10 +88,12 @@ class LBetaTest(test.TestCase):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.test_session(use_gpu=True):
- self.assertAllClose(
- [0.5, 0.5], math_ops.exp(special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose([0.5, 0.5],
+ math_ops.exp(
+ special_math_ops.lbeta(x_one_half)).eval())
self.assertEqual(
- (2,), array_ops.shape(special_math_ops.lbeta(x_one_half)).eval())
+ (2,),
+ array_ops.shape(special_math_ops.lbeta(x_one_half)).eval())
self.assertEqual(
tensor_shape.TensorShape([2]),
special_math_ops.lbeta(x_one_half).get_shape())
@@ -97,8 +101,8 @@ class LBetaTest(test.TestCase):
def test_complicated_shape(self):
with self.test_session(use_gpu=True):
x = ops.convert_to_tensor(np.random.rand(3, 2, 2))
- self.assertAllEqual(
- (3, 2), array_ops.shape(special_math_ops.lbeta(x)).eval())
+ self.assertAllEqual((3, 2),
+ array_ops.shape(special_math_ops.lbeta(x)).eval())
self.assertEqual(
tensor_shape.TensorShape([3, 2]),
special_math_ops.lbeta(x).get_shape())
@@ -155,7 +159,6 @@ class EinsumTest(test.TestCase):
'ijk->i',
'ijk->kji',
'ji,kj->ik',
-
'ikl,kji->kl',
'klj,lki->ij',
'ijk,ilj->kli',
@@ -164,7 +167,6 @@ class EinsumTest(test.TestCase):
'i,ijk,j->k',
'ij,ij,jk,kl->il',
'ij,kj,il,jm->ml',
-
'a,ab,abc->abc',
'a,b,ab->ab',
'ab,ab,c->',
@@ -173,25 +175,21 @@ class EinsumTest(test.TestCase):
'ab,ab,cd,cd->ac',
'ab,ab,cd,cd->cd',
'ab,ab,cd,cd,ef,ef->',
-
'ab,cd,ef->abcdef',
'ab,cd,ef->acdf',
'ab,cd,de->abcde',
'ab,cd,de->be',
'ab,bcd,cd->abcd',
'ab,bcd,cd->abd',
-
'eb,cb,fb->cef',
'abcd,ad',
'bd,db,eac->ace',
'ba,ac,da->bcd',
-
'ab,ab',
'ab,ba',
'abc,abc',
'abc,bac',
'abc,cba',
-
'dba,ead,cad->bce',
'aef,fbc,dca->bde',
]
@@ -234,10 +232,8 @@ class EinsumTest(test.TestCase):
def test_invalid(self):
for axes in self.invalid_cases:
inputs = [
- array_ops.placeholder(
- dtypes.float32, shape=(3, 4)),
- array_ops.placeholder(
- dtypes.float32, shape=(3, 4)),
+ array_ops.placeholder(dtypes.float32, shape=(3, 4)),
+ array_ops.placeholder(dtypes.float32, shape=(3, 4)),
]
with self.assertRaises(ValueError):
_ = special_math_ops.einsum(axes, *inputs)
@@ -245,16 +241,22 @@ class EinsumTest(test.TestCase):
def test_invalid_keyword_arguments(self):
m0 = array_ops.placeholder(dtypes.int32, shape=(1, None))
m1 = array_ops.placeholder(dtypes.int32, shape=(None, 1))
- with self.assertRaisesRegexp(TypeError,
+ with self.assertRaisesRegexp(
+ TypeError,
'invalid keyword arguments for this function: invalid1, invalid2'):
- _ = special_math_ops.einsum('ij,jk->ik', m0, m1, name="name",
- invalid1="value1", invalid2="value2")
+ _ = special_math_ops.einsum(
+ 'ij,jk->ik',
+ m0,
+ m1,
+ name='name',
+ invalid1='value1',
+ invalid2='value2')
def test_dim_mismatch(self):
for axes, input_shapes in self.dim_mismatch_cases:
inputs = [
- array_ops.placeholder(
- dtypes.float32, shape=shape) for shape in input_shapes
+ array_ops.placeholder(dtypes.float32, shape=shape)
+ for shape in input_shapes
]
with self.assertRaises(ValueError):
_ = special_math_ops.einsum(axes, *inputs)
@@ -291,8 +293,8 @@ class EinsumTest(test.TestCase):
m0: [[1, 2, 3]],
m1: [[2], [1], [1]],
}
- np.testing.assert_almost_equal(
- [[7]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[7]], sess.run(
+ out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 3))
@@ -312,11 +314,11 @@ class EinsumTest(test.TestCase):
out = special_math_ops.einsum('ijk,kl->ijl', m0, m1)
with session.Session() as sess:
feed_dict = {
- m0: [[[1,2]]],
+ m0: [[[1, 2]]],
m1: [[3], [2]],
}
- np.testing.assert_almost_equal(
- [[[7]]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[[7]]],
+ sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1))
@@ -325,10 +327,10 @@ class EinsumTest(test.TestCase):
with session.Session() as sess:
feed_dict = {
m0: [[3], [2]],
- m1: [[[1,2]]],
+ m1: [[[1, 2]]],
}
- np.testing.assert_almost_equal(
- [[[7]]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[[7]]],
+ sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
@@ -339,8 +341,8 @@ class EinsumTest(test.TestCase):
m0: [[[1, 2]]],
m1: [3, 2],
}
- np.testing.assert_almost_equal(
- [[7]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[7]], sess.run(
+ out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2))
@@ -351,8 +353,8 @@ class EinsumTest(test.TestCase):
m0: [[[[1, 2]], [[2, 1]]]],
m1: [[3, 2]],
}
- np.testing.assert_almost_equal(
- [[[7, 8]]], sess.run(out, feed_dict=feed_dict))
+ np.testing.assert_almost_equal([[[7, 8]]],
+ sess.run(out, feed_dict=feed_dict))
if __name__ == '__main__':
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index db594ac6a0..81565a6377 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -771,8 +771,8 @@ class _VariableStore(object):
if initializer is None:
initializer, initializing_from_value = self._get_default_initializer(
name=name, shape=shape, dtype=dtype)
- # Clear control dependencies while creating the initializer.
- with ops.control_dependencies(None):
+ # Enter an init scope when creating the initializer.
+ with ops.init_scope():
if initializing_from_value:
init_val = initializer
variable_dtype = None
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 7d7fa646c0..19e3298e40 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.deprecation import deprecated
@@ -211,6 +212,7 @@ class Variable(object):
if not context.in_graph_mode():
raise RuntimeError("tf.Variable not supported in Eager mode. "
"Please use tfe.Variable instead")
+ self._in_graph_mode = context.in_graph_mode()
if variable_def:
# If variable_def is provided, recreates the variable from its fields.
if initial_value:
@@ -306,7 +308,7 @@ class Variable(object):
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
- with ops.control_dependencies(None):
+ with ops.init_scope():
with ops.name_scope(name, "Variable", [] if init_from_fn else
[initial_value]) as name:
@@ -377,8 +379,8 @@ class Variable(object):
else:
with ops.colocate_with(self._variable.op):
self._snapshot = array_ops.identity(self._variable, name="read")
+ ops.add_to_collections(collections, self)
- ops.add_to_collections(collections, self)
self._caching_device = caching_device
self._save_slice_info = None
self._constraint = constraint
@@ -552,7 +554,7 @@ class Variable(object):
A `Tensor` holding the value of this variable after its initializer
has run.
"""
- with ops.control_dependencies(None):
+ with ops.init_scope():
return control_flow_ops.cond(is_variable_initialized(self),
self.read_value,
lambda: self.initial_value)
@@ -1021,6 +1023,61 @@ class Variable(object):
return Variable(variable_def=variable_def,
import_scope=import_scope)
+ def __iadd__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable += will be deprecated. Use variable.assign_add"
+ " if you want assignment to the variable value or 'x = x + y'"
+ " if you want a new python Tensor object.", 1)
+ return self + other
+
+ def __isub__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable -= will be deprecated. Use variable.assign_sub"
+ " if you want assignment to the variable value or 'x = x - y'"
+ " if you want a new python Tensor object.", 1)
+ return self - other
+
+ def __imul__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable *= will be deprecated. Use variable.assign_mul"
+ " if you want assignment to the variable value or 'x = x * y'"
+ " if you want a new python Tensor object.", 1)
+ return self * other
+
+ def __idiv__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable /= will be deprecated. Use variable.assign_div"
+ " if you want assignment to the variable value or 'x = x / y'"
+ " if you want a new python Tensor object.", 1)
+ return self / other
+
+ def __itruediv__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable /= will be deprecated. Use variable.assign_div"
+ " if you want assignment to the variable value or 'x = x / y'"
+ " if you want a new python Tensor object.", 1)
+ return self / other
+
+ def __irealdiv__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable /= will be deprecated. Use variable.assign_div"
+ " if you want assignment to the variable value or 'x = x / y'"
+ " if you want a new python Tensor object.", 1)
+ return self / other
+
+ def __ipow__(self, other):
+ logging.log_first_n(
+ logging.WARN,
+ "Variable **= will be deprecated. Use 'x = x ** y'"
+ " if you want a new python Tensor object.", 1)
+ return self ** other
+
class SaveSliceInfo(object):
"""Information on how to save this Variable as a slice.
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index 9153855588..04ba28c219 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -224,15 +224,15 @@ class PrintModelAnalysisTest(test.TestCase):
# pylint: disable=line-too-long
with gfile.Open(outfile, 'r') as f:
lines = f.read().split('\n')
+ self.assertGreater(len(lines), 5)
result = '\n'.join([l[:min(len(l), 80)] for l in lines])
- self.assertEqual(
- compat.as_bytes(
- 'node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/168.86k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/45.37k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/8 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/1.30k flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/2.30k flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/67.39k f\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/46.66\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/20.74\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/18.58k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/37.00k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/258 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/129 flop\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/141 flops)\n'
- ), compat.as_bytes(lib.CheckAndRemoveDoc(result)))
+ self.assertTrue(
+ compat.as_text(lib.CheckAndRemoveDoc(result))
+ .startswith('node name | # parameters | # float_ops'))
self.assertLess(0, tfprof_node.total_exec_micros)
self.assertEqual(2844, tfprof_node.total_parameters)
- self.assertEqual(168863, tfprof_node.total_float_ops)
+ self.assertLess(168800, tfprof_node.total_float_ops)
self.assertEqual(8, len(tfprof_node.children))
self.assertEqual('_TFProfRoot', tfprof_node.name)
self.assertEqual(
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index 8716058e61..dd876cbe7f 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -97,8 +97,9 @@ def parse_numpy_printoption(kv_str):
raise argparse.ArgumentTypeError(
"Setting '%s' from the command line is not supported." % k)
try:
- v = (v_type(v_str) if v_type is not bool
- else flags.BooleanParser().parse(v_str))
+ v = (
+ v_type(v_str)
+ if v_type is not bool else flags.BooleanParser().parse(v_str))
except ValueError as e:
raise argparse.ArgumentTypeError(e.message)
np.set_printoptions(**{k: v})
@@ -121,9 +122,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
- "--file_name", type=str, default="", help="Checkpoint filename. "
- "Note, if using Checkpoint V2 format, file_name is the "
- "shared prefix between all files in the checkpoint.")
+ "--file_name",
+ type=str,
+ default="",
+ help="Checkpoint filename. "
+ "Note, if using Checkpoint V2 format, file_name is the "
+ "shared prefix between all files in the checkpoint.")
parser.add_argument(
"--tensor_name",
type=str,
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index 5054873bc1..b5d3e78797 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -176,7 +176,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
reader = load_checkpoint(ckpt_dir_or_file)
variable_map = reader.get_variable_to_shape_map()
- for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map):
+ for tensor_name_in_ckpt, current_var_or_name in sorted(
+ six.iteritems(assignment_map)):
var = None
# Check if this is Variable object or list of Variable objects (in case of
# partitioned variables).
@@ -233,7 +234,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
if "/part_" in var_name:
var_name = var_name[:var_name.index("/part_")]
scope_variables.add(var_name)
- for var_name in scope_variables:
+ for var_name in sorted(scope_variables):
# Lookup name with specified prefix and suffix from current variable.
# If tensor_name given is '/' (root), don't use it for full name.
full_tensor_name = var_name[len(scopes):]
diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py
index 149d3eed41..3e4ac1dfff 100644
--- a/tensorflow/python/training/coordinator_test.py
+++ b/tensorflow/python/training/coordinator_test.py
@@ -85,8 +85,8 @@ class CoordinatorTest(test.TestCase):
self.assertFalse(coord.wait_for_stop(0.1))
wait_for_stop_ev = threading.Event()
has_stopped_ev = threading.Event()
- t = threading.Thread(target=StopOnEvent,
- args=(coord, wait_for_stop_ev, has_stopped_ev))
+ t = threading.Thread(
+ target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev))
t.start()
self.assertFalse(coord.should_stop())
self.assertFalse(coord.wait_for_stop(0.01))
@@ -100,7 +100,8 @@ class CoordinatorTest(test.TestCase):
threads = [
threading.Thread(target=SleepABit, args=(0.01,)),
threading.Thread(target=SleepABit, args=(0.02,)),
- threading.Thread(target=SleepABit, args=(0.01,))]
+ threading.Thread(target=SleepABit, args=(0.01,))
+ ]
for t in threads:
t.start()
coord.join(threads)
@@ -112,7 +113,8 @@ class CoordinatorTest(test.TestCase):
threads = [
threading.Thread(target=SleepABit, args=(0.01, coord)),
threading.Thread(target=SleepABit, args=(0.02, coord)),
- threading.Thread(target=SleepABit, args=(0.01, coord))]
+ threading.Thread(target=SleepABit, args=(0.01, coord))
+ ]
for t in threads:
t.start()
WaitForThreadsToRegister(coord, 3)
@@ -125,7 +127,8 @@ class CoordinatorTest(test.TestCase):
threads = [
threading.Thread(target=SleepABit, args=(0.01, coord)),
threading.Thread(target=SleepABit, args=(0.02,)),
- threading.Thread(target=SleepABit, args=(0.01, coord))]
+ threading.Thread(target=SleepABit, args=(0.01, coord))
+ ]
for t in threads:
t.start()
WaitForThreadsToRegister(coord, 2)
@@ -135,14 +138,17 @@ class CoordinatorTest(test.TestCase):
self.assertFalse(t.is_alive())
def testJoinGraceExpires(self):
+
def TestWithGracePeriod(stop_grace_period):
coord = coordinator.Coordinator()
wait_for_stop_ev = threading.Event()
has_stopped_ev = threading.Event()
threads = [
- threading.Thread(target=StopOnEvent,
- args=(coord, wait_for_stop_ev, has_stopped_ev)),
- threading.Thread(target=SleepABit, args=(10.0,))]
+ threading.Thread(
+ target=StopOnEvent,
+ args=(coord, wait_for_stop_ev, has_stopped_ev)),
+ threading.Thread(target=SleepABit, args=(10.0,))
+ ]
for t in threads:
t.daemon = True
t.start()
@@ -150,6 +156,7 @@ class CoordinatorTest(test.TestCase):
has_stopped_ev.wait()
with self.assertRaisesRegexp(RuntimeError, "threads still running"):
coord.join(threads, stop_grace_period_secs=stop_grace_period)
+
TestWithGracePeriod(1e-10)
TestWithGracePeriod(0.002)
TestWithGracePeriod(1.0)
@@ -159,16 +166,16 @@ class CoordinatorTest(test.TestCase):
wait_for_stop_ev = threading.Event()
has_stopped_ev = threading.Event()
threads = [
- threading.Thread(target=StopOnEvent,
- args=(coord, wait_for_stop_ev, has_stopped_ev)),
- threading.Thread(target=SleepABit, args=(10.0,))]
+ threading.Thread(
+ target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev)),
+ threading.Thread(target=SleepABit, args=(10.0,))
+ ]
for t in threads:
t.daemon = True
t.start()
wait_for_stop_ev.set()
has_stopped_ev.wait()
- coord.join(
- threads, stop_grace_period_secs=1., ignore_live_threads=True)
+ coord.join(threads, stop_grace_period_secs=1., ignore_live_threads=True)
def testJoinRaiseReportExcInfo(self):
coord = coordinator.Coordinator()
@@ -180,7 +187,8 @@ class CoordinatorTest(test.TestCase):
args=(coord, ev_1, ev_2, RuntimeError("First"), False)),
threading.Thread(
target=RaiseOnEvent,
- args=(coord, ev_2, None, RuntimeError("Too late"), False))]
+ args=(coord, ev_2, None, RuntimeError("Too late"), False))
+ ]
for t in threads:
t.start()
@@ -199,7 +207,8 @@ class CoordinatorTest(test.TestCase):
args=(coord, ev_1, ev_2, RuntimeError("First"), True)),
threading.Thread(
target=RaiseOnEvent,
- args=(coord, ev_2, None, RuntimeError("Too late"), True))]
+ args=(coord, ev_2, None, RuntimeError("Too late"), True))
+ ]
for t in threads:
t.start()
@@ -214,9 +223,8 @@ class CoordinatorTest(test.TestCase):
threading.Thread(
target=RaiseOnEvent,
args=(coord, ev_1, None,
- errors_impl.OutOfRangeError(None, None, "First"),
- True))
- ]
+ errors_impl.OutOfRangeError(None, None, "First"), True))
+ ]
for t in threads:
t.start()
@@ -230,7 +238,7 @@ class CoordinatorTest(test.TestCase):
threading.Thread(
target=RaiseOnEvent,
args=(coord, ev_1, None, ValueError("Clean stop"), True))
- ]
+ ]
for t in threads:
t.start()
@@ -247,7 +255,8 @@ class CoordinatorTest(test.TestCase):
args=(coord, ev_1, ev_2, RuntimeError("First"))),
threading.Thread(
target=RaiseOnEventUsingContextHandler,
- args=(coord, ev_2, None, RuntimeError("Too late")))]
+ args=(coord, ev_2, None, RuntimeError("Too late")))
+ ]
for t in threads:
t.start()
@@ -262,7 +271,7 @@ class CoordinatorTest(test.TestCase):
threading.Thread(
target=RaiseOnEvent,
args=(coord, ev_1, None, RuntimeError("First"), True)),
- ]
+ ]
for t in threads:
t.start()
@@ -274,7 +283,7 @@ class CoordinatorTest(test.TestCase):
threading.Thread(
target=RaiseOnEvent,
args=(coord, ev_1, None, RuntimeError("Second"), True)),
- ]
+ ]
for t in threads:
t.start()
with self.assertRaisesRegexp(RuntimeError, "Second"):
@@ -337,24 +346,29 @@ class LooperTest(test.TestCase):
def testTargetArgs(self):
n = [3]
coord = coordinator.Coordinator()
- thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0,
- args=(coord, n))
+ thread = coordinator.LooperThread.loop(
+ coord, 0, target=_StopAt0, args=(coord, n))
coord.join([thread])
self.assertEqual(0, n[0])
def testTargetKwargs(self):
n = [3]
coord = coordinator.Coordinator()
- thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0,
- kwargs={"coord": coord, "n": n})
+ thread = coordinator.LooperThread.loop(
+ coord, 0, target=_StopAt0, kwargs={
+ "coord": coord,
+ "n": n
+ })
coord.join([thread])
self.assertEqual(0, n[0])
def testTargetMixedArgs(self):
n = [3]
coord = coordinator.Coordinator()
- thread = coordinator.LooperThread.loop(coord, 0, target=_StopAt0,
- args=(coord,), kwargs={"n": n})
+ thread = coordinator.LooperThread.loop(
+ coord, 0, target=_StopAt0, args=(coord,), kwargs={
+ "n": n
+ })
coord.join([thread])
self.assertEqual(0, n[0])
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index e34c759e89..43ed1ac170 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -187,7 +187,7 @@ def _zero_debias(unbiased_var, value, decay):
with variable_scope.variable_scope(
unbiased_var.op.name, values=[unbiased_var, value, decay]) as scope:
with ops.colocate_with(unbiased_var):
- with ops.control_dependencies(None):
+ with ops.init_scope():
biased_initializer = init_ops.zeros_initializer(
dtype=unbiased_var.dtype)(unbiased_var.get_shape())
local_step_initializer = init_ops.zeros_initializer()
@@ -385,7 +385,7 @@ class ExponentialMovingAverage(object):
# For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism.
- with ops.control_dependencies(None):
+ with ops.init_scope():
if isinstance(var, variables.Variable):
avg = slot_creator.create_slot(var,
var.initialized_value(),
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 038469b1ba..719b83e5ca 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -514,7 +514,7 @@ class Optimizer(object):
if not var_list:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, _, v in converted_grads_and_vars],))
- with ops.control_dependencies(None):
+ with ops.init_scope():
self._create_slots([_get_variable_for(v) for v in var_list])
update_ops = []
with ops.name_scope(name, self._name) as name:
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index f32d456155..23d11c88ed 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -11,6 +11,10 @@ load(
"if_static",
)
load(
+ "@local_config_tensorrt//:build_defs.bzl",
+ "if_tensorrt",
+)
+load(
"@local_config_cuda//cuda:build_defs.bzl",
"if_cuda",
"cuda_default_copts",
@@ -197,6 +201,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False):
"-fno-exceptions",
"-ftemplate-depth=900"])
+ if_cuda(["-DGOOGLE_CUDA=1"])
+ + if_tensorrt(["-DGOOGLE_TENSORRT=1"])
+ if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML", "-fopenmp",])
+ if_android_arm(["-mfpu=neon"])
+ if_linux_x86_64(["-msse3"])
@@ -866,9 +871,11 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
When the library is built with --config=cuda:
- - both deps and cuda_deps are used as dependencies
- - the cuda runtime is added as a dependency (if necessary)
- - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts
+ - Both deps and cuda_deps are used as dependencies.
+ - The cuda runtime is added as a dependency (if necessary).
+ - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts.
+ - In addition, when the library is also built with TensorRT enabled, it
+ additionally passes -DGOOGLE_TENSORRT=1 to the list of copts.
Args:
- cuda_deps: BUILD dependencies which will be linked if and only if:
@@ -887,7 +894,8 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
clean_dep("//tensorflow/core:cuda"),
"@local_config_cuda//cuda:cuda_headers"
]),
- copts=copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]),
+ copts=(copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
+ if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs)
register_extension_info(
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
index 7fe3e2db09..2bf584fa29 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
@@ -160,15 +160,15 @@ tf_class {
}
member_method {
name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
}
member_method {
name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
}
member_method {
name: "from_config"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
}
member_method {
name: "predict_on_batch"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt
new file mode 100644
index 0000000000..42cb914450
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.densenet.pbtxt
@@ -0,0 +1,23 @@
+path: "tensorflow.keras.applications.densenet"
+tf_module {
+ member_method {
+ name: "DenseNet121"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "DenseNet169"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "DenseNet201"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "decode_predictions"
+ argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+ }
+ member_method {
+ name: "preprocess_input"
+ argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt
new file mode 100644
index 0000000000..cd75b87540
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.nasnet.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.applications.nasnet"
+tf_module {
+ member_method {
+ name: "NASNetLarge"
+ argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "NASNetMobile"
+ argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "decode_predictions"
+ argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+ }
+ member_method {
+ name: "preprocess_input"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt
index daeb5aad41..9fc086eb8e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.keras.applications"
tf_module {
member {
+ name: "densenet"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "inception_resnet_v2"
mtype: "<type \'module\'>"
}
@@ -13,6 +17,10 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "nasnet"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "resnet50"
mtype: "<type \'module\'>"
}
@@ -29,6 +37,18 @@ tf_module {
mtype: "<type \'module\'>"
}
member_method {
+ name: "DenseNet121"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "DenseNet169"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "DenseNet201"
+ argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
name: "InceptionResNetV2"
argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
}
@@ -41,6 +61,14 @@ tf_module {
argspec: "args=[\'input_shape\', \'alpha\', \'depth_multiplier\', \'dropout\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'1\', \'0.001\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
}
member_method {
+ name: "NASNetLarge"
+ argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
+ name: "NASNetMobile"
+ argspec: "args=[\'input_shape\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+ }
+ member_method {
name: "ResNet50"
argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
index 44fbe0f7a0..ba2d083a75 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
@@ -398,7 +398,7 @@ tf_module {
}
member_method {
name: "rnn"
- argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\'], "
+ argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "round"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt
index 8719c07ca3..d4c85a4519 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'schedule\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'schedule\', \'verbose\'], varargs=None, keywords=None, defaults=[\'0\'], "
}
member_method {
name: "on_batch_begin"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt
index ef08f9b20f..bda31751d4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt
@@ -2,6 +2,6 @@ path: "tensorflow.keras.datasets.boston_housing"
tf_module {
member_method {
name: "load_data"
- argspec: "args=[\'path\', \'seed\', \'test_split\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'113\', \'0.2\'], "
+ argspec: "args=[\'path\', \'test_split\', \'seed\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'0.2\', \'113\'], "
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt
index 8b1c17e9da..ff962876b6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt
@@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
- argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=None, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
+ argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt
index 6b3ed1e9af..2da4a13067 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt
@@ -6,6 +6,6 @@ tf_module {
}
member_method {
name: "load_data"
- argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=None, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
+ argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=kwargs, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
}
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
index a32151e22f..770a107b66 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
index 46b1713196..0ce42b706e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
index 9bfaf27562..b371ad148c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
index 2b8ac4f1f4..2f5e65a0c5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -123,7 +123,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "compute_mask"
@@ -131,7 +131,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
index c9a0b88725..ff08def0a0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index b847e224d6..6db22ca032 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -116,7 +116,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -128,7 +128,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
index 86578d958e..07d3f023e5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
index 348012dcde..92b9760d53 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
index 0419251083..83c528b401 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 337e85e812..b329f1c46b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index 1357dc0f0d..d0f6d2a14f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -183,7 +183,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -195,7 +195,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index b71a08f6c3..57596badf1 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
index a01a6067ef..3829353cc3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 0dbbdf2838..3b171b137a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 964ef89c2e..0036d6805b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -187,7 +187,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -199,7 +199,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index 6a7b23c540..8134fb7386 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index 324745e5a3..c5d4523009 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index e12ae05054..bcbed9241b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
index 9e889ca863..ff0db15f19 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
index 932680941d..1d3f33f045 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -127,7 +127,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
index db644f958f..c86bc49b22 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
index 74fa1db020..b29f65d79d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -94,7 +94,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'activity_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
@@ -118,7 +118,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -130,7 +130,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
new file mode 100644
index 0000000000..dd67b76523
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -0,0 +1,186 @@
+path: "tensorflow.keras.layers.SeparableConv1D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
new file mode 100644
index 0000000000..bf62c095e7
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -0,0 +1,186 @@
+path: "tensorflow.keras.layers.SeparableConvolution1D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional.SeparableConv1D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._SeparableConv\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 3414810db4..6e3cde3e3e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -114,7 +114,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index cf34034ef0..b875898a81 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -175,7 +175,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
@@ -187,7 +187,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
new file mode 100644
index 0000000000..ee4b2fa39e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
@@ -0,0 +1,183 @@
+path: "tensorflow.keras.layers.Softmax"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.advanced_activations.Softmax\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index b76499658d..db9f90caef 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -118,7 +118,7 @@ tf_class {
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index 2376d815a6..ef31c5443e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -126,7 +126,7 @@ tf_class {
}
member_method {
name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
index fe336c4be5..088c8e88e2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
@@ -293,10 +293,18 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "SeparableConv1D"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "SeparableConv2D"
mtype: "<type \'type\'>"
}
member {
+ name: "SeparableConvolution1D"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "SeparableConvolution2D"
mtype: "<type \'type\'>"
}
@@ -309,6 +317,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "Softmax"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "SpatialDropout1D"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
index d239098b0b..0b816b5863 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -160,15 +160,15 @@ tf_class {
}
member_method {
name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
}
member_method {
name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
}
member_method {
name: "from_config"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
+ argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
}
member_method {
name: "predict_on_batch"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt
index ed040c1586..32667cf31e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'None\', \'0.0\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt
index a24651429a..efca59e8e4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'None\', \'0.0\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt
index a0d978fded..5546e2067a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\', \'amsgrad\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'None\', \'0.0\', \'False\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt
index 1b70c93ad5..aaa54a1060 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.0\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt
index b49dbe5cf8..1fada7fd9c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'schedule_decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'1e-08\', \'0.004\'], "
+ argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'schedule_decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'None\', \'0.004\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
index c8860d80d4..fd3f97f35d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'1e-08\', \'0.0\'], "
+ argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'None\', \'0.0\'], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
index 5bc8c40120..ce91caa1af 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
@@ -4,7 +4,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'num_words\', \'filters\', \'lower\', \'split\', \'char_level\'], varargs=None, keywords=None, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \', \'False\'], "
+ argspec: "args=[\'self\', \'num_words\', \'filters\', \'lower\', \'split\', \'char_level\', \'oov_token\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \', \'False\', \'None\'], "
}
member_method {
name: "fit_on_sequences"
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index aa341b144c..27fa1b89ce 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -177,7 +177,13 @@ do_pylint() {
echo "pylint took $((PYLINT_END_TIME - PYLINT_START_TIME)) s"
echo ""
- grep -E '(\[E|\[W0311|\[W0312)' ${OUTPUT_FILE} > ${ERRORS_FILE}
+ # Report only what we care about
+ # Ref https://pylint.readthedocs.io/en/latest/technical_reference/features.html
+ # E: all errors
+ # W0311 bad-indentation
+ # W0312 mixed-indentation
+ # C0330 bad-continuation
+ grep -E '(\[E|\[W0311|\[W0312|\[C0330)' ${OUTPUT_FILE} > ${ERRORS_FILE}
N_ERRORS=0
while read -r LINE; do
diff --git a/tensorflow/tools/compatibility/tf_upgrade.py b/tensorflow/tools/compatibility/tf_upgrade.py
index fa1cc73905..f678681dac 100644
--- a/tensorflow/tools/compatibility/tf_upgrade.py
+++ b/tensorflow/tools/compatibility/tf_upgrade.py
@@ -236,8 +236,8 @@ class _ASTCallVisitor(ast.NodeVisitor):
new_col_offset = col - m.start(1) - 1
return line, new_col_offset
else:
- if (reversed_preceding_text=="" or
- reversed_preceding_text.isspace()):
+ if (reversed_preceding_text == "" or
+ reversed_preceding_text.isspace()):
line = line - 1
prev_line = self._lines[line - 1]
# TODO(aselle):
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 9145d9e58a..f7d9075032 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,6 +1,7 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure")
@@ -68,6 +69,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
check_bazel_version_at_least("0.5.4")
clang6_configure(name="local_config_clang6")
cuda_configure(name="local_config_cuda")
+ tensorrt_configure(name="local_config_tensorrt")
git_configure(name="local_config_git")
sycl_configure(name="local_config_sycl")
python_configure(name="local_config_python")
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 2727fa5efe..8e1dd8a54f 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -236,7 +236,7 @@ def _cudnn_install_basedir(repository_ctx):
return cudnn_install_path
-def _matches_version(environ_version, detected_version):
+def matches_version(environ_version, detected_version):
"""Checks whether the user-specified version matches the detected version.
This function performs a weak matching so that if the user specifies only the
@@ -317,7 +317,7 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
environ_version = ""
if _TF_CUDA_VERSION in repository_ctx.os.environ:
environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
- if environ_version and not _matches_version(environ_version, full_version):
+ if environ_version and not matches_version(environ_version, full_version):
auto_configure_fail(
("CUDA version detected from nvcc (%s) does not match " +
"TF_CUDA_VERSION (%s)") % (full_version, environ_version))
@@ -338,35 +338,49 @@ _DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
-def _find_cuda_define(repository_ctx, cudnn_header_dir, define):
- """Returns the value of a #define in cudnn.h
+def find_cuda_define(repository_ctx, header_dir, header_file, define):
+ """Returns the value of a #define in a header file.
- Greps through cudnn.h and returns the value of the specified #define. If the
- #define is not found, then raise an error.
+ Greps through a header file and returns the value of the specified #define.
+ If the #define is not found, then raise an error.
Args:
repository_ctx: The repository context.
- cudnn_header_dir: The directory containing the cuDNN header.
+ header_dir: The directory containing the header file.
+ header_file: The header file name.
define: The #define to search for.
Returns:
- The value of the #define found in cudnn.h.
+ The value of the #define found in the header.
"""
- # Confirm location of cudnn.h and grep for the line defining CUDNN_MAJOR.
- cudnn_h_path = repository_ctx.path("%s/cudnn.h" % cudnn_header_dir)
- if not cudnn_h_path.exists:
- auto_configure_fail("Cannot find cudnn.h at %s" % str(cudnn_h_path))
- result = repository_ctx.execute(["grep", "--color=never", "-E", define, str(cudnn_h_path)])
+ # Confirm location of the header and grep for the line defining the macro.
+ h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
+ if not h_path.exists:
+ auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
+ result = repository_ctx.execute(
+ # Grep one more lines as some #defines are splitted into two lines.
+ ["grep", "--color=never", "-A1", "-E", define, str(h_path)])
if result.stderr:
- auto_configure_fail("Error reading %s: %s" %
- (result.stderr, str(cudnn_h_path)))
+ auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
- # Parse the cuDNN major version from the line defining CUDNN_MAJOR
- lines = result.stdout.splitlines()
- if len(lines) == 0 or lines[0].find(define) == -1:
+ # Parse the version from the line defining the macro.
+ if result.stdout.find(define) == -1:
auto_configure_fail("Cannot find line containing '%s' in %s" %
- (define, str(cudnn_h_path)))
- return lines[0].replace(define, "").strip()
+ (define, h_path))
+ version = result.stdout
+ # Remove the new line and '\' character if any.
+ version = version.replace("\\", " ")
+ version = version.replace("\n", " ")
+ version = version.replace(define, "").lstrip()
+ # Remove the code after the version number.
+ version_end = version.find(" ")
+ if version_end != -1:
+ if version_end == 0:
+ auto_configure_fail(
+ "Cannot extract the version from line containing '%s' in %s" %
+ (define, str(h_path)))
+ version = version[:version_end].strip()
+ return version
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
@@ -382,12 +396,12 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
"""
cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
cudnn_install_basedir)
- major_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_MAJOR)
- minor_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_MINOR)
- patch_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
- _DEFINE_CUDNN_PATCHLEVEL)
+ major_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MAJOR)
+ minor_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MINOR)
+ patch_version = find_cuda_define(
+ repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_PATCHLEVEL)
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
# Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
@@ -395,7 +409,7 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
environ_version = ""
if _TF_CUDNN_VERSION in repository_ctx.os.environ:
environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
- if environ_version and not _matches_version(environ_version, full_version):
+ if environ_version and not matches_version(environ_version, full_version):
cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
cudnn_install_basedir)
auto_configure_fail(
@@ -427,7 +441,7 @@ def _compute_capabilities(repository_ctx):
return capabilities
-def _cpu_value(repository_ctx):
+def get_cpu_value(repository_ctx):
"""Returns the name of the host operating system.
Args:
@@ -447,7 +461,7 @@ def _cpu_value(repository_ctx):
def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows."""
- return _cpu_value(repository_ctx) == "Windows"
+ return get_cpu_value(repository_ctx) == "Windows"
def _lib_name(lib, cpu_value, version="", static=False):
"""Constructs the platform-specific name of a library.
@@ -582,11 +596,8 @@ def _find_libs(repository_ctx, cuda_config):
cuda_config: The CUDA config as returned by _get_cuda_config
Returns:
- Map of library names to structs of filename and path as returned by
- _find_cuda_lib and _find_cupti_lib.
+ Map of library names to structs of filename and path.
"""
- cudnn_version = cuda_config.cudnn_version
- cudnn_ext = ".%s" % cudnn_version if cudnn_version else ""
cpu_value = cuda_config.cpu_value
return {
"cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
@@ -611,7 +622,7 @@ def _find_libs(repository_ctx, cuda_config):
"cudnn": _find_cuda_lib(
"cudnn", repository_ctx, cpu_value, cuda_config.cudnn_install_basedir,
cuda_config.cudnn_version),
- "cupti": _find_cupti_lib(repository_ctx, cuda_config),
+ "cupti": _find_cupti_lib(repository_ctx, cuda_config)
}
@@ -654,7 +665,7 @@ def _get_cuda_config(repository_ctx):
compute_capabilities: A list of the system's CUDA compute capabilities.
cpu_value: The name of the host operating system.
"""
- cpu_value = _cpu_value(repository_ctx)
+ cpu_value = get_cpu_value(repository_ctx)
cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
@@ -712,13 +723,13 @@ error_gpu_disabled()
def _create_dummy_repository(repository_ctx):
- cpu_value = _cpu_value(repository_ctx)
+ cpu_value = get_cpu_value(repository_ctx)
# Set up BUILD file for cuda/.
_tpl(repository_ctx, "cuda:build_defs.bzl",
{
"%{cuda_is_configured}": "False",
- "%{cuda_extra_copts}": "[]"
+ "%{cuda_extra_copts}": "[]",
})
_tpl(repository_ctx, "cuda:BUILD",
{
@@ -805,8 +816,8 @@ def _norm_path(path):
return path
-def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
- src_files = [], dest_files = []):
+def symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
+ src_files = [], dest_files = []):
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise
@@ -913,11 +924,11 @@ def _create_local_cuda_repository(repository_ctx):
# cuda_toolkit_path
cuda_toolkit_path = cuda_config.cuda_toolkit_path
cuda_include_path = cuda_toolkit_path + "/include"
- genrules = [_symlink_genrule_for_dir(repository_ctx,
+ genrules = [symlink_genrule_for_dir(repository_ctx,
cuda_include_path, "cuda/include", "cuda-include")]
- genrules.append(_symlink_genrule_for_dir(repository_ctx,
+ genrules.append(symlink_genrule_for_dir(repository_ctx,
cuda_toolkit_path + "/nvvm", "cuda/nvvm", "cuda-nvvm"))
- genrules.append(_symlink_genrule_for_dir(repository_ctx,
+ genrules.append(symlink_genrule_for_dir(repository_ctx,
cuda_toolkit_path + "/extras/CUPTI/include",
"cuda/extras/CUPTI/include", "cuda-extras"))
@@ -927,15 +938,15 @@ def _create_local_cuda_repository(repository_ctx):
for lib in cuda_libs.values():
cuda_lib_src.append(lib.path)
cuda_lib_dest.append("cuda/lib/" + lib.file_name)
- genrules.append(_symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
- cuda_lib_src, cuda_lib_dest))
+ genrules.append(symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
+ cuda_lib_src, cuda_lib_dest))
- # Set up the symbolic links for cudnn if cudnn was was not installed to
+ # Set up the symbolic links for cudnn if cndnn was not installed to
# CUDA_TOOLKIT_PATH.
included_files = _read_dir(repository_ctx, cuda_include_path).replace(
cuda_include_path, '').splitlines()
if '/cudnn.h' not in included_files:
- genrules.append(_symlink_genrule_for_dir(repository_ctx, None,
+ genrules.append(symlink_genrule_for_dir(repository_ctx, None,
"cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"],
["cudnn.h"]))
else:
@@ -952,7 +963,6 @@ def _create_local_cuda_repository(repository_ctx):
"%{cuda_is_configured}": "True",
"%{cuda_extra_copts}": _compute_cuda_extra_copts(
repository_ctx, cuda_config.compute_capabilities),
-
})
_tpl(repository_ctx, "cuda:BUILD",
{
diff --git a/third_party/tensorrt/BUILD b/third_party/tensorrt/BUILD
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/tensorrt/BUILD
diff --git a/third_party/tensorrt/BUILD.tpl b/third_party/tensorrt/BUILD.tpl
new file mode 100644
index 0000000000..feaeb0bea6
--- /dev/null
+++ b/third_party/tensorrt/BUILD.tpl
@@ -0,0 +1,67 @@
+# NVIDIA TensorRT
+# A high-performance deep learning inference optimizer and runtime.
+
+licenses(["notice"])
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts")
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+ name = "tensorrt_headers",
+ hdrs = [%{tensorrt_headers}],
+ includes = [
+ "include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nv_infer",
+ srcs = [%{nv_infer}],
+ data = [%{nv_infer}],
+ includes = [
+ "include",
+ ],
+ copts= cuda_default_copts(),
+ deps = [
+ "@local_config_cuda//cuda:cuda",
+ ":tensorrt_headers",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nv_infer_plugin",
+ srcs = [%{nv_infer_plugin}],
+ data = [%{nv_infer_plugin}],
+ includes = [
+ "include",
+ ],
+ copts= cuda_default_copts(),
+ deps = [
+ "@local_config_cuda//cuda:cuda",
+ ":nv_infer",
+ ":tensorrt_headers",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "nv_parsers",
+ srcs = [%{nv_parsers}],
+ data = [%{nv_parsers}],
+ includes = [
+ "include",
+ ],
+ copts= cuda_default_copts(),
+ deps = [
+ ":tensorrt_headers",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+%{tensorrt_genrules}
diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl
new file mode 100644
index 0000000000..0dc3a7ba2d
--- /dev/null
+++ b/third_party/tensorrt/build_defs.bzl.tpl
@@ -0,0 +1,7 @@
+# Build configurations for TensorRT.
+
+def if_tensorrt(if_true, if_false=[]):
+ """Tests whether TensorRT was enabled during the configure process."""
+ if %{tensorrt_is_configured}:
+ return if_true
+ return if_false
diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl
new file mode 100644
index 0000000000..8aa0f28f39
--- /dev/null
+++ b/third_party/tensorrt/tensorrt_configure.bzl
@@ -0,0 +1,224 @@
+# -*- Python -*-
+"""Repository rule for TensorRT configuration.
+
+`tensorrt_configure` depends on the following environment variables:
+
+ * `TF_TENSORRT_VERSION`: The TensorRT libnvinfer version.
+ * `TENSORRT_INSTALL_PATH`: The installation path of the TensorRT library.
+"""
+
+load(
+ "//third_party/gpus:cuda_configure.bzl",
+ "auto_configure_fail",
+ "get_cpu_value",
+ "find_cuda_define",
+ "matches_version",
+ "symlink_genrule_for_dir",
+)
+
+_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
+_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
+
+_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin", "nvparsers"]
+_TF_TENSORRT_HEADERS = [
+ "NvInfer.h", "NvInferPlugin.h", "NvCaffeParser.h", "NvUffParser.h",
+ "NvUtils.h"
+]
+
+_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
+_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
+_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
+
+
+def _headers_exist(repository_ctx, path):
+ """Returns whether all TensorRT header files could be found in 'path'.
+
+ Args:
+ repository_ctx: The repository context.
+ path: The TensorRT include path to check.
+
+ Returns:
+ True if all TensorRT header files can be found in the path.
+ """
+ for h in _TF_TENSORRT_HEADERS:
+ if not repository_ctx.path("%s/%s" % (path, h)).exists:
+ return False
+ return True
+
+
+def _find_trt_header_dir(repository_ctx, trt_install_path):
+ """Returns the path to the directory containing headers of TensorRT.
+
+ Args:
+ repository_ctx: The repository context.
+ trt_install_path: The TensorRT library install directory.
+
+ Returns:
+ The path of the directory containing the TensorRT header.
+ """
+ if trt_install_path == "/usr/lib/x86_64-linux-gnu":
+ path = "/usr/include/x86_64-linux-gnu"
+ if _headers_exist(repository_ctx, path):
+ return path
+ path = str(repository_ctx.path("%s/../include" % trt_install_path).realpath)
+ if _headers_exist(repository_ctx, path):
+ return path
+ auto_configure_fail(
+ "Cannot find NvInfer.h with TensorRT install path %s" % trt_install_path)
+
+
+def _trt_lib_version(repository_ctx, trt_install_path):
+ """Detects the library (e.g. libnvinfer) version of TensorRT.
+
+ Args:
+ repository_ctx: The repository context.
+ trt_install_path: The TensorRT library install directory.
+
+ Returns:
+ A string containing the library version of TensorRT.
+ """
+ trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
+ major_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
+ _DEFINE_TENSORRT_SONAME_MAJOR)
+ minor_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
+ _DEFINE_TENSORRT_SONAME_MINOR)
+ patch_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h",
+ _DEFINE_TENSORRT_SONAME_PATCH)
+ full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
+ environ_version = repository_ctx.os.environ[_TF_TENSORRT_VERSION].strip()
+ if not matches_version(environ_version, full_version):
+ auto_configure_fail(
+ ("TensorRT library version detected from %s/%s (%s) does not match " +
+ "TF_TENSORRT_VERSION (%s). To fix this rerun configure again.") %
+ (trt_header_dir, "NvInfer.h", full_version, environ_version))
+ return environ_version
+
+
+def _find_trt_libs(repository_ctx, trt_install_path, trt_lib_version):
+ """Finds the given TensorRT library on the system.
+
+ Adapted from code contributed by Sami Kama (https://github.com/samikama).
+
+ Args:
+ repository_ctx: The repository context.
+ trt_install_path: The TensorRT library installation directory.
+ trt_lib_version: The version of TensorRT library files as returned
+ by _trt_lib_version.
+
+ Returns:
+ Map of library names to structs with the following fields:
+ src_file_path: The full path to the library found on the system.
+ dst_file_name: The basename of the target library.
+ """
+ objdump = repository_ctx.which("objdump")
+ result = {}
+ for lib in _TF_TENSORRT_LIBS:
+ dst_file_name = "lib%s.so.%s" % (lib, trt_lib_version)
+ src_file_path = repository_ctx.path("%s/%s" % (trt_install_path,
+ dst_file_name))
+ if not src_file_path.exists:
+ auto_configure_fail(
+ "Cannot find TensorRT library %s" % str(src_file_path))
+ if objdump != None:
+ objdump_out = repository_ctx.execute([objdump, "-p", str(src_file_path)])
+ for line in objdump_out.stdout.splitlines():
+ if "SONAME" in line:
+ dst_file_name = line.strip().split(" ")[-1]
+ result.update({
+ lib:
+ struct(
+ dst_file_name=dst_file_name,
+ src_file_path=str(src_file_path.realpath))
+ })
+ return result
+
+
+def _tpl(repository_ctx, tpl, substitutions):
+ repository_ctx.template(tpl, Label("//third_party/tensorrt:%s.tpl" % tpl),
+ substitutions)
+
+
+def _create_dummy_repository(repository_ctx):
+ """Create a dummy TensorRT repository."""
+ _tpl(repository_ctx, "build_defs.bzl", {"%{tensorrt_is_configured}": "False"})
+ substitutions = {
+ "%{tensorrt_genrules}": "",
+ "%{tensorrt_headers}": "",
+ }
+ for lib in _TF_TENSORRT_LIBS:
+ k = "%%{%s}" % lib.replace("nv", "nv_")
+ substitutions.update({k: ""})
+ _tpl(repository_ctx, "BUILD", substitutions)
+
+
+def _tensorrt_configure_impl(repository_ctx):
+ """Implementation of the tensorrt_configure repository rule."""
+ if _TENSORRT_INSTALL_PATH not in repository_ctx.os.environ:
+ _create_dummy_repository(repository_ctx)
+ return
+
+ if (get_cpu_value(repository_ctx) != "Linux"):
+ auto_configure_fail("TensorRT is supported only on Linux.")
+ if _TF_TENSORRT_VERSION not in repository_ctx.os.environ:
+ auto_configure_fail("TensorRT library (libnvinfer) version is not set.")
+ trt_install_path = repository_ctx.os.environ[_TENSORRT_INSTALL_PATH].strip()
+ if not repository_ctx.path(trt_install_path).exists:
+ auto_configure_fail(
+ "Cannot find TensorRT install path %s." % trt_install_path)
+
+ # Set up the symbolic links for the library files.
+ trt_lib_version = _trt_lib_version(repository_ctx, trt_install_path)
+ trt_libs = _find_trt_libs(repository_ctx, trt_install_path, trt_lib_version)
+ trt_lib_src = []
+ trt_lib_dest = []
+ for lib in trt_libs.values():
+ trt_lib_src.append(lib.src_file_path)
+ trt_lib_dest.append(lib.dst_file_name)
+ genrules = [
+ symlink_genrule_for_dir(repository_ctx, None, "tensorrt/lib/",
+ "tensorrt_lib", trt_lib_src, trt_lib_dest)
+ ]
+
+ # Set up the symbolic links for the header files.
+ trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
+ src_files = [
+ "%s/%s" % (trt_header_dir, header) for header in _TF_TENSORRT_HEADERS
+ ]
+ dest_files = _TF_TENSORRT_HEADERS
+ genrules.append(
+ symlink_genrule_for_dir(repository_ctx, None, "tensorrt/include/",
+ "tensorrt_include", src_files, dest_files))
+
+ # Set up config file.
+ _tpl(repository_ctx, "build_defs.bzl", {"%{tensorrt_is_configured}": "True"})
+
+ # Set up BUILD file.
+ substitutions = {
+ "%{tensorrt_genrules}": "\n".join(genrules),
+ "%{tensorrt_headers}": '":tensorrt_include"',
+ }
+ for lib in _TF_TENSORRT_LIBS:
+ k = "%%{%s}" % lib.replace("nv", "nv_")
+ v = '"tensorrt/lib/%s"' % trt_libs[lib].dst_file_name
+ substitutions.update({k: v})
+ _tpl(repository_ctx, "BUILD", substitutions)
+
+
+tensorrt_configure = repository_rule(
+ implementation=_tensorrt_configure_impl,
+ environ=[
+ _TENSORRT_INSTALL_PATH,
+ _TF_TENSORRT_VERSION,
+ ],
+)
+"""Detects and configures the local CUDA toolchain.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+tensorrt_configure(name = "local_config_tensorrt")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""