aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-04-06 14:24:42 -0700
committerGravatar GitHub <noreply@github.com>2018-04-06 14:24:42 -0700
commit8b5212011c7b67b7f8c2ea1b641aa0a7151c82d0 (patch)
tree0b748e2aae4c0e2de843a75adb6e443edbbbb3df
parent08e4863ea6c7e75c85f097760216509f85081916 (diff)
Branch 191925087 (#18299)
* Fix docstring. PiperOrigin-RevId: 191747417 * Use constants in tf.zeros if the constant won't be too big. Using fill saves on GraphDef size, but can slow down models since the total number of ops is greater (fill + shape + constant op). This change makes us only use fill for large shapes. PiperOrigin-RevId: 191747456 * Fix typos in "Profile Model Float Operations" documentation. PiperOrigin-RevId: 191751175 * Added a call in CheckpointSaverHook.after_create_session to always save checkpoint before the first training step. PiperOrigin-RevId: 191753026 * Document expected regular structure of the statistical testing library. PiperOrigin-RevId: 191753693 * Refine BatchReshape error messages. PiperOrigin-RevId: 191754120 * Include the operators module in the test framework as well. PiperOrigin-RevId: 191756100 * Expand activity analysis to the test nodes of if and while statements. PiperOrigin-RevId: 191756234 * Inline more functions PiperOrigin-RevId: 191761109 * Sync only the convolutional_recurrent file to Keras 2.1.5. PiperOrigin-RevId: 191763101 * Internal change PiperOrigin-RevId: 191769724 * Expose odeint_fixed in tf.contrib.integrate PiperOrigin-RevId: 191769890 * Automated g4 rollback of changelist 191761109 PiperOrigin-RevId: 191771969 * Fix final eval bottleneck creation to work in cases where it isn't cached already. Fixes #17423 PiperOrigin-RevId: 191773001 * Fix regression caused by cl/191020868: Re-use materialized shapes for other broadcast gradient shape nodes. PiperOrigin-RevId: 191779263 * Save the original from_proto method before calling it to avoid infinite loop. PiperOrigin-RevId: 191784430 * Automated g4 rollback of changelist 191753026 PiperOrigin-RevId: 191784709 * [XLA] Remove a dead function and a stale todo. PiperOrigin-RevId: 191786563 * Enable branch prediction in TensorFlow PiperOrigin-RevId: 191788253 * Changes loss_reduction default to SUM_OVER_BATCH_SIZE for multi_class_head and binary_classification_head. PiperOrigin-RevId: 191793392 * quantized LSTM support improvements PiperOrigin-RevId: 191794956 * Fix TF_ImportGraphDefResults and TF_Function leaks in Python API. PiperOrigin-RevId: 191797853 * [XLA] Better support for mul reductions in MakeFakeArguments() Mul reductions want a 1 as their init value, not a 0 or a random value. PiperOrigin-RevId: 191802819 * Disable tests that are currently failing with cuda 9 PiperOrigin-RevId: 191805453 * Make tf.contrib.estimator.add_metrics work with warm-starting. PiperOrigin-RevId: 191805682 * Add Raspberry Pi section and link to github build instructions. PiperOrigin-RevId: 191807862 * Add for and while loops to the list of operators. Do not use them yet. PiperOrigin-RevId: 191807973 * [TF:XLA] No need to set return value in the while loop's condition. PiperOrigin-RevId: 191809110 * Add functions to extract the basic symbols on which a composite name relies. This in turn allows to statically obtain a block's syntactic closure. PiperOrigin-RevId: 191809965 * Add link for index file in performance tab. PiperOrigin-RevId: 191811610 * Added an option to inline all functions in aggressive mode. PiperOrigin-RevId: 191819577 * Make concat handler support mixed range input PiperOrigin-RevId: 191822664 * Automated g4 rollback of changelist 191605505 PiperOrigin-RevId: 191824447 * Add a command line parameter to toco to change the way toco rescales input and output tensors. PiperOrigin-RevId: 191825756 * refactor and add proto field required by POD support. PiperOrigin-RevId: 191826636 * Lazily evaluate shapes with the C API enabled. This change makes it so shapes are computed only when requested with _USE_C_API = True. Note that the C API will still raise a shape error if necessary when the op is created. In addition, it cleans up the logic for _USE_C_SHAPES = True. In this case, we lazily fetch and cache shapes directly from the C API. We no longer need set_shapes_for_outputs at all in this case. PiperOrigin-RevId: 191830565 * [XLA] Don't call Literal::Get in HloEvaluator's convolution loop. This speeds up the implementation of conv because Literal::Get calls Literal::Piece::data, which is relatively slow. Instead, we call Literal::Data() once and cache the result. Before: ConvolutionTest/0.StridedFilter (59094 ms) After: ConvolutionTest/0.StridedFilter (41812 ms) Speedup: 59/42 = 1.4x PiperOrigin-RevId: 191830741 * Added `drop_final_batch` argument to make_batched_features_dataset. This allows the batch_and_drop_remainder function to be used instead of the default batch function. PiperOrigin-RevId: 191831842 * Add RunMetadata logging to tf.train.ProfilerHook for Tensorboard Memeory/CPU usage visualization PiperOrigin-RevId: 191832832 * [XLA] Don't call MultidimensionalIndexToLinearIndex in HloEvaluator's convolution routine. Before: ConvolutionTest/0.StridedFilter (41812 ms) After: ConvolutionTest/0.StridedFilter (28054 ms) Speedup: 42 / 28 = 1.5x PiperOrigin-RevId: 191835735 * Expose the adaptive sampling option for SDCA and shuffle the data when adaptive sampling is off. PiperOrigin-RevId: 191836004 * Swap in the new implementation of while and for loops. PiperOrigin-RevId: 191838806 * Upgrade libpng PiperOrigin-RevId: 191840652 * Fix StringPiece use-after-free in MasterSession::ReffedClientGraph. Use the owned ClientGraph as the source for the node_to_name_ map, rather than the borrowed GraphExecutionState (which can be deleted while the ReffedClientGraph is in use). PiperOrigin-RevId: 191847023 * Add a test to check graceful handling of out-of-memory conditions. PiperOrigin-RevId: 191860462 * internal change PiperOrigin-RevId: 191869400 * Fix typos in XlaCompilationCache PiperOrigin-RevId: 191881135 * Define PRNG seeding style for new code in Distributions and TF Probability, with rationales. Implement lightweight PRNG for seed generation in that style. Enables incremental refactoring of existing code into this style. PiperOrigin-RevId: 191884573 * Avoid marking clusters containing only Identity ops for compilation. This would produce clusters where XLA cannot optimize anything. PiperOrigin-RevId: 191887414 * Add description to the LPIRC 2018 competition benchmarker. PiperOrigin-RevId: 191889484 * The training model need not be built when the kfac optimizer is initialized so the self._variables will be empty list. So pass a function which returns list of trainable variables to estimator. PiperOrigin-RevId: 191893084 * Fix up the support for the case where a given array name occurs multiple times in the inputs/outputs list of an op. The (non-essential) computation of the optimal workspace size had not been updated for that case, causing it to fail on a simple test case. Moreover, the initial implementation had some redundant usage of std::find that this CL moves to a shared helper function. PiperOrigin-RevId: 191894081 * Support override of device filters for gRPC, by overriding the requests with default session config. PiperOrigin-RevId: 191895856 * Tweaked docstrings in LayerCollection. PiperOrigin-RevId: 191897098 * [TPUClusterResolver] Start a TFServer when running in GKE This change allows advanced input pipelines (e.g. StreamingFilesDataset, or split-pipelines that use py_func's) to run in GKE- and GKE-like enviornments. PiperOrigin-RevId: 191897639 * [tf.data] Enable using `tf.contrib.data.prefetch_to_device()` in eager mode. The added functionality is a substitute for the implicit prefetching in `tfe.Iterator`, and the two paths will converge in a future change. Fixes #18260. PiperOrigin-RevId: 191897666 * Materialize tensor array sizes whenever possible PiperOrigin-RevId: 191900015 * Object-based checkpointing support for unidirectional cuDNN LSTM cells Once checked in, this will be the only way I know of to save canonical weights when executing eagerly. Eager's name-based saving support will only do the opaque parameter buffer. I'm not going to try converting everything in one go, but it's a start at least. And everything else should raise a NotImplementedError rather than silently not saving correctly. Single-layer cuDNN cells can be swapped for un-wrapped cuDNN compatible cells or single cells wrapped in MultiRNNCells. Multi-layer cells need MultiRNNCell wrapping. PiperOrigin-RevId: 191905703 * Allow TFE_NewContext to fail more reasonably when SWIG is checking status. Before: TFE_Context would check nullptr, and the function would fail straight away. Now: TFE_Context is nullptr, so it skips down to checking the status, and an error is raised. I'm not able to find in SWIG documentation how to order typemaps in the generated code - ideally, I'd order it to check the status typemap first. This code makes it not dependent on this ordering either way. PiperOrigin-RevId: 191905893 * Change GetInstructionCallContext to take an opcode instead of an HloInstruction. This enables use of the function without an actual instruction (eg, if you just have an HloProto). PiperOrigin-RevId: 191905914 * TPU Cost Estimator has been modified to also account for the memory cost in the execution time. Until more sophisticated methods are added, we resort to the roofline model to calculate such cost. PiperOrigin-RevId: 191913626 * Properly handle callable objects. PiperOrigin-RevId: 191913834 * Minor doc clarification for reduce_sum return type PiperOrigin-RevId: 191914398 * Added headers only version of tensorflow/core/kernels:cwise_lib, cwise_lib_hdrs. This is for clients that want to use the cwise_ops machinery when making their own custom ops, including cwise_lib directly causes multiple definition linker errors. PiperOrigin-RevId: 191914445 * [TF:XLA] Create Despecializing Pass Pipeline When comparing backends, it is useful to take an HLO optimized for one backend and perform transformations in order to match numerics. This can be thought of as finding a lowest common denominator. Move this grouping of passes into its own HloPassPipeline that can be reused in a few different places. PiperOrigin-RevId: 191914799 * Update tf.keras to keras 2.1.5 version. PiperOrigin-RevId: 191914904 * Remove `TF_InitializeTPU` and `TF_ShutdownTPU` from experimental C API as they are no longer needed. Also remove a duplicate function declaration. PiperOrigin-RevId: 191918408 * Fix small performance regression in microbenchmarks. PiperOrigin-RevId: 191919464 * Support RNN profiling in StreamExecutor for CUDA GPUs. This change hasn't applied autotune on TF Cudnn kernels, only provides lower level support. PiperOrigin-RevId: 191919566 * Validate errorReporter and improve the documentation on it. PiperOrigin-RevId: 191920009 * Fix a few bugs in ArithmeticOptimizer and make it robust to failures of shape inference. PiperOrigin-RevId: 191922788 * Update the rewriter options with the optimizer options PiperOrigin-RevId: 191923287 * Pull changes from prefetching_ops to support dicts in prefetching_ops_v2 in distribute, and update estimator test to use prefetching. Also update readme to reflect the support of dictionaries. PiperOrigin-RevId: 191924990 * Replaced calls to deprecated tensorflow::StringPiece methods with their tensorflow::str_util equivalents. This will allow the deprecated methods to be removed. PiperOrigin-RevId: 191925087
-rw-r--r--tensorflow/c/c_api_experimental.cc51
-rw-r--r--tensorflow/c/c_api_experimental.h21
-rw-r--r--tensorflow/cc/profiler/BUILD3
-rw-r--r--tensorflow/compiler/jit/BUILD4
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc5
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc19
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc34
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h6
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc34
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h7
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc34
-rw-r--r--tensorflow/compiler/tests/BUILD20
-rw-r--r--tensorflow/compiler/tests/jit_test.py36
-rw-r--r--tensorflow/compiler/tests/oom_test.py61
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc1
-rw-r--r--tensorflow/compiler/xla/executable_run_options.cc7
-rw-r--r--tensorflow/compiler/xla/executable_run_options.h4
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc7
-rw-r--r--tensorflow/compiler/xla/service/BUILD15
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc6
-rw-r--r--tensorflow/compiler/xla/service/call_graph.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD23
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc84
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc23
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc128
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h84
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc5
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc35
-rw-r--r--tensorflow/compiler/xla/service/despecializer.h45
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc71
-rw-r--r--tensorflow/compiler/xla/service/service.h4
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc68
-rw-r--r--tensorflow/compiler/xla/xla.proto3
-rw-r--r--tensorflow/contrib/autograph/converters/BUILD12
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py7
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py76
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py72
-rw-r--r--tensorflow/contrib/autograph/converters/converter_test_base.py2
-rw-r--r--tensorflow/contrib/autograph/converters/for_loops.py92
-rw-r--r--tensorflow/contrib/autograph/converters/for_loops_test.py70
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py11
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py3
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD17
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py5
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py179
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow_test.py82
-rw-r--r--tensorflow/contrib/autograph/pyct/ast_util.py2
-rw-r--r--tensorflow/contrib/autograph/pyct/inspect_utils.py6
-rw-r--r--tensorflow/contrib/autograph/pyct/inspect_utils_test.py9
-rw-r--r--tensorflow/contrib/autograph/pyct/qual_names.py23
-rw-r--r--tensorflow/contrib/autograph/pyct/qual_names_test.py15
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity.py18
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py2
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/annos.py1
-rw-r--r--tensorflow/contrib/autograph/utils/__init__.py3
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py68
-rw-r--r--tensorflow/contrib/autograph/utils/multiple_dispatch.py41
-rw-r--r--tensorflow/contrib/autograph/utils/multiple_dispatch_test.py23
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py75
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py8
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD1
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py151
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py20
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py75
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py38
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py20
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py114
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py11
-rw-r--r--tensorflow/contrib/distribute/README.md6
-rw-r--r--tensorflow/contrib/distribute/python/estimator_integration_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py4
-rw-r--r--tensorflow/contrib/distributions/BUILD10
-rw-r--r--tensorflow/contrib/distributions/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py70
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/seed_stream.py228
-rw-r--r--tensorflow/contrib/distributions/python/ops/statistical_testing.py111
-rw-r--r--tensorflow/contrib/eager/python/datasets.py7
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py13
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders.py5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py11
-rw-r--r--tensorflow/contrib/integrate/__init__.py1
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py11
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py123
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py10
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py54
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py7
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py9
-rw-r--r--tensorflow/contrib/lite/java/BUILD39
-rw-r--r--tensorflow/contrib/lite/java/ovic/README.md83
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java35
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc22
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation_test.cc68
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h56
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h55
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h23
-rw-r--r--tensorflow/contrib/lite/model.cc26
-rw-r--r--tensorflow/contrib/lite/model.h26
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc36
-rw-r--r--tensorflow/contrib/lite/toco/args.h1
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc45
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc3
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc14
-rw-r--r--tensorflow/contrib/lite/toco/model_flags.proto6
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc26
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc23
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py8
-rw-r--r--tensorflow/contrib/testing/python/framework/fake_summary_writer.py6
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc68
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc75
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.h1
-rw-r--r--tensorflow/contrib/tpu/profiler/tpu_profiler.proto22
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc24
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc26
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.h7
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc3
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc15
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h5
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc48
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc53
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc24
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.h5
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer_test.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc4
-rw-r--r--tensorflow/core/kernels/BUILD11
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_test.cc9
-rw-r--r--tensorflow/core/kernels/cudnn_rnn_ops.cc46
-rw-r--r--tensorflow/core/kernels/decode_image_op.cc6
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op_test.cc5
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op_test.cc27
-rw-r--r--tensorflow/core/kernels/gather_op_test.cc3
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op_test.cc9
-rw-r--r--tensorflow/core/kernels/quantize_and_dequantize_op_test.cc9
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc3
-rw-r--r--tensorflow/core/kernels/resize_bicubic_op_test.cc6
-rw-r--r--tensorflow/core/kernels/resize_bilinear_op_test.cc18
-rw-r--r--tensorflow/core/kernels/roll_op_test.cc18
-rw-r--r--tensorflow/core/kernels/scatter_nd_op_test.cc27
-rw-r--r--tensorflow/core/kernels/scatter_op_test.cc18
-rw-r--r--tensorflow/core/kernels/sdca_internal.cc5
-rw-r--r--tensorflow/core/kernels/sdca_internal.h7
-rw-r--r--tensorflow/core/kernels/sdca_ops.cc6
-rw-r--r--tensorflow/core/kernels/shape_op_test.cc5
-rw-r--r--tensorflow/core/kernels/softmax_op.cc3
-rw-r--r--tensorflow/core/kernels/softmax_op_gpu.cu.cc3
-rw-r--r--tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc3
-rw-r--r--tensorflow/core/kernels/summary_op_test.cc13
-rw-r--r--tensorflow/core/platform/macros.h17
-rw-r--r--tensorflow/core/profiler/g3doc/profile_model_architecture.md32
-rw-r--r--tensorflow/docs_src/mobile/tflite/devguide.md9
-rw-r--r--tensorflow/docs_src/performance/leftnav_files1
-rw-r--r--tensorflow/examples/image_retraining/retrain.py7
-rw-r--r--tensorflow/python/client/tf_session_helper.cc9
-rw-r--r--tensorflow/python/client/tf_session_helper.h7
-rw-r--r--tensorflow/python/eager/benchmarks_test.py23
-rw-r--r--tensorflow/python/eager/function.py4
-rw-r--r--tensorflow/python/eager/graph_callable.py2
-rw-r--r--tensorflow/python/framework/c_api_util.py26
-rw-r--r--tensorflow/python/framework/function.py10
-rw-r--r--tensorflow/python/framework/function_test.py32
-rw-r--r--tensorflow/python/framework/importer.py14
-rw-r--r--tensorflow/python/framework/ops.py182
-rw-r--r--tensorflow/python/framework/tensor_util.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/mobilenet.py222
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/resnet50.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional.py195
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py1222
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py1
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_test.py38
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py137
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent_test.py16
-rw-r--r--tensorflow/python/keras/layers/__init__.py1
-rw-r--r--tensorflow/python/ops/array_ops.py23
-rw-r--r--tensorflow/python/ops/math_ops.py5
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py7
-rw-r--r--tensorflow/python/pywrap_tfe.i6
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py3
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py13
-rw-r--r--tensorflow/python/training/distribute.py5
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc18
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc216
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h32
-rw-r--r--tensorflow/stream_executor/cuda/cuda_timer.h7
-rw-r--r--tensorflow/stream_executor/dnn.cc4
-rw-r--r--tensorflow/stream_executor/dnn.h22
-rw-r--r--tensorflow/stream_executor/stream.cc36
-rw-r--r--tensorflow/stream_executor/stream.h18
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc14
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h11
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt114
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt187
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt4
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh12
-rw-r--r--tensorflow/workspace.bzl8
-rw-r--r--third_party/png.BUILD12
206 files changed, 5394 insertions, 1928 deletions
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index bea9378571..e82a546092 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -56,57 +56,6 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
}
}
-void TF_InitializeTPU(TF_Session* session, TF_Status* status) {
- VLOG(1) << "Initializing TPU";
- TF_Operation* config_op =
- TF_GraphOperationByName(session->graph, "ConfigureDistributedTPU");
- if (config_op == nullptr) {
- status->status = tensorflow::errors::Internal(
- "Unable to find node ConfigureDistributedTPU in the TF graph.");
- return;
- }
-
- TF_Output config_node{config_op, 0};
-
- TF_Tensor* dummy_output;
- TF_SessionRun(session, /*run_options*/ nullptr,
- // input related parameters
- /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
- // output related parameters
- /*outputs*/ &config_node, /*output_values*/ &dummy_output,
- /*noutputs*/ 1,
- /*targets*/ nullptr, /*ntargets*/ 0,
- /*run_metadata*/ nullptr, status);
- if (status->status.ok()) {
- TF_DeleteTensor(dummy_output);
- }
-}
-
-void TF_ShutdownTPU(TF_Session* session, TF_Status* status) {
- {
- tensorflow::mutex_lock c(session->graph->mu);
- VLOG(1) << "Shutting down TPU, with input graph: "
- << session->graph->graph.ToGraphDefDebug().DebugString();
- }
-
- TF_Operation* shutdown_op =
- TF_GraphOperationByName(session->graph, "ShutdownDistributedTPU");
- if (shutdown_op == nullptr) {
- status->status = tensorflow::errors::Internal(
- "Unable to find node ShutdownDistributedTPU in the TF graph.");
- return;
- }
-
- TF_SessionRun(session, /*run_options*/ nullptr,
- // input related parameters
- /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
- // output related parameters
- /*outputs*/ nullptr, /*output_values*/ nullptr,
- /*noutputs*/ 0,
- /*targets*/ &shutdown_op, /*ntargets*/ 1,
- /*run_metadata*/ nullptr, status);
-}
-
const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
tensorflow::mutex_lock c(graph->mu);
const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index ebcec8176b..666342974e 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -60,27 +60,6 @@ extern "C" {
TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
unsigned char enable);
-// Initializes TPU system. Must be called exactly once before TF_SessionRun() is
-// called on a TPU graph.
-//
-// The session graph must contain a node named ConfigureDistributedTPU.
-// TODO(b/74774824): Improve the API on initializing TPU system.
-TF_CAPI_EXPORT extern void TF_InitializeTPU(TF_Session* session,
- TF_Status* status);
-
-// Shuts down TPU system. For any `session` where TF_InitializeTPU() has
-// been successfully called, this call must be made exactly once before the
-// session is closed.
-// The session graph must contain a node named ShutdownDistributedTPU.
-TF_CAPI_EXPORT extern void TF_ShutdownTPU(TF_Session* session,
- TF_Status* status);
-
-// Returns the graph content in a human-readable format, with length set in
-// `len`. The format is subject to change in the future.
-// The returned string is heap-allocated, and caller should call free() on it.
-TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph,
- size_t* len);
-
// Returns the graph content in a human-readable format, with length set in
// `len`. The format is subject to change in the future.
// The returned string is heap-allocated, and caller should call free() on it.
diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD
index 00799526fc..cf65fe1ab9 100644
--- a/tensorflow/cc/profiler/BUILD
+++ b/tensorflow/cc/profiler/BUILD
@@ -9,6 +9,9 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
tf_cuda_cc_test(
name = "profiler_test",
srcs = ["profiler_test.cc"],
+ tags = [
+ "noguitar", # b/77649654
+ ],
deps = [
":profiler",
"//tensorflow/cc:cc_ops",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 24aa203c00..a492fc6b9b 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -204,14 +204,14 @@ cc_library(
":common",
":xla_compilation_cache",
":xla_tensor",
+ "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
- "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
+ "//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 2d6511a45b..f48941fce3 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -155,6 +155,9 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
options.device_allocator = xla_allocator;
+ // TODO(b/77671268): We don't set variable_representation_shape_fn here. This
+ // is restricted to Variables, but we need something like this to apply to
+ // normal Tensors too.
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
@@ -179,8 +182,10 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
run_options.set_stream(stream);
run_options.set_allocator(xla_allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(ctx->step_id());
Env* env = Env::Default();
auto start_time = env->NowMicros();
+
auto run_result = executable->Run(launch_context.arguments(), run_options);
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 3b631d6f4e..386240ff8d 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -732,11 +732,15 @@ Status MarkForCompilationPass::RunImpl(
}
}
- // Count the number of elements in each cluster.
- std::vector<int> cluster_sizes(graph->num_node_ids());
+ // Count the number of non-trivial elements in each cluster.
+ std::vector<int> effective_cluster_sizes(graph->num_node_ids());
for (const Node* n : compilation_candidates) {
int cluster = clusters[n->id()].Get().representative;
- cluster_sizes[cluster]++;
+ // Identity nodes will be removed if the node gets marked for compilation.
+ // Therefore we don't want to count them towards the effective cluster size.
+ if (n->def().op() != "Identity") {
+ effective_cluster_sizes[cluster]++;
+ }
}
// Names for each cluster.
@@ -769,9 +773,12 @@ Status MarkForCompilationPass::RunImpl(
const XlaOpRegistry::DeviceRegistration* registration;
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);
- // Or compile if this is a cluster of >= min_cluster_size compilable
- // operators.
- if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation ||
+ // Compile if this is a cluster of >= min_cluster_size compilable operators.
+ // Also, always compile if the operator is placed on a device that requires
+ // compilation, or if it contains at least one op that is marked for
+ // compilation that is not an Identity op.
+ if (effective_cluster_sizes[cluster] >= min_cluster_size ||
+ (effective_cluster_sizes[cluster] > 0 && marked_for_compilation) ||
registration->requires_compilation) {
string& name = cluster_names[cluster];
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 2e362e0a63..80edaf28b8 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -16,7 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -575,5 +577,37 @@ TEST(XlaCompilationTest, Retval) {
EXPECT_EQ(clusters["A"], clusters["B"]);
}
+TEST(XlaCompilationTest, DontCountIdentityOps) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ Scope root = Scope::NewRootScope().ExitOnError();
+ {
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::Identity(root.WithOpName("B"), a);
+ auto c = ops::Identity(root.WithOpName("C"), b);
+ auto r = ops::_Retval(root.WithOpName("R"), c, 0);
+ }
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+ auto clusters = GetClusters(*graph);
+
+ EXPECT_TRUE(clusters.empty());
+}
+
+TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ Scope root = Scope::NewRootScope().ExitOnError();
+ {
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::Identity(root.WithOpName("B"), a);
+ b.node()->AddAttr(kXlaCompileAttr, true);
+ auto r = ops::_Retval(root.WithOpName("R"), b, 0);
+ }
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+ auto clusters = GetClusters(*graph);
+
+ EXPECT_TRUE(clusters.empty());
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index 5c0c79b880..be1043d8c3 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -52,13 +52,14 @@ class XlaCompilationCache : public ResourceBase {
// Compiles a function into a XlaCompiler::CompilationResult that can be used
// to execute an XLA Computation. Compilation results are cached.
// `function` is the name of a Tensorflow function to compile.
- // `constant_args` is a maps of tensorflow argument number to constant value.
+ // `constant_args` is a map of tensorflow argument number to its constant
+ // value.
// `variable_args` is a snapshot of the current values of the
// resource variable arguments to `function`; uninitialized variables are
// represented by an absent OptionalTensor.
// The result of compilation is written to `*compilation_result`, which must
// be non-null. If `executable` is non-null, also builds an
- // xla::LocalExecutable and sets `executable to point to it. The resulting
+ // xla::LocalExecutable and sets `executable` to point to it. The resulting
// executable pointer may be null if the computation has no non-constant
// outputs.
Status Compile(const XlaCompiler::Options& options,
@@ -96,6 +97,7 @@ class XlaCompilationCache : public ResourceBase {
xla::LocalExecutable** executable,
const XlaCompiler::CompileOptions* compile_options,
bool compile_single_op);
+
// Takes `result` which has been compiled from a Tensorflow subgraph to a
// XLA computation already, and generates an XLA LocalExecutable `executable`.
Status BuildExecutable(const XlaCompiler::Options& options,
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 6a57831cde..43eb164012 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/platform/mem.h"
@@ -53,8 +54,33 @@ XlaTransferManager::XlaTransferManager(se::Stream* stream,
bool transfer_as_literal)
: stream_(stream),
client_(client),
+ transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal) {}
+Status XlaTransferManager::TransferLiteralToDevice(
+ const Tensor& host_tensor, Tensor* device_tensor) const {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal));
+ VLOG(1) << "Transfer to device as literal: " << literal.ToString();
+
+ const xla::ShapedBuffer& shaped_buffer =
+ XlaTensor::FromTensor(device_tensor)->shaped_buffer();
+ return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal,
+ shaped_buffer);
+}
+
+Status XlaTransferManager::TransferLiteralFromDevice(
+ Tensor* host_tensor, const Tensor& device_tensor) const {
+ const xla::ShapedBuffer& shaped_buffer =
+ XlaTensor::FromTensor(&device_tensor)->shaped_buffer();
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
+ transfer_manager_->TransferLiteralFromDevice(
+ stream_->parent(), shaped_buffer));
+ VLOG(1) << "Transfer from device as literal: " << literal->ToString();
+ return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor);
+}
+
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
@@ -86,9 +112,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
Status status;
if (transfer_as_literal_) {
- status = xla::Unimplemented(
- "XlaTransferManager::CopyCPUTensorToDevice not implemented for "
- "literals");
+ status = TransferLiteralToDevice(*cpu_tensor, device_tensor);
} else {
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
@@ -129,9 +153,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Status status;
if (transfer_as_literal_) {
- status = xla::Unimplemented(
- "XlaTransferManager::CopyDeviceTensorToCPU not implemented for "
- "literals");
+ status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
} else {
stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index a8ad511fbd..ad914a1c23 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -57,11 +57,18 @@ class XlaTransferManager {
perftools::gputools::Stream* stream() const { return stream_; }
private:
+ Status TransferLiteralToDevice(const Tensor& host_tensor,
+ Tensor* device_tensor) const;
+ Status TransferLiteralFromDevice(Tensor* host_tensor,
+ const Tensor& device_tensor) const;
+
// Stream obtained from a Device, used to transfer tensors between
// CPU and device.
perftools::gputools::Stream* stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
+ // Transfer manager, for marshalling data to and from the device.
+ xla::TransferManager* transfer_manager_;
// True if we must use XLA's TransferManager for correct device transfers.
bool transfer_as_literal_;
};
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 354be1e1b5..50b0061d69 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -16,12 +16,14 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
@@ -165,6 +167,8 @@ void XlaComputationLaunchContext::PopulateOutputs(
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString();
+ VLOG(2) << "Result tuple shape (on device): "
+ << output->on_device_shape().DebugString();
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
@@ -179,6 +183,10 @@ void XlaComputationLaunchContext::PopulateOutputs(
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
+ // Manually allocate memory using an XlaTensorBuffer so we can allocate
+ // as much memory as the device requires (as given by
+ // GetByteSizeRequirement). This avoids XlaTransferManager having to
+ // reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
OP_REQUIRES_OK(
@@ -189,15 +197,23 @@ void XlaComputationLaunchContext::PopulateOutputs(
client_, stream->parent()->device_ordinal()));
}
- const void* src_ptr = DMAHelper::base(&const_tensor);
- gpu::DeviceMemoryBase dst_ptr =
- XlaTensor::DeviceMemoryFromTensor(*output_tensor);
- // Memcpying asynchronously is safe for the GPU, but the CPU uses a
- // shared allocator so hold a reference to the copied-to buffer until
- // complete.
- TensorReference ref(*output_tensor);
- stream->ThenMemcpy(&dst_ptr, src_ptr, total_bytes);
- stream->ThenDoHostCallback([ref] { ref.Unref(); });
+ Device* device = dynamic_cast<Device*>(ctx->device());
+ OP_REQUIRES(ctx, device != nullptr,
+ errors::Internal("DeviceBase was not a Device."));
+ ctx->op_device_context()->CopyCPUTensorToDevice(
+ &const_tensor, device, output_tensor,
+ [&](Status status) { TF_CHECK_OK(status); });
+
+ if (device->device_type() == DEVICE_GPU) {
+ // The GPUDeviceContext enqueues the host->device transfer in a
+ // separate stream from the main compute stream. We must ensure the
+ // compute stream is synchronized with the host->device transfer
+ // stream now otherwise we will create a race condition.
+ auto* gpu_device_context =
+ static_cast<GPUDeviceContext*>(ctx->op_device_context());
+ gpu_device_context->stream()->ThenWaitFor(
+ gpu_device_context->host_to_device_stream());
+ }
} else {
// No copy required.
ctx->set_output(i, const_tensor);
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index edabdc218a..e345c1266a 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -192,6 +192,26 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "oom_test",
+ size = "medium",
+ srcs = ["oom_test.py"],
+ disabled_backends = [
+ "cpu",
+ "cpu_ondemand",
+ ],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:array_ops_gen",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradient_checker",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "conv2d_test",
size = "medium",
srcs = ["conv2d_test.py"],
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index f9d87c2d1c..1f7da659e5 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.contrib.compiler import jit
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -38,6 +39,18 @@ from tensorflow.python.platform import test
jit_scope = jit.experimental_jit_scope
+# Disable rewrites to make sure we don't end up having to update this test
+# whenever we implement new ones.
+def NoRewriteSessionConfig():
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ function_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ return config_pb2.ConfigProto(graph_options=graph_options)
+
+
def CompiledKernel(fn, *inputs, **kwargs):
"""Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
name = kwargs.pop("name", None)
@@ -81,7 +94,7 @@ class JitLaunchTest(test.TestCase):
# actually ran. However, it is sometimes possible for _XlaLaunch ops to be
# constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
- with session_lib.Session() as sess:
+ with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
placeholders = []
feeds = {}
for arg in args:
@@ -258,7 +271,7 @@ class XlaCompilationTest(test.TestCase):
def testReshape(self):
"""Tests an operator with compile-time constant and non-constant inputs."""
- with self.test_session() as sess:
+ with self.test_session(config=NoRewriteSessionConfig()) as sess:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.int32)
with jit_scope():
@@ -282,7 +295,7 @@ class XlaCompilationTest(test.TestCase):
def testIgnoredArguments(self):
"""Tests that JIT computations can ignore formal parameters."""
- with self.test_session() as sess:
+ with self.test_session(config=NoRewriteSessionConfig()) as sess:
x = array_ops.placeholder(dtypes.int32)
y = array_ops.placeholder(dtypes.int32)
with jit_scope():
@@ -306,7 +319,7 @@ class XlaCompilationTest(test.TestCase):
def testLoops(self):
"""Tests that compilation accepts computations containing loops."""
- with self.test_session() as session:
+ with self.test_session(config=NoRewriteSessionConfig()) as session:
x = array_ops.placeholder(dtypes.float32)
with jit_scope():
c = lambda i, _: math_ops.less(i, 5)
@@ -324,7 +337,7 @@ class XlaCompilationTest(test.TestCase):
def testCond(self):
"""Tests that compilation handles switch operators."""
- with self.test_session() as session:
+ with self.test_session(config=NoRewriteSessionConfig()) as session:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
c = array_ops.placeholder(dtypes.bool)
@@ -365,7 +378,8 @@ class XlaCompilationTest(test.TestCase):
inp = array_ops.placeholder(dtypes.float32)
out = Entry(inp)
- with self.test_session(graph=g, use_gpu=True) as sess:
+ with self.test_session(
+ config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess:
run_metadata = config_pb2.RunMetadata()
val = sess.run(out,
feed_dict={inp: [2., 10.]},
@@ -377,7 +391,7 @@ class XlaCompilationTest(test.TestCase):
def testLoopDeadlock(self):
"""Regression test for bug that caused deadlocks in graphs with loops."""
- with self.test_session() as session:
+ with self.test_session(config=NoRewriteSessionConfig()) as session:
x = array_ops.placeholder(dtypes.float32)
with jit_scope():
y = x + 1.0
@@ -404,10 +418,10 @@ class XlaCompilationTest(test.TestCase):
y = Forward(x)
dx, = gradients_impl.gradients(y, [x], 1.0)
- cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
- optimizer_options=config_pb2.OptimizerOptions(
- opt_level=config_pb2.OptimizerOptions.L1,
- do_function_inlining=True)))
+ cfg = NoRewriteSessionConfig()
+ cfg.graph_options.optimizer_options.opt_level = (
+ config_pb2.OptimizerOptions.L1)
+ cfg.graph_options.optimizer_options.do_function_inlining = True
with session_lib.Session(graph=g, config=cfg) as sess:
run_metadata = config_pb2.RunMetadata()
dx_val = sess.run(dx,
diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py
new file mode 100644
index 0000000000..1434e965e3
--- /dev/null
+++ b/tensorflow/compiler/tests/oom_test.py
@@ -0,0 +1,61 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for out-of-memory conditions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class OutOfMemoryTest(xla_test.XLATestCase):
+
+ def testOutputOutOfMemory(self):
+ """Allocates tensors until out of memory.
+
+ Generates a large rank-1 tensor. The tensor is an output of an XLA
+ computation, not constant.
+
+ Check that a ResourceExhaustedError is raised and can be caught.
+
+ We spin in a loop generating larger and larger tensors until an OOM event
+ happens. We may be running sandboxed, so have a small host memory limit, so
+ any hardcoded value is unlikely to land in the sweet spot between device
+ memory size and host memory size with stability.
+ """
+
+ def test_loop():
+ size = 2e8
+ while True:
+ with self.test_session():
+ # Force the compiled code to not be constant by feeding in an addend.
+ p = array_ops.placeholder(dtypes.float32, shape=[])
+ with self.test_scope():
+ # Create a large R1 tensor.
+ c = array_ops.zeros([size, 1]) + p
+
+ c.eval(feed_dict={p: 1.0})
+ size *= 2
+
+ self.assertRaises(errors.ResourceExhaustedError, test_loop)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 86c02ac2e6..495d9c6078 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -54,7 +54,6 @@ xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
auto result,
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
cond_builder.get()));
- TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result));
}
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc
index 392ad9010a..1700c97718 100644
--- a/tensorflow/compiler/xla/executable_run_options.cc
+++ b/tensorflow/compiler/xla/executable_run_options.cc
@@ -87,4 +87,11 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const {
return device_assignment_;
}
+ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) {
+ rng_seed_ = rng_seed;
+ return *this;
+}
+
+int ExecutableRunOptions::rng_seed() const { return rng_seed_; }
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h
index d4fcbf0493..2c1d9ffff1 100644
--- a/tensorflow/compiler/xla/executable_run_options.h
+++ b/tensorflow/compiler/xla/executable_run_options.h
@@ -84,6 +84,9 @@ class ExecutableRunOptions {
DeviceAssignment* device_assignment);
const DeviceAssignment* device_assignment() const;
+ ExecutableRunOptions& set_rng_seed(int rng_seed);
+ int rng_seed() const;
+
private:
DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1;
@@ -92,6 +95,7 @@ class ExecutableRunOptions {
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr;
+ int rng_seed_ = 0;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index c8ed3e3a2b..f037663e3f 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -40,6 +40,9 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
flags->set_xla_cpu_multi_thread_eigen(true);
flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
flags->set_xla_eliminate_hlo_implicit_broadcast(true);
+#ifdef INTEL_MKL
+ flags->set_xla_cpu_use_mkl_dnn(true);
+#endif // INTEL_MKL
// Set cudnn batchnorm off by default; it does not provide a performance win
// on average.
@@ -288,6 +291,10 @@ void AllocateFlags() {
flag_values->xla_gpu_use_cudnn_batchnorm(),
"Allows the GPU backend to implement batchnorm HLOs using cudnn, "
"rather than expanding them to a soup of HLOs."),
+ tensorflow::Flag("xla_cpu_use_mkl_dnn",
+ bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
+ flag_values->xla_cpu_use_mkl_dnn(),
+ "Generate calls to MKL-DNN in the CPU backend."),
});
ParseFlagsFromEnv(*flag_objects);
}
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 3a99d84bea..db91e80407 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2640,6 +2640,21 @@ tf_cc_test(
)
cc_library(
+ name = "despecializer",
+ srcs = ["despecializer.cc"],
+ hdrs = ["despecializer.h"],
+ deps = [
+ ":bfloat16_normalization",
+ ":defuser",
+ ":hlo",
+ ":hlo_pass",
+ ":hlo_pass_pipeline",
+ ":implicit_broadcast_remover",
+ "//tensorflow/compiler/xla:statusor",
+ ],
+)
+
+cc_library(
name = "source_map_util",
srcs = ["source_map_util.cc"],
hdrs = ["source_map_util.h"],
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 13eb02ca01..a8053d15e1 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -51,8 +51,8 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) {
return out;
}
-CallContext GetInstructionCallContext(const HloInstruction* instruction) {
- switch (instruction->opcode()) {
+CallContext GetInstructionCallContext(HloOpcode opcode) {
+ switch (opcode) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kWhile:
@@ -101,7 +101,7 @@ void CallGraphNode::AddCallerCallSite(const CallSite& caller_callsite) {
void CallGraphNode::AddCallSiteForInstruction(HloInstruction* instruction) {
CHECK_EQ(instruction->parent(), computation());
- const CallContext context = GetInstructionCallContext(instruction);
+ const CallContext context = GetInstructionCallContext(instruction->opcode());
if (!instruction->called_computations().empty()) {
CHECK(context == CallContext::kSequential ||
context == CallContext::kParallel);
diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h
index 688c4085df..97d3811508 100644
--- a/tensorflow/compiler/xla/service/call_graph.h
+++ b/tensorflow/compiler/xla/service/call_graph.h
@@ -53,7 +53,7 @@ enum class CallContext {
string CallContextToString(CallContext context);
std::ostream& operator<<(std::ostream& out, const CallContext& context);
-CallContext GetInstructionCallContext(const HloInstruction* instruction);
+CallContext GetInstructionCallContext(HloOpcode opcode);
// Represents an HLO instruction which calls one or more computations.
class CallSite {
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 966e2d0fc5..246b802861 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -18,6 +18,10 @@ load(":build_defs.bzl", "runtime_copts")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
+load(
+ "//third_party/mkl:build_defs.bzl",
+ "if_mkl",
+)
# Filegroup used to collect source files for dependency checking.
filegroup(
@@ -170,6 +174,7 @@ cc_library(
":runtime_fft",
":runtime_fork_join",
":runtime_matmul",
+ ":runtime_matmul_mkl",
":runtime_single_threaded_conv2d",
":runtime_single_threaded_matmul",
"@llvm//:execution_engine",
@@ -539,6 +544,22 @@ cc_library(
)
cc_library(
+ name = "runtime_matmul_mkl",
+ srcs = ["runtime_matmul_mkl.cc"],
+ hdrs = ["runtime_matmul_mkl.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ] + if_mkl([
+ "//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn",
+ ]),
+)
+
+cc_library(
name = "runtime_single_threaded_conv2d",
srcs = [
"runtime_conv2d_impl.h",
@@ -584,10 +605,12 @@ cc_library(
tf_cc_test(
name = "cpu_runtime_test",
srcs = ["cpu_runtime_test.cc"],
+ shard_count = 10,
tags = ["optonly"],
deps = [
":cpu_runtime",
":runtime_matmul",
+ ":runtime_matmul_mkl",
":runtime_single_threaded_matmul",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 9a3bd68c80..872b0be1f8 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -37,6 +37,14 @@ extern const char* const kEigenMatMulF32SymbolName =
"__xla_cpu_runtime_EigenMatMulF32";
extern const char* const kEigenMatMulF64SymbolName =
"__xla_cpu_runtime_EigenMatMulF64";
+extern const char* const kMKLMatMulF32SymbolName =
+ "__xla_cpu_runtime_MKLMatMulF32";
+extern const char* const kMKLMatMulF64SymbolName =
+ "__xla_cpu_runtime_MKLMatMulF64";
+extern const char* const kMKLSingleThreadedMatMulF32SymbolName =
+ "__xla_cpu_runtime_MKLSingleThreadedMatMulF32";
+extern const char* const kMKLSingleThreadedMatMulF64SymbolName =
+ "__xla_cpu_runtime_MKLSingleThreadedMatMulF64";
extern const char* const kEigenConvF16SymbolName =
"__xla_cpu_runtime_EigenConvF16";
extern const char* const kEigenConvF32SymbolName =
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index e61d6ea28b..e392e231b4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -44,6 +44,10 @@ namespace runtime {
extern const char* const kEigenMatMulF16SymbolName;
extern const char* const kEigenMatMulF32SymbolName;
extern const char* const kEigenMatMulF64SymbolName;
+extern const char* const kMKLMatMulF32SymbolName;
+extern const char* const kMKLMatMulF64SymbolName;
+extern const char* const kMKLSingleThreadedMatMulF32SymbolName;
+extern const char* const kMKLSingleThreadedMatMulF64SymbolName;
extern const char* const kEigenConvF16SymbolName;
extern const char* const kEigenConvF32SymbolName;
extern const char* const kEigenFftSymbolName;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
index f385829cdf..2ac950e6d9 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
@@ -130,25 +131,23 @@ MatMulShape MatMulShapes[] = {
// * transpose_lhs
// * transpose_rhs
// * single_threaded
-using EigenMatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
+using MatMulTestParam = std::tuple<MatMulShape, bool, bool, bool>;
-class EigenMatMulTest
- : public CpuRuntimeTest,
- public ::testing::WithParamInterface<EigenMatMulTestParam> {
+class EigenMatMulTest : public CpuRuntimeTest,
+ public ::testing::WithParamInterface<MatMulTestParam> {
public:
- static string Name(
- const ::testing::TestParamInfo<EigenMatMulTestParam>& info) {
+ static string Name(const ::testing::TestParamInfo<MatMulTestParam>& info) {
MatMulShape shape = std::get<0>(info.param);
bool transpose_lhs = std::get<1>(info.param);
bool transpose_rhs = std::get<2>(info.param);
bool single_threaded = std::get<3>(info.param);
return tensorflow::strings::Printf(
- "MatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
+ "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
single_threaded ? "single" : "multi");
}
-}; // namespace xla
+};
TEST_P(EigenMatMulTest, DoIt) {
MatMulShape shape = std::get<0>(GetParam());
@@ -169,5 +168,74 @@ INSTANTIATE_TEST_CASE_P(EigenMatMulTestInstantiaion, EigenMatMulTest,
::testing::Bool()),
EigenMatMulTest::Name);
+#ifdef INTEL_MKL
+class MKLMatMulTest : public CpuRuntimeTest,
+ public ::testing::WithParamInterface<MatMulTestParam> {
+ public:
+ static string Name(const ::testing::TestParamInfo<MatMulTestParam>& info) {
+ MatMulShape shape = std::get<0>(info.param);
+ bool transpose_lhs = std::get<1>(info.param);
+ bool transpose_rhs = std::get<2>(info.param);
+ bool single_threaded = std::get<3>(info.param);
+
+ return tensorflow::strings::Printf(
+ "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n,
+ transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "",
+ single_threaded ? "single" : "multi");
+ }
+};
+
+std::unique_ptr<Array2D<float>> MKLMatrixMultiply(const Array2D<float>& a,
+ const Array2D<float>& b,
+ bool transpose_lhs,
+ bool transpose_rhs,
+ bool single_threaded) {
+ CHECK_EQ(a.width(), b.height());
+ int64 m = a.height();
+ int64 n = b.width();
+ int64 k = a.width();
+
+ // The MKL matmul runtime function expects the matrix to be in column major
+ // order and array2d is in row-major order. Create transposes of a and b. The
+ // 'data' buffer in the transposed array is the original array in column major
+ // order.
+ auto a_transpose = MaybeTransposeArray2D(a, !transpose_lhs);
+ auto b_transpose = MaybeTransposeArray2D(b, !transpose_rhs);
+
+ // Since we're going to transpose c before returning it, swap the order of the
+ // dimension sizes to ensure the returned array is properly dimensioned.
+ auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+ if (single_threaded) {
+ __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
+ nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(),
+ m, n, k, transpose_lhs, transpose_rhs);
+ } else {
+ __xla_cpu_runtime_MKLMatMulF32(nullptr, c_transpose->data(),
+ a_transpose->data(), b_transpose->data(), m,
+ n, k, transpose_lhs, transpose_rhs);
+ }
+ return MaybeTransposeArray2D(*c_transpose, true);
+}
+
+TEST_P(MKLMatMulTest, DoIt) {
+ MatMulShape shape = std::get<0>(GetParam());
+ bool transpose_lhs = std::get<1>(GetParam());
+ bool transpose_rhs = std::get<2>(GetParam());
+ bool single_threaded = std::get<3>(GetParam());
+
+ auto a = MakeLinspaceArray2D(0.0, 1.0, shape.m, shape.k);
+ auto b = MakeLinspaceArray2D(-2.0, 2.0, shape.k, shape.n);
+ auto c =
+ MKLMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs, single_threaded);
+ CheckMatrixMultiply(*a, *b, *c);
+}
+
+INSTANTIATE_TEST_CASE_P(MKLMatMulTestInstantiaion, MKLMatMulTest,
+ ::testing::Combine(::testing::ValuesIn(MatMulShapes),
+ ::testing::Bool(), ::testing::Bool(),
+ ::testing::Bool()),
+ MKLMatMulTest::Name);
+#endif // INTEL_MKL
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 8b1e20d79e..29afd8ea5f 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -918,28 +918,35 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
// The two transpose_... parameters are actually booleans, but we use int32
// to avoid target-dependent calling convention details.
- bool multi_threaded_eigen =
+ bool multi_threaded =
hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
+ bool use_mkl_dnn = hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn();
PrimitiveType type = target_array_.GetShape().element_type();
llvm::Type* float_type;
const char* fn_name;
switch (type) {
case F16:
- fn_name = multi_threaded_eigen
+ fn_name = multi_threaded
? runtime::kEigenMatMulF16SymbolName
: runtime::kEigenSingleThreadedMatMulF16SymbolName;
float_type = ir_builder_->getHalfTy();
break;
case F32:
- fn_name = multi_threaded_eigen
- ? runtime::kEigenMatMulF32SymbolName
- : runtime::kEigenSingleThreadedMatMulF32SymbolName;
+ fn_name = multi_threaded
+ ? (use_mkl_dnn ? runtime::kMKLMatMulF32SymbolName
+ : runtime::kEigenMatMulF32SymbolName)
+ : (use_mkl_dnn
+ ? runtime::kMKLSingleThreadedMatMulF32SymbolName
+ : runtime::kEigenSingleThreadedMatMulF32SymbolName);
float_type = ir_builder_->getFloatTy();
break;
case F64:
- fn_name = multi_threaded_eigen
- ? runtime::kEigenMatMulF64SymbolName
- : runtime::kEigenSingleThreadedMatMulF64SymbolName;
+ fn_name = multi_threaded
+ ? (use_mkl_dnn ? runtime::kMKLMatMulF64SymbolName
+ : runtime::kEigenMatMulF64SymbolName)
+ : (use_mkl_dnn
+ ? runtime::kMKLSingleThreadedMatMulF64SymbolName
+ : runtime::kEigenSingleThreadedMatMulF64SymbolName);
float_type = ir_builder_->getDoubleTy();
break;
default:
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
new file mode 100644
index 0000000000..92da5f71c2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
@@ -0,0 +1,128 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
+#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
+#include "third_party/intel_mkl_ml/include/mkl_service.h"
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/types.h"
+
+#define EIGEN_USE_THREADS
+#include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
+
+using tensorflow::int32;
+using tensorflow::int64;
+
+namespace {
+// BLAS GEMM API for 32-bit Matrix Multiplication.
+
+// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
+// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0.
+// Matrix lhs, rhs and out are all colum-major.
+void MatMulF32(const void* run_options_ptr, float* out, float* lhs, float* rhs,
+ int64 m, int64 n, int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
+ const float alpha = 1.0f, beta = 0.0f;
+ // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c,
+ // respectively. For column-major matrices, the leading dimension is the
+ // stride between consecutive columns (which equals the number of rows). If
+ // the matrix is transposed, the leading dimension is the stride between
+ // consecutive rows (which equals the number of columns).
+ int lda = transpose_lhs ? k : m;
+ int ldb = transpose_rhs ? n : k;
+ int ldc = m;
+ cblas_sgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans,
+ transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs,
+ lda, rhs, ldb, beta, out, ldc);
+}
+
+// BLAS GEMM API for 64-bit Matrix Multiplication.
+
+// MatMul function is defined as: c = alpha * op(a) * op(b) + beta * c.
+// Since XLA MatMul does not used alpha, beta, we set them to 1.0 and 0.0.
+// Matrix lhs, rhs and out are all colum-major.
+void MatMulF64(const void* run_options_ptr, double* out, double* lhs,
+ double* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
+ const float alpha = 1.0f, beta = 0.0f;
+ // lda, ldb, and ldc are the leading dimensions of matrices a, b, and c,
+ // respectively. For a column-major matrix, the leading dimension is the
+ // stride between consecutive columns (which equals the number of rows). If
+ // the matrix is transposed, the leading dimension is the stride between
+ // consecutive rows (which equals the number of columns).
+ int lda = transpose_lhs ? k : m;
+ int ldb = transpose_rhs ? n : k;
+ int ldc = m;
+ cblas_dgemm(CblasColMajor, transpose_lhs ? CblasTrans : CblasNoTrans,
+ transpose_rhs ? CblasTrans : CblasNoTrans, m, n, k, alpha, lhs,
+ lda, rhs, ldb, beta, out, ldc);
+}
+
+} // namespace
+
+void __xla_cpu_runtime_MKLMatMulF32(const void* run_options_ptr, float* out,
+ float* lhs, float* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
+ const xla::ExecutableRunOptions* run_options =
+ static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+ // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
+ // number specified in intra_op_thread_pool to MKL.
+ int prev_num_threads = mkl_set_num_threads_local(
+ run_options->intra_op_thread_pool()->numThreads());
+ MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+ // Set thread number back to the previous number.
+ mkl_set_num_threads_local(prev_num_threads);
+}
+// BLAS GEMM API for 64-bit Matrix Multiplication
+void __xla_cpu_runtime_MKLMatMulF64(const void* run_options_ptr, double* out,
+ double* lhs, double* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
+ const xla::ExecutableRunOptions* run_options =
+ static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+ // BLAS GEMM MatMul uses OpenMP for parallelization, so we pass the thread
+ // number specified in intra_op_thread_pool to MKL.
+ int prev_num_threads = mkl_set_num_threads_local(
+ run_options->intra_op_thread_pool()->numThreads());
+ MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+ // Set thread number back to the previous number.
+ mkl_set_num_threads_local(prev_num_threads);
+}
+void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(const void* run_options_ptr,
+ float* out, float* lhs,
+ float* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
+ // Set the thread number to 1 for single threaded excution.
+ int prev_num_threads = mkl_set_num_threads_local(1);
+ MatMulF32(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+ // Set thread number back to the previous number.
+ mkl_set_num_threads_local(prev_num_threads);
+}
+void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(const void* run_options_ptr,
+ double* out, double* lhs,
+ double* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
+ // Set the thread number to 1 for single threaded excution.
+ int prev_num_threads = mkl_set_num_threads_local(1);
+ MatMulF64(nullptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs);
+ // Set thread number back to the previous number.
+ mkl_set_num_threads_local(prev_num_threads);
+}
+#endif // INTEL_MKL
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h
new file mode 100644
index 0000000000..831b796efb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
+
+#include <iostream>
+#include "tensorflow/core/platform/types.h"
+#ifdef INTEL_MKL
+#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
+
+extern void __xla_cpu_runtime_MKLMatMulF32(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+ float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+extern void __xla_cpu_runtime_MKLMatMulF64(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+ double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+ float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+ double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+
+#else
+extern void __xla_cpu_runtime_MKLMatMulF32(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+ float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs) {
+ std::cerr << "Attempt to call MKL MatMul runtime library without defining "
+ "INTEL_MKL. Add --config=mkl to build with MKL.";
+ exit(1);
+}
+extern void __xla_cpu_runtime_MKLMatMulF64(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+ double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs) {
+ std::cerr << "Attempt to call MKL MatMul runtime library without defining "
+ "INTEL_MKL. Add --config=mkl to build with MKL.";
+ exit(1);
+}
+extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF32(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+ float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs) {
+ std::cerr << "Attempt to call MKL MatMul runtime library without defining "
+ "INTEL_MKL. Add --config=mkl to build with MKL.";
+ exit(1);
+}
+extern void __xla_cpu_runtime_MKLSingleThreadedMatMulF64(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+ double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs) {
+ std::cerr << "Attempt to call MKL MatMul runtime library without defining "
+ "INTEL_MKL. Add --config=mkl to build with MKL.";
+ exit(1);
+}
+
+#endif // INTEL_MKL
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_MKL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index 4198260a22..b7ce5bbe47 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
@@ -183,6 +184,10 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
+ REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
+ REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
new file mode 100644
index 0000000000..d938f3a2c4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/despecializer.h"
+
+#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
+#include "tensorflow/compiler/xla/service/defuser.h"
+#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h"
+
+namespace xla {
+
+Despecializer::Despecializer() : pipeline_("despecializer") {
+ // TODO(b/70588125): Also deal with window reversal in a fast way.
+ pipeline_.AddPass<Defuser>();
+ pipeline_.AddPass<ImplicitBroadcastRemover>();
+ pipeline_.AddPass<BFloat16MixedPrecisionRemoval>();
+}
+
+StatusOr<bool> Despecializer::Run(HloModule* module) {
+ return pipeline_.Run(module);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
new file mode 100644
index 0000000000..af48f4ab6e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// Creates an HloPassPipeline containing multiple HloPasses that can
+// despecialize an optimized HloModule. This is useful to run an HloModule
+// optimized for one specfic platform on a different platform (undoing platform
+// specific passes) with matching numerics for comparison.
+//
+// Current despecialization passes are Defuser, ImplicitBroadcastRemover,
+// and BFloat16MixedPrecisionRemoval.
+class Despecializer : public HloPassInterface {
+ public:
+ Despecializer();
+ tensorflow::StringPiece name() const override { return "despecializer"; }
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ HloPassPipeline pipeline_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DESPECIALIZER_H_
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc
index 2b6caa1494..85409b330b 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc
@@ -93,7 +93,7 @@ Status FlattenNode(const CallGraphNode& node) {
auto current = worklist.back();
worklist.pop_back();
for (auto* instruction : current->instructions()) {
- if (GetInstructionCallContext(instruction) !=
+ if (GetInstructionCallContext(instruction->opcode()) !=
CallContext::kSequential) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 9d7251b6ae..53ad8909c5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -202,6 +202,25 @@ void IterateThroughWindow(
} while (IndexUtil::BumpIndices(window_shape, &window_index));
}
+// Creates a vector of multipliers which can be used to create a linear index
+// into shape.
+//
+// Given the multidimensional index {i1, ..., iN} and
+// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply
+//
+// LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N].
+//
+// This lets you calculate LI given the multidimensional indices in any order.
+DimensionVector MakeDimMultipliers(const Shape& shape) {
+ DimensionVector v(ShapeUtil::Rank(shape));
+ int64 scale = 1;
+ for (auto dim : LayoutUtil::MinorToMajor(shape)) {
+ v[dim] = scale;
+ scale *= shape.dimensions(dim);
+ }
+ return v;
+}
+
} // namespace
template <typename ReturnT, typename ElementwiseT>
@@ -999,25 +1018,30 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const Shape& window_shape =
ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
- DimensionVector lhs_index(lhs_rank);
- DimensionVector rhs_index(rhs_rank);
+ DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
+ DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
+
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
+ auto lhs_literal_data = lhs_literal.data<ReturnT>();
+ auto rhs_literal_data = rhs_literal.data<ReturnT>();
+
auto func = [&](ArraySlice<int64> out_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
-
- std::fill(lhs_index.begin(), lhs_index.end(), 0);
- std::fill(rhs_index.begin(), rhs_index.end(), 0);
std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
- lhs_index[input_batch_dim] = out_index[output_batch_dim];
- rhs_index[kernel_output_z_dim] = out_index[output_z_dim];
-
// Convolve input feature with kernel.
do {
for (int64 iz = 0; iz < z_size; ++iz) {
- lhs_index[input_z_dim] = iz;
- rhs_index[kernel_input_z_dim] = iz;
+ int64 lhs_linear_index = 0;
+ lhs_linear_index += out_index[output_batch_dim] *
+ lhs_dim_multipliers[input_batch_dim];
+ lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
+
+ int64 rhs_linear_index = 0;
+ rhs_linear_index += out_index[output_z_dim] *
+ rhs_dim_multipliers[kernel_output_z_dim];
+ rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
@@ -1042,29 +1066,32 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
// Calculate the actual lhs (input) index after dilation. As an
// optimization, skip this integer divide if there's no dilation.
+ int64 lhs_spatial_index;
if (window_dim.base_dilation() > 1) {
- lhs_index[input_spatial_dim] =
- undilated_index / window_dim.base_dilation();
+ lhs_spatial_index = undilated_index / window_dim.base_dilation();
} else {
- lhs_index[input_spatial_dim] = undilated_index;
+ lhs_spatial_index = undilated_index;
}
+ lhs_linear_index +=
+ lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
- // Skip if input index is not in bound.
- if (!(lhs_index[input_spatial_dim] >= 0 &&
- lhs_index[input_spatial_dim] <
+ // Skip if input index is not in bounds.
+ if (!(lhs_spatial_index >= 0 &&
+ lhs_spatial_index <
lhs_shape.dimensions(input_spatial_dim))) {
goto cnt;
}
- rhs_index[dnums.kernel_spatial_dimensions(ki)] =
- window_dim.window_reversal()
- ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
- : rhs_spatial_index[ki];
+ rhs_linear_index +=
+ (window_dim.window_reversal()
+ ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
+ : rhs_spatial_index[ki]) *
+ rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
}
result_val +=
- static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
- static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
+ static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) *
+ static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
}
cnt : {}
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index e09d58bbe7..9fa72c1b8c 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -300,8 +300,6 @@ class Service : public ServiceInterface {
Service(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend);
- static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
-
// Resolves the given argument handles in the allocation tracker and returns
// the corresponding allocations for every replica. The function also verifies
// that each allocation matches the execution platform and device ordinal of
@@ -437,8 +435,6 @@ class Service : public ServiceInterface {
CompilationCache compilation_cache_;
// Backend to compile and execute computations on.
- //
- // TODO(b/28616830): Support multiple backends for execution.
std::unique_ptr<Backend> execute_backend_;
TF_DISALLOW_COPY_AND_ASSIGN(Service);
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 821432ef7d..68f75d50cb 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -160,27 +160,38 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
return std::move(literal);
}
-// Matches binary addition computations.
-bool LooksLikeSum(const HloComputation& computation) {
+enum class ConstantType { kUnknown, kZero, kOne };
+
+// Return the constant type required by this computation, if known.
+ConstantType GetInitValue(const HloComputation& computation) {
const HloInstruction* const root = computation.root_instruction();
- return root->opcode() == HloOpcode::kAdd &&
- computation.num_parameters() == 2 &&
- root->operand(0)->opcode() == HloOpcode::kParameter &&
- root->operand(1)->opcode() == HloOpcode::kParameter &&
- root->operand(0) != root->operand(1);
+ if (computation.num_parameters() != 2 ||
+ root->operand(0)->opcode() != HloOpcode::kParameter ||
+ root->operand(1)->opcode() != HloOpcode::kParameter ||
+ root->operand(0) == root->operand(1)) {
+ return ConstantType::kUnknown;
+ }
+
+ switch (root->opcode()) {
+ case HloOpcode::kAdd:
+ return ConstantType::kZero;
+ case HloOpcode::kMultiply:
+ return ConstantType::kOne;
+ default:
+ return ConstantType::kUnknown;
+ }
}
-// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition,
-// which requires an init_value of 0 rather than a random value.
-bool NeedsZeroInitValue(const HloUse& use) {
+// Reduce, ReduceWindow, and SelectAndScatter ops may need a non-random
+// initialization value.
+bool NeedsInitValue(const HloUse& use) {
const HloInstruction* const instruction = use.instruction;
const HloOpcode opcode = instruction->opcode();
const int64 op_num = use.operand_number;
return (
((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) &&
- op_num == 1 && LooksLikeSum(*instruction->to_apply())) ||
- (opcode == HloOpcode::kSelectAndScatter && op_num == 2 &&
- LooksLikeSum(*instruction->scatter())));
+ op_num == 1) ||
+ (opcode == HloOpcode::kSelectAndScatter && op_num == 2));
}
// Generate random values that are constrained to the input_shape minus the
@@ -222,7 +233,7 @@ std::vector<HloInstruction*> FindConstrainedUses(
auto fused_uses = FindConstrainedUses(dataflow, *to_analyze);
constrained_uses.insert(constrained_uses.end(), fused_uses.begin(),
fused_uses.end());
- } else if (NeedsZeroInitValue(use)) {
+ } else if (NeedsInitValue(use)) {
constrained_uses.push_back(instruction);
} else if (opcode == HloOpcode::kConvert ||
opcode == HloOpcode::kReducePrecision) {
@@ -243,7 +254,8 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
HloInstruction* needs_index = nullptr;
- HloInstruction* needs_zero = nullptr;
+ HloInstruction* needs_constant = nullptr;
+ ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
@@ -258,8 +270,13 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
+ needs_constant = use;
+ constant_type = GetInitValue(*use->to_apply());
+ break;
+
case HloOpcode::kSelectAndScatter:
- needs_zero = use;
+ needs_constant = use;
+ constant_type = GetInitValue(*use->scatter());
break;
default:
@@ -268,17 +285,26 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
use->ToString().c_str());
}
}
- if (needs_index != nullptr && needs_zero != nullptr) {
+ if (needs_index != nullptr && needs_constant != nullptr) {
return Unimplemented(
"Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
- "zero: %s\n",
- needs_index->ToString().c_str(), needs_zero->ToString().c_str());
+ "constant: %s\n",
+ needs_index->ToString().c_str(), needs_constant->ToString().c_str());
}
if (needs_index != nullptr) {
return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
needs_index->shape(), engine);
- } else if (needs_zero != nullptr) {
- return Literal::CreateFromShape(param.shape());
+ } else if (needs_constant != nullptr) {
+ switch (constant_type) {
+ case ConstantType::kZero:
+ return Literal::Zero(param.shape().element_type()).CloneToUnique();
+ case ConstantType::kOne:
+ return Literal::One(param.shape().element_type()).CloneToUnique();
+ case ConstantType::kUnknown:
+ // We want the identity element for the computation, but we don't really
+ // know what it is - so any value we generate will be just as wrong.
+ return MakeFakeLiteralInternal(param.shape(), engine);
+ }
} else {
return MakeFakeLiteralInternal(param.shape(), engine);
}
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 5cb18113e5..f9943f71d3 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -189,6 +189,9 @@ message DebugOptions {
// directory.
string xla_dump_per_pass_hlo_proto_to = 96;
+ // Generate calls to MKL-DNN in the CPU backend.
+ bool xla_cpu_use_mkl_dnn = 97;
+
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD
index c5a0dc1095..8f9bffa55e 100644
--- a/tensorflow/contrib/autograph/converters/BUILD
+++ b/tensorflow/contrib/autograph/converters/BUILD
@@ -24,7 +24,6 @@ py_library(
"continue_statements.py",
"control_flow.py",
"decorators.py",
- "for_loops.py",
"ifexp.py",
"list_comprehension.py",
"lists.py",
@@ -49,6 +48,7 @@ py_library(
visibility = ["//tensorflow:__subpackages__"],
deps = [
":converters",
+ "//tensorflow/contrib/autograph/operators",
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/contrib/autograph/pyct/static_analysis",
"//tensorflow/contrib/autograph/utils",
@@ -133,16 +133,6 @@ py_test(
)
py_test(
- name = "for_loops_test",
- srcs = ["for_loops_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":test_lib",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
name = "name_scopes_test",
srcs = ["name_scopes_test.py"],
deps = [
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 48026bccab..62115d4005 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -32,6 +32,7 @@ class BreakCanonicalizationTransformer(transformer.Base):
def __init__(self, context):
super(BreakCanonicalizationTransformer, self).__init__(context)
# This is a stack structure, to correctly process nested loops.
+ # Each item is a list [break_used, break_variable_name]
self.break_uses = []
def _create_break_check(self):
@@ -99,9 +100,9 @@ class BreakCanonicalizationTransformer(transformer.Base):
self.break_uses.append([False, break_var])
node.body = self._manual_visit_list(node.body)
if self.break_uses[-1][0]:
- anno.setanno(node, 'extra_cond',
- gast.UnaryOp(gast.Not(),
- gast.Name(break_var, gast.Load(), None)))
+ extra_cond = templates.replace_as_expression(
+ 'not var_name', var_name=break_var)
+ anno.setanno(node, 'extra_cond', extra_cond)
final_nodes = [self._create_break_init(), node]
else:
final_nodes = node
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index 49d932026f..55a28e8ac3 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -22,6 +22,7 @@ import gast
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
@@ -49,11 +50,6 @@ class ControlFlowTransformer(transformer.Base):
def __init__(self, context):
super(ControlFlowTransformer, self).__init__(context)
- # pylint:disable=invalid-name
-
- def visit_For(self, node):
- assert False, 'for statement should have been canonicalized at this point'
-
def _create_cond_branch(self, body_name, aliased_orig_names,
aliased_new_names, body, returns):
if aliased_orig_names:
@@ -170,6 +166,13 @@ class ControlFlowTransformer(transformer.Base):
body_closure = body_scope.modified - body_scope.created
all_referenced = body_scope.referenced
+ cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE)
+ cond_closure = set()
+ for s in cond_scope.referenced:
+ for root in s.support_set:
+ if root not in body_scope.created:
+ cond_closure.add(root)
+
state = list(body_closure)
if not state:
# TODO(mdan): Implement this properly.
@@ -204,7 +207,8 @@ class ControlFlowTransformer(transformer.Base):
def body_name(state_ssf):
body
return state_ssf,
- state_ast_tuple = autograph_utils.run_while(test_name, body_name, [state])
+ state_ast_tuple = __ops.while_loop(
+ test_name, body_name, (state,), (extra_deps,))
"""
node = templates.replace(
template,
@@ -216,11 +220,67 @@ class ControlFlowTransformer(transformer.Base):
test=test,
body_name=self.context.namer.new_symbol('loop_body',
body_scope.referenced),
- body=node_body)
+ body=node_body,
+ extra_deps=tuple(s.ast() for s in cond_closure),
+ )
return node
- # pylint:enable=invalid-name
+ def visit_For(self, node):
+ self.generic_visit(node)
+
+ body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ body_closure = body_scope.modified - body_scope.created
+ all_referenced = body_scope.referenced
+
+ state = list(body_closure)
+
+ state_ssf = [
+ self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
+ ]
+ ssf_map = {
+ name: ssf
+ for name, ssf in zip(state, state_ssf)
+ if str(name) != ssf
+ }
+
+ if len(state) == 1:
+ state = state[0]
+ state_ssf = state_ssf[0]
+ state_ast_tuple = state
+ else:
+ state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
+
+ node_body = ast_util.rename_symbols(node.body, ssf_map)
+ if anno.hasanno(node, 'extra_cond'):
+ extra_cond = anno.getanno(node, 'extra_cond')
+ extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
+ else:
+ extra_cond = parser.parse_expression('True')
+
+ template = """
+ def extra_cond_name(state_ssf):
+ return extra_cond_expr
+ def body_name(iterate, state_ssf):
+ body
+ return state_ssf,
+ state_ast_tuple = __ops.for_loop(
+ iterated, extra_cond_name, body_name, (state,))
+ """
+ node = templates.replace(
+ template,
+ state=state,
+ state_ssf=state_ssf,
+ state_ast_tuple=state_ast_tuple,
+ iterated=node.iter,
+ iterate=node.target,
+ extra_cond_name=self.context.namer.new_symbol('extra_cond',
+ all_referenced),
+ extra_cond_expr=extra_cond,
+ body_name=self.context.namer.new_symbol('loop_body', all_referenced),
+ body=node_body)
+
+ return node
def transform(node, context):
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index 86fed51f27..c5610b16b4 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.contrib.autograph.converters import control_flow
from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
@@ -94,6 +95,77 @@ class ControlFlowTest(converter_test_base.TestCase):
with self.test_session() as sess:
self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
+ def test_simple_for(self):
+
+ def test_fn(l):
+ s1 = 0
+ s2 = 0
+ for e in l:
+ s1 += e
+ s2 += e * e
+ return s1, s2
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ with self.test_session() as sess:
+ l = [1, 2, 3]
+ self.assertEqual(
+ test_fn(l), sess.run(result.test_fn(constant_op.constant(l))))
+ l = []
+ self.assertEqual(
+ test_fn(l),
+ sess.run(
+ result.test_fn(
+ constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
+
+ def test_for_single_var(self):
+
+ def test_fn(l):
+ s = 0
+ for e in l:
+ s += e
+ return s
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ with self.test_session() as sess:
+ l = [1, 2, 3]
+ self.assertEqual(
+ test_fn(l), sess.run(result.test_fn(constant_op.constant(l))))
+ l = []
+ self.assertEqual(
+ test_fn(l),
+ sess.run(
+ result.test_fn(
+ constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
+
+ def test_for_with_iterated_expression(self):
+
+ eval_count = [0]
+
+ def count_evals(x):
+ eval_count[0] += 1
+ return x
+
+ def test_fn(n):
+ s = 0
+ for e in count_evals(range(n)):
+ s += e
+ return s
+
+ node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
+ node = control_flow.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ result.count_evals = count_evals
+ self.assertEqual(test_fn(5), result.test_fn(5))
+ # count_evals ran twice, once for test_fn and another for result.test_fn
+ self.assertEqual(eval_count[0], 2)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/converters/converter_test_base.py b/tensorflow/contrib/autograph/converters/converter_test_base.py
index 3ea2cfd668..6f75e9a529 100644
--- a/tensorflow/contrib/autograph/converters/converter_test_base.py
+++ b/tensorflow/contrib/autograph/converters/converter_test_base.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import contextlib
import imp
+from tensorflow.contrib.autograph import operators
from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import context
@@ -77,6 +78,7 @@ class TestCase(test.TestCase):
result.tf = self.make_fake_mod('fake_tf', *symbols)
result.autograph_utils = utils
result.autograph_api = self.make_fake_mod('fake_api', converted_call)
+ result.__dict__['__ops'] = operators
yield result
except Exception: # pylint:disable=broad-except
if source is None:
diff --git a/tensorflow/contrib/autograph/converters/for_loops.py b/tensorflow/contrib/autograph/converters/for_loops.py
deleted file mode 100644
index 4999c47bdc..0000000000
--- a/tensorflow/contrib/autograph/converters/for_loops.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Canonicalizes for loops into while loops.
-
-This canonicalizer uses the len function on its argument. That should be
-converted to a tf.shape separately.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
-
-
-class ForLoopCanonicalizationTransformer(transformer.Base):
- """Canonicalizes for loops (e.g. into while loops)."""
-
- def __init__(self, context):
- super(ForLoopCanonicalizationTransformer, self).__init__(context)
-
- def visit_For(self, node):
- self.generic_visit(node)
- body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- i_var = self.context.namer.new_symbol('i', body_scope.referenced)
- smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter',
- body_scope.referenced)
- cont_var = self.context.namer.new_symbol('cont', body_scope.referenced)
- # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
- if anno.hasanno(node, 'extra_cond'):
- template = """
- i = 0
- smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter)
- cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
- while cont and extra_cond:
- body
- i += 1
- cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
- """
- return templates.replace(
- template,
- loop_iter=node.iter,
- target=node.target,
- body=node.body,
- i=i_var,
- smart_loop_iter=smart_loop_iter_var,
- cont=cont_var,
- extra_cond=anno.getanno(node, 'extra_cond'))
- else:
- template = """
- i = 0
- smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter)
- cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
- while cont:
- body
- i += 1
- cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
- """
- repl = templates.replace(
- template,
- loop_iter=node.iter,
- target=node.target,
- body=node.body,
- i=i_var,
- smart_loop_iter=smart_loop_iter_var,
- cont=cont_var)
- return repl
-
- def visit_Continue(self, node):
- assert False, 'continue statement should be desugared at this point'
-
- def visit_Break(self, node):
- assert False, 'break statement should be desugared at this point'
-
-
-def transform(node, context):
- return ForLoopCanonicalizationTransformer(context).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/for_loops_test.py b/tensorflow/contrib/autograph/converters/for_loops_test.py
deleted file mode 100644
index 943f52de55..0000000000
--- a/tensorflow/contrib/autograph/converters/for_loops_test.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for for_loops module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.autograph.converters import converter_test_base
-from tensorflow.contrib.autograph.converters import for_loops
-from tensorflow.python.platform import test
-
-
-class ControlFlowTest(converter_test_base.TestCase):
-
- def test_basic_for(self):
-
- def test_fn(l):
- s = 0
- for e in l:
- s += e
- return s
-
- node = self.parse_and_analyze(test_fn, {})
- node = for_loops.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- l = [1, 2, 3]
- self.assertEqual(test_fn(l), result.test_fn(l))
- l = []
- self.assertEqual(test_fn(l), result.test_fn(l))
-
- def test_for_with_iterated_expression(self):
-
- eval_count = [0]
-
- def count_evals(x):
- eval_count[0] += 1
- return x
-
- def test_fn(n):
- s = 0
- for e in count_evals(range(n)):
- s += e
- return s
-
- node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
- node = for_loops.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- result.count_evals = count_evals
- self.assertEqual(test_fn(5), result.test_fn(5))
- # count_evals ran twice, once for test_fn and another for result.test_fn
- self.assertEqual(eval_count[0], 2)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index ee2d301d75..f9db07778a 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -37,8 +37,12 @@ class ApiTest(test.TestCase):
def setUp(self):
config.COMPILED_IMPORT_STATEMENTS = (
'from __future__ import print_function',
- 'from tensorflow.contrib.autograph import utils as '
- 'autograph_utils', 'tf = autograph_utils.fake_tf()')
+ 'from tensorflow.contrib.autograph import utils'
+ ' as autograph_utils',
+ 'from tensorflow.contrib.autograph import operators'
+ ' as __ops',
+ 'tf = autograph_utils.fake_tf()',
+ )
def test_decorator_recurses(self):
@@ -197,8 +201,7 @@ class ApiTest(test.TestCase):
compiled_code = api.to_code(test_fn)
- # Just check for some key words and that it is parseable Python code.
- self.assertRegexpMatches(compiled_code, 'autograph_utils\\.run_while')
+ # Just check that it is parseable Python code.
self.assertIsNotNone(parser.parse_str(compiled_code))
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index 62a49cd92d..3bacc94300 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -28,7 +28,6 @@ from tensorflow.contrib.autograph.converters import call_trees
from tensorflow.contrib.autograph.converters import continue_statements
from tensorflow.contrib.autograph.converters import control_flow
from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.converters import for_loops
from tensorflow.contrib.autograph.converters import ifexp
from tensorflow.contrib.autograph.converters import lists
from tensorflow.contrib.autograph.converters import logical_expressions
@@ -324,8 +323,6 @@ def node_to_graph(node, ctx, nocompile_decorators):
node = _static_analysis_pass(node, ctx)
node = lists.transform(node, ctx)
- node = for_loops.transform(node, ctx)
- # for_loops may insert new global references.
node = builtin_functions.transform(node, ctx)
node = _static_analysis_pass(node, ctx)
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 7856c253bd..4c62468575 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -2,6 +2,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load("//tensorflow:tensorflow.bzl", "py_test")
+
filegroup(
name = "all_files",
srcs = glob(
@@ -18,8 +20,21 @@ py_library(
name = "operators",
srcs = [
"__init__.py",
+ "control_flow.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
- deps = [],
+ deps = [
+ "//tensorflow/contrib/autograph/utils",
+ ],
+)
+
+py_test(
+ name = "control_flow_test",
+ srcs = ["control_flow_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
)
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index c3f4cab69e..04b4734551 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -22,3 +22,8 @@ closures for the body.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
+# TODO(mdan): Add a container for implementation-specific toggles (throughout).
+
+from tensorflow.contrib.autograph.operators.control_flow import for_loop
+from tensorflow.contrib.autograph.operators.control_flow import while_loop
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
new file mode 100644
index 0000000000..5b8cb2d63c
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -0,0 +1,179 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow statements: loops, conditionals, etc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.utils import builtins
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+
+
+def for_loop(iterated, extra_cond, loop_body, init_state):
+ """Functional form of a for statement.
+
+ The loop operates on a so-called state, which includes all symbols that are
+ variant across loop iterations, excluding the iterate. In what follows we
+ refer to state as either a tuple of entities that represent an actual state,
+ or a list of arguments of the corresponding types.
+
+ Args:
+ iterated: The entity being iterated over.
+ extra_cond: Callable with the state as arguments, and boolean return type.
+ An additionnal loop condition.
+ loop_body: Callable with the iterate and the state as arguments, and
+ state as return type. The actual loop body.
+ init_state: Tuple containing the initial state.
+
+ Returns:
+ Tuple containing the final state.
+ """
+ if tensor_util.is_tensor(iterated):
+ return _known_len_for_loop(iterated, extra_cond, loop_body, init_state)
+ elif isinstance(iterated, dataset_ops.Dataset):
+ return _dataset_for_loop(iterated, extra_cond, loop_body, init_state)
+ else:
+ return _py_for_loop(iterated, extra_cond, loop_body, init_state)
+
+
+def _py_for_loop(iterated, extra_cond, loop_body, init_state):
+ """Overload of for_loop that executes a Python for loop."""
+ state = init_state
+ for iterate in iterated:
+ if not extra_cond(*state):
+ break
+ state = loop_body(iterate, *state)
+
+ # TODO(mdan): Remove this special case.
+ if len(state) == 1:
+ return state[0]
+ return state
+
+
+def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
+ """Overload of for_loop that iterates over objects that define a length."""
+ n = builtins.dynamic_len(iterated)
+
+ def while_body(iterate_index, *state):
+ iterate = iterated[iterate_index]
+ new_state = loop_body(iterate, *state)
+ return (iterate_index + 1,) + new_state
+
+ def while_cond(iterate_index, *state):
+ return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state))
+
+ results = while_loop(
+ while_cond,
+ while_body,
+ init_state=(0,) + init_state,
+ extra_deps=(iterated,))
+ # Dropping the iteration index because it's not syntactically visible.
+ results = results[1:]
+
+ # TODO(mdan): Remove this special case.
+ if len(results) == 1:
+ return results[0]
+ return results
+
+
+def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
+ """Overload of for_loop that iterates over TF Datasets."""
+ # Because Datsets only expose get_next, in the style of Python iterators,
+ # we are forced to unpack the loop as:
+ #
+ # epoch_number, iterate = ds.get_next()
+ # while epoch_number < 2:
+ # <body>
+ # epoch_number, iterate = ds.get_next()
+ epoch_numbers = dataset_ops.Dataset.range(2)
+ def tag_with(ds, tag):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(tag).repeat(), ds))
+ ds_with_epoch = epoch_numbers.flat_map(lambda i: tag_with(ds, i))
+
+ iterator = ds_with_epoch.make_initializable_iterator()
+ with ops.control_dependencies((iterator.initializer,)):
+ epoch_number, iterate = iterator.get_next()
+
+ def while_body(epoch_number, iterate, *state):
+ new_state = loop_body(iterate, *state)
+ epoch_number, iterate = iterator.get_next()
+ return (epoch_number, iterate) + new_state
+
+ def while_cond(epoch_number, iterate, *state):
+ del iterate
+ return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state))
+
+ results = while_loop(
+ while_cond,
+ while_body,
+ init_state=(epoch_number, iterate) + init_state,
+ extra_deps=())
+ # Dropping the epoch number and iterate because they are not not syntactically
+ # visible.
+ results = results[2:]
+
+ # TODO(mdan): Remove this special case.
+ if len(results) == 1:
+ return results[0]
+ return results
+
+
+def while_loop(loop_cond, loop_body, init_state, extra_deps):
+ """Functional form of a while statement.
+
+ The loop operates on a so-called state, which includes all symbols that are
+ variant across loop iterations. In what follows we refer to state as either
+ a tuple of entities that represent an actual state, or a list of arguments
+ of the corresponding types.
+
+ Args:
+ loop_cond: Callable with the state as arguments, and boolean return type.
+ The loop condition.
+ loop_body: Callable with the state as arguments, and state as return type.
+ The actual loop body.
+ init_state: Tuple containing the initial state.
+ extra_deps: Tuple containing additional entities on which the loop may
+ depend, such as loop invariants referenced by loop_cond. Used
+ exclusively for dispatch control.
+
+ Returns:
+ Tuple containing the final state.
+ """
+ # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch.
+ # That could be somethins as simple as a collection of dispatch rules, with
+ # some prioritization.
+ if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
+ return _tf_while_loop(loop_cond, loop_body, init_state)
+ else:
+ return _py_while_loop(loop_cond, loop_body, init_state)
+
+
+def _tf_while_loop(loop_cond, loop_body, init_state):
+ """Overload of while_loop that stages a TF while_loop."""
+ return control_flow_ops.while_loop(loop_cond, loop_body, init_state)
+
+
+def _py_while_loop(loop_cond, loop_body, init_state):
+ """Overload of while_loop that executes a Python while loop."""
+ state = init_state
+ while loop_cond(*state):
+ state = loop_body(*state)
+ return state
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py
new file mode 100644
index 0000000000..9112b1627f
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/control_flow_test.py
@@ -0,0 +1,82 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for control_flow module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph import operators
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ForLoopTest(test.TestCase):
+
+ def test_tensor(self):
+ s = operators.for_loop(
+ constant_op.constant([1, 2, 3, 4]),
+ extra_cond=lambda s: True,
+ loop_body=lambda i, s: (s + i,),
+ init_state=(0,))
+ with self.test_session() as sess:
+ self.assertEqual((10,), sess.run(s))
+
+ def test_python(self):
+ s = operators.for_loop(
+ range(5),
+ extra_cond=lambda s: True,
+ loop_body=lambda i, s: (s + i,),
+ init_state=(0,))
+ self.assertEqual(10, s)
+
+ def test_dataset(self):
+ to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
+ s = operators.for_loop(
+ dataset_ops.Dataset.range(5).map(to_int32),
+ extra_cond=lambda s: True,
+ loop_body=lambda i, s: (s + i,),
+ init_state=(0,))
+ with self.test_session() as sess:
+ self.assertEqual((10,), sess.run(s))
+
+
+class WhileLoopTest(test.TestCase):
+
+ def test_tensor(self):
+ n = constant_op.constant(5)
+ results = operators.while_loop(
+ loop_cond=lambda i, s: i < n,
+ loop_body=lambda i, s: (i + 1, s + i,),
+ init_state=(0, 0),
+ extra_deps=(n,))
+ with self.test_session() as sess:
+ self.assertEqual((5, 10), sess.run(results))
+
+ def test_python(self):
+ n = 5
+ results = operators.while_loop(
+ loop_cond=lambda i, s: i < n,
+ loop_body=lambda i, s: (i + 1, s + i),
+ init_state=(0, 0),
+ extra_deps=(n,))
+ self.assertEqual((5, 10), results)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py
index 4f76a69522..4a70bab440 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util.py
+++ b/tensorflow/contrib/autograph/pyct/ast_util.py
@@ -28,7 +28,7 @@ from tensorflow.contrib.autograph.pyct import anno
class CleanCopier(gast.NodeVisitor):
"""Copy AST nodes.
- The copied nodes will ignore almost all fields that prefixed by '__'.
+ The copied nodes will ignore almost all fields that are prefixed by '__'.
Exceptions make some annotations.
"""
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py
index d19c6ed75e..30a5961821 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils.py
+++ b/tensorflow/contrib/autograph/pyct/inspect_utils.py
@@ -74,6 +74,12 @@ def getmethodclass(m):
ValueError: if the class could not be resolved for any unexpected reason.
"""
+ # Callable objects: return their own class.
+ if (not hasattr(m, '__name__') and hasattr(m, '__class__') and
+ hasattr(m, '__call__')):
+ if isinstance(m.__class__, six.class_types):
+ return m.__class__
+
# Instance method and class methods: should be bound to a non-null "self".
# If self is a class, then it's a class method.
if hasattr(m, '__self__'):
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
index ddca6f963b..eda3fc13fd 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
@@ -225,6 +225,15 @@ class InspectUtilsTest(test.TestCase):
inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
LocalClass)
+ def test_getmethodclass_callables(self):
+ class TestCallable(object):
+
+ def __call__(self):
+ pass
+
+ c = TestCallable()
+ self.assertEqual(inspect_utils.getmethodclass(c), TestCallable)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py
index 4d5764a974..583cf7ecd7 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names.py
+++ b/tensorflow/contrib/autograph/pyct/qual_names.py
@@ -112,6 +112,29 @@ class QN(object):
raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0])
return self._parent
+ @property
+ def support_set(self):
+ """Returns the set of simple symbols that this QN relies on.
+
+ This would be the smallest set of symbols necessary for the QN to
+ statically resolve (assuming properties and index ranges are verified
+ at runtime).
+
+ Examples:
+ 'a.b' has only one support symbol, 'a'
+ 'a[i]' has two roots, 'a' and 'i'
+ """
+ # TODO(mdan): This might be the set of Name nodes in the AST. Track those?
+ roots = set()
+ if self.has_attr():
+ roots.update(self.parent.support_set)
+ elif self.has_subscript():
+ roots.update(self.parent.support_set)
+ roots.update(self.qn[1].support_set)
+ else:
+ roots.add(self)
+ return roots
+
def __hash__(self):
return hash(self.qn + (self._has_attr, self._has_subscript))
diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/contrib/autograph/pyct/qual_names_test.py
index 103bd25aa3..264afd508c 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names_test.py
+++ b/tensorflow/contrib/autograph/pyct/qual_names_test.py
@@ -154,6 +154,21 @@ class QNTest(test.TestCase):
a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3)))
self.assertEqual(a_sub_three.ast().slice.value.n, 3)
+ def test_support_set(self):
+ a = QN('a')
+ b = QN('b')
+ c = QN('c')
+ a_sub_b = QN(a, subscript=b)
+ a_dot_b = QN(a, attr='b')
+ a_dot_b_dot_c = QN(a_dot_b, attr='c')
+ a_dot_b_sub_c = QN(a_dot_b, subscript=c)
+
+ self.assertSetEqual(a.support_set, set((a,)))
+ self.assertSetEqual(a_sub_b.support_set, set((a, b)))
+ self.assertSetEqual(a_dot_b.support_set, set((a,)))
+ self.assertSetEqual(a_dot_b_dot_c.support_set, set((a,)))
+ self.assertSetEqual(a_dot_b_sub_c.support_set, set((a, c)))
+
class QNResolverTest(test.TestCase):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
index da6a2f6f05..6dd53091fa 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
@@ -265,10 +265,10 @@ class ActivityAnalizer(transformer.Base):
qn = QN(node.name)
self.scope.mark_write(qn)
current_scope = self.scope
- fndef_scope = Scope(current_scope, isolated=True)
- self.scope = fndef_scope
+ body_scope = Scope(current_scope, isolated=True)
+ self.scope = body_scope
self.generic_visit(node)
- anno.setanno(node, NodeAnno.BODY_SCOPE, fndef_scope)
+ anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope)
self.scope = current_scope
return node
@@ -282,7 +282,13 @@ class ActivityAnalizer(transformer.Base):
return node
def visit_If(self, node):
+ current_scope = self.scope
+ cond_scope = Scope(current_scope, isolated=False)
+ self.scope = cond_scope
self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
+ self.scope = current_scope
+
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
@@ -297,7 +303,13 @@ class ActivityAnalizer(transformer.Base):
return node
def visit_While(self, node):
+ current_scope = self.scope
+ cond_scope = Scope(current_scope, isolated=False)
+ self.scope = cond_scope
self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
+ self.scope = current_scope
+
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
index 37c28872bb..1e6c686b01 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
@@ -204,6 +204,8 @@ class ActivityAnalizerTest(test.TestCase):
self.assertScopeIsRmc(
anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
('b', 'c'), ('a', 'b', 'c'))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
def test_for(self):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
index 5254b83ca7..d6d9f7e1a6 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
@@ -43,6 +43,7 @@ class NodeAnno(NoValue):
# Scopes
# Scopes are represented by objects of type activity.Scope.
ARGS_SCOPE = 'The scope for the argument list of a function call.'
+ COND_SCOPE = 'The scope for the test node of a conditional statement.'
BODY_SCOPE = (
'The scope for the main body of a statement (True branch for if '
'statements, main body for loops).')
diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py
index 22898b17e9..817d4126d1 100644
--- a/tensorflow/contrib/autograph/utils/__init__.py
+++ b/tensorflow/contrib/autograph/utils/__init__.py
@@ -19,8 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin
-from tensorflow.contrib.autograph.utils.builtins import dynamic_dataset
-from tensorflow.contrib.autograph.utils.builtins import dynamic_for_cond
from tensorflow.contrib.autograph.utils.builtins import dynamic_print
from tensorflow.contrib.autograph.utils.builtins import dynamic_range
from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
@@ -28,7 +26,6 @@ from tensorflow.contrib.autograph.utils.misc import alias_tensors
from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not
from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond
-from tensorflow.contrib.autograph.utils.multiple_dispatch import run_while
from tensorflow.contrib.autograph.utils.py_func import wrap_py_func
from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append
from tensorflow.contrib.autograph.utils.testing import fake_tf
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
index c6af0e4d13..7fbb7c09d8 100644
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ b/tensorflow/contrib/autograph/utils/builtins.py
@@ -24,10 +24,8 @@ import six
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.contrib.autograph.utils import type_check
-from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_inspect
@@ -106,69 +104,3 @@ def dynamic_print(*values):
return py_func.wrap_py_func(
flushed_print, None, values, use_dummy_return=True)
-
-
-def dynamic_dataset(iterated):
- """Implementartion of smart tf.data.Dataset epoch wrapping.
-
- The function checks if the input is a tf.data.Dataset and if so then wraps it
- so that for each element it returns it also returns the current epoch the
- dataset iteration is in, for two epochs. If the input is not a
- tf.data.Dataset then it just returns the input.
-
- Args:
- iterated: The iterable or tf.data.Dataset that is being iterated over.
- Returns:
- Either just the untouched input, or in the case of input being a
- tf.data.Dataset then it returns a wrapped tf.data.Dataset where for each
- element it returns it also returns the current epoch the dataset iteration
- is in.
- """
- if not isinstance(iterated, dataset_ops.Dataset):
- return iterated
-
- def epoch_dataset_number_helper(i):
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(i).repeat(), iterated))
-
- epoch_numbers = dataset_ops.Dataset.range(2)
- return epoch_numbers.flat_map(epoch_dataset_number_helper)
-
-
-def dynamic_for_cond(iteration, iterated):
- """Implementartion of smart while-loop condition using dynamic dispatch.
-
- The function checks if it is iterating over a tf.data.Dataset or not, and in
- the case it is not then it simply returns if we are still in range of the
- iterated and the next element. If it is iterating over a dataset then it only
- iterates for a single epoch.
-
- Args:
- iteration: The current iteration of the loop.
- iterated: The iterable or tf.data.Dataset that is being iterated over.
- Returns:
- A tuple of a bool that indicates whether the loop should continue, and the
- next element in iterated.
- """
- # TODO(znado): Clean up.
- # TODO(znado): This won't work for unpacked iterates. Fix.
- if isinstance(iterated, dataset_ops.Dataset):
- curr_epoch, next_elem = iterated.make_one_shot_iterator().get_next()
- return math_ops.less(curr_epoch, 1), next_elem
- elif tensor_util.is_tensor(iterated):
- if iterated.shape.ndims > 1:
- elem_shape = array_ops.shape(iterated)[1:]
- else:
- elem_shape = ()
- if iterated.shape.ndims == 0 or iterated.shape[0] == 0:
- return False, array_ops.zeros(elem_shape, iterated.dtype)
- return control_flow_ops.cond(
- math_ops.less(iteration, dynamic_len(iterated)),
- lambda: (True, iterated[iteration]),
- lambda: (False, array_ops.zeros(elem_shape, iterated.dtype)))
- elif hasattr(iterated, '__len__'):
- if iteration < len(iterated):
- return True, iterated[iteration]
- return False, None
- else:
- raise NotImplementedError('Python iterators not yet supported.')
diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch.py b/tensorflow/contrib/autograph/utils/multiple_dispatch.py
index 47049255f3..70eef5676f 100644
--- a/tensorflow/contrib/autograph/utils/multiple_dispatch.py
+++ b/tensorflow/contrib/autograph/utils/multiple_dispatch.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import six
-
from tensorflow.contrib.autograph.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
@@ -66,42 +64,3 @@ def py_cond(condition, true_fn, false_fn):
if len(results) == 1:
return results[0]
return results
-
-
-def run_while(cond_fn, body_fn, init_args):
- """Type-dependent functional while loop.
-
- Args:
- cond_fn: A Python callable implementing the stop conditions of the loop.
- body_fn: A Python callable implementing the body of the loop.
- init_args: The initial values of the arguments that will be passed to both
- cond_fn and body_fn.
-
- Returns:
- result: A list of values with the same shape and type as init_args. If any
- of the init_args, or any variables closed-over in cond_fn are Tensors,
- tf.while_loop will be used, otherwise a Python while loop will be ran.
-
- Raises:
- ValueError: if init_args is not a tuple or list with one or more elements.
- """
- if not isinstance(init_args, (tuple, list)) or not init_args:
- raise ValueError(
- 'init_args must be a non-empty list or tuple, found %s' % init_args)
-
- # TODO(alexbw): statically determine all active variables in cond_fn,
- # and pass them directly
- closure_vars = tuple(
- [c.cell_contents for c in six.get_function_closure(cond_fn) or []])
- possibly_tensors = tuple(init_args) + closure_vars
- if is_tensor(*possibly_tensors):
- return control_flow_ops.while_loop(cond_fn, body_fn, init_args)
- else:
- return py_while_loop(cond_fn, body_fn, init_args)
-
-
-def py_while_loop(cond_fn, body_fn, init_args):
- state = init_args
- while cond_fn(*state):
- state = body_fn(*state)
- return state
diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py
index e6a41bb416..f72f8e94a0 100644
--- a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py
+++ b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py
@@ -70,29 +70,6 @@ class MultipleDispatchTest(test.TestCase):
out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn)
self.assertEqual(sess.run(out), 3)
- def test_run_while_python(self):
- cond_fn = lambda x, t, s: x > t
- body_fn = lambda x, t, s: (x * s, t, s)
-
- x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5])
- self.assertEqual(x, 0.75)
-
- x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5])
- self.assertEqual(x, 3.0)
-
- def test_run_while_tf(self):
- cond_fn = lambda x, t, s: x > t
- body_fn = lambda x, t, s: (x * s, t, s)
-
- with Session() as sess:
- x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn,
- [constant(3.0), 1.0, 0.5])
- self.assertEqual(sess.run(x), 0.75)
-
- x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn,
- [constant(3.0), 4.0, 0.5])
- self.assertEqual(sess.run(x), 3.0)
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 95c5c920aa..5a2771229d 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -61,11 +61,13 @@ class TPUClusterResolver(ClusterResolver):
return False
return True
- def _inGke(self):
+ @staticmethod
+ def _inGke():
"""When running in GKE, the environment variable will be set."""
return _GKE_ENV_VARIABLE in os.environ
- def _gkeMaster(self):
+ @staticmethod
+ def _gkeMaster():
return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
def __init__(self,
@@ -119,8 +121,9 @@ class TPUClusterResolver(ClusterResolver):
'Using multiple TPUs in a single session is not yet implemented')
tpu = tpu[0]
+ in_gke = self._inGke()
# When using GKE with Cloud TPUs, the env variable will be set.
- if tpu is None and self._inGke():
+ if tpu is None and in_gke:
tpu = self._gkeMaster()
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
@@ -158,7 +161,8 @@ class TPUClusterResolver(ClusterResolver):
self._service = service
self._coordinator_name = coordinator_name
- if coordinator_name and not coordinator_address and should_resolve:
+ if coordinator_name and not coordinator_address and (should_resolve or
+ in_gke):
self._start_local_server()
else:
self._coordinator_address = coordinator_address
@@ -204,31 +208,50 @@ class TPUClusterResolver(ClusterResolver):
Raises:
RuntimeError: If the provided TPU is not healthy.
"""
- if not self._shouldResolve():
- return server_lib.ClusterSpec({})
-
- full_name = 'projects/%s/locations/%s/nodes/%s' % (
- self._project, self._zone, compat.as_text(self._tpu))
- request = self._service.projects().locations().nodes().get(name=full_name)
- response = request.execute()
-
- if 'health' in response and response['health'] != 'HEALTHY':
- raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
- response['health']))
-
- if 'networkEndpoints' in response:
- worker_list = [
- '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
- for endpoint in response['networkEndpoints']
- ]
+ ############################################################################
+ # There are 5 potential cases this code must handle:
+ # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
+ # a. Create a ClusterSpec that includes the coordinator job
+ # b. Create a ClusterSpec without the coordinator job.
+ # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
+ # tasks and
+ # a. Create a ClusterSpec with the coordinator
+ # b. Create a ClusterSpec without the coordinator
+ # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
+ ############################################################################
+
+ if self._shouldResolve():
+ # Case 1.
+ full_name = 'projects/%s/locations/%s/nodes/%s' % (
+ self._project, self._zone, compat.as_text(self._tpu))
+ request = self._service.projects().locations().nodes().get(name=full_name)
+ response = request.execute()
+
+ if 'health' in response and response['health'] != 'HEALTHY':
+ raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
+ response['health']))
+
+ if 'networkEndpoints' in response:
+ worker_list = [
+ '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
+ for endpoint in response['networkEndpoints']
+ ]
+ else:
+ # Fall back to the deprecated response format
+ instance_url = '%s:%s' % (response['ipAddress'], response['port'])
+ worker_list = [instance_url]
+
+ cluster_spec = {self._job_name: worker_list}
else:
- # Fall back to the deprecated response format
- instance_url = '%s:%s' % (response['ipAddress'], response['port'])
- worker_list = [instance_url]
-
- cluster_spec = {self._job_name: worker_list}
+ if not self._tpu.startswith(compat.as_bytes('grpc://')):
+ # Case 3.
+ return server_lib.ClusterSpec({})
+ # Case 2.
+ cluster_spec = {self._job_name: [self._tpu[len(
+ compat.as_bytes('grpc://')):]]}
if self._coordinator_address:
+ # {1, 2}.a
cluster_spec[self._coordinator_name] = [self._coordinator_address]
return server_lib.ClusterSpec(cluster_spec)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index e1e3e6867a..dff7a03b68 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -362,14 +362,10 @@ class TPUClusterResolverTest(test.TestCase):
def testGkeEnvironment(self):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
- tpu_cluster_resolver = TPUClusterResolver()
- self.assertTrue(tpu_cluster_resolver._inGke())
+ self.assertTrue(TPUClusterResolver._inGke())
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
- compat.as_bytes(tpu_cluster_resolver._gkeMaster()))
- self.assertEqual(
- compat.as_bytes('grpc://10.120.27.5:8470'),
- compat.as_bytes(tpu_cluster_resolver.get_master()))
+ compat.as_bytes(TPUClusterResolver._gkeMaster()))
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index 8b5d13f725..d68015ae15 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -25,6 +25,7 @@ tf_custom_op_py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/contrib/eager/python:checkpointable_utils",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 9897c31a98..9cc6ca09ad 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import argparse
import collections
+import functools
import itertools
import os
import sys
@@ -28,13 +29,14 @@ import numpy as np
from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
+from tensorflow.contrib.eager.python import checkpointable_utils
from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
-from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_nn_ops
@@ -265,7 +267,7 @@ def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None):
return outputs, (output_state_fw, output_state_bw)
-class CudnnRNNTestBasic(TensorFlowTestCase):
+class CudnnRNNTestBasic(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@@ -467,7 +469,7 @@ class CudnnRNNTestBasic(TensorFlowTestCase):
# TODO(jamesqin): Transform to parameterized test after it is included in the
# TF open source codebase.
-class CudnnRNNTestSaveRestore(TensorFlowTestCase):
+class CudnnRNNTestSaveRestore(test_util.TensorFlowTestCase):
def _CompareWeights(self, lhs, rhs):
self.assertEqual(len(lhs), len(rhs))
@@ -701,9 +703,146 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase):
self._TestSaveRestoreHelper(CUDNN_RNN_RELU)
+class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
+
+ def _VerifyCheckpoint(
+ self, checkpoint_path, compatible_cell_fn, cudnn_cell_fn,
+ num_layers, input_size, expected_variable_values, num_applications=3):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with ops.device("gpu:0"):
+ cudnn_layer = cudnn_cell_fn()
+ cudnn_checkpoint = checkpointable_utils.Checkpoint(cell=cudnn_layer)
+ status = cudnn_checkpoint.restore(checkpoint_path)
+ inputs = 3. * array_ops.ones([num_applications, num_layers, input_size],
+ dtype=dtypes.float32)
+ cudnn_output, _ = cudnn_layer(inputs)
+ status.assert_consumed().run_restore_ops()
+ second_save_path = cudnn_checkpoint.save(checkpoint_prefix)
+ restore_layer = compatible_cell_fn()
+ restore_layer_checkpoint = checkpointable_utils.Checkpoint(
+ cell=restore_layer)
+ status = restore_layer_checkpoint.restore(second_save_path)
+ current_state = restore_layer.zero_state(1, dtypes.float32)
+ for _ in range(num_applications):
+ restore_layer_output, current_state = restore_layer(
+ inputs=3. * array_ops.ones([1, input_size]),
+ state=current_state)
+ status.assert_consumed().run_restore_ops()
+ self.assertTrue(restore_layer.variables)
+ for variable, expected_value in zip(
+ restore_layer.variables, expected_variable_values):
+ self.assertAllClose(expected_value, self.evaluate(variable))
+ self.assertAllClose(self.evaluate(restore_layer_output),
+ self.evaluate(cudnn_output)[-1, -1:, ...])
+
+ def _CheckpointableSingleCellUnidirectionalTestTemplate(
+ self, single_cell_fn, cudnn_cell_fn):
+ # Single-layer cuDNN cells with object-based checkpointing should be
+ # checkpoint compatible with either single CudnnCompatible cells or
+ # MultiRnnCells with one cell.
+ input_size = 3
+ save_cell_layer = single_cell_fn()
+ save_cell_layer(
+ inputs=array_ops.ones([1, input_size]),
+ state=save_cell_layer.zero_state(1, dtypes.float32))
+ self.assertTrue(save_cell_layer.variables)
+ expected_values = []
+ np.random.seed(10)
+ for variable in save_cell_layer.variables:
+ value = np.random.normal(size=variable.shape)
+ expected_values.append(value)
+ self.evaluate(variable.assign(value))
+ save_checkpoint = checkpointable_utils.Checkpoint(cell=save_cell_layer)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ first_save_path = save_checkpoint.save(checkpoint_prefix)
+ self._VerifyCheckpoint(
+ checkpoint_path=first_save_path,
+ compatible_cell_fn=
+ lambda: rnn_cell_impl.MultiRNNCell([single_cell_fn()]),
+ cudnn_cell_fn=cudnn_cell_fn,
+ num_layers=1,
+ expected_variable_values=expected_values,
+ input_size=input_size)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ @test_util.run_in_graph_and_eager_modes()
+ def testLSTMCheckpointableSingleLayer(self):
+ num_units = 2
+ direction = CUDNN_RNN_UNIDIRECTION
+ self._CheckpointableSingleCellUnidirectionalTestTemplate(
+ single_cell_fn=functools.partial(
+ cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units),
+ cudnn_cell_fn=functools.partial(
+ cudnn_rnn.CudnnLSTM, num_layers=1, num_units=num_units,
+ direction=direction, name="awesome_lstm"))
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ @test_util.run_in_graph_and_eager_modes()
+ def testGRUCheckpointableSingleLayer(self):
+ num_units = 2
+ direction = CUDNN_RNN_UNIDIRECTION
+ with self.assertRaises(NotImplementedError):
+ # TODO(allenl): Implement object-based saving for GRUs and other cells.
+ self._CheckpointableSingleCellUnidirectionalTestTemplate(
+ single_cell_fn=functools.partial(
+ cudnn_rnn_ops.CudnnCompatibleGRUCell, num_units=num_units),
+ cudnn_cell_fn=functools.partial(
+ cudnn_rnn.CudnnGRU, num_layers=1, num_units=num_units,
+ direction=direction, name="awesome_gru"))
+
+ def _CheckpointableMultiLayerTestTemplate(
+ self, single_cell_fn, cudnn_cell_fn, num_layers):
+
+ def _MultiCellFn():
+ return rnn_cell_impl.MultiRNNCell(
+ [single_cell_fn() for _ in range(num_layers)])
+ input_size = 3
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(graph=save_graph):
+ save_layer = _MultiCellFn()
+ save_layer(inputs=array_ops.ones([1, input_size]),
+ state=save_layer.zero_state(1, dtypes.float32))
+ self.assertTrue(save_layer.variables)
+ expected_values = []
+ np.random.seed(10)
+ for variable in save_layer.variables:
+ value = np.random.normal(size=variable.shape)
+ expected_values.append(value)
+ self.evaluate(variable.assign(value))
+ save_checkpoint = checkpointable_utils.Checkpoint(cell=save_layer)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ first_save_path = save_checkpoint.save(checkpoint_prefix)
+ self._VerifyCheckpoint(
+ checkpoint_path=first_save_path,
+ compatible_cell_fn=_MultiCellFn, cudnn_cell_fn=cudnn_cell_fn,
+ num_layers=num_layers,
+ expected_variable_values=expected_values,
+ input_size=input_size)
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ @test_util.run_in_graph_and_eager_modes()
+ def testCudnnCompatibleLSTMCheckpointablMultiLayer(self):
+ num_units = 2
+ num_layers = 3
+ direction = CUDNN_RNN_UNIDIRECTION
+ self._CheckpointableMultiLayerTestTemplate(
+ single_cell_fn=functools.partial(
+ cudnn_rnn_ops.CudnnCompatibleLSTMCell, num_units=num_units),
+ cudnn_cell_fn=functools.partial(
+ cudnn_rnn.CudnnLSTM, num_layers=num_layers, num_units=num_units,
+ direction=direction, name="awesome_lstm"),
+ num_layers=num_layers)
+
+
# TODO(jamesqin): Transform to parameterized test after it is included in the
# TF open source codebase.
-class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase):
+class CudnnRNNTestCompatibleRNNCells(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
@@ -884,7 +1023,7 @@ class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase):
rtol=2e-5)
-class CudnnRNNTestParamsSize(TensorFlowTestCase):
+class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase):
def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size,
dtype, direction):
@@ -931,7 +1070,7 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase):
dtype, direction)
-class CudnnRNNTestTraining(TensorFlowTestCase):
+class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1):
"""Compute the numeric gradient of y wrt to x.
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
index 36fba917a8..00d9544602 100644
--- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
+++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
@@ -142,6 +142,9 @@ class _CudnnRNN(base_layer.Layer):
"""
# pylint:enable=line-too-long
+ # TODO(allenl): Document object-based saving and checkpoint compatibility once
+ # it's implemented for more cuDNN Layers.
+
# The following are constants defined by subclasses.
# Type of RNN cell.
_rnn_mode = None
@@ -363,6 +366,11 @@ class _CudnnRNN(base_layer.Layer):
self._create_saveable()
self.built = True
+ def _gather_saveables_for_checkpoint(self):
+ raise NotImplementedError(
+ "This cell does not yet support object-based saving. File a feature "
+ "request if this limitation bothers you.")
+
def call(self, inputs, initial_state=None, training=True):
"""Runs the forward step for the RNN model.
@@ -499,6 +507,8 @@ class _CudnnRNN(base_layer.Layer):
direction=self.direction,
scope=vs.get_variable_scope(),
name="%s_saveable" % self.trainable_variables[0].name.split(":")[0])
+ self._saveable._add_checkpointable_dependencies( # pylint: disable=protected-access
+ checkpointable=self, dtype=self._plain_dtype)
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
@@ -521,6 +531,16 @@ class CudnnLSTM(_CudnnRNN):
return ([self.num_layers * self.num_dirs, batch_size, self.num_units],
[self.num_layers * self.num_dirs, batch_size, self.num_units])
+ @property
+ def _gather_saveables_for_checkpoint(self):
+ if self._direction == CUDNN_RNN_UNIDIRECTION:
+ # Skip one inheritance level to avoid NotImplementedError.
+ return super(_CudnnRNN, self)._gather_saveables_for_checkpoint
+ else:
+ raise NotImplementedError(
+ "Object-based saving does not currently support bidirectional LSTM "
+ "cells. File a feature request if this limitation bothers you.")
+
class _CudnnRNNNoInputC(_CudnnRNN):
"""Abstract simple CudnnRNN layer without input_c."""
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 2ac9442406..9796aae4b0 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.eager.python import checkpointable_utils
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
@@ -31,6 +32,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.training import checkpointable as checkpointable_lib
from tensorflow.python.training import saver
CUDNN_RNN_UNIDIRECTION = "unidirectional"
@@ -266,13 +268,16 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
# instead of having the master pull all slices and then save them.
slice_spec = ""
params = weights + biases
- param_names = weight_names + bias_names
+ self._weight_names = weight_names
+ self._bias_names = bias_names
+ self._param_names = weight_names + bias_names
+ prefixed_param_names = weight_names + bias_names
if self._scope:
- param_names = ["%s/%s" % (self._scope, pn) for pn in param_names]
-
+ prefixed_param_names = [
+ "%s/%s" % (self._scope, pn) for pn in prefixed_param_names]
specs = [
saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name)
- for param, param_name in zip(params, param_names)
+ for param, param_name in zip(params, prefixed_param_names)
]
super(CudnnOpaqueParamsSaveable, self).__init__(
array_ops.identity(self._variables), specs, name)
@@ -285,6 +290,45 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
return state_ops.assign(
self._variables, opaque_params, validate_shape=False)
+ def _checkpointable_save(self, save_buffer):
+ weights, biases = self._OpaqueParamsToCanonical()
+ with ops.device("gpu:0"):
+ (weights, _), (biases, _) = self._TransformCanonical(
+ weights, biases)
+ for name, tensor in zip(self._param_names, weights + biases):
+ save_buffer[name] = array_ops.identity(tensor)
+
+ def _checkpointable_restore(self, restore_buffer):
+ tensors = [array_ops.identity(restore_buffer[name])
+ for name in self._param_names]
+ return self.restore(
+ restored_tensors=tensors,
+ restored_shapes=None # Unused
+ )
+
+ def _add_checkpointable_dependencies(self, checkpointable, dtype):
+ """Add canonical weight dependencies to `checkpointable`.
+
+ When saving or restoring, converts to or from the opaque buffer
+ format. Weights are saved and loaded in the configuration expected by
+ cuDNN-compatible cells.
+
+ Args:
+ checkpointable: An object inheriting from `CheckpointableBase` to add
+ dependencies too (typically the cuDNN `Layer`).
+ dtype: The dtype for the canonical parameter Tensors.
+ """
+ split_dependencies = checkpointable_utils.split_dependency(
+ component_names=self._param_names,
+ component_dtypes=(dtype,) * len(self._param_names),
+ fill_save_buffer_fn=self._checkpointable_save,
+ consume_restore_buffer_fn=self._checkpointable_restore)
+ self._checkpointable_track_params(checkpointable, split_dependencies)
+
+ def _checkpointable_track_params(self, checkpointable, params):
+ """Tracks parameters in a canonical configuration."""
+ return # NotImplementedError raised by the Layer.
+
def _TFCanonicalNamePrefix(self, layer, is_fwd=True):
if self._direction == CUDNN_RNN_UNIDIRECTION:
return "rnn/multi_rnn_cell/cell_%d/%s" % (layer, self._rnn_cell_name)
@@ -574,6 +618,29 @@ class CudnnLSTMSaveable(CudnnOpaqueParamsSaveable):
tf_biases.append(b)
tf_bias_names.append(prefix + "/bias")
+ def _checkpointable_track_params(self, checkpointable, params):
+ """Track parameters for compatibility with CudnnCompatibleLSTMCell."""
+ biases = []
+ weights = []
+ for name in self._weight_names:
+ weights.append(params[name])
+ for name in self._bias_names:
+ biases.append(params[name])
+ assert len(params) == len(weights) + len(biases)
+ if len(weights) == 1 and len(biases) == 1:
+ # For single-layer cells, allow substituting a cell with no MultiRNNCell
+ # wrapping.
+ kernel, = weights # pylint: disable=unbalanced-tuple-unpacking
+ bias, = biases # pylint: disable=unbalanced-tuple-unpacking
+ checkpointable._track_checkpointable(kernel, name="kernel") # pylint: disable=protected-access
+ checkpointable._track_checkpointable(bias, name="bias") # pylint: disable=protected-access
+ assert len(biases) == len(weights)
+ for cell_index, (bias, kernel) in enumerate(zip(biases, weights)):
+ cell = checkpointable_lib.Checkpointable()
+ checkpointable._track_checkpointable(cell, name="cell-%d" % cell_index) # pylint: disable=protected-access
+ cell.bias = bias
+ cell.kernel = kernel
+
class CudnnGRUSaveable(CudnnOpaqueParamsSaveable):
"""SaveableObject implementation handling Cudnn GRU opaque params."""
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 4b50260670..b08132cd72 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@@ -265,6 +266,43 @@ class PrefetchToDeviceTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testPrefetchSparseTensorsToDevice(self):
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto()
+ worker_config.device_count["CPU"] = 2
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 6ee1b572f1..f3e9302409 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -271,7 +271,8 @@ class ReadBatchFeaturesTest(test.TestCase):
reader_num_threads=1,
parser_num_threads=1,
shuffle=False,
- shuffle_seed=None):
+ shuffle_seed=None,
+ drop_final_batch=False):
self.filenames = filenames
self.num_epochs = num_epochs
self.batch_size = batch_size
@@ -289,7 +290,8 @@ class ReadBatchFeaturesTest(test.TestCase):
shuffle=shuffle,
shuffle_seed=shuffle_seed,
reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ parser_num_threads=parser_num_threads,
+ drop_final_batch=drop_final_batch).make_one_shot_iterator(
).get_next()
def _record(self, f, r):
@@ -559,6 +561,20 @@ class ReadBatchFeaturesTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default():
+ # Basic test: read from file 0.
+ self.outputs = self._read_batch_features(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ drop_final_batch=True)
+ for _, tensor in self.outputs.items():
+ if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
+ self.assertEqual(tensor.shape[0], batch_size)
+
class MakeCsvDatasetTest(test.TestCase):
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 77e23d0319..89c04dc89a 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -25,10 +25,11 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
# TODO(rohanj): Add a python class that constructs resource in the __init__
@@ -111,19 +112,7 @@ class _PrefetchToDeviceIterator(object):
self._input_iterator.output_shapes,
self._input_iterator.output_classes)
ret = remote_iterator.get_next()
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
- ])
-
- # Serialize any sparse tensors and convert result to tensors.
- ret = nest.pack_sequence_as(ret, [
- ops.convert_to_tensor(t)
- for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
- ])
- return nest.flatten(ret)
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
with ops.device(device):
self._buffering_resource = function_buffering_resource(
@@ -179,6 +168,68 @@ class _PrefetchToDeviceIterator(object):
@property
def output_types(self):
return self._input_dataset.output_types
+
+
+class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
+ """A replacement for @{tf.data.Iterator} that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ device,
+ buffer_size):
+ with ops.device("/device:CPU:0"):
+ super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
+ input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle(
+ self._resource)
+
+ self._device = device
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self.output_types, self.output_shapes, self.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ _prefetch_fn.add_to_graph(None)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ target_device=gen_dataset_ops.iterator_get_device(self._resource),
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=iterator_ops._generate_shared_name(
+ "function_buffer_resource"))
+
+ def _next_internal(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ # This runs in sync mode as iterators use an error status to communicate
+ # that there is no more data to iterate over.
+ # TODO(b/77291417): Fix
+ with context.execution_mode(context.SYNC):
+ with ops.device(self._device):
+ ret = gen_dataset_ops.function_buffering_resource_get_next(
+ function_buffer_resource=self._buffering_resource,
+ output_types=self._flat_output_types)
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
# pylint: enable=protected-access
@@ -190,12 +241,37 @@ class _PrefetchToDeviceDataset(dataset_ops.Dataset):
self._device = device
self._buffer_size = buffer_size if buffer_size is not None else 1
+ # The static analysis cannot tell that the eager iterator's superclass has
+ # a `next()` method.
+ # pylint: disable=non-iterator-returned
+ def __iter__(self):
+ """Creates an `Iterator` for enumerating the elements of this dataset.
+
+ The returned iterator implements the Python iterator protocol and therefore
+ can only be used in eager mode.
+
+ Returns:
+ An `Iterator` over the elements of this dataset.
+
+ Raises:
+ RuntimeError: If eager execution is enabled.
+ """
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ raise RuntimeError("dataset.__iter__() is only supported when eager "
+ "execution is enabled.")
+ # pylint: enable=non-iterator-returned
+
def make_one_shot_iterator(self):
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=True,
- device=self._device,
- buffer_size=self._buffer_size)
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
+ device=self._device,
+ buffer_size=self._buffer_size)
def make_initializable_iterator(self, shared_name=None):
return _PrefetchToDeviceIterator(
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 9a48aa02fb..b8eb09978e 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -370,7 +370,8 @@ def make_batched_features_dataset(file_pattern,
prefetch_buffer_size=1,
reader_num_threads=1,
parser_num_threads=2,
- sloppy_ordering=False):
+ sloppy_ordering=False,
+ drop_final_batch=False):
"""Returns a `Dataset` of feature dictionaries from `Example` protos.
Example:
@@ -443,6 +444,9 @@ def make_batched_features_dataset(file_pattern,
produced is deterministic prior to shuffling (elements are still
randomized if `shuffle=True`. Note that if the seed is set, then order
of elements after shuffling is deterministic). Defaults to `False`.
+ drop_final_batch: If `True`, and the batch size does not evenly divide the
+ input dataset size, the final smaller batch will be dropped. Defaults to
+ `False`.
Returns:
A dataset of `dict` elements. Each `dict` maps feature keys to
@@ -481,7 +485,10 @@ def make_batched_features_dataset(file_pattern,
elif shuffle:
dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed)
- dataset = dataset.batch(batch_size)
+ if drop_final_batch:
+ dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size))
+ else:
+ dataset = dataset.batch(batch_size)
# Parse `Example` tensors to a dictionary of `Feature` tensors.
dataset = dataset.map(
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 28483f4c88..14de1e8f49 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -117,7 +117,7 @@ in the input function gives a solid boost in performance. When using
This feature is in early stages and there are a lot of improvements forthcoming:
* Metrics are not yet supported during distributed training.
-* Summaries are currently computed in every tower.
+* Summaries are only computed in the first tower in `MirroredStrategy`.
* Evaluation is not yet distributed.
* Eager support is in the works; performance can be more challenging with eager
execution.
@@ -129,10 +129,6 @@ effective batch size will be `num_gpus * batch_size`. Therefore, consider
adjusting your learning rate or batch size according to the number of GPUs.
We are working on addressing this limitation by splitting each batch across GPUs
instead.
-* Dictionaries inside dataset in the input are not supported when prefetching
-on GPUs is turned on. (If you need to use dictionaries in the dataset, turn off
-prefetching on GPUs by passing param `prefetch_on_device=False` to
-`MirroredStrategy`)
* PartitionedVariables are not supported yet.
## What's next?
diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py
index 2b49b8f4ef..c5a520ab5a 100644
--- a/tensorflow/contrib/distribute/python/estimator_integration_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py
@@ -61,7 +61,7 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase,
mode=['graph'],
distribution=[
combinations.one_device_strategy,
- combinations.mirrored_strategy_without_prefetch
+ combinations.mirrored_strategy_with_gpu_and_cpu
]))
def test_complete_flow_with_mode(self, distribution):
label_dimension = 2
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
index e1ddf3cece..dfcbb8568f 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -45,10 +45,12 @@ class _PrefetchToDeviceIterator(object):
@function.Defun(dtypes.string)
def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
remote_iterator = iterator_ops.Iterator.from_string_handle(
handle, input_iterator.output_types, input_iterator.output_shapes,
input_iterator.output_classes)
- return remote_iterator.get_next()
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
target_device = gen_dataset_ops.iterator_get_device(
input_iterator._iterator_resource)
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 9799901483..fec6eafd4a 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -491,6 +491,16 @@ cuda_py_test(
)
cuda_py_test(
+ name = "seed_stream_test",
+ size = "small",
+ srcs = ["python/kernel_tests/seed_stream_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+cuda_py_test(
name = "statistical_testing_test",
size = "medium",
srcs = [
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 4d4489468d..ddf59891e6 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -59,6 +59,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.sample_stats import *
+from tensorflow.contrib.distributions.python.ops.seed_stream import *
from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import *
from tensorflow.contrib.distributions.python.ops.test_util import *
from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import *
@@ -126,6 +127,7 @@ _allowed_symbols = [
'NormalWithSoftplusScale',
'Poisson',
'PoissonLogNormalQuadratureCompound',
+ 'SeedStream',
'SinhArcsinh',
'StudentT',
'StudentTWithAbsDfSoftplusScale',
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
index c6c8d2cf6e..59d549b7b8 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py
@@ -536,14 +536,14 @@ class _BatchReshapeTest(object):
if self.is_static_shape:
with self.assertRaisesRegexp(NotImplementedError,
- "too few event dims"):
+ "too few batch and event dims"):
poisson_141_reshaped.log_prob(x_4)
with self.assertRaisesRegexp(NotImplementedError,
"unexpected batch and event shape"):
poisson_141_reshaped.log_prob(x_114)
return
- with self.assertRaisesOpError("too few event dims"):
+ with self.assertRaisesOpError("too few batch and event dims"):
with self.test_session():
poisson_141_reshaped.log_prob(x_4).eval()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py
new file mode 100644
index 0000000000..9680573317
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the SeedStream class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops import seed_stream
+from tensorflow.python.platform import test
+
+
+class SeedStreamTest(test.TestCase):
+
+ def assertAllUnique(self, items):
+ self.assertEqual(len(items), len(set(items)))
+
+ def testNonRepetition(self):
+ # The probability of repetitions in a short stream from a correct
+ # PRNG is negligible; this test catches bugs that prevent state
+ # updates.
+ strm = seed_stream.SeedStream(seed=4, salt="salt")
+ output = [strm() for _ in range(50)]
+ self.assertEqual(sorted(output), sorted(list(set(output))))
+
+ def testReproducibility(self):
+ strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+ strm2 = seed_stream.SeedStream(seed=4, salt="salt")
+ strm3 = seed_stream.SeedStream(seed=4, salt="salt")
+ outputs = [strm1() for _ in range(50)]
+ self.assertEqual(outputs, [strm2() for _ in range(50)])
+ self.assertEqual(outputs, [strm3() for _ in range(50)])
+
+ def testSeededDistinctness(self):
+ strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+ strm2 = seed_stream.SeedStream(seed=5, salt="salt")
+ self.assertAllUnique(
+ [strm1() for _ in range(50)] + [strm2() for _ in range(50)])
+
+ def testSaltedDistinctness(self):
+ strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+ strm2 = seed_stream.SeedStream(seed=4, salt="another salt")
+ self.assertAllUnique(
+ [strm1() for _ in range(50)] + [strm2() for _ in range(50)])
+
+ def testNestingRobustness(self):
+ # SeedStreams started from generated seeds should not collide with
+ # the master or with each other, even if the salts are the same.
+ strm1 = seed_stream.SeedStream(seed=4, salt="salt")
+ strm2 = seed_stream.SeedStream(strm1(), salt="salt")
+ strm3 = seed_stream.SeedStream(strm1(), salt="salt")
+ outputs = [strm1() for _ in range(50)]
+ self.assertAllUnique(
+ outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index 3e6c35e0d6..bf5590cd55 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -290,7 +290,7 @@ class BatchReshape(distribution_lib.Distribution):
isinstance(expected_batch_event_ndims, int)):
if x_ndims < expected_batch_event_ndims:
raise NotImplementedError(
- "Broadcasting is not supported; too few event dims "
+ "Broadcasting is not supported; too few batch and event dims "
"(expected at least {}, saw {}).".format(
expected_batch_event_ndims, x_ndims))
ndims_assertion = []
@@ -299,7 +299,8 @@ class BatchReshape(distribution_lib.Distribution):
check_ops.assert_greater_equal(
x_ndims,
expected_batch_event_ndims,
- message="Broadcasting is not supported; too few event dims.",
+ message=("Broadcasting is not supported; too few "
+ "batch and event dims."),
name="assert_batch_and_event_ndims_large_enough"),
]
diff --git a/tensorflow/contrib/distributions/python/ops/seed_stream.py b/tensorflow/contrib/distributions/python/ops/seed_stream.py
new file mode 100644
index 0000000000..056d349688
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/seed_stream.py
@@ -0,0 +1,228 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Local PRNG for amplifying seed entropy into seeds for base operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import hashlib
+
+
+class SeedStream(object):
+ """Local PRNG for amplifying seed entropy into seeds for base operations.
+
+ Writing sampling code which correctly sets the pseudo-random number
+ generator (PRNG) seed is surprisingly difficult. This class serves as
+ a helper for the TensorFlow Probability coding pattern designed to
+ avoid common mistakes.
+
+ # Motivating Example
+
+ A common first-cut implementation of a sampler for the beta
+ distribution is to compute the ratio of a gamma with itself plus
+ another gamma. This code snippet tries to do that, but contains a
+ surprisingly common error:
+
+ ```python
+ def broken_beta(shape, alpha, beta, seed):
+ x = tf.random_gamma(shape, alpha, seed=seed)
+ y = tf.random_gamma(shape, beta, seed=seed)
+ return x / (x + y)
+ ```
+
+ The mistake is that the two gamma draws are seeded with the same
+ seed. This causes them to always produce the same results, which,
+ in turn, leads this code snippet to always return `0.5`. Because it
+ can happen across abstraction boundaries, this kind of error is
+ surprisingly easy to make when handling immutable seeds.
+
+ # Goals
+
+ TensorFlow Probability adopts a code style designed to eliminate the
+ above class of error, without exacerbating others. The goals of
+ this code style are:
+
+ - Support reproducibility of results (by encouraging seeding of all
+ pseudo-random operations).
+
+ - Avoid shared-write global state (by not relying on a global PRNG).
+
+ - Prevent accidental seed reuse by TF Probability implementers. This
+ goal is served with the local pseudo-random seed generator provided
+ in this module.
+
+ - Mitigate potential accidental seed reuse by TF Probability clients
+ (with a salting scheme).
+
+ - Prevent accidental resonances with downstream PRNGs (by hashing the
+ output).
+
+ ## Non-goals
+
+ - Implementing a high-performance PRNG for generating large amounts of
+ entropy. That's the job of the underlying TensorFlow PRNG we are
+ seeding.
+
+ - Avoiding random seed collisions, aka "birthday attacks".
+
+ # Code pattern
+
+ ```python
+ def random_beta(shape, alpha, beta, seed): # (a)
+ seed = SeedStream(seed, salt="random_beta") # (b)
+ x = tf.random_gamma(shape, alpha, seed=seed()) # (c)
+ y = tf.random_gamma(shape, beta, seed=seed()) # (c)
+ return x / (x + y)
+ ```
+
+ The elements of this pattern are:
+
+ - Accept an explicit seed (line a) as an argument in all public
+ functions, and write the function to be deterministic (up to any
+ numerical issues) for fixed seed.
+
+ - Rationale: This provides the client with the ability to reproduce
+ results. Accepting an immutable seed rather than a mutable PRNG
+ object reduces code coupling, permitting different sections to be
+ reproducible independently.
+
+ - Use that seed only to initialize a local `SeedStream` instance (line b).
+
+ - Rationale: Avoids accidental seed reuse.
+
+ - Supply the name of the function being implemented as a salt to the
+ `SeedStream` instance (line b). This serves to keep the salts
+ unique; unique salts ensure that clients of TF Probability will see
+ different functions always produce independent results even if
+ called with the same seeds.
+
+ - Seed each callee operation with the output of a unique call to the
+ `SeedStream` instance (lines c). This ensures reproducibility of
+ results while preventing seed reuse across callee invocations.
+
+ # Why salt?
+
+ Salting the `SeedStream` instances (with unique salts) is defensive
+ programming against a client accidentally committing a mistake
+ similar to our motivating example. Consider the following situation
+ that might arise without salting:
+
+ ```python
+ def tfp_foo(seed):
+ seed = SeedStream(seed, salt="")
+ foo_stuff = tf.random_normal(seed=seed())
+ ...
+
+ def tfp_bar(seed):
+ seed = SeedStream(seed, salt="")
+ bar_stuff = tf.random_normal(seed=seed())
+ ...
+
+ def client_baz(seed):
+ foo = tfp_foo(seed=seed)
+ bar = tfp_bar(seed=seed)
+ ...
+ ```
+
+ The client should have used different seeds as inputs to `foo` and
+ `bar`. However, because they didn't, *and because `foo` and `bar`
+ both sample a Gaussian internally as their first action*, the
+ internal `foo_stuff` and `bar_stuff` will be the same, and the
+ returned `foo` and `bar` will not be independent, leading to subtly
+ incorrect answers from the client's simulation. This kind of bug is
+ particularly insidious for the client, because it depends on a
+ Distributions implementation detail, namely the order in which `foo`
+ and `bar` invoke the samplers they depend on. In particular, a
+ Bayesflow team member can introduce such a bug in previously
+ (accidentally) correct client code by performing an internal
+ refactoring that causes this operation order alignment.
+
+ A salting discipline eliminates this problem by making sure that the
+ seeds seen by `foo`'s callees will differ from those seen by `bar`'s
+ callees, even if `foo` and `bar` are invoked with the same input
+ seed.
+ """
+
+ def __init__(self, seed, salt):
+ """Initializes a `SeedStream`.
+
+ Args:
+ seed: Any Python object convertible to string, supplying the
+ initial entropy. If `None`, operations seeded with seeds
+ drawn from this `SeedStream` will follow TensorFlow semantics
+ for not being seeded.
+ salt: Any Python object convertible to string, supplying
+ auxiliary entropy. Must be unique across the Distributions
+ and TensorFlow Probability code base. See class docstring for
+ rationale.
+ """
+ self._seed = seed
+ self._salt = salt
+ self._counter = 0
+
+ def __call__(self):
+ """Returns a fresh integer usable as a seed in downstream operations.
+
+ If this `SeedStream` was initialized with `seed=None`, returns
+ `None`. This has the effect that downstream operations (both
+ `SeedStream`s and primitive TensorFlow ops) will behave as though
+ they were unseeded.
+
+ The returned integer is non-negative, and uniformly distributed in
+ the half-open interval `[0, 2**512)`. This is consistent with
+ TensorFlow, as TensorFlow operations internally use the residue of
+ the given seed modulo `2**31 - 1` (see
+ `tensorflow/python/framework/random_seed.py`).
+
+ Returns:
+ seed: A fresh integer usable as a seed in downstream operations,
+ or `None`.
+ """
+ self._counter += 1
+ if self._seed is None:
+ return None
+ composite = str((self._seed, self._counter, self._salt)).encode("utf-8")
+ return int(hashlib.sha512(composite).hexdigest(), 16)
+
+ @property
+ def original_seed(self):
+ return self._seed
+
+ @property
+ def salt(self):
+ return self._salt
+
+# Design rationales for the SeedStream class
+#
+# - Salts are accepted for the reason given above to supply them.
+#
+# - A `None` seed propagates to downstream seeds, so they exhibit
+# their "unseeded" behavior.
+#
+# - The return value is a Python int so it can be passed directly to
+# TensorFlow operations as a seed. It is large to avoid losing seed
+# space needlessly (TF will internally read only the last 31 bits).
+#
+# - The output is hashed with a crypto-grade hash function as a form
+# of defensive programming: this reliably prevents all possible
+# accidental resonances with all possible downstream PRNGs. The
+# specific function used is not important; SHA512 was ready to hand.
+#
+# - The internal state update is a simple counter because (a) given
+# that the output is hashed anyway, this is enough, and (b) letting
+# it be this predictable permits a future "generate many seeds in
+# parallel" operation whose results would agree with running
+# sequentially.
diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
index d66c34cc1a..5c52015e5f 100644
--- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py
+++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
@@ -12,7 +12,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Statistical test assertions calibrated for their error rates."""
+"""Statistical test assertions calibrated for their error rates.
+
+Statistical tests have an inescapable probability of error: a correct
+sampler can still fail a test by chance, and an incorrect sampler can
+still pass a test by chance. This library is about bounding both of
+those error rates. This requires admitting a task-specific notion of
+"discrepancy": Correct code will fail rarely, code that misbehaves by
+more than the discrepancy will pass rarely, and nothing reliable can
+be said about code that misbehaves, but misbehaves by less than the
+discrepancy.
+
+# Example
+
+Consider testing that the mean of a scalar probability distribution P
+is some expected constant. Suppose the support of P is the interval
+`[0, 1]`. Then you might do this:
+
+```python
+tfd = tf.contrib.distributions
+
+expected_mean = ...
+num_samples = 5000
+samples = ... draw 5000 samples from P
+
+# Check that the mean looks right
+check1 = tfd.assert_true_mean_equal_by_dkwm(
+ samples, low=0., high=1., expected=expected_mean,
+ false_fail_rate=1e-6)
+
+# Check that the difference in means detectable with 5000 samples is
+# small enough
+check2 = tf.assert_less(
+ tfd.min_discrepancy_of_true_means_detectable_by_dkwm(
+ num_samples, low=0., high=1.0,
+ false_fail_rate=1e-6, false_pass_rate=1e-6),
+ 0.01)
+
+# Be sure to execute both assertion ops
+sess.run([check1, check2])
+```
+
+The second assertion is an instance of experiment design. It's a
+deterministic computation (independent of the code under test) that
+checks that `5000` samples is enough to reliably resolve mean
+differences of `0.01` or more. Here "reliably" means that if the code
+under test is correct, the probability of drawing an unlucky sample
+that causes this test to fail is at most 1e-6; and if the code under
+test is incorrect enough that its true mean is 0.01 more or less than
+expected, then the probability of drawing a "lucky" sample that causes
+the test to false-pass is also at most 1e-6.
+
+# Overview
+
+Every function in this library can be characterized in terms of:
+
+- The property being tested, such as the full density of the
+ distribution under test, or just its true mean, or a single
+ Bernoulli probability, etc.
+
+- The relation being asserted, e.g., whether the mean is less, more,
+ or equal to the given expected value.
+
+- The stochastic bound being relied upon, such as the
+ [Dvoretzky-Kiefer-Wolfowitz-Massart inequality]
+ (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)
+ or the CDF of the binomial distribution (for assertions about
+ Bernoulli probabilities).
+
+- The number of sample sets in the statistical test. For example,
+ testing equality of means has a one-sample variant, where the
+ expected mean is given exactly, and a two-sample variant, where the
+ expected mean is itself given by a set of samples (e.g., from an
+ alternative algorithm).
+
+- What operation(s) of the test are to be performed. Each test has
+ three of these:
+
+ 1. `assert` executes the test. Specifically, it creates a TF op that
+ produces an error if it has enough evidence to prove that the
+ property under test is violated. These functions depend on the
+ desired false failure rate, because that determines the sizes of
+ appropriate confidence intervals, etc.
+
+ 2. `min_discrepancy` computes the smallest difference reliably
+ detectable by that test, given the sample count and error rates.
+ What it's a difference of is test-specific. For example, a test
+ for equality of means would make detection guarantees about the
+ difference the true means.
+
+ 3. `min_num_samples` computes the minimum number of samples needed
+ to reliably detect a given discrepancy with given error rates.
+
+ The latter two are for experimental design, and are meant to be
+ usable either interactively or inline in the overall test method.
+
+This library follows a naming convention, to make room for every
+combination of the above. A name mentions the operation first, then
+the property, then the relation, then the bound, then, if the test
+takes more than one set of samples, a token indicating this. For
+example, `assert_true_mean_equal_by_dkwm` (which is implicitly
+one-sample). Each name is a grammatically sound noun phrase (or verb
+phrase, for the asserts).
+
+# Asymptotic properties
+
+The number of samples needed tends to scale as `O(1/discrepancy**2)` and
+as `O(log(1/error_rate))`.
+"""
from __future__ import absolute_import
from __future__ import division
@@ -40,7 +147,7 @@ __all__ = [
def _batch_sort_vector(x, ascending=True, name=None):
- with ops.name_scope(name, "sort_each_row", [x]):
+ with ops.name_scope(name, "_batch_sort_vector", [x]):
x = ops.convert_to_tensor(x, name="x")
n = array_ops.shape(x)[-1]
if ascending:
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 99b1e098d5..0783d1b5d7 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -71,8 +71,15 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
dataset: A `tf.data.Dataset` object.
Raises:
+ TypeError: If `dataset` is an unsupported type.
RuntimeError: When invoked without eager execution enabled.
"""
+ if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access
+ raise TypeError(
+ "`tf.contrib.data.prefetch_to_device()` is not compatible with "
+ "`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate "
+ "over the dataset instead.")
+
super(Iterator, self).__init__(dataset)
if not context.context().device_spec.device_type:
is_remote_device = False
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index c658505de4..f76a896d3d 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -24,6 +24,7 @@ import time
import numpy as np
from tensorflow.contrib import lookup
+from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.contrib.data.python.ops import threadpool
from tensorflow.contrib.data.python.ops import unique
from tensorflow.contrib.eager.python import checkpointable_utils
@@ -192,6 +193,18 @@ class IteratorTest(test.TestCase):
x = math_ops.add(x, x)
self.assertAllEqual([0., 2.], x.numpy())
+ def testTensorsExplicitPrefetchToDevice(self):
+ ds = Dataset.from_tensor_slices([0., 1.])
+ ds = ds.apply(prefetching_ops.prefetch_to_device(test.gpu_device_name()))
+
+ with self.assertRaisesRegexp(TypeError, 'prefetch_to_device'):
+ datasets.Iterator(ds)
+
+ for i, x in enumerate(ds):
+ with ops.device(test.gpu_device_name()):
+ x = math_ops.add(x, x)
+ self.assertEqual(float(i) + float(i), x.numpy())
+
def testOverrideThreadPool(self):
def get_thread_id(_):
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py
index 266ae93305..201699ed77 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders.py
@@ -97,7 +97,10 @@ def add_metrics(estimator, metric_fn):
return estimator_lib.Estimator(
model_fn=new_model_fn,
model_dir=estimator.model_dir,
- config=estimator.config)
+ config=estimator.config,
+ # pylint: disable=protected-access
+ warm_start_from=estimator._warm_start_settings)
+ # pylint: enable=protected-access
def clip_gradients_by_norm(optimizer, clip_norm):
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 85ef3291ba..ae2fd8b490 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -41,11 +41,10 @@ from tensorflow.python.training import training_util
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
-# TODO(b/65403806): Switch loss_reduction default to SUM_OVER_BATCH_SIZE.
def multi_class_head(n_classes,
weight_column=None,
label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
name=None):
"""Creates a `_Head` for multi class classification.
@@ -86,7 +85,8 @@ def multi_class_head(n_classes,
have any value in `label_vocabulary`. Note that errors will be raised if
`label_vocabulary` is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
- reduce training loss over batch. Defaults to `SUM`.
+ reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely
+ weighted sum of losses divided by batch size. See `tf.losses.Reduction`.
loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -111,7 +111,7 @@ def binary_classification_head(
weight_column=None,
thresholds=None,
label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
name=None):
"""Creates a `_Head` for single label binary classification.
@@ -155,7 +155,8 @@ def binary_classification_head(
`label_vocabulary`. Note that errors will be raised if `label_vocabulary`
is not provided but labels are strings.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
- reduce training loss over batch. Defaults to `SUM`.
+ reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely
+ weighted sum of losses divided by batch size. See `tf.losses.Reduction`.
loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py
index 68bf511099..694f0c14bd 100644
--- a/tensorflow/contrib/integrate/__init__.py
+++ b/tensorflow/contrib/integrate/__init__.py
@@ -18,6 +18,7 @@
See the @{$python/contrib.integrate} guide.
@@odeint
+@@odeint_fixed
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
index ced1110676..d11c9c8288 100644
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ b/tensorflow/contrib/kfac/python/ops/estimator.py
@@ -85,9 +85,9 @@ class FisherEstimator(object):
"""Create a FisherEstimator object.
Args:
- variables: A list of the variables for which to estimate the Fisher. This
- must match the variables registered in layer_collection (if it is not
- None).
+ variables: A `list` of variables or `callable` which returns the variables
+ for which to estimate the Fisher. This must match the variables
+ registered in layer_collection (if it is not None).
cov_ema_decay: The decay factor used when calculating the covariance
estimate moving averages.
damping: float. The damping factor used to stabilize training due to
@@ -147,7 +147,10 @@ class FisherEstimator(object):
@property
def variables(self):
- return self._variables
+ if callable(self._variables):
+ return self._variables()
+ else:
+ return self._variables
@property
def damping(self):
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 19608aca47..411da033c3 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -84,7 +84,7 @@ _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = {
APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB
}
-# Possible value for 'reuse' keyword argument. Sets 'reuse' to
+# Possible value for `reuse` keyword argument. Sets `reuse` to
# tf.get_variable_scope().reuse.
VARIABLE_SCOPE = "VARIABLE_SCOPE"
@@ -294,8 +294,8 @@ class LayerCollection(object):
layer_key: A variable or tuple of variables. The key to check for in
existing registrations and to register if valid.
fisher_block: The associated `FisherBlock`.
- reuse: Method to use for inserting new `FisherBlock`s. One of True, False,
- or 'VARIABLE_SCOPE'.
+ reuse: Method to use for inserting new `FisherBlock's. One of True, False,
+ or `VARIABLE_SCOPE`.
Raises:
ValueError: If `layer_key` was already registered and reuse is `False`,
@@ -359,15 +359,14 @@ class LayerCollection(object):
is None.
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
- reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock.
- If False, create a new FisherBlock. If VARIABLE_SCOPE, use
- tf.get_variable_scope().reuse.
+ reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional
+ tower for the existing loss function.
Raises:
ValueError: If reuse == True and name == None.
ValueError: If reuse == True and seed != None.
- KeyError: If reuse == True and no existing LossFunction with 'name' found.
- KeyError: If reuse == False and existing LossFunction with 'name' found.
+ KeyError: If reuse == True and no existing LossFunction with `name` found.
+ KeyError: If reuse == False and existing LossFunction with `name` found.
"""
name = name or self._graph.unique_name(base_name)
@@ -491,24 +490,24 @@ class LayerCollection(object):
"""
params = frozenset(utils.ensure_sequence(params))
- # Check if any of the variables in 'params' is already in
- # 'self.fisher_blocks.keys()'.
+ # Check if any of the variables in `params` is already in
+ # 'self.fisher_blocks.keys()`.
for registered_params, fisher_block in self.fisher_blocks.items():
registered_params_set = set(utils.ensure_sequence(registered_params))
for variable in params:
if (variable in registered_params_set and
params != registered_params_set):
raise ValueError(
- "Can't link parameters {}, variable {} was already registered in "
+ "Can`t link parameters {}, variable {} was already registered in "
"group {} with layer {}".format(params, variable,
registered_params, fisher_block))
- # Check if any of the variables in 'params' is already in
- # 'self.linked_parameters'.
+ # Check if any of the variables in `params` is already in
+ # 'self.linked_parameters`.
for variable in params:
for other_linked_params in self.linked_parameters:
if variable in other_linked_params:
- raise ValueError("Can't link parameters {}, variable {} was already "
+ raise ValueError("Can`t link parameters {}, variable {} was already "
"linked in group {}.".format(params, variable,
other_linked_params))
self._linked_parameters[params] = approximation
@@ -576,15 +575,15 @@ class LayerCollection(object):
produced by layer.
approx: str or None. If not None must be "kron". The Fisher
approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
block_type, approx = self._get_block_type(
@@ -618,15 +617,15 @@ class LayerCollection(object):
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
@@ -669,15 +668,15 @@ class LayerCollection(object):
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
@@ -686,7 +685,7 @@ class LayerCollection(object):
_CONV2D_APPROX_TO_BLOCK_TYPES)
# It feels bad to pass in configuration that has to do with the internal
- # implementation. And then we can't use the same constructor for both
+ # implementation. And then we can`t use the same constructor for both
# anymore and are thus forced to use this ugly if-statement.
# TODO(b/74793309): Clean this up?
if approx == APPROX_KRONECKER_NAME:
@@ -749,15 +748,15 @@ class LayerCollection(object):
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
# TODO(b/74793309): Have this use _get_block_type like the other
@@ -804,15 +803,15 @@ class LayerCollection(object):
data_format: str or None. Format of data.
approx: str or None. If not None must "diagonal". The Fisher
approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
# TODO(b/74793309): Have this use _get_block_type like the other
@@ -872,15 +871,15 @@ class LayerCollection(object):
approx: str or None. If not None must be one of "kron" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, this adds 'inputs' and 'outputs' as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
self.register_depthwise_conv2d(
@@ -917,14 +916,14 @@ class LayerCollection(object):
approx: str or None. It not None, must be one of "full" or "diagonal".
The Fisher approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, this adds 'batch_size' to the total
+ reuse: bool or str. If True, this adds `batch_size` to the total
mini-batch size use when estimating the Fisher block for this layer
(which must have already been registered). If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
block_type, approx = self._get_block_type(
@@ -954,10 +953,10 @@ class LayerCollection(object):
correspond to a "time-step" in an RNN). OR, can be single Tensor, of
shape [num_uses * batch_size , input_size], which is a reshaped version
of a Tensor of shape [num_uses, batch_size, input_size].
- outputs: A list of Tensors, the same length as 'inputs', each of shape
+ outputs: A list of Tensors, the same length as `inputs`, each of shape
[batch_size, output_size]. Outputs produced by layer. The list indexes
each use in the graph (which might correspond to a "time-step" in an
- RNN). Needs to correspond with the order used in 'inputs'. OR, can be
+ RNN). Needs to correspond with the order used in `inputs`. OR, can be
a single Tensor of shape [num_uses * batch_size, output_size], which is
a reshaped version of a Tensor of shape [num_uses, batch_size,
output_size].
@@ -967,16 +966,16 @@ class LayerCollection(object):
approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
or "kron_series_2". The Fisher approximation to use. If None the default
value is used. (Default: None)
- reuse: bool or str. If True, this adds inputs and outputs as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word 'use' here has a completely different meaning to "use in the graph"
- as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+ word `use` here has a completely different meaning to "use in the graph"
+ as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.)
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
+ ValueError: For improper value to `approx`.
"""
block_type, approx = self._get_block_type(
params, approx, self.default_fully_connected_multi_approximation,
@@ -1025,7 +1024,7 @@ class LayerCollection(object):
outputs: A list of Tensors, each of shape [batch_size, height, width,
out_channels]. Output produced by layer. The list indexes each use
in the graph (which might correspond to a "time-step" in an RNN).
- Needs to correspond with the order used in 'inputs'. OR, can be a
+ Needs to correspond with the order used in `inputs`. OR, can be a
single Tensor, of shape [num_uses * batch_size, height, width,
out_channels], which is a reshaped version of a Tensor of shape
[num_uses, batch_size, height, width, out_channels].
@@ -1037,17 +1036,17 @@ class LayerCollection(object):
approx: str or None. If not None must by "kron_indep". The Fisher
approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, this adds inputs and outputs as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word 'use' here has a completely different meaning to "use in the graph"
- as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+ word `use` here has a completely different meaning to "use in the graph"
+ as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.)
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
block_type, approx = self._get_block_type(
@@ -1098,7 +1097,7 @@ class LayerCollection(object):
outputs: A list of Tensors, each of shape [batch_size, embedding_size].
Outputs produced by layer. The list indexes each use in the graph
(which might correspond to a "time-step" in an RNN). Needs to
- correspond with the order used in 'inputs'. OR, can be a
+ correspond with the order used in `inputs`. OR, can be a
single Tensor, of shape [num_uses * batch_size, embedding_size], which
is a reshaped version of a Tensor of shape [num_uses, batch_size,
embedding_size].
@@ -1108,17 +1107,17 @@ class LayerCollection(object):
approx: str or None. If not None must by "kron_indep". The Fisher
approximation to use. If None the default value is used.
(Default: None)
- reuse: bool or str. If True, this adds inputs and outputs as an
+ reuse: bool or str. If True, this adds `inputs` and `outputs` as an
additional mini-batch/tower of data to use when estimating the Fisher
block for this layer (which must have already been registered). If
"VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word 'use' here has a completely different meaning to "use in the graph"
- as it perturns to the 'inputs', 'outputs', and 'num_uses' arguments.)
+ word `use` here has a completely different meaning to "use in the graph"
+ as it perturns to the `inputs`, `outputs`, and `num_uses` arguments.)
(Default: "VARIABLE_SCOPE")
Raises:
- ValueError: For improper value to 'approx'.
- KeyError: If reuse == True but no FisherBlock found for 'params'.
+ ValueError: For improper value to `approx`.
+ KeyError: If reuse == True but no FisherBlock found for `params`.
ValueError: If reuse == True and FisherBlock found but of the wrong type.
"""
block_type, approx = self._get_block_type(
@@ -1156,7 +1155,7 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds 'logits' as an additional
+ reuse: bool or str. If True, this adds `logits` as an additional
mini-batch/tower of inputs to the loss-function/predictive distribution
(which must have already been registered). If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
@@ -1190,7 +1189,7 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds 'mean' and 'var' as an additional
+ reuse: bool or str. If True, this adds `mean` and `var` as an additional
mini-batch/tower of inputs to the loss-function/predictive distribution
(which must have already been registered). If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
@@ -1219,7 +1218,7 @@ class LayerCollection(object):
(Default: None)
name: (OPTIONAL) str or None. Unique name for this loss function. If None,
a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds 'logits' as an additional
+ reuse: bool or str. If True, this adds `logits` as an additional
mini-batch/tower of inputs to the loss-function/predictive distribution
(which must have already been registered). If "VARIABLE_SCOPE", use
tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
@@ -1231,18 +1230,18 @@ class LayerCollection(object):
name=name, reuse=reuse)
def make_or_get_factor(self, cls, args):
- """Insert 'cls(args)' into 'self.fisher_factors' if not already present.
+ """Insert `cls(args)` into 'self.fisher_factors` if not already present.
- Wraps constructor in 'tf.variable_scope()' to ensure variables constructed
- in 'cls.__init__' are placed under this LayerCollection's scope.
+ Wraps constructor in `tf.variable_scope()` to ensure variables constructed
+ in `cls.__init__` are placed under this LayerCollection's scope.
Args:
cls: Class that implements FisherFactor.
- args: Tuple of arguments to pass into 'cls's constructor. Must be
+ args: Tuple of arguments to pass into `cls's constructor. Must be
hashable.
Returns:
- Instance of 'cls' found in self.fisher_factors.
+ Instance of `cls` found in self.fisher_factors.
"""
try:
hash(args)
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index 843aeef7d8..f01c5a8322 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -108,13 +108,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
ValueError: If momentum is non-zero and momentum_type is not 'regular'
or 'adam'.
"""
-
- variables = var_list
- if variables is None:
- variables = tf_variables.trainable_variables()
-
# Parameters to be passed to the Fisher estimator:
- self._variables = variables
+ self._variables = var_list or tf_variables.trainable_variables
self._cov_ema_decay = cov_ema_decay
self._layers = layer_collection
self._estimation_mode = estimation_mode
@@ -235,7 +230,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
@property
def variables(self):
- return self._variables
+ return self._fisher_est.variables
@property
def damping(self):
@@ -373,6 +368,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
else:
kwargs["var_list"] = kwargs.get("var_list") or self.variables
var_list = kwargs["var_list"]
+
if set(var_list) != set(self.variables):
raise ValueError("var_list doesn't match with set of Fisher-estimating "
"variables.")
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index cfe62fac43..ac50699f59 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import random
import threading
from tensorflow.contrib.linear_optimizer.python.ops.sdca_ops import SdcaModel
@@ -102,6 +103,33 @@ def make_example_dict(example_protos, example_weights):
example_ids=['%d' % i for i in range(0, len(example_protos))])
+def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
+ random.seed(1)
+ sparse_features = [
+ SparseFeatureColumn(
+ [int(i / num_non_zero) for i in range(num_examples * num_non_zero)],
+ [int(random.random() * dim) for _ in range(
+ num_examples * num_non_zero)],
+ [num_non_zero**(-0.5) for _ in range(num_examples * num_non_zero)])
+ ]
+ examples_dict = dict(
+ sparse_features=sparse_features,
+ dense_features=[],
+ example_weights=[random.random() for _ in range(num_examples)],
+ example_labels=[
+ 1. if random.random() > 0.5 else 0. for _ in range(num_examples)
+ ],
+ example_ids=[str(i) for i in range(num_examples)])
+
+ weights = variables_lib.Variable(
+ array_ops.zeros([dim], dtype=dtypes.float32))
+ variables_dict = dict(
+ sparse_features_weights=[weights],
+ dense_features_weights=[])
+
+ return examples_dict, variables_dict
+
+
def make_variable_dict(max_age, max_gender):
# TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from
# examples_dict.
@@ -235,6 +263,32 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
self.assertAllClose(
0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+ def testSparseRandom(self):
+ dim = 20
+ num_examples = 1000
+ # Number of non-zero features per example.
+ non_zeros = 10
+ # Setup test data.
+ with self._single_threaded_test_session():
+ examples, variables = make_random_examples_and_variables_dicts(
+ num_examples, dim, non_zeros)
+ options = dict(
+ symmetric_l2_regularization=.1,
+ symmetric_l1_regularization=0,
+ num_table_shards=1,
+ adaptive=False,
+ loss_type='logistic_loss')
+
+ lr = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+ train_op = lr.minimize()
+ for _ in range(4):
+ train_op.run()
+ lr.update_weights(train_op).run()
+ # Duality gap is 1.4e-5.
+ # It would be 0.01 without shuffling and 0.02 with adaptive sampling.
+ self.assertNear(0.0, lr.approximate_duality_gap().eval(), err=1e-3)
+
def testDistributedSimple(self):
# Setup test data
example_protos = [
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 3f5fdc18bb..f980746a19 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -168,6 +168,10 @@ class SdcaModel(object):
# of workers
return self._options.get('num_loss_partitions', 1)
+ def _adaptive(self):
+ # Perform adaptive sampling.
+ return self._options.get('adaptive', True)
+
def _num_table_shards(self):
# Number of hash table shards.
# Return 1 if not specified or if the value is 'None'
@@ -344,7 +348,8 @@ class SdcaModel(object):
l1=self._options['symmetric_l1_regularization'],
l2=self._symmetric_l2_regularization(),
num_loss_partitions=self._num_loss_partitions(),
- num_inner_iterations=1)
+ num_inner_iterations=1,
+ adaptative=self._adaptive())
# pylint: enable=protected-access
with ops.control_dependencies([esu]):
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 92d022f2a3..dffdddacfb 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -71,12 +71,14 @@ class SDCAOptimizer(object):
num_loss_partitions=1,
num_table_shards=None,
symmetric_l1_regularization=0.0,
- symmetric_l2_regularization=1.0):
+ symmetric_l2_regularization=1.0,
+ adaptive=True):
self._example_id_column = example_id_column
self._num_loss_partitions = num_loss_partitions
self._num_table_shards = num_table_shards
self._symmetric_l1_regularization = symmetric_l1_regularization
self._symmetric_l2_regularization = symmetric_l2_regularization
+ self._adaptive = adaptive
def get_name(self):
return 'SDCAOptimizer'
@@ -101,6 +103,10 @@ class SDCAOptimizer(object):
def symmetric_l2_regularization(self):
return self._symmetric_l2_regularization
+ @property
+ def adaptive(self):
+ return self._adaptive
+
def get_train_step(self, columns_to_variables, weight_column_name, loss_type,
features, targets, global_step):
"""Returns the training operation of an SdcaModel optimizer."""
@@ -228,6 +234,7 @@ class SDCAOptimizer(object):
options=dict(
symmetric_l1_regularization=self._symmetric_l1_regularization,
symmetric_l2_regularization=self._symmetric_l2_regularization,
+ adaptive=self._adaptive,
num_loss_partitions=self._num_loss_partitions,
num_table_shards=self._num_table_shards,
loss_type=loss_type))
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
index 7f7a2632dd..b14230acd7 100644
--- a/tensorflow/contrib/lite/java/BUILD
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -47,6 +47,23 @@ android_library(
)
java_library(
+ name = "ovicbenchmarkerlib",
+ srcs = [
+ "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java",
+ "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
+ ],
+ javacopts = JAVACOPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":libtensorflowlite_jni.so",
+ ":tensorflowlite_java",
+ "//tensorflow/contrib/lite/java/src/main/native",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
+ "@org_checkerframework_qual",
+ ],
+)
+
+java_library(
name = "tensorflowlitelib",
srcs = glob(
[
@@ -147,6 +164,28 @@ java_test(
],
)
+java_test(
+ name = "OvicClassifierTest",
+ size = "medium",
+ srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"],
+ data = [
+ "ovic/src/testdata/float_model.lite",
+ "ovic/src/testdata/labels.txt",
+ "ovic/src/testdata/low_res_model.lite",
+ "ovic/src/testdata/quantized_model.lite",
+ "ovic/src/testdata/test_image_128.jpg",
+ "ovic/src/testdata/test_image_224.jpg",
+ ],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.ovic.OvicClassifierTest",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":ovicbenchmarkerlib",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
filegroup(
name = "libtensorflowlite_jni",
srcs = select({
diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md
new file mode 100644
index 0000000000..76c33838bf
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/README.md
@@ -0,0 +1,83 @@
+# Benchmarker for LPIRC Workshop at CVPR 2018
+
+This folder contains building code for track one of the [Low Power ImageNet Recognition Challenge workshop at CVPR 2018.](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018)
+
+## Pre-requesits
+
+Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
+
+## To test the benchmarker:
+
+The testing utilities helps the developers (you) to make sure that your submissions in TfLite format will be processed as expected in the competition's benchmarking system.
+
+Note: for now the tests only provides correctness checks, i.e. classifier predicts the correct category on the test image, but no on-device latency measurements. To test the latency measurement functionality, the tests will print the latency running on a desktop computer, which is not indicative of the on-device run-time.
+We are releasing an benchmarker Apk that would allow developers to measure latency on their own devices.
+
+### Obtain the sample models
+
+The test data (models and images) should be downloaded automatically for you by Bazel. In case they are not, you can manually install them as below.
+
+Note: all commands should be called from your tensorflow installation folder (under this folder you should find `tensorflow/contrib/lite`).
+
+
+* Download the [testdata package](https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip):
+
+```sh
+curl -L https://storage.googleapis.com/download.tensorflow.org/data/ovic.zip -o /tmp/ovic.zip
+```
+
+* Unzip the package into the testdata folder:
+
+```sh
+unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/
+```
+
+### Run tests
+
+You can run test with Bazel as below. This helps to ensure that the installation is correct.
+
+```sh
+bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --test_output=all
+```
+
+### Test your submissions
+
+Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it as below.
+
+* Move your submission to the testdata folder:
+
+Let say the submission file is located at `/tmp/my_model.lite`, then
+
+```sh
+cp /tmp/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/
+```
+
+* Resize the test image to the resolutions that are expected by your submission:
+
+The test images can be found at `tensorflow/contrib/lite/java/ovic/src/testdata/test_image_*.jpg`. You may reuse these images if your image resolutions are 128x128 or 224x224.
+
+* Add your model and test image to the BUILD rule:
+
+```JSON
+java_test(
+ name = "OvicClassifierTest",
+ size = "medium",
+ srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"],
+ data = [
+ "ovic/src/testdata/float_model.lite",
+ "ovic/src/testdata/labels.txt",
+ "ovic/src/testdata/low_res_model.lite",
+ "ovic/src/testdata/quantized_model.lite",
+ "ovic/src/testdata/test_image_128.jpg",
+ "ovic/src/testdata/test_image_224.jpg",
+ "ovic/src/testdata/my_model.lite", # <--- Your submission.
+ "ovic/src/testdata/my_test_image.jpg", # <--- Your test image.
+ ],
+ ...
+```
+
+* Modify `OvicClassifierTest.java` to test your model.
+
+Change `TEST_IMAGE_PATH` to `testdata/my_test_image.jpg`. If your model runs inference in floating point, change `FLOAT_MODEL_PATH` to `testdata/my_model.lite`. If your model runs [quantized inference](https://www.tensorflow.org/performance/quantization), change `QUANTIZED_MODEL_PATH` to `testdata/my_model.lite`.
+
+Now you can run the bazel tests to catch any runtime issues with the submission.
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
index 4fd23a99d2..098ed8ceba 100644
--- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -26,7 +26,6 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
-import java.nio.file.Paths;
import javax.imageio.ImageIO;
import org.junit.Before;
import org.junit.Test;
@@ -45,27 +44,33 @@ public final class OvicClassifierTest {
private ByteBuffer testImage = null;
private ByteBuffer lowResTestImage = null;
private OvicSingleImageResult testResult = null;
- private static final String LABELS_PATH = "testdata/labels.txt";
- private static final String QUANTIZED_MODEL_PATH = "testdata/quantized_model.lite";
- private static final String LOW_RES_MODEL_PATH = "testdata/low_res_model.lite";
- private static final String FLOAT_MODEL_PATH = "testdata/float_model.lite";
- private static final String TEST_IMAGE_PATH = "testdata/test_image_224.jpg";
- private static final String TEST_LOW_RES_IMAGE_PATH = "testdata/test_image_128.jpg";
+ private static final String LABELS_PATH =
+ "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt";
+ private static final String QUANTIZED_MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/quantized_model.lite";
+ private static final String LOW_RES_MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/low_res_model.lite";
+ private static final String FLOAT_MODEL_PATH =
+ "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/float_model.lite";
+ private static final String TEST_IMAGE_PATH =
+ "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/test_image_224.jpg";
+ private static final String TEST_LOW_RES_IMAGE_PATH =
+ "third_party/tensorflow/contrib/lite/java/ovic/src/testdata/test_image_128.jpg";
private static final int TEST_IMAGE_GROUNDTRUTH = 653; // "military uniform"
@Before
public void setUp() {
try {
- File labelsfile = new File(getTestDir(LABELS_PATH));
+ File labelsfile = new File(LABELS_PATH);
labelsInputStream = new FileInputStream(labelsfile);
- quantizedModel = loadModelFile(getTestDir(QUANTIZED_MODEL_PATH));
- floatModel = loadModelFile(getTestDir(FLOAT_MODEL_PATH));
- lowResModel = loadModelFile(getTestDir(LOW_RES_MODEL_PATH));
- File imageFile = new File(getTestDir(TEST_IMAGE_PATH));
+ quantizedModel = loadModelFile(QUANTIZED_MODEL_PATH);
+ floatModel = loadModelFile(FLOAT_MODEL_PATH);
+ lowResModel = loadModelFile(LOW_RES_MODEL_PATH);
+ File imageFile = new File(TEST_IMAGE_PATH);
BufferedImage img = ImageIO.read(imageFile);
testImage = toByteBuffer(img);
// Low res image and models.
- imageFile = new File(getTestDir(TEST_LOW_RES_IMAGE_PATH));
+ imageFile = new File(TEST_LOW_RES_IMAGE_PATH);
img = ImageIO.read(imageFile);
lowResTestImage = toByteBuffer(img);
} catch (IOException e) {
@@ -74,10 +79,6 @@ public final class OvicClassifierTest {
System.out.println("Successful setup");
}
- private static String getTestDir(String testfile) throws IOException {
- return Paths.get("third_party/tensorflow/contrib/lite/java/ovic/src/", testfile).toString();
- }
-
@Test
public void ovicClassifier_quantizedModelCreateSuccess() throws Exception {
classifier = new OvicClassifier(labelsInputStream, quantizedModel);
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index a619ada86a..45ea8d0049 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -67,10 +67,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
TF_LITE_ENSURE_EQ(context, t->type, input_type);
- if (input_type == kTfLiteUInt8) {
- TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point);
- TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale);
- }
for (int d = 0; d < t0->dims->size; ++d) {
if (d == axis) {
sum_axis += t->dims->data[axis];
@@ -87,11 +83,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
TF_LITE_ENSURE_EQ(context, output->type, input_type);
- if (input_type == kTfLiteUInt8) {
- TF_LITE_ENSURE_EQ(context, output->params.zero_point,
- t0->params.zero_point);
- TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale);
- }
return context->ResizeTensor(context, output, output_size);
}
@@ -115,6 +106,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \
GetTensorDims(output))
+#define TF_LITE_CONCATENATION_QUANTIZED(type) \
+ VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
+ type::Concatenation( \
+ RemapDim(NumDimensions(output), axis), all_inputs.data(), \
+ all_inputs.dims(), all_inputs.zero_point(), all_inputs.scale(), \
+ node->inputs->size, GetTensorData<uint8>(output), GetTensorDims(output), \
+ output->params.zero_point, output->params.scale)
+
switch (output->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
@@ -125,9 +124,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
- TF_LITE_CONCATENATION(reference_ops, uint8_t);
+ TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
} else {
- TF_LITE_CONCATENATION(optimized_ops, uint8_t);
+ TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
}
break;
default:
@@ -136,6 +135,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
+#undef TF_LITE_CONCATENATION_QUANTIZED
#undef TF_LITE_CONCATENATION
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc
index ba1ffc5f84..467ff6f7e1 100644
--- a/tensorflow/contrib/lite/kernels/concatenation_test.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc
@@ -28,6 +28,7 @@ class BaseConcatenationOpModel : public SingleOpModel {
public:
// TODO(ahentz): Also test different activation types, axis, input
// dimensions.
+ BaseConcatenationOpModel() {}
BaseConcatenationOpModel(const TensorData& input_template, int axis,
int num_inputs) {
std::vector<std::vector<int>> all_input_shapes;
@@ -60,6 +61,23 @@ class ConcatenationOpModel : public BaseConcatenationOpModel {
class QuantizedConcatenationOpModel : public BaseConcatenationOpModel {
public:
using BaseConcatenationOpModel::BaseConcatenationOpModel;
+ QuantizedConcatenationOpModel(const std::vector<TensorData>& input_template,
+ int axis, int num_inputs,
+ const TensorData& output_template) {
+ std::vector<std::vector<int>> all_input_shapes;
+ CHECK_EQ(input_template.size(), num_inputs);
+ for (int i = 0; i < num_inputs; ++i) {
+ all_input_shapes.push_back(input_template[i].shape);
+ AddInput(input_template[i]);
+ }
+ output_ = AddOutput({output_template.type, /*shape=*/{},
+ output_template.min, output_template.max});
+ SetBuiltinOp(
+ BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions,
+ CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE)
+ .Union());
+ BuildInterpreter(all_input_shapes);
+ }
void SetInput(int index, std::initializer_list<float> data) {
QuantizeAndPopulate<uint8_t>(index, data);
}
@@ -168,6 +186,56 @@ TEST(ConcatenationOpTest, FourInputsQuantized) {
}));
}
+TEST(ConcatenationOpTest, FourInputsQuantizedMixedRange) {
+ QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
+ {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
+ {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
+ {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
+ /*axis=*/2, /*num_inputs=*/4,
+ {TensorType_UINT8, {2, 1, 2}, -12.7, 12.8});
+
+ m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f});
+ m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
+ m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f});
+ m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({
+ 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, //
+ 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, //
+ })));
+ EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
+ 137, 157, 138, 158, 139, 159, 140, 160, //
+ 167, 197, 168, 198, 169, 199, 170, 200, //
+ }));
+}
+
+TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) {
+ QuantizedConcatenationOpModel m0({{TensorType_UINT8, {2, 1, 2}, -10.7, 10.8},
+ {TensorType_UINT8, {2, 1, 2}, 0, 12.8},
+ {TensorType_UINT8, {2, 1, 2}, -11, 11.8},
+ {TensorType_UINT8, {2, 1, 2}, 0, 7.4}},
+ /*axis=*/2, /*num_inputs=*/4,
+ {TensorType_UINT8, {2, 1, 2}, -1., 1.});
+
+ m0.SetInput(0, {1.0f, -3.0f, -4.0f, -7.0f});
+ m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f});
+ m0.SetInput(2, {1.2f, -3.2f, -4.2f, 7.2f});
+ m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f});
+ m0.Invoke();
+ EXPECT_THAT(m0.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, //
+ -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, //
+ },
+ 4e-3)));
+ EXPECT_THAT(m0.GetOutput(), ElementsAreArray({
+ 255, 0, 255, 255, 255, 0, 255, 255, //
+ 0, 0, 255, 255, 0, 255, 255, 255, //
+ }));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 3642da311c..9a274612ad 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -2732,6 +2732,62 @@ void Concatenation(int concat_dim, const Scalar* const* input_data,
}
}
+// TODO(prabhumk): This is the same as the reference implementation.
+// TODO(prabhumk): The quantized implementation of concatentation isn't fully
+// quantized as it takes scale as a floating point value. This should be fixed
+// when optimizng this routine further.
+inline void Concatenation(int concat_dim, const uint8* const* input_data,
+ const Dims<4>* const* input_dims,
+ const int32* input_zeropoint,
+ const float* input_scale, int inputs_count,
+ uint8* output_data, const Dims<4>& output_dims,
+ const int32 output_zeropoint,
+ const float output_scale) {
+ // The arguments input_zeropoint and input_scale are expected to be an array
+ // that have the quantization paramaters for all the inputs to the concat
+ // operator.
+ gemmlowp::ScopedProfilingLabel label("Concatenation");
+ TFLITE_DCHECK_GT(inputs_count, 1);
+ int concat_size = 0;
+ for (int i = 0; i < inputs_count; i++) {
+ for (int j = 0; j < 4; j++) {
+ if (j != concat_dim) {
+ MatchingArraySize(*input_dims[i], j, output_dims, j);
+ }
+ }
+ concat_size += ArraySize(*input_dims[i], concat_dim);
+ }
+ TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ int outer_size = 1;
+ for (int i = concat_dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ const float inverse_output_scale = 1.f / output_scale;
+ uint8* output_ptr = output_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ const int copy_size =
+ input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ const uint8* input_ptr = input_data[i] + k * copy_size;
+ if (input_zeropoint[i] == output_zeropoint &&
+ input_scale[i] == output_scale) {
+ memcpy(output_ptr, input_ptr, copy_size);
+ } else {
+ const float scale = input_scale[i] * inverse_output_scale;
+ const float bias = -input_zeropoint[i] * scale;
+ for (int j = 0; j < copy_size; ++j) {
+ const int32_t value =
+ static_cast<int32_t>(round(input_ptr[j] * scale + bias)) +
+ output_zeropoint;
+ output_ptr[j] =
+ static_cast<uint8_t>(std::max(std::min(255, value), 0));
+ }
+ }
+ output_ptr += copy_size;
+ }
+ }
+}
+
template <FusedActivationFunctionType Ac, typename Scalar>
void DepthConcatenation(const Scalar* const* input_data,
const Dims<4>* const* input_dims, int inputs_count,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 3575974ae9..31e190e248 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -1566,6 +1566,61 @@ void Concatenation(int concat_dim, const Scalar* const* input_data,
}
}
+// TODO(prabhumk): This is the same as the optimized implementation.
+// TODO(prabhumk): The quantized implementation of concatentation isn't fully
+// quantized as it takes scale as a floating point value. This should be fixed
+// when optimizng this routine further.
+inline void Concatenation(int concat_dim, const uint8* const* input_data,
+ const Dims<4>* const* input_dims,
+ const int32* input_zeropoint,
+ const float* input_scale, int inputs_count,
+ uint8* output_data, const Dims<4>& output_dims,
+ const int32 output_zeropoint,
+ const float output_scale) {
+ // The arguments input_zeropoint and input_scale are expected to be an array
+ // that have the quantization paramaters for all the inputs to the concat
+ // operator.
+ TFLITE_DCHECK_GT(inputs_count, 1);
+ int64_t concat_size = 0;
+ for (int i = 0; i < inputs_count; i++) {
+ for (int j = 0; j < 4; j++) {
+ if (j != concat_dim) {
+ MatchingArraySize(*input_dims[i], j, output_dims, j);
+ }
+ }
+ concat_size += ArraySize(*input_dims[i], concat_dim);
+ }
+ TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ int64_t outer_size = 1;
+ for (int i = concat_dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ const float inverse_output_scale = 1.f / output_scale;
+ uint8* output_ptr = output_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ const int copy_size =
+ input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ const uint8* input_ptr = input_data[i] + k * copy_size;
+ if (input_zeropoint[i] == output_zeropoint &&
+ input_scale[i] == output_scale) {
+ memcpy(output_ptr, input_ptr, copy_size);
+ } else {
+ const float scale = input_scale[i] * inverse_output_scale;
+ const float bias = -input_zeropoint[i] * scale;
+ for (int j = 0; j < copy_size; ++j) {
+ const int32_t value =
+ static_cast<int32_t>(round(input_ptr[j] * scale + bias)) +
+ output_zeropoint;
+ output_ptr[j] =
+ static_cast<uint8_t>(std::max(std::min(255, value), 0));
+ }
+ }
+ output_ptr += copy_size;
+ }
+ }
+}
+
template <FusedActivationFunctionType Ac, typename Scalar>
void DepthConcatenation(const Scalar* const* input_data,
const Dims<4>* const* input_dims, int inputs_count,
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index 62e38e0d4c..4bce2ffaaf 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -126,6 +126,29 @@ class VectorOfTensors {
std::vector<Dims<4>*> all_dims_ptr_;
};
+// A list of quantized tensors in a format that can be used by kernels like
+// split and concatenation.
+class VectorOfQuantizedTensors : public VectorOfTensors<uint8> {
+ public:
+ // Build with the tensors in 'tensor_list'.
+ VectorOfQuantizedTensors(const TfLiteContext& context,
+ const TfLiteIntArray& tensor_list)
+ : VectorOfTensors<uint8>(context, tensor_list) {
+ for (int i = 0; i < tensor_list.size; ++i) {
+ TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
+ zero_point_.push_back(t->params.zero_point);
+ scale_.push_back(t->params.scale);
+ }
+ }
+
+ const float* scale() const { return scale_.data(); }
+ const int32* zero_point() const { return zero_point_.data(); }
+
+ private:
+ std::vector<int32> zero_point_;
+ std::vector<float> scale_;
+};
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 606f4a5635..3448de68e8 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -30,6 +30,13 @@ limitations under the License.
namespace tflite {
+namespace {
+// Ensure that ErrorReporter is non-null.
+ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
+ return e ? e : DefaultErrorReporter();
+}
+} // namespace
+
const char* kEmptyTensorName = "";
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
@@ -78,6 +85,8 @@ std::unique_ptr<Allocation> GetAllocationFromFile(const char* filename,
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
const char* filename, ErrorReporter* error_reporter) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
@@ -89,6 +98,8 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile(
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
const char* filename, TfLiteVerifier* verifier,
ErrorReporter* error_reporter) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+
std::unique_ptr<FlatBufferModel> model;
auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
error_reporter, /*use_nnapi=*/true);
@@ -104,6 +115,8 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+
std::unique_ptr<FlatBufferModel> model;
Allocation* allocation =
new MemoryAllocation(buffer, buffer_size, error_reporter);
@@ -114,6 +127,8 @@ std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer(
std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromModel(
const tflite::Model* model_spec, ErrorReporter* error_reporter) {
+ error_reporter = ValidateErrorReporter(error_reporter);
+
std::unique_ptr<FlatBufferModel> model;
model.reset(new FlatBufferModel(model_spec, error_reporter));
if (!model->initialized()) model.reset();
@@ -133,15 +148,13 @@ bool FlatBufferModel::CheckModelIdentifier() const {
FlatBufferModel::FlatBufferModel(const Model* model,
ErrorReporter* error_reporter)
- : error_reporter_(error_reporter ? error_reporter
- : DefaultErrorReporter()) {
+ : error_reporter_(ValidateErrorReporter(error_reporter)) {
model_ = model;
}
FlatBufferModel::FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter)
- : error_reporter_(error_reporter ? error_reporter
- : DefaultErrorReporter()) {
+ : error_reporter_(ValidateErrorReporter(error_reporter)) {
allocation_ = allocation;
if (!allocation_->valid() || !CheckModelIdentifier()) return;
@@ -154,7 +167,7 @@ InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model,
const OpResolver& op_resolver)
: model_(model.GetModel()),
op_resolver_(op_resolver),
- error_reporter_(model.error_reporter()),
+ error_reporter_(ValidateErrorReporter(model.error_reporter())),
allocation_(model.allocation()) {}
InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
@@ -162,8 +175,7 @@ InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
ErrorReporter* error_reporter)
: model_(model),
op_resolver_(op_resolver),
- error_reporter_(error_reporter ? error_reporter
- : DefaultErrorReporter()) {}
+ error_reporter_(ValidateErrorReporter(error_reporter)) {}
TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
TfLiteStatus status = kTfLiteOk;
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 036dc46e03..5a55b031a8 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -56,27 +56,37 @@ class TfLiteVerifier {
// or mmapped. This uses flatbuffers as the serialization format.
class FlatBufferModel {
public:
- // Builds a model based on a file. Returns a nullptr in case of failure.
+ // Builds a model based on a file.
+ // Caller retains ownership of `error_reporter` and must ensure its lifetime
+ // is longer than the FlatBufferModel instance.
+ // Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromFile(
const char* filename,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Verifies whether the content of the file is legit, then builds a model
- // based on the file. Returns a nullptr in case of failure.
+ // based on the file.
+ // Caller retains ownership of `error_reporter` and must ensure its lifetime
+ // is longer than the FlatBufferModel instance.
+ // Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> VerifyAndBuildFromFile(
const char* filename, TfLiteVerifier* verifier = nullptr,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Builds a model based on a pre-loaded flatbuffer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
- // is destroyed. Returns a nullptr in case of failure.
+ // is destroyed. Caller retains ownership of `error_reporter` and must ensure
+ // its lifetime is longer than the FlatBufferModel instance.
+ // Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
const char* buffer, size_t buffer_size,
ErrorReporter* error_reporter = DefaultErrorReporter());
// Builds a model directly from a flatbuffer pointer. The caller retains
// ownership of the buffer and should keep it alive until the returned object
- // is destroyed. Returns a nullptr in case of failure.
+ // is destroyed. Caller retains ownership of `error_reporter` and must ensure
+ // its lifetime is longer than the FlatBufferModel instance.
+ // Returns a nullptr in case of failure.
static std::unique_ptr<FlatBufferModel> BuildFromModel(
const tflite::Model* model_spec,
ErrorReporter* error_reporter = DefaultErrorReporter());
@@ -100,7 +110,10 @@ class FlatBufferModel {
private:
// Loads a model from a given allocation. FlatBufferModel will take over the
- // ownership of `allocation`, and delete it in desctructor.
+ // ownership of `allocation`, and delete it in destructor. The ownership of
+ // `error_reporter`remains with the caller and must have lifetime at least
+ // as much as FlatBufferModel. This is to allow multiple models to use the
+ // same ErrorReporter instance.
FlatBufferModel(Allocation* allocation,
ErrorReporter* error_reporter = DefaultErrorReporter());
@@ -111,7 +124,10 @@ class FlatBufferModel {
// Flatbuffer traverser pointer. (Model* is a pointer that is within the
// allocated memory of the data allocated by allocation's internals.
const tflite::Model* model_ = nullptr;
+ // The error reporter to use for model errors and subsequent errors when
+ // the interpreter is created
ErrorReporter* error_reporter_;
+ // The allocator used for holding memory of the model.
Allocation* allocation_ = nullptr;
};
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index 621fbcb98d..1f3ea2e1c7 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -200,6 +200,12 @@ void DeallocateTransientArray(const Model& model, const string& array_name,
allocator->Deallocate(*array->alloc);
}
+void PushBackIfNotFound(const string& s, std::vector<string>* v) {
+ if (std::find(v->begin(), v->end(), s) == v->end()) {
+ v->push_back(s);
+ }
+}
+
} // namespace
void AllocateTransientArrays(Model* model,
@@ -251,18 +257,12 @@ void AllocateTransientArrays(Model* model,
std::vector<string> arrays_to_allocate;
for (const auto& input : op->inputs) {
if (StartsAt(array_lifespans[input], op_index)) {
- if (std::find(arrays_to_allocate.begin(), arrays_to_allocate.end(),
- input) == arrays_to_allocate.end()) {
- arrays_to_allocate.push_back(input);
- }
+ PushBackIfNotFound(input, &arrays_to_allocate);
}
}
for (const auto& output : op->outputs) {
if (StartsAt(array_lifespans[output], op_index)) {
- if (std::find(arrays_to_allocate.begin(), arrays_to_allocate.end(),
- output) == arrays_to_allocate.end()) {
- arrays_to_allocate.push_back(output);
- }
+ PushBackIfNotFound(output, &arrays_to_allocate);
}
}
for (const string& array : arrays_to_allocate) {
@@ -274,18 +274,12 @@ void AllocateTransientArrays(Model* model,
std::vector<string> arrays_to_deallocate;
for (const auto& input : op->inputs) {
if (EndsAt(array_lifespans[input], op_index)) {
- if (std::find(arrays_to_deallocate.begin(), arrays_to_deallocate.end(),
- input) == arrays_to_deallocate.end()) {
- arrays_to_deallocate.push_back(input);
- }
+ PushBackIfNotFound(input, &arrays_to_deallocate);
}
}
for (const auto& output : op->outputs) {
if (EndsAt(array_lifespans[output], op_index)) {
- if (std::find(arrays_to_deallocate.begin(), arrays_to_deallocate.end(),
- output) == arrays_to_deallocate.end()) {
- arrays_to_deallocate.push_back(output);
- }
+ PushBackIfNotFound(output, &arrays_to_deallocate);
}
}
for (const string& array : arrays_to_deallocate) {
@@ -310,17 +304,21 @@ void AllocateTransientArrays(Model* model,
// for each operator, compute the sum of the sizes of the array that must
// be live during the execution of this operator, plus the size of
// persistent arrays that must be live at all times.
- std::size_t size = persistent_alloc_size;
+ std::vector<string> non_persistent_edges;
for (const auto& input : op->inputs) {
if (!array_lifespans[input].persistent) {
- size += TransientArraySize(*model, input, transient_data_alignment);
+ PushBackIfNotFound(input, &non_persistent_edges);
}
}
for (const auto& output : op->outputs) {
if (!array_lifespans[output].persistent) {
- size += TransientArraySize(*model, output, transient_data_alignment);
+ PushBackIfNotFound(output, &non_persistent_edges);
}
}
+ std::size_t size = persistent_alloc_size;
+ for (const string& edge : non_persistent_edges) {
+ size += TransientArraySize(*model, edge, transient_data_alignment);
+ }
// The optimal total size is the maximum of all operator-specific sizes.
optimal_transient_alloc_size = std::max(optimal_transient_alloc_size, size);
}
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 39e49bc347..7a7059e357 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -202,6 +202,7 @@ struct ParsedModelFlags {
Arg<toco::IntList> input_shape;
Arg<toco::StringMapList> rnn_states;
Arg<toco::StringMapList> model_checks;
+ Arg<bool> change_concat_input_ranges = Arg<bool>(true);
// Debugging output options.
// TODO(benoitjacob): these shouldn't be ModelFlags.
Arg<string> graphviz_first_array;
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 5d51431005..4a77196aab 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -37,6 +37,7 @@ limitations under the License.
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT16;
using tensorflow::DT_INT32;
using tensorflow::DT_INT64;
using tensorflow::DT_UINT8;
@@ -1868,6 +1869,9 @@ void AddPlaceholder(const string& name, ArrayDataType type,
case ArrayDataType::kInt64:
(*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
break;
+ case ArrayDataType::kInt16:
+ (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
+ break;
default:
LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 23c9e3246b..437e30a918 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -95,30 +95,37 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
overall_minmax.min = overall_min;
overall_minmax.max = overall_max;
bool changed = false;
- for (const auto& input : op->inputs) {
- auto& array = model->GetArray(input);
- if (!array.minmax) {
- changed = true;
- } else if (!(overall_minmax == array.GetMinMax())) {
- changed = true;
- LOG(WARNING)
- << "Tweaking the MinMax of array " << input << ", which is "
- << "an input to " << LogName(*op) << ", because we want all inputs "
- << "and outputs of a Concatenation operator to have the same MinMax "
- << "so that it can be implemented as a pure byte-copy, no "
- "arithmetic.";
+ if (model->flags.change_concat_input_ranges()) {
+ for (const auto& input : op->inputs) {
+ auto& array = model->GetArray(input);
+ if (!array.minmax) {
+ changed = true;
+ } else if (!(overall_minmax == array.GetMinMax())) {
+ changed = true;
+ LOG(WARNING)
+ << "Tweaking the MinMax of array " << input << ", which is "
+ << "an input to " << LogName(*op) << ", because we want all inputs "
+ << "and outputs of a Concatenation operator to have the same "
+ << "MinMax so that it can be implemented as a pure byte-copy, no "
+ "arithmetic.";
+ }
+ array.GetOrCreateMinMax() = overall_minmax;
}
- array.GetOrCreateMinMax() = overall_minmax;
}
if (!output.minmax) {
changed = true;
} else if (!(overall_minmax == output.GetMinMax())) {
- changed = true;
- LOG(WARNING)
- << "Tweaking the MinMax of the output array of " << LogName(*op)
- << ", because we want all inputs "
- << "and outputs of a Concatenation operator to have the same MinMax "
- << "so that it can be implemented as a pure byte-copy, no arithmetic.";
+ if (model->flags.change_concat_input_ranges()) {
+ changed = true;
+ LOG(WARNING)
+ << "Tweaking the MinMax of the output array of " << LogName(*op)
+ << ", because we want all inputs "
+ << "and outputs of a Concatenation operator to have the same MinMax "
+ << "so that it can be implemented as a pure byte-copy, no "
+ << "arithmetic.";
+ } else {
+ return false;
+ }
}
output.GetOrCreateMinMax() = overall_minmax;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
index 935da9f966..183b3d3f2e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -78,15 +78,21 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
image_input_op->outputs = {dequantized_input_name};
model->operators.emplace(model->operators.begin(), image_input_op);
- CHECK(input_array.final_data_type == ArrayDataType::kUint8);
- input_array.data_type = ArrayDataType::kUint8;
dequantized_input_array.data_type = ArrayDataType::kFloat;
const auto& input_minmax = input_array.GetMinMax();
auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax();
dequantized_input_minmax = input_minmax;
auto& input_qparams = input_array.GetOrCreateQuantizationParams();
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
- &input_qparams);
+ input_array.data_type = input_array.final_data_type;
+ if (input_array.data_type == ArrayDataType::kUint8) {
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
+ &input_qparams);
+ } else if (input_array.data_type == ArrayDataType::kInt16) {
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(input_minmax,
+ &input_qparams);
+ } else {
+ LOG(FATAL) << "unhandled data type";
+ }
transformation->AddMessageF(
"Created %s"
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 7784558b22..5b1268f9a9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -431,7 +431,8 @@ bool ChooseQuantizationForOperatorOutput(
(op.type == OperatorType::kSpaceToDepth) ||
(op.type == OperatorType::kTensorFlowReshape) ||
(op.type == OperatorType::kTensorFlowSplit) ||
- (op.type == OperatorType::kConcatenation)) {
+ (op.type == OperatorType::kConcatenation &&
+ model->flags.change_concat_input_ranges())) {
int data_input_index = 0;
if (op.type == OperatorType::kTensorFlowSplit) {
data_input_index = 1;
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 245eb52444..7bbeab7c9d 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -165,6 +165,11 @@ bool ParseModelFlagsFromCommandLineFlags(
"Path to an optional file containing a serialized ModelFlags proto. "
"Options specified on the command line will override the values in "
"the proto."),
+ Flag("change_concat_input_ranges",
+ parsed_flags.change_concat_input_ranges.bind(),
+ parsed_flags.change_concat_input_ranges.default_value(),
+ "Boolean to change the behavior of min/max ranges for inputs and"
+ " output of the concat operators."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -399,12 +404,15 @@ void ReadModelFlagsFromCommandLineFlags(
parsed_model_flags.allow_nonascii_arrays.value());
model_flags->set_allow_nonexistent_arrays(
parsed_model_flags.allow_nonexistent_arrays.value());
+ model_flags->set_change_concat_input_ranges(
+ parsed_model_flags.change_concat_input_ranges.value());
if (parsed_model_flags.arrays_extra_info_file.specified()) {
string arrays_extra_info_file_contents;
- port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
- &arrays_extra_info_file_contents,
- port::file::Defaults());
+ CHECK(port::file::GetContents(
+ parsed_model_flags.arrays_extra_info_file.value(),
+ &arrays_extra_info_file_contents, port::file::Defaults())
+ .ok());
ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
model_flags->mutable_arrays_extra_info());
}
diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto
index 835dea49eb..d23e80c464 100644
--- a/tensorflow/contrib/lite/toco/model_flags.proto
+++ b/tensorflow/contrib/lite/toco/model_flags.proto
@@ -128,7 +128,7 @@ message ArraysExtraInfo {
// optional int32 input_dims = 11 [ default = 4];
// repeated int32 input_shape = 13;
//
-// Next ID to USE: 19.
+// Next ID to USE: 20.
message ModelFlags {
// Information about the input arrays, i.e. the arrays from which input
// activations will be read.
@@ -175,4 +175,8 @@ message ModelFlags {
// If set, this ArraysExtraInfo allows to pass extra information about arrays
// not specified in the input model file, such as extra MinMax information.
optional ArraysExtraInfo arrays_extra_info = 18;
+
+ // When set to false, toco will not change the input ranges and the output
+ // ranges of concat operator to the overlap of all input ranges.
+ optional bool change_concat_input_ranges = 19 [default = true];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 76e9a27aef..96c5ebd64f 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -130,20 +130,26 @@ bool SupportsPreallocatedWorkspace(FileFormat format) {
}
bool IsRealValued(toco::ArrayDataType type) {
+ // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
+ // for quantized real-number values, and no other integer type is ever used
+ // for that. This is dirty, should be resolved as part of a more general push
+ // to more explicitly distinguish between true-integers and
+ // integers used as quantized values representing real numbers.
return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
- type == toco::ArrayDataType::kUint8);
+ type == toco::ArrayDataType::kUint8 ||
+ type == toco::ArrayDataType::kInt16);
}
void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
const FileFormat output_format = toco_flags.output_format();
ArrayDataType type;
- if (toco_flags.has_inference_input_type()) {
+ if (!SupportsQuantization(output_format)) {
+ // Data type is implicitly float for non-quantized formats
+ type = ArrayDataType::kFloat;
+ } else if (toco_flags.has_inference_input_type()) {
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
} else if (toco_flags.has_inference_type()) {
type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
- } else if (!SupportsQuantization(output_format)) {
- // Data type is implicitly float for non-quantized formats
- type = ArrayDataType::kFloat;
} else {
// Nothing to do. Data types stay as-is.
return;
@@ -198,11 +204,6 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
}
void Transform(const TocoFlags& toco_flags, Model* model) {
- // Clean up after import.
- SetFinalDataTypeOnInputs(toco_flags, model);
- UseArraysExtraInfo(model);
- FinishBuildingRNNStates(model);
-
const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type();
@@ -215,6 +216,11 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
<< "Quantized inference is not allowed with float inputs.";
}
+ // Clean up after import.
+ SetFinalDataTypeOnInputs(toco_flags, model);
+ UseArraysExtraInfo(model, quantize_output);
+ FinishBuildingRNNStates(model);
+
// Remove unused ops before performing any other optimizations. This is to
// stop optimizations from crossing the input/output boundaries. For example
// this will stop BatchNorm fusing if the output node is in between a conv
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 56fa8f4b69..b72f5fa2a7 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -1378,12 +1378,22 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
const float mean_value = input_array_proto.mean_value();
const float std_value = input_array_proto.std_value();
MinMax input_minmax;
- input_minmax.min = (0.f - mean_value) / std_value;
- input_minmax.max = (255.f - mean_value) / std_value;
+ float qmin = 0, qmax = 255;
+ if (input_array.data_type == ArrayDataType::kInt16) {
+ qmin = -32768;
+ qmax = 32767;
+ }
+ input_minmax.min = (qmin - mean_value) / std_value;
+ input_minmax.max = (qmax - mean_value) / std_value;
if (input_array.minmax) {
if (input_array_proto.has_mean_value() ||
input_array_proto.has_std_value()) {
- CHECK(input_minmax == *input_array.minmax)
+ const double width = input_minmax.max - input_minmax.min;
+ const double kMinMaxAllowedDiff = 1e-6 * width;
+ CHECK(std::abs(input_minmax.min - input_array.minmax->min) <
+ kMinMaxAllowedDiff &&
+ std::abs(input_minmax.max - input_array.minmax->max) <
+ kMinMaxAllowedDiff)
<< input_minmax.min << ", " << input_minmax.max
<< " != " << input_array.minmax->min << ", "
<< input_array.minmax->max;
@@ -1403,7 +1413,8 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
CHECK(input_array.shape().dims_size());
}
}
-
+ model->flags.set_change_concat_input_ranges(
+ model_flags.change_concat_input_ranges());
model->flags.set_allow_nonascii_arrays(model_flags.allow_nonascii_arrays());
model->flags.set_allow_nonexistent_arrays(
model_flags.allow_nonexistent_arrays());
@@ -2000,7 +2011,7 @@ void FinishBuildingRNNStates(Model* model) {
}
}
-void UseArraysExtraInfo(Model* model) {
+void UseArraysExtraInfo(Model* model, bool quantize_output) {
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
if (!model->HasArray(entry.name())) {
continue;
@@ -2012,7 +2023,7 @@ void UseArraysExtraInfo(Model* model) {
minmax.min = entry.min();
minmax.max = entry.max();
}
- if (entry.has_data_type()) {
+ if (entry.has_data_type() && quantize_output) {
array.final_data_type =
ConvertIODataTypeToArrayDataType(entry.data_type());
}
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 259ee7fbd0..dfd81173c3 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -285,7 +285,7 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type);
// already quantized, then case (a) should hold.
void FinishBuildingRNNStates(Model* model);
-void UseArraysExtraInfo(Model* model);
+void UseArraysExtraInfo(Model* model, bool quantize_output);
} // namespace toco
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 63fdd91d36..c7d85862f6 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -842,12 +842,12 @@ class RNNCellTest(test.TestCase):
batch_size = 3
input_size = 4
expected_state_c = np.array(
- [[6.450831e-04, 4.697885e-04], [9.862894e-05, 7.212213e-04],
- [4.401947e-04, 9.143004e-04]],
+ [[0.00072015, 0.00036633], [0.00083481, 0.00047266],
+ [0.00085111, 0.00053054]],
dtype=np.float32)
expected_state_h = np.array(
- [[4.621217e-04, 3.365449e-04], [7.438179e-05, 5.439147e-04],
- [3.347936e-04, 6.953785e-04]],
+ [[0.0005159, 0.00026243], [0.00062958, 0.00035646],
+ [0.00064732, 0.00040351]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
diff --git a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
index 15a415df30..eac34afc4a 100644
--- a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
+++ b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
@@ -52,6 +52,7 @@ class FakeSummaryWriter(object):
self._added_graphs = []
self._added_meta_graphs = []
self._added_session_logs = []
+ self._added_run_metadata = {}
@property
def summaries(self):
@@ -127,6 +128,11 @@ class FakeSummaryWriter(object):
# pylint: disable=unused-argument
self._added_session_logs.append(session_log)
+ def add_run_metadata(self, run_metadata, tag, global_step=None):
+ if (global_step is not None) and (global_step < 0):
+ raise ValueError('Invalid global_step %s.' % global_step)
+ self._added_run_metadata[tag] = run_metadata
+
def flush(self):
pass
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index f2003e04dd..6b198dbc16 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -64,9 +64,11 @@ Status ValidateHostPortPair(const string& host_port) {
return Status::OK();
}
-ProfileResponse Profile(const string& service_addr, int duration_ms,
- const string& repository_root, const string& session_id,
- const ProfileOptions& opts) {
+// Returns whether the returned trace is empty.
+// Failure are handled by CHECK, i.e. abort()
+bool Profile(const string& service_addr, const string& logdir, int duration_ms,
+ const string& repository_root, const string& session_id,
+ const ProfileOptions& opts) {
ProfileRequest request;
request.set_duration_ms(duration_ms);
request.set_max_events(kMaxEvents);
@@ -94,7 +96,31 @@ ProfileResponse Profile(const string& service_addr, int duration_ms,
channel_args));
ProfileResponse response;
TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response)));
- return response;
+
+ if (!response.encoded_trace().empty()) {
+ TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
+ logdir, session_id, "", response, &std::cout));
+ // Print this at the end so that it's not buried in irrelevant LOG messages.
+ std::cout
+ << "NOTE: using the trace duration " << duration_ms << "ms."
+ << std::endl
+ << "Set an appropriate duration (with --duration_ms) if you "
+ "don't see a full step in your trace or the captured trace is too "
+ "large."
+ << std::endl;
+ }
+
+ return response.encoded_trace().empty();
+}
+
+// Start a new profiling session that include all the hosts included in
+// hostnames, for the time interval of duration_ms. Possibly save the profiling
+// result in the directory specified by repository_root and session_id.
+bool NewSession(const string& service_addr,
+ const std::vector<tensorflow::string>& hostnames,
+ int duration_ms, const string& repository_root,
+ const string& session_id, const ProfileOptions& opts) {
+ return true;
}
} // namespace
@@ -104,12 +130,16 @@ ProfileResponse Profile(const string& service_addr, int duration_ms,
int main(int argc, char** argv) {
tensorflow::string FLAGS_service_addr;
tensorflow::string FLAGS_logdir;
+ tensorflow::string FLAGS_workers_list;
int FLAGS_duration_ms = 2000;
int FLAGS_num_tracing_attempts = 3;
bool FLAGS_include_dataset_ops = true;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("service_addr", &FLAGS_service_addr,
"Address of TPU profiler service e.g. localhost:8466"),
+ tensorflow::Flag("workers_list", &FLAGS_workers_list,
+ "The list of worker TPUs that we are about to profile "
+ "in the current session."),
tensorflow::Flag("logdir", &FLAGS_logdir,
"Path of TensorBoard log directory e.g. /tmp/tb_log, "
"gs://tb_bucket"),
@@ -153,18 +183,30 @@ int main(int argc, char** argv) {
constexpr char kProfilePluginDirectory[] = "plugins/profile/";
tensorflow::string repository_root =
::tensorflow::io::JoinPath(FLAGS_logdir, kProfilePluginDirectory);
+ std::vector<tensorflow::string> hostnames =
+ tensorflow::str_util::Split(FLAGS_workers_list, ",");
+
+ bool empty_trace = false;
while (true) {
std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
<< "Remaining attempt(s): " << remaining_attempts-- << std::endl;
- response = tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms,
- repository_root, session_id, opts);
- if (remaining_attempts <= 0 || !response.encoded_trace().empty()) break;
+ if (hostnames.empty()) {
+ empty_trace = tensorflow::tpu::Profile(FLAGS_service_addr, FLAGS_logdir,
+ duration_ms, repository_root,
+ session_id, opts);
+ } else {
+ tensorflow::string tpu_master = FLAGS_service_addr;
+ empty_trace =
+ tensorflow::tpu::NewSession(tpu_master, hostnames, duration_ms,
+ repository_root, session_id, opts);
+ }
+ if (remaining_attempts <= 0 || !empty_trace) break;
std::cout << "No trace event is collected. Automatically retrying."
<< std::endl
<< std::endl;
}
- if (response.encoded_trace().empty()) {
+ if (empty_trace) {
std::cout << "No trace event is collected after "
<< FLAGS_num_tracing_attempts << " attempt(s). "
<< "Perhaps, you want to try again (with more attempts?)."
@@ -175,13 +217,5 @@ int main(int argc, char** argv) {
return 0;
}
- TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
- FLAGS_logdir, session_id, response, &std::cout));
- // Print this at the end so that it's not buried in irrelevant LOG messages.
- std::cout
- << "NOTE: using the trace duration " << duration_ms << "ms." << std::endl
- << "Set an appropriate duration (with --duration_ms) if you "
- "don't see a full step in your trace or the captured trace is too "
- "large."
- << std::endl;
+ return 0;
}
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
index ebd6185faa..ae508583f8 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
@@ -41,6 +41,7 @@ namespace {
using ::tensorflow::io::JoinPath;
using ::tensorflow::protobuf::util::JsonOptions;
using ::tensorflow::protobuf::util::MessageToJsonString;
+using ::tensorflow::strings::StrCat;
constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph.";
constexpr char kJsonOpProfileFileName[] = "op_profile.json";
@@ -61,28 +62,33 @@ Status WriteGzippedDataToFile(const string& filename, const string& data) {
return Status::OK();
}
-Status DumpTraceToLogDirectory(StringPiece run_dir, const string& encoded_trace,
- std::ostream* os) {
+Status DumpTraceToLogDirectory(StringPiece run_dir, const string& host_prefix,
+ const string& encoded_trace, std::ostream* os) {
string proto_path = JoinPath(run_dir, kProtoTraceFileName);
TF_RETURN_IF_ERROR(
WriteStringToFile(Env::Default(), proto_path, encoded_trace));
LOG(INFO) << "Dumped raw-proto trace data to " << proto_path;
- string json_path = JoinPath(run_dir, kJsonTraceFileName);
+ string json_path = JoinPath(run_dir, StrCat(host_prefix, kJsonTraceFileName));
Trace trace;
trace.ParseFromString(encoded_trace);
- *os << "Trace contains " << trace.trace_events_size() << " events."
- << std::endl;
+ if (os) {
+ *os << "Trace contains " << trace.trace_events_size() << " events."
+ << std::endl;
+ }
TF_RETURN_IF_ERROR(
WriteGzippedDataToFile(json_path, TraceEventsToJson(trace)));
- *os << "Dumped JSON trace data to " << json_path << std::endl;
+ if (os) {
+ *os << "Dumped JSON trace data to " << json_path << std::endl;
+ }
return Status::OK();
}
Status DumpOpProfileToLogDirectory(StringPiece run_dir,
+ const string& host_prefix,
const tpu::op_profile::Profile& profile,
std::ostream* os) {
- string path = JoinPath(run_dir, kJsonOpProfileFileName);
+ string path = JoinPath(run_dir, StrCat(host_prefix, kJsonOpProfileFileName));
string json;
JsonOptions options;
options.always_print_primitive_fields = true;
@@ -93,49 +99,20 @@ Status DumpOpProfileToLogDirectory(StringPiece run_dir,
string(status.error_message()));
}
TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, json));
- *os << "Dumped json op profile data to " << path << std::endl;
+ if (os) {
+ *os << "Dumped json op profile data to " << path << std::endl;
+ }
return Status::OK();
}
Status DumpToolDataToLogDirectory(StringPiece run_dir,
+ const string& host_prefix,
const tensorflow::ProfileToolData& tool,
std::ostream* os) {
- string path = JoinPath(run_dir, tool.name());
+ string path = JoinPath(run_dir, StrCat(host_prefix, tool.name()));
TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data()));
- *os << "Dumped tool data for " << tool.name() << " to " << path << std::endl;
- return Status::OK();
-}
-
-Status DumpGraphEvents(const string& logdir, const string& run,
- const ProfileResponse& response, std::ostream* os) {
- int num_graphs = response.computation_graph_size();
- if (response.computation_graph_size() == 0) return Status::OK();
- // The server might generates multiple graphs for one program; we simply
- // pick the first one.
- if (num_graphs > 1) {
- *os << num_graphs
- << " TPU program variants observed over the profiling period. "
- << "One computation graph will be chosen arbitrarily." << std::endl;
- }
- // The graph plugin expects the graph in <logdir>/<run>/<event.file>.
- string run_dir = JoinPath(logdir, strings::StrCat(kGraphRunPrefix, run));
- TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir));
- EventsWriter event_writer(JoinPath(run_dir, "events"));
- Event event;
- // Add the computation graph.
- event.set_graph_def(response.computation_graph(0).SerializeAsString());
- event_writer.WriteEvent(event);
- *os << "Wrote a HLO graph to " << event_writer.FileName() << std::endl;
-
- if (response.has_hlo_metadata()) {
- tensorflow::TaggedRunMetadata tagged_run_metadata;
- tagged_run_metadata.set_tag(run);
- tagged_run_metadata.set_run_metadata(
- response.hlo_metadata().SerializeAsString());
- tensorflow::Event meta_event;
- *meta_event.mutable_tagged_run_metadata() = tagged_run_metadata;
- event_writer.WriteEvent(meta_event);
- *os << "Wrote HLO ops run metadata to " << event_writer.FileName()
+ if (os) {
+ *os << "Dumped tool data for " << tool.name() << " to " << path
<< std::endl;
}
return Status::OK();
@@ -144,27 +121,29 @@ Status DumpGraphEvents(const string& logdir, const string& run,
} // namespace
Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
+ const string& host,
const ProfileResponse& response,
std::ostream* os) {
// Dumps profile data to <logdir>/plugins/profile/<run>/.
+ string host_prefix = host.empty() ? "" : StrCat(host, ".");
string profile_run_dir = JoinPath(logdir, kProfilePluginDirectory, run);
TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir));
// Ignore computation_graph for now.
if (!response.encoded_trace().empty()) {
LOG(INFO) << "Converting trace events to TraceViewer JSON.";
- TF_RETURN_IF_ERROR(
- DumpTraceToLogDirectory(profile_run_dir, response.encoded_trace(), os));
+ TF_RETURN_IF_ERROR(DumpTraceToLogDirectory(profile_run_dir, host_prefix,
+ response.encoded_trace(), os));
}
if (response.has_op_profile() &&
(response.op_profile().has_by_program_structure() ||
response.op_profile().has_by_category())) {
- TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir,
+ TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, host_prefix,
response.op_profile(), os));
}
for (const auto& tool_data : response.tool_data()) {
- TF_RETURN_IF_ERROR(
- DumpToolDataToLogDirectory(profile_run_dir, tool_data, os));
+ TF_RETURN_IF_ERROR(DumpToolDataToLogDirectory(profile_run_dir, host_prefix,
+ tool_data, os));
}
return Status::OK();
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
index 29ef977bac..ecf21b1de2 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h
@@ -32,6 +32,7 @@ namespace tpu {
// Note: this function creates a directory even when all fields in
// ProfileResponse are unset/empty.
Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
+ const string& host,
const ProfileResponse& response,
std::ostream* os);
diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
index cddc3cd1b4..8505c4bc69 100644
--- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
+++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
@@ -21,6 +21,17 @@ message ProfileOptions {
// next-field: 2
}
+message ToolRequestOptions {
+ // Required formats for the tool, it should be one of "json", "proto", "raw"
+ // etc. If not specified (backward compatible), use default format, i.e. most
+ // tools use json format.
+ string output_formats = 2;
+
+ // Whether save the result directly to repository or pass it back to caller.
+ // Default to false for backward compatibilities.
+ bool save_to_repo = 3;
+}
+
message ProfileRequest {
// In future, the caller will be able to customize when profiling starts and
// stops. For now, it collects `duration_ms` milliseconds worth of data.
@@ -30,9 +41,12 @@ message ProfileRequest {
// events.
uint64 max_events = 2;
- // required profiling tools name such as "input_pipeline_analyzer" etc
+ // Required profiling tools name such as "input_pipeline_analyzer" etc
repeated string tools = 3;
+ // Specifies the requirement for each tools.
+ map<string, ToolRequestOptions> tool_options = 8;
+
// Optional profiling options that control how a TF session will be profiled.
ProfileOptions opts = 4;
@@ -43,10 +57,14 @@ message ProfileRequest {
// The user provided profile session identifier.
string session_id = 6;
+ // The hostname of system where the profile should happen.
+ // We use it as identifier in part of our output filename.
+ string host_name = 7;
+
// In future, the caller will indicate which TF session is being profiled, and
// only data relating to that program will be returned. For now, we assume
// all activity during the profiling period is relevant.
- // next-field: 7
+ // next-field: 9
}
message ProfileToolData {
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 01da54fcb3..64adf35c5e 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -66,8 +66,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
std::unique_ptr<ClientGraph> cg,
const SessionOptions& session_opts,
const StatsPublisherFactory& stats_publisher_factory,
- GraphExecutionState* execution_state, bool is_partial,
- WorkerCacheInterface* worker_cache, bool should_deregister)
+ bool is_partial, WorkerCacheInterface* worker_cache,
+ bool should_deregister)
: session_handle_(handle),
client_graph_(std::move(cg)),
session_opts_(session_opts),
@@ -80,8 +80,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
- // Initialize a name to node map for testing that fetches are reachable.
- for (Node* n : execution_state->full_graph()->nodes()) {
+ // Initialize a name to node map for processing device stats.
+ for (Node* n : client_graph_->graph.nodes()) {
name_to_node_.insert({n->name(), n});
}
}
@@ -829,8 +829,6 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats(
// TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends
// on once at setup time to prevent us from computing the dependencies
// everytime.
-// TODO(suharshs,mrry): Consider removing the need for execution_state to reduce
-// contention.
Status MasterSession::ReffedClientGraph::CheckFetches(
const RunStepRequestWrapper& req, const RunState* run_state,
GraphExecutionState* execution_state) {
@@ -840,8 +838,8 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
// Skip if already fed.
if (input.second) continue;
TensorId id(ParseTensorName(input.first));
- const auto it = name_to_node_.find(id.first);
- if (it == name_to_node_.end()) {
+ const Node* n = execution_state->get_node_by_name(id.first.ToString());
+ if (n == nullptr) {
return errors::NotFound("Feed ", input.first, ": not found");
}
pending_feeds.insert(id);
@@ -856,11 +854,11 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
for (size_t i = 0; i < req.num_fetches(); ++i) {
const string& fetch = req.fetch_name(i);
const TensorId id(ParseTensorName(fetch));
- auto it = name_to_node_.find(id.first);
- if (it == name_to_node_.end()) {
+ const Node* n = execution_state->get_node_by_name(id.first.ToString());
+ if (n == nullptr) {
return errors::NotFound("Fetch ", fetch, ": not found");
}
- stack.push_back(it->second);
+ stack.push_back(n);
}
// Any tensor needed for fetches can't be in pending_feeds.
@@ -1293,8 +1291,8 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
- stats_publisher_factory_, execution_state_.get(), is_partial,
- worker_cache, !should_delete_worker_sessions_);
+ stats_publisher_factory_, is_partial, worker_cache,
+ !should_delete_worker_sessions_);
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index b4d18d8607..63745e8ebd 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -47,11 +47,11 @@ namespace tensorflow {
class GrpcMasterService : public AsyncServiceInterface {
public:
- GrpcMasterService(Master* master, int64 default_timeout_in_ms,
+ GrpcMasterService(Master* master, const ConfigProto& default_session_config,
::grpc::ServerBuilder* builder)
: master_impl_(master),
- default_timeout_in_ms_(default_timeout_in_ms),
- is_shutdown_(false) {
+ is_shutdown_(false),
+ default_session_config_(default_session_config) {
builder->RegisterService(&master_service_);
cq_ = builder->AddCompletionQueue();
}
@@ -129,12 +129,12 @@ class GrpcMasterService : public AsyncServiceInterface {
private:
Master* master_impl_ = nullptr; // Not owned.
- const int64 default_timeout_in_ms_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
grpc::MasterService::AsyncService master_service_;
mutex mu_;
bool is_shutdown_ GUARDED_BY(mu_);
+ const ConfigProto default_session_config_;
::grpc::Alarm* shutdown_alarm_ = nullptr;
template <class RequestMessage, class ResponseMessage>
@@ -144,9 +144,13 @@ class GrpcMasterService : public AsyncServiceInterface {
// RPC handler for creating a session.
void CreateSessionHandler(
MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
- master_impl_->CreateSession(&call->request, &call->response,
- [call](const Status& status) {
+ CreateSessionRequest* rewritten_req = new CreateSessionRequest;
+ rewritten_req->mutable_config()->MergeFrom(default_session_config_);
+ rewritten_req->MergeFrom(call->request);
+ master_impl_->CreateSession(rewritten_req, &call->response,
+ [call, rewritten_req](const Status& status) {
call->SendResponse(ToGrpcStatus(status));
+ delete rewritten_req;
});
ENQUEUE_REQUEST(CreateSession, true);
}
@@ -178,7 +182,7 @@ class GrpcMasterService : public AsyncServiceInterface {
if (call->request.options().timeout_in_ms() > 0) {
call_opts->SetTimeout(call->request.options().timeout_in_ms());
} else {
- call_opts->SetTimeout(default_timeout_in_ms_);
+ call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
}
RunStepRequestWrapper* wrapped_request =
new ProtoRunStepRequest(&call->request);
@@ -249,10 +253,10 @@ class GrpcMasterService : public AsyncServiceInterface {
TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
};
-AsyncServiceInterface* NewGrpcMasterService(Master* master,
- int64 default_timeout_in_ms,
- ::grpc::ServerBuilder* builder) {
- return new GrpcMasterService(master, default_timeout_in_ms, builder);
+AsyncServiceInterface* NewGrpcMasterService(
+ Master* master, const ConfigProto& default_session_config,
+ ::grpc::ServerBuilder* builder) {
+ return new GrpcMasterService(master, default_session_config, builder);
}
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h
index 473604f257..f0fe5b0c4e 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/master.pb.h"
namespace grpc {
class ServerBuilder;
@@ -28,9 +29,9 @@ namespace tensorflow {
class AsyncServiceInterface;
class Master;
-AsyncServiceInterface* NewGrpcMasterService(Master* master,
- int64 default_timeout_in_ms,
- ::grpc::ServerBuilder* builder);
+AsyncServiceInterface* NewGrpcMasterService(
+ Master* master, const ConfigProto& default_session_config,
+ ::grpc::ServerBuilder* builder);
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index a6f4be3eaf..be19103582 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -183,8 +183,7 @@ Status GrpcServer::Init(
builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
master_impl_ = CreateMaster(&master_env_);
- master_service_ = NewGrpcMasterService(
- master_impl_.get(), config.operation_timeout_in_ms(), &builder);
+ master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
worker_impl_ =
worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
worker_service_ =
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 14e46ecdd9..79735e6cc2 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -459,11 +459,7 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
Costs costs;
costs.compute_time = compute_cost;
costs.memory_time = memory_cost;
- if (compute_memory_overlap_) {
- costs.execution_time = std::max(compute_cost, memory_cost);
- } else {
- costs.execution_time = compute_cost + memory_cost;
- }
+ CombineCostsAndUpdateExecutionTime(&costs);
return costs;
}
@@ -1375,5 +1371,14 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
return costs;
}
+void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
+ Costs* costs) const {
+ if (compute_memory_overlap_) {
+ costs->execution_time = std::max(costs->compute_time, costs->memory_time);
+ } else {
+ costs->execution_time = costs->compute_time + costs->memory_time;
+ }
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index fcbecbb6dc..7080264698 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -173,6 +173,11 @@ class OpLevelCostEstimator {
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);
+ // This method calculates the execution time depending on whether IO can
+ // overlap with computation. It assumes the memory and the compute times have
+ // already been calculated.
+ void CombineCostsAndUpdateExecutionTime(Costs* costs) const;
+
protected:
std::map<string, int> elementwise_ops_;
typedef std::function<Costs(const OpContext& op_context)> CostImpl;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 59a5695af0..7bf264ba30 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -237,17 +237,16 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
return false;
}
- // Now, src_shape and dst_shape have at most one dimension with unknown
- // sizes, and are compatible. Therefore, the reshape is a no-op when
- //
- // 1. at least one of them is fully-defined, or
- // 2. both are partially defined and the -1 appears on the same dimension,
- // i.e., IsIdenticalTo returns true.
- if (src_num_unknown_dim_sizes == 1 && dst_num_unknown_dim_sizes == 1) {
- return dst_shape.IsIdenticalTo(src_shape);
+ // If dst_num_unknown_dim_sizes != src_num_unknown_dim_sizes we would weaken
+ // shape inference in subsequent passes if we removed this reshape.
+ if (src_num_unknown_dim_sizes != dst_num_unknown_dim_sizes) {
+ return false;
}
- return true;
+ // Remove the reshape if both are fully defined or partially defined and the
+ // unknown or symbolic shape appears on the same dimension, i.e., if
+ // IsIdenticalTo returns true.
+ return dst_shape.IsIdenticalTo(src_shape);
}
NodeDef* GetTailOfValuePreservingChain(
@@ -727,7 +726,9 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
// Hoist non-shared factors up into the new AddN node.
for (int i = 0; i < unique_factors.size(); ++i) {
- new_add_node->set_input(i, unique_factors[i]);
+ const string& unique_factor_i = unique_factors[i];
+ new_add_node->set_input(i, unique_factor_i);
+ ctx_.node_map->AddOutput(unique_factor_i, new_add_node->name());
}
// Add control deps on add node
@@ -859,13 +860,18 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
NodeDef* node_perm;
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
+ if (!IsConstant(*node_perm)) {
+ return Status::OK();
+ }
std::vector<int64> node_perm_values;
TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
-
if (input->op() == node->op()) {
// Remove pairs of transposes that cancel each other.
NodeDef* input_perm;
TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
+ if (!IsConstant(*input_perm)) {
+ return Status::OK();
+ }
std::vector<int64> input_perm_values;
TF_RETURN_IF_ERROR(GetPermutation(*input_perm, &input_perm_values));
if (AreInversePermutations(node_perm_values, input_perm_values)) {
@@ -1337,9 +1343,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// ^ |
// | |
// input ---+
- NodeDef* reshape = node_map_->GetNode(node->name());
+ NodeDef* reshape = const_cast<NodeDef*>(node);
int output_pos = 0;
- string input_node_name = ParseNodeName(node->input(0), &output_pos);
+ string input_node_name = ParseNodeName(reshape->input(0), &output_pos);
const NodeDef* input = node_map_->GetNode(input_node_name);
if (input->op() == "Reshape" && !HasControlInputs(*input)) {
reshape->set_input(0, input->input(0));
@@ -1653,7 +1659,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
return "";
}
-Status ArithmeticOptimizer::SimplifyArithmeticOps() {
+Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
SetVector<NodeDef*> nodes_to_simplify;
nodes_to_simplify.Reserve(optimized_graph_->node_size());
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
@@ -1668,11 +1674,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
const auto stop = [](const string& result) { return !result.empty(); };
GraphOptimizerStagePipeline<string> pipeline(stop);
- if (options_.combine_add_to_addn)
+ if (options_.combine_add_to_addn && can_use_shapes)
pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
- if (options_.hoist_common_factor_out_of_aggregation)
+ if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
- if (options_.remove_identity_transpose)
+ if (options_.remove_identity_transpose && can_use_shapes)
pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
if (options_.remove_redundant_bitcast)
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
@@ -1759,10 +1765,14 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
// Shapes are only needed in aggressive mode.
graph_properties_.reset(new GraphProperties(item));
- TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+ const Status status = graph_properties_->InferStatically(false);
+ const bool can_use_shapes = status.ok();
+ if (!can_use_shapes) {
+ VLOG(1) << "Shape inference failed." << status.error_message();
+ }
// Perform the optimizations.
- TF_RETURN_IF_ERROR(SimplifyArithmeticOps());
+ TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
optimized_graph->Swap(optimized_graph_);
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 7e81ed0a1f..39b89dedba 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -105,7 +105,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse
// transposes.
- Status SimplifyArithmeticOps();
+ Status SimplifyArithmeticOps(bool can_use_shapes);
// Tries to simplify the expression that roots at `node` and replaces the uses
// of `node` to the simplified expression. Returns the name of the simplified
// tensor (e.g. "split:1") or an emtpy string if no simplification is
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index d941a0b3f9..b2a1ce6ab6 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -298,7 +298,8 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
for (int node_idx = 0; node_idx < node_count; ++node_idx) {
NodeDef* node = graph_->mutable_node(node_idx);
const string op = node->op();
- if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
+ if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
+ op != "TensorArraySizeV3") {
continue;
}
@@ -349,6 +350,36 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
continue;
}
+ if (op == "TensorArraySizeV3") {
+ const NodeDef* array = node_map_->GetNode(node->input(0));
+ if (array->attr().count("dynamic_size") != 0 &&
+ array->attr().at("dynamic_size").b()) {
+ continue;
+ }
+ const NodeDef* array_size = node_map_->GetNode(array->input(0));
+ if (IsReallyConstant(*array_size)) {
+ // Don't materialize 0 sizes to avoid triggering incorrect static
+ // checks. A 0 sized array that can't grow isn't useful anyway.
+ const TensorProto& raw_val = array_size->attr().at("value").tensor();
+ if (raw_val.dtype() != DT_INT32) {
+ continue;
+ }
+ Tensor value(raw_val.dtype(), raw_val.tensor_shape());
+ if (!value.FromProto(raw_val)) {
+ continue;
+ }
+ if (value.flat<int32>()(0) == 0) {
+ continue;
+ }
+ node->set_op("Const");
+ *node->mutable_attr() = array_size->attr();
+ node->set_input(0, AsControlDependency(NodeName(node->input(0))));
+ node->set_input(1, AddControlDependency(NodeName(node->input(1)),
+ graph_, node_map_.get()));
+ }
+ continue;
+ }
+
// Handle ShapeN materialization case.
// It's possible that not all input tensors have known shapes.
CHECK_EQ(op, "ShapeN");
@@ -552,7 +583,6 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
const DataType type = node.attr().at("T").type();
NodeDef* out[2];
- bool created_const = false;
for (int j = 0; j < 2; ++j) {
int reduction_indices = reduce_dims[j].size();
Tensor value(type, TensorShape({reduction_indices}));
@@ -576,20 +606,17 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
AddControlDependency(node.name(), graph_, node_map_.get());
*out[j]->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
- created_const = true;
}
}
- if (created_const) {
- const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
- for (NodeDef* output : outputs) {
- for (int k = 0; k < output->input_size(); ++k) {
- int port;
- string node_name = ParseNodeName(output->input(k), &port);
- if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
- *output->mutable_input(k) = out[port]->name();
- node_map_->UpdateInput(output->name(), node_name, out[port]->name());
- }
+ const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
+ for (NodeDef* output : outputs) {
+ for (int k = 0; k < output->input_size(); ++k) {
+ int port;
+ string node_name = ParseNodeName(output->input(k), &port);
+ if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
+ *output->mutable_input(k) = out[port]->name();
+ node_map_->UpdateInput(output->name(), node_name, out[port]->name());
}
}
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 71ee81dfde..08c92687e3 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -2402,6 +2402,48 @@ TEST_F(ConstantFoldingTest, Enter) {
}
}
+TEST_F(ConstantFoldingTest, TensorArraySize) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+ Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
+ auto dynamic_array =
+ ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
+ ops::TensorArray::DynamicSize(true));
+ auto static_array =
+ ops::TensorArray(scope.WithOpName("static"), size, DT_FLOAT,
+ ops::TensorArray::DynamicSize(false));
+ auto dynamic_sz = ops::TensorArraySize(
+ scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
+ auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
+ static_array.handle, static_array.flow);
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ auto tensors_expected =
+ EvaluateNodes(item.graph, {"dynamic_sz", "static_sz"});
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(5, output.node_size());
+ EXPECT_EQ("dynamic_sz", output.node(3).name());
+ EXPECT_EQ("TensorArraySizeV3", output.node(3).op());
+ EXPECT_EQ("static_sz", output.node(4).name());
+ EXPECT_EQ("Const", output.node(4).op());
+
+ auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
+ EXPECT_EQ(2, tensors_expected.size());
+ EXPECT_EQ(2, tensors_actual.size());
+ test::ExpectTensorEqual<int32>(tensors_expected[0], tensors_actual[0]);
+ test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index f1da469a6c..343c89a9da 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -36,8 +36,11 @@ namespace {
class FunctionInliningContext {
public:
- explicit FunctionInliningContext(const GrapplerItem& item)
- : library_(&item.graph.library()), functions_(InliningCandidates(item)) {}
+ explicit FunctionInliningContext(const GrapplerItem& item,
+ RewriterConfig::Toggle opt_level)
+ : library_(&item.graph.library()),
+ opt_level_(opt_level),
+ functions_(InliningCandidates(item)) {}
const FunctionDefLibrary& Library() const { return *library_; }
@@ -59,13 +62,9 @@ class FunctionInliningContext {
std::unordered_map<string, const FunctionDef*> functions;
for (const FunctionDef& func : item.graph.library().function()) {
// Don't inline functions marked as noinline
- if (func.attr().count("_noinline") != 0) {
- continue;
- }
- // Don't touch anything marked XLA to prevent XLA failures further down
- // the road.
- if (func.attr().count("_XlaCompile") > 0 &&
- func.attr().at("_XlaCompile").b()) {
+ if (func.attr().count("_noinline") != 0 &&
+ func.attr().at("_noinline").b() &&
+ opt_level_ != RewriterConfig::AGGRESSIVE) {
continue;
}
// Can't create IdentityN nodes with no input or output: skip these
@@ -80,6 +79,7 @@ class FunctionInliningContext {
}
const FunctionDefLibrary* library_;
+ RewriterConfig::Toggle opt_level_;
std::unordered_map<string, const FunctionDef*> functions_;
TF_DISALLOW_COPY_AND_ASSIGN(FunctionInliningContext);
@@ -206,6 +206,10 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
ctx, optimized_graph));
} else {
+ // Annotate the node with the function attributes.
+ for (const auto& attr : func.attr()) {
+ func_body_node.mutable_attr()->insert(attr);
+ }
// Move the node to the main graph
optimized_graph->add_node()->Swap(&func_body_node);
}
@@ -367,7 +371,7 @@ Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- FunctionInliningContext function_inlining_ctx(item);
+ FunctionInliningContext function_inlining_ctx(item, opt_level_);
// Nothing to do here.
if (!function_inlining_ctx.HasInlinedFunctions()) {
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.h b/tensorflow/core/grappler/optimizers/function_optimizer.h
index 41444e4673..b124efe01d 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.h
@@ -26,7 +26,7 @@ namespace grappler {
// operations to make the overall graph more efficient.
class FunctionOptimizer : public GraphOptimizer {
public:
- FunctionOptimizer(RewriterConfig::Toggle opt_level) {}
+ FunctionOptimizer(RewriterConfig::Toggle opt_level) : opt_level_(opt_level) {}
~FunctionOptimizer() override {}
string name() const override { return "function_optimizer"; };
@@ -36,6 +36,9 @@ class FunctionOptimizer : public GraphOptimizer {
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override;
+
+ private:
+ RewriterConfig::Toggle opt_level_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index c804d75756..fe26a56fc2 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -412,7 +412,7 @@ TEST_F(FunctionOptimizerTest, InlineFunctionWithNestedFunctionCall) {
{mul_func, square_func});
GraphDef output;
- FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ FunctionOptimizer optimizer(RewriterConfig::ON);
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
int count = 0;
@@ -508,7 +508,7 @@ TEST_F(FunctionOptimizerTest, SymbolicGradients) {
TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
*item.graph.mutable_library()->add_function() = func;
- FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ FunctionOptimizer optimizer(RewriterConfig::ON);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -550,7 +550,7 @@ TEST_F(FunctionOptimizerTest, SymbolicGradientsIdentity) {
TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
*item.graph.mutable_library()->add_function() = func;
- FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ FunctionOptimizer optimizer(RewriterConfig::ON);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -613,7 +613,7 @@ TEST_F(FunctionOptimizerTest, SymbolicGradientsNoInlineFunc) {
TF_EXPECT_OK(scope.ToGraphDef(&item.graph));
*item.graph.mutable_library()->add_function() = func;
- FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ FunctionOptimizer optimizer(RewriterConfig::ON);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
// The optimizer should succeed but the graphs should be the same.
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
index 7044705ade..1ea57f7b4f 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc
@@ -42,6 +42,10 @@ Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
Status GetTensorProperties(const GraphOptimizerContext& ctx,
const string& tensor,
OpInfo::TensorProperties* properties) {
+ if (ctx.graph_properties == nullptr) {
+ return errors::InvalidArgument("Graph properties are unknown.");
+ }
+
int port;
string tensor_node_name = ParseNodeName(tensor, &port);
if (port < 0) {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index d2a2cdd13d..1857d8d655 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2265,6 +2265,7 @@ tf_cc_tests(
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
@@ -5905,6 +5906,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@@ -6180,3 +6182,12 @@ cc_library(
"@gemmlowp",
],
)
+
+# Header-only version of cwise_lib for clients that want to use the cwise_ops
+# functionality in their own custom ops.
+cc_header_only_library(
+ name = "cwise_lib_hdrs",
+ deps = [
+ ":cwise_lib",
+ ],
+)
diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc
index a35e1b0788..709082e799 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_test.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -242,7 +243,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) {
AddInputFromArray<int32>(TensorShape({2}), {4, 4});
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
- EXPECT_TRUE(StringPiece(s.ToString()).contains("input image must be 4-D"))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "input image must be 4-D"))
<< s;
}
@@ -255,7 +256,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("box_index has incompatible shape"))
+ str_util::StrContains(s.ToString(), "box_index has incompatible shape"))
<< s;
}
@@ -267,8 +268,8 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
AddInputFromArray<int32>(TensorShape({2}), {3, 3});
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("box_index has values outside [0, batch_size)"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "box_index has values outside [0, batch_size)"))
<< s;
}
diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index ba9686e94e..07dc786d9b 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -104,6 +104,7 @@ namespace {
using perftools::gputools::DeviceMemory;
using perftools::gputools::DeviceMemoryBase;
using perftools::gputools::ScratchAllocator;
+using perftools::gputools::dnn::AlgorithmConfig;
using perftools::gputools::dnn::RnnDirectionMode;
using perftools::gputools::dnn::RnnInputMode;
using perftools::gputools::dnn::RnnMode;
@@ -544,9 +545,10 @@ class CudnnRNNKernelCommon : public OpKernel {
auto* stream = context->op_device_context()->stream();
// ExtracCudnnRNNParamsInfo is only called by op_kernels that do not require
// random number generator, therefore set state_allocator to nullptr.
+ const AlgorithmConfig algo_config;
auto rnn_desc_s = stream->parent()->createRnnDescriptor(
num_layers, num_units, input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), ToDataType<T>::value, dropout(), seed(),
+ rnn_mode(), ToDataType<T>::value, algo_config, dropout(), seed(),
nullptr /* state_allocator */);
if (!rnn_desc_s.ok()) {
return FromExecutorStatus(rnn_desc_s);
@@ -891,22 +893,24 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
new CudnnRNNPersistentSpaceAllocator(context);
rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
+ const AlgorithmConfig algo_config;
auto rnn_desc_s = executor->createRnnDescriptor(
model_shapes.num_layers, model_shapes.num_units,
model_shapes.input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
+ rnn_mode(), data_type, algo_config, dropout(), seed(),
+ dropout_state_allocator);
OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
}
launch_status =
stream
- ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data,
- *hidden_state_desc, input_h_data,
- *hidden_state_desc, input_c_data, params_data,
- *output_desc, &output_data, *hidden_state_desc,
- &output_h_data, *hidden_state_desc,
- &output_c_data, is_training_,
- &reserve_space_allocator, &workspace_allocator)
+ ->ThenRnnForward(
+ *rnn_state.rnn_desc, *input_desc, input_data,
+ *hidden_state_desc, input_h_data, *hidden_state_desc,
+ input_c_data, params_data, *output_desc, &output_data,
+ *hidden_state_desc, &output_h_data, *hidden_state_desc,
+ &output_c_data, is_training_, &reserve_space_allocator,
+ &workspace_allocator, /* output_result_profile */ nullptr)
.ok();
}
OP_REQUIRES(context, launch_status,
@@ -1095,25 +1099,27 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
new CudnnRNNPersistentSpaceAllocator(context);
rnn_state.dropout_state_allocator.reset(dropout_state_allocator);
+ const AlgorithmConfig algo_config;
auto rnn_desc_s = executor->createRnnDescriptor(
model_shapes.num_layers, model_shapes.num_units,
model_shapes.input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator);
+ rnn_mode(), data_type, algo_config, dropout(), seed(),
+ dropout_state_allocator);
OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s));
rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie());
}
launch_status =
stream
- ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data,
- *hidden_state_desc, input_h_data,
- *hidden_state_desc, input_c_data, params_data,
- *output_desc, output_data, *hidden_state_desc,
- output_h_data, *hidden_state_desc,
- output_c_data, output_backprop_data,
- output_h_backprop_data, output_c_backprop_data,
- &input_backprop_data, &input_h_backprop_data,
- &input_c_backprop_data, &params_backprop_data,
- &reserve_space_uint8, &workspace_allocator)
+ ->ThenRnnBackward(
+ *rnn_state.rnn_desc, *input_desc, input_data,
+ *hidden_state_desc, input_h_data, *hidden_state_desc,
+ input_c_data, params_data, *output_desc, output_data,
+ *hidden_state_desc, output_h_data, *hidden_state_desc,
+ output_c_data, output_backprop_data, output_h_backprop_data,
+ output_c_backprop_data, &input_backprop_data,
+ &input_h_backprop_data, &input_c_backprop_data,
+ &params_backprop_data, &reserve_space_uint8,
+ &workspace_allocator, /* output_result_profile */ nullptr)
.ok();
}
OP_REQUIRES(context, launch_status,
diff --git a/tensorflow/core/kernels/decode_image_op.cc b/tensorflow/core/kernels/decode_image_op.cc
index 912d04c153..2cafa44f37 100644
--- a/tensorflow/core/kernels/decode_image_op.cc
+++ b/tensorflow/core/kernels/decode_image_op.cc
@@ -41,9 +41,9 @@ enum FileFormat {
// Classify the contents of a file based on starting bytes (the magic number).
FileFormat ClassifyFileFormat(StringPiece data) {
// The 4th byte of JPEG is '\xe0' or '\xe1', so check just the first three
- if (data.starts_with("\xff\xd8\xff")) return kJpgFormat;
- if (data.starts_with("\x89PNG\r\n\x1a\n")) return kPngFormat;
- if (data.starts_with("\x47\x49\x46\x38")) return kGifFormat;
+ if (str_util::StartsWith(data, "\xff\xd8\xff")) return kJpgFormat;
+ if (str_util::StartsWith(data, "\x89PNG\r\n\x1a\n")) return kPngFormat;
+ if (str_util::StartsWith(data, "\x47\x49\x46\x38")) return kGifFormat;
return kUnknownFormat;
}
diff --git a/tensorflow/core/kernels/dynamic_partition_op_test.cc b/tensorflow/core/kernels/dynamic_partition_op_test.cc
index 9a7ed0af21..17eb4e24b7 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_test.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_test.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -153,8 +154,8 @@ TEST_F(DynamicPartitionOpTest, Error_IndexOutOfRange) {
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
AddInputFromArray<int32>(TensorShape({5}), {0, 2, 99, 2, 2});
Status s = RunOpKernel();
- EXPECT_TRUE(
- StringPiece(s.ToString()).contains("partitions[2] = 99 is not in [0, 4)"))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(),
+ "partitions[2] = 99 is not in [0, 4)"))
<< s;
}
diff --git a/tensorflow/core/kernels/dynamic_stitch_op_test.cc b/tensorflow/core/kernels/dynamic_stitch_op_test.cc
index 6775893ce6..7fa6e320f5 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op_test.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -88,9 +89,9 @@ TEST_F(DynamicStitchOpTest, Error_IndicesMultiDimensional) {
AddInputFromArray<float>(TensorShape({3}), {0, 40, 70});
AddInputFromArray<float>(TensorShape({5}), {10, 60, 20, 30, 50});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("data[1].shape = [5] does not start with "
- "indices[1].shape = [1,5]"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(),
+ "data[1].shape = [5] does not start with indices[1].shape = [1,5]"))
<< s;
}
@@ -103,9 +104,9 @@ TEST_F(DynamicStitchOpTest, Error_DataNumDimsMismatch) {
AddInputFromArray<float>(TensorShape({3}), {0, 40, 70});
AddInputFromArray<float>(TensorShape({1, 5}), {10, 60, 20, 30, 50});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("data[1].shape = [1,5] does not start with "
- "indices[1].shape = [5]"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(),
+ "data[1].shape = [1,5] does not start with indices[1].shape = [5]"))
<< s;
}
@@ -119,9 +120,10 @@ TEST_F(DynamicStitchOpTest, Error_DataDimSizeMismatch) {
AddInputFromArray<float>(TensorShape({4, 2}),
{10, 11, 60, 61, 20, 21, 30, 31});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Need data[0].shape[1:] = data[1].shape[1:], "
- "got data[0].shape = [3,1], data[1].shape = [4,2]"))
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(),
+ "Need data[0].shape[1:] = data[1].shape[1:], got "
+ "data[0].shape = [3,1], data[1].shape = [4,2]"))
<< s;
}
@@ -134,10 +136,9 @@ TEST_F(DynamicStitchOpTest, Error_DataAndIndicesSizeMismatch) {
AddInputFromArray<float>(TensorShape({3}), {0, 40, 70});
AddInputFromArray<float>(TensorShape({4}), {10, 60, 20, 30});
Status s = RunOpKernel();
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains(
- "data[1].shape = [4] does not start with indices[1].shape = [5]"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(),
+ "data[1].shape = [4] does not start with indices[1].shape = [5]"))
<< s;
}
diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc
index 3edcb34bca..0409cadb67 100644
--- a/tensorflow/core/kernels/gather_op_test.cc
+++ b/tensorflow/core/kernels/gather_op_test.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -171,7 +172,7 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) {
AddInputFromArray<int32>(TensorShape({}), {0});
Status s = RunOpKernel();
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("indices[2] = 99 is not in [0, 5)"))
+ str_util::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 5)"))
<< s;
}
diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc
index 67d9217b95..9387fb13bc 100644
--- a/tensorflow/core/kernels/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -147,7 +148,7 @@ TEST_F(NonMaxSuppressionOpTest, TestInconsistentBoxAndScoreShapes) {
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("scores has incompatible shape"))
+ str_util::StrContains(s.ToString(), "scores has incompatible shape"))
<< s;
}
@@ -160,7 +161,7 @@ TEST_F(NonMaxSuppressionOpTest, TestInvalidIOUThreshold) {
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("iou_threshold must be in [0, 1]"))
+ str_util::StrContains(s.ToString(), "iou_threshold must be in [0, 1]"))
<< s;
}
@@ -308,7 +309,7 @@ TEST_F(NonMaxSuppressionV2OpTest, TestInconsistentBoxAndScoreShapes) {
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("scores has incompatible shape"))
+ str_util::StrContains(s.ToString(), "scores has incompatible shape"))
<< s;
}
@@ -322,7 +323,7 @@ TEST_F(NonMaxSuppressionV2OpTest, TestInvalidIOUThreshold) {
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("iou_threshold must be in [0, 1]"))
+ str_util::StrContains(s.ToString(), "iou_threshold must be in [0, 1]"))
<< s;
}
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
index 5ffcc7d65d..e41df12d91 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
@@ -379,8 +380,8 @@ TEST_F(QuantizeAndDequantizeTest, Invalid_range_given) {
AddInputFromArray<float>(TensorShape({}), {0.0}); // Max
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Invalid range: input_min 1 > input_max 0"))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(),
+ "Invalid range: input_min 1 > input_max 0"))
<< s;
}
@@ -401,8 +402,8 @@ TEST_F(QuantizeAndDequantizeTest, Invalid_range_given_V3) {
AddInputFromArray<int32>(TensorShape({}), {8}); // num_bits
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Invalid range: input_min 1 > input_max 0"))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(),
+ "Invalid range: input_min 1 > input_max 0"))
<< s;
}
diff --git a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc
index d5b37b1ce1..9217c25978 100644
--- a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@@ -181,7 +182,7 @@ class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
int cluster_count = 0;
for (const NodeDef& node_def : output_graph_def_.node()) {
const string& name = node_def.name();
- if (StringPiece(name).starts_with(REMOTE_FUSED_GRAPH_NODE_NAME)) {
+ if (str_util::StartsWith(name, REMOTE_FUSED_GRAPH_NODE_NAME)) {
++cluster_count;
RemoteFusedGraphExecuteInfo info;
string serialized_proto;
diff --git a/tensorflow/core/kernels/resize_bicubic_op_test.cc b/tensorflow/core/kernels/resize_bicubic_op_test.cc
index 25a37d5e1a..c23570d885 100644
--- a/tensorflow/core/kernels/resize_bicubic_op_test.cc
+++ b/tensorflow/core/kernels/resize_bicubic_op_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -218,9 +219,8 @@ TEST_F(ResizeBicubicOpTest, TestBicubic2x2To0x0) {
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
Status s = RunOpKernel();
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains("Invalid argument: output dimensions must be positive"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Invalid argument: output dimensions must be positive"))
<< s;
}
diff --git a/tensorflow/core/kernels/resize_bilinear_op_test.cc b/tensorflow/core/kernels/resize_bilinear_op_test.cc
index a920e60281..6d57892828 100644
--- a/tensorflow/core/kernels/resize_bilinear_op_test.cc
+++ b/tensorflow/core/kernels/resize_bilinear_op_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -457,9 +458,8 @@ TEST_F(ResizeBilinearOpTest, TestInvalidOutputSize) {
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<int32>(TensorShape({2}), {0, 0});
Status s = RunOpKernel();
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains("Invalid argument: output dimensions must be positive"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Invalid argument: output dimensions must be positive"))
<< s;
}
@@ -467,8 +467,8 @@ TEST_F(ResizeBilinearOpTest, TestInvalidInputShape) {
AddInputFromArray<float>(TensorShape({2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<int32>(TensorShape({2}), {4, 4});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Invalid argument: input must be 4-dimensional"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Invalid argument: input must be 4-dimensional"))
<< s;
}
@@ -476,8 +476,8 @@ TEST_F(ResizeBilinearOpTest, TestInvalidSizeDim) {
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<int32>(TensorShape({2, 1}), {4, 4});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Invalid argument: shape_t must be 1-dimensional"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Invalid argument: shape_t must be 1-dimensional"))
<< s;
}
@@ -485,8 +485,8 @@ TEST_F(ResizeBilinearOpTest, TestInvalidSizeElements) {
AddInputFromArray<float>(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4});
AddInputFromArray<int32>(TensorShape({3}), {4, 4, 1});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Invalid argument: shape_t must have two elements"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Invalid argument: shape_t must have two elements"))
<< s;
}
diff --git a/tensorflow/core/kernels/roll_op_test.cc b/tensorflow/core/kernels/roll_op_test.cc
index 90b6f8d0f3..e431226aa6 100644
--- a/tensorflow/core/kernels/roll_op_test.cc
+++ b/tensorflow/core/kernels/roll_op_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -372,7 +373,8 @@ TEST_F(RollOpTest, Error_InputMustBeVectorOrHigher) {
AddInputFromArray<int32>(TensorShape({}), {1});
AddInputFromArray<int32>(TensorShape({}), {0});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString()).contains("input must be 1-D or higher"))
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(), "input must be 1-D or higher"))
<< s;
}
@@ -384,8 +386,8 @@ TEST_F(RollOpTest, Error_AxisMustBeScalarOrVector) {
AddInputFromArray<int32>(TensorShape({}), {1});
AddInputFromArray<int32>(TensorShape({1, 2}), {0, 1});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("axis must be a scalar or a 1-D vector"))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(),
+ "axis must be a scalar or a 1-D vector"))
<< s;
}
@@ -397,8 +399,8 @@ TEST_F(RollOpTest, Error_ShiftMustBeScalarOrVector) {
AddInputFromArray<int32>(TensorShape({1, 2}), {0, 1});
AddInputFromArray<int32>(TensorShape({}), {1});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("shift must be a scalar or a 1-D vector"))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(),
+ "shift must be a scalar or a 1-D vector"))
<< s;
}
@@ -410,8 +412,8 @@ TEST_F(RollOpTest, Error_ShiftAndAxisMustBeSameSize) {
AddInputFromArray<int32>(TensorShape({1}), {1});
AddInputFromArray<int32>(TensorShape({2}), {0, 1});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("shift and axis must have the same size"))
+ EXPECT_TRUE(str_util::StrContains(s.ToString(),
+ "shift and axis must have the same size"))
<< s;
}
@@ -423,7 +425,7 @@ TEST_F(RollOpTest, Error_AxisOutOfRange) {
AddInputFromArray<int32>(TensorShape({}), {1});
AddInputFromArray<int32>(TensorShape({}), {1});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString()).contains("is out of range")) << s;
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "is out of range")) << s;
}
// isd - (inner shift dimension) The inner most dimension to be shifted.
diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc
index ae81efa31d..c134a8dd5b 100644
--- a/tensorflow/core/kernels/scatter_nd_op_test.cc
+++ b/tensorflow/core/kernels/scatter_nd_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -183,9 +184,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) {
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains("Invalid indices: [2,0] = [99] does not index into [5,3]"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Invalid indices: [2,0] = [99] does not index into [5,3]"))
<< s;
}
@@ -198,10 +198,10 @@ TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("The outermost dimension of updates and indices "
- "must match. Got indices.shape [1,3,1], "
- "updates.shape [3,3]"))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(),
+ "The outermost dimension of updates and indices must match. Got "
+ "indices.shape [1,3,1], updates.shape [3,3]"))
<< s;
}
@@ -216,10 +216,8 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
TensorShape({3, 4}),
{100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
Status s = RunOpKernel();
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains("Must have updates.shape = indices.shape[:batch_dim]"))
-
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(), "Must have updates.shape = indices.shape[:batch_dim]"))
<< s;
}
@@ -233,10 +231,9 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
AddInputFromArray<float>(TensorShape({2, 3}),
{100, 101, 102, 10000, 10001, 10002});
Status s = RunOpKernel();
- EXPECT_TRUE(
- StringPiece(s.ToString())
- .contains(
- "The outermost dimension of updates and indices must match."))
+ EXPECT_TRUE(str_util::StrContains(
+ s.ToString(),
+ "The outermost dimension of updates and indices must match."))
<< s;
}
diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc
index 5b3537b94c..2ec8c42233 100644
--- a/tensorflow/core/kernels/scatter_op_test.cc
+++ b/tensorflow/core/kernels/scatter_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -170,7 +171,7 @@ TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("indices[2] = 99 is not in [0, 5)"))
+ str_util::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 5)"))
<< s;
}
@@ -183,8 +184,9 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
AddInputFromArray<float>(TensorShape({3, 3}),
{100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Must have updates.shape = indices.shape + "
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(),
+ "Must have updates.shape = indices.shape + "
"params.shape[1:] or updates.shape = [], got "))
<< s;
}
@@ -200,8 +202,9 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
TensorShape({3, 4}),
{100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Must have updates.shape = indices.shape + "
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(),
+ "Must have updates.shape = indices.shape + "
"params.shape[1:] or updates.shape = [], got "))
<< s;
@@ -217,8 +220,9 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
AddInputFromArray<float>(TensorShape({2, 3}),
{100, 101, 102, 10000, 10001, 10002});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString())
- .contains("Must have updates.shape = indices.shape + "
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(),
+ "Must have updates.shape = indices.shape + "
"params.shape[1:] or updates.shape = [], got "))
<< s;
}
diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc
index 5a389a6548..623de2a482 100644
--- a/tensorflow/core/kernels/sdca_internal.cc
+++ b/tensorflow/core/kernels/sdca_internal.cc
@@ -302,6 +302,11 @@ Status Examples::SampleAdaptiveProbabilities(
return Status::OK();
}
+void Examples::RandomShuffle() {
+ std::iota(sampled_index_.begin(), sampled_index_.end(), 0);
+ std::random_shuffle(sampled_index_.begin(), sampled_index_.end());
+}
+
// TODO(sibyl-Aix6ihai): Refactor/shorten this function.
Status Examples::Initialize(OpKernelContext* const context,
const ModelWeights& weights,
diff --git a/tensorflow/core/kernels/sdca_internal.h b/tensorflow/core/kernels/sdca_internal.h
index 1665b1210e..bfdb3febdc 100644
--- a/tensorflow/core/kernels/sdca_internal.h
+++ b/tensorflow/core/kernels/sdca_internal.h
@@ -322,10 +322,7 @@ class Examples {
return examples_.at(example_index);
}
- int sampled_index(const int id, const bool adaptive) const {
- if (adaptive) return sampled_index_[id];
- return id;
- }
+ int sampled_index(const int id) const { return sampled_index_[id]; }
// Adaptive SDCA in the current implementation only works for
// binary classification, where the input argument for num_weight_vectors
@@ -337,6 +334,8 @@ class Examples {
const std::unique_ptr<DualLossUpdater>& loss_updater,
const int num_weight_vectors);
+ void RandomShuffle();
+
int num_examples() const { return examples_.size(); }
int num_features() const { return num_features_; }
diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc
index 5b63057f3f..55e68b348b 100644
--- a/tensorflow/core/kernels/sdca_ops.cc
+++ b/tensorflow/core/kernels/sdca_ops.cc
@@ -153,8 +153,9 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
options.num_loss_partitions, options.regularizations,
model_weights, example_state_data, options.loss_updater,
/*num_weight_vectors =*/1));
+ } else {
+ examples.RandomShuffle();
}
-
mutex mu;
Status train_step_status GUARDED_BY(mu);
std::atomic<std::int64_t> atomic_index(-1);
@@ -162,8 +163,7 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
// The static_cast here is safe since begin and end can be at most
// num_examples which is an int.
for (int id = static_cast<int>(begin); id < end; ++id) {
- const int64 example_index =
- examples.sampled_index(++atomic_index, options.adaptive);
+ const int64 example_index = examples.sampled_index(++atomic_index);
const Example& example = examples.example(example_index);
const float dual = example_state_data(example_index, 0);
const float example_weight = example.example_weight();
diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc
index a545fb146c..9cd590ae61 100644
--- a/tensorflow/core/kernels/shape_op_test.cc
+++ b/tensorflow/core/kernels/shape_op_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -62,8 +63,8 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE");
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE",
GetShapeFromKnownVecSize);
-static void ExpectHasError(const Status& s, const string& substr) {
- EXPECT_TRUE(StringPiece(s.ToString()).contains(substr))
+static void ExpectHasError(const Status& s, StringPiece substr) {
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
<< ">>" << s << "<<, expected substring >>" << substr << "<<";
}
diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc
index e1712ac239..e72608945b 100644
--- a/tensorflow/core/kernels/softmax_op.cc
+++ b/tensorflow/core/kernels/softmax_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc.
+#include "tensorflow/core/lib/strings/str_util.h"
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -55,7 +56,7 @@ template <typename Device, typename T>
class SoftmaxOp : public OpKernel {
public:
explicit SoftmaxOp(OpKernelConstruction* context) : OpKernel(context) {
- log_ = StringPiece(type_string()).starts_with("Log");
+ log_ = str_util::StartsWith(type_string(), "Log");
}
void Compute(OpKernelContext* context) override {
diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
index 130d693dbd..b63dcbb163 100644
--- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/lib/strings/str_util.h"
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -128,7 +129,7 @@ template <typename T>
class SoftmaxOpGPU : public OpKernel {
public:
explicit SoftmaxOpGPU(OpKernelConstruction* context) : OpKernel(context) {
- log_ = StringPiece(type_string()).starts_with("Log");
+ log_ = str_util::StartsWith(type_string(), "Log");
}
void Compute(OpKernelContext* context) override {
diff --git a/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc b/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc
index fe198af7e6..29577ebb4e 100644
--- a/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc
+++ b/tensorflow/core/kernels/sparse_dense_binary_op_shared_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -32,7 +33,7 @@ namespace tensorflow {
namespace {
static void ExpectHasSubstr(StringPiece s, StringPiece expected) {
- EXPECT_TRUE(StringPiece(s).contains(expected))
+ EXPECT_TRUE(str_util::StrContains(s, expected))
<< "'" << s << "' does not contain '" << expected << "'";
}
diff --git a/tensorflow/core/kernels/summary_op_test.cc b/tensorflow/core/kernels/summary_op_test.cc
index 3c46abb8ab..9dcabcc584 100644
--- a/tensorflow/core/kernels/summary_op_test.cc
+++ b/tensorflow/core/kernels/summary_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -122,7 +123,7 @@ TEST_F(SummaryScalarOpTest, Error_MismatchedSize) {
AddInputFromArray<string>(TensorShape({2}), {"tag1", "tag2"});
AddInputFromArray<float>(TensorShape({3}), {1.0f, -0.73f, 10000.0f});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString()).contains("not the same shape")) << s;
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "not the same shape")) << s;
}
TEST_F(SummaryScalarOpTest, Error_WrongDimsTags) {
@@ -133,7 +134,7 @@ TEST_F(SummaryScalarOpTest, Error_WrongDimsTags) {
AddInputFromArray<float>(TensorShape({2}), {1.0f, -0.73f});
Status s = RunOpKernel();
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("tags and values not the same shape"))
+ str_util::StrContains(s.ToString(), "tags and values not the same shape"))
<< s;
}
@@ -145,7 +146,7 @@ TEST_F(SummaryScalarOpTest, Error_WrongDimsValues) {
AddInputFromArray<float>(TensorShape({2, 1}), {1.0f, -0.73f});
Status s = RunOpKernel();
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("tags and values not the same shape"))
+ str_util::StrContains(s.ToString(), "tags and values not the same shape"))
<< s;
}
@@ -256,7 +257,7 @@ TEST_F(SummaryHistoOpTest, Error_WrongDimsTags) {
AddInputFromArray<string>(TensorShape({2, 1}), {"tag1", "tag2"});
AddInputFromArray<float>(TensorShape({2}), {1.0f, -0.73f});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString()).contains("tags must be scalar")) << s;
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "tags must be scalar")) << s;
}
TEST_F(SummaryHistoOpTest, Error_TooManyTagValues) {
@@ -266,7 +267,7 @@ TEST_F(SummaryHistoOpTest, Error_TooManyTagValues) {
AddInputFromArray<string>(TensorShape({2}), {"tag1", "tag2"});
AddInputFromArray<float>(TensorShape({2, 1}), {1.0f, -0.73f});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString()).contains("tags must be scalar")) << s;
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "tags must be scalar")) << s;
}
// --------------------------------------------------------------------------
@@ -365,7 +366,7 @@ TEST_F(SummaryMergeOpTest, Error_MismatchedSize) {
AddInputFromArray<string>(TensorShape({2}),
{s1.SerializeAsString(), s2.SerializeAsString()});
Status s = RunOpKernel();
- EXPECT_TRUE(StringPiece(s.ToString()).contains("Duplicate tag")) << s;
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "Duplicate tag")) << s;
}
} // namespace
diff --git a/tensorflow/core/platform/macros.h b/tensorflow/core/platform/macros.h
index 6119edfd5a..1b1faed703 100644
--- a/tensorflow/core/platform/macros.h
+++ b/tensorflow/core/platform/macros.h
@@ -67,11 +67,18 @@ limitations under the License.
#define TF_EXPORT __attribute__((visibility("default")))
#endif // COMPILER_MSVC
-// GCC can be told that a certain branch is not likely to be taken (for
-// instance, a CHECK failure), and use that information in static analysis.
-// Giving it this information can help it optimize for the common case in
-// the absence of better information (ie. -fprofile-arcs).
-#if defined(COMPILER_GCC3)
+#ifdef __has_builtin
+#define TF_HAS_BUILTIN(x) __has_builtin(x)
+#else
+#define TF_HAS_BUILTIN(x) 0
+#endif
+
+// Compilers can be told that a certain branch is not likely to be taken
+// (for instance, a CHECK failure), and use that information in static
+// analysis. Giving it this information can help it optimize for the
+// common case in the absence of better information (ie.
+// -fprofile-arcs).
+#if TF_HAS_BUILTIN(__builtin_expect) || (defined(__GNUC__) && __GNUC__ >= 3)
#define TF_PREDICT_FALSE(x) (__builtin_expect(x, 0))
#define TF_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
#else
diff --git a/tensorflow/core/profiler/g3doc/profile_model_architecture.md b/tensorflow/core/profiler/g3doc/profile_model_architecture.md
index 61bb66bd21..4ccd43ce68 100644
--- a/tensorflow/core/profiler/g3doc/profile_model_architecture.md
+++ b/tensorflow/core/profiler/g3doc/profile_model_architecture.md
@@ -45,22 +45,22 @@ sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
For an operation to have float operation statistics:
-* It must have `RegisterStatistics('flops')` defined in TensorFlow. tfprof
-use the definition to calculate float operations. Contributes are welcome.
-
-* It must have known "shape" information for RegisterStatistics('flops')
-to calculate the statistics. It is suggested to pass in `-run_meta_path` if
-shape is only known during runtime. tfprof can fill in the missing shape with
-the runtime shape information from RunMetadata.
-Hence, it is suggested to use `-account_displayed_op_only`
-option so that you know the statistics are only for the operations printed out.
-
-* If no RunMetadata provided, tfprof count float_ops of each graph node once,
-even if it is defined in tf.while_loop. This is because tfprof doesn't know
-how many times are run statically. If RunMetadata provided, tfprof calculate
-float_ops as float_ops * run_count.
-
-
+* It must have `RegisterStatistics('flops')` defined in TensorFlow. tfprof
+ uses the definition to calculate float operations. Contributions are
+ welcomed.
+
+* It must have known "shape" information for RegisterStatistics('flops') to
+ calculate the statistics. It is suggested to pass in `-run_meta_path` if
+ shape is only known during runtime. tfprof can fill in the missing shape
+ with the runtime shape information from RunMetadata. Hence, it is suggested
+ to use `-account_displayed_op_only` option so that you know the statistics
+ are only for the operations printed out.
+
+* If no RunMetadata is provided, tfprof counts float_ops of each graph node
+ once, even if it is defined in a tf.while_loop. This is because tfprof
+ doesn't know statically how many times each graph node is run. If
+ RunMetadata is provided, tfprof calculates float_ops as float_ops *
+ run_count.
```python
# To profile float opertions in commandline, you need to pass --graph_path
diff --git a/tensorflow/docs_src/mobile/tflite/devguide.md b/tensorflow/docs_src/mobile/tflite/devguide.md
index 96392a3c9b..4133bc172a 100644
--- a/tensorflow/docs_src/mobile/tflite/devguide.md
+++ b/tensorflow/docs_src/mobile/tflite/devguide.md
@@ -190,7 +190,7 @@ graph visualization.
## 3. Use the TensorFlow Lite model for inference in a mobile app
-After completing the prior steps, you should now have a .tflite model file.
+After completing the prior steps, you should now have a `.tflite` model file.
### Android
@@ -222,3 +222,10 @@ trained Tensorflow models to the
[CoreML](https://developer.apple.com/machine-learning/) format for use on Apple
devices. To use the converter, refer to the
[Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml).
+
+### Raspberry Pi
+
+Compile Tensorflow Lite for a Raspberry Pi by following the
+[RPi build instructions](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/rpi.md)
+This compiles a static library file (`.a`) used to build your app. There are
+plans for Python bindings and a demo app.
diff --git a/tensorflow/docs_src/performance/leftnav_files b/tensorflow/docs_src/performance/leftnav_files
index d11a7e5d07..1f894c39fe 100644
--- a/tensorflow/docs_src/performance/leftnav_files
+++ b/tensorflow/docs_src/performance/leftnav_files
@@ -1,3 +1,4 @@
+index.md
performance_guide.md
datasets_performance.md
performance_models.md
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 99a71206ac..fcc191250f 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -870,15 +870,16 @@ def run_final_eval(sess, model_info, class_count, image_lists, jpeg_data_tensor,
resized_image_tensor: The input node of the recognition graph.
bottleneck_tensor: The bottleneck output layer of the CNN graph.
"""
- (sess, bottleneck_input, ground_truth_input, evaluation_step,
- prediction) = build_eval_session(model_info, class_count)
-
test_bottlenecks, test_ground_truth, test_filenames = (
get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size,
'testing', FLAGS.bottleneck_dir,
FLAGS.image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor,
bottleneck_tensor, FLAGS.architecture))
+
+ (sess, bottleneck_input, ground_truth_input, evaluation_step,
+ prediction) = build_eval_session(model_info, class_count)
+
test_accuracy, predictions = sess.run(
[evaluation_step, prediction],
feed_dict={
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index b48d758e4a..b6481e7e29 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -629,15 +629,6 @@ void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
}
-std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
- TF_Output output,
- int num_dims,
- TF_Status* status) {
- std::vector<int64_t> dims(num_dims);
- TF_GraphGetTensorShape(graph, output, dims.data(), num_dims, status);
- return dims;
-}
-
std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
TF_ImportGraphDefResults* results) {
int num_missing_unused_input_mappings;
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index d2b4abc476..cfd27c2bee 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -229,13 +229,6 @@ void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
const std::vector<int64_t>& dims,
bool unknown_shape, TF_Status* status);
-// Return the shape of output. `num_dims` should be the output of
-// TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called.
-std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
- TF_Output output,
- int num_dims,
- TF_Status* status);
-
// Returns the string representations of the missing unused input mappings.
std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
TF_ImportGraphDefResults* results);
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 7ad37058fd..3aad4a114a 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -217,10 +217,11 @@ class MicroBenchmarks(test.Benchmark):
self._run(f, 30000)
def benchmark_tf_gradient_function_identity(self):
- m = self._m_2
- self._run(
- lambda: backprop.gradients_function(gen_array_ops.identity, [0])(m),
- 30000)
+ with context.device(CPU):
+ m = gen_array_ops.identity(self._m_2)
+ self._run(
+ lambda: backprop.gradients_function(gen_array_ops.identity, [0])(m),
+ 30000)
def benchmark_tf_gradient_forward_identity(self):
with backprop.GradientTape() as tape:
@@ -236,10 +237,11 @@ class MicroBenchmarks(test.Benchmark):
self._run(f, 30000)
def benchmark_tf_gradient_function_no_op(self):
- m = self._m_2
- self._run(
- lambda: backprop.gradients_function(lambda x: x, [0])(m),
- 30000)
+ with context.device(CPU):
+ m = gen_array_ops.identity(self._m_2)
+ self._run(
+ lambda: backprop.gradients_function(lambda x: x, [0])(m),
+ 30000)
def _benchmark_np_matmul(self, m, transpose_b, num_iters):
a = m.cpu().numpy()
@@ -271,11 +273,12 @@ class MicroBenchmarks(test.Benchmark):
# pylint: disable=protected-access
ctx_handle = context.context()._handle
# pylint: enable=protected-access
+ device = context.context().device_name
attrs = ("transpose_a", False, "transpose_b", transpose_b, "T",
m.dtype.as_datatype_enum)
def func():
- pywrap_tensorflow.TFE_Py_Execute(ctx_handle, None, "MatMul", inputs,
- attrs, 1)
+ pywrap_tensorflow.TFE_Py_Execute(ctx_handle, device, "MatMul",
+ inputs, attrs, 1)
self._run(func, num_iters)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 711eddcec1..61859d6be3 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -294,7 +294,7 @@ class _EagerDefinedFunction(object):
self.signature = function_def.signature
self.grad_func_name = None
self.python_grad_func = None
- self._c_func = fn
+ self._c_func = c_api_util.ScopedTFFunction(fn)
self._grad_func = None
@@ -661,7 +661,7 @@ def _defun_internal(name, func, args, kwds):
if context.executing_eagerly():
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
- _register(f._c_func) # pylint: disable=protected-access
+ _register(f._c_func.func) # pylint: disable=protected-access
return GraphModeFunction(
fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
func_outputs, output_shapes, variables)
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index ee5d87f083..d40ea982c7 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -325,7 +325,7 @@ def _graph_callable_internal(func, shape_and_dtypes):
# Also, what about the gradient registry of these functions? Those need to be
# addressed as well.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
- function._register(f._c_func) # pylint: disable=protected-access
+ function._register(f._c_func.func) # pylint: disable=protected-access
initializer_function = function.GraphModeFunction(
initialization_name,
placeholder_inputs,
diff --git a/tensorflow/python/framework/c_api_util.py b/tensorflow/python/framework/c_api_util.py
index 4356a534b4..7bbe3183df 100644
--- a/tensorflow/python/framework/c_api_util.py
+++ b/tensorflow/python/framework/c_api_util.py
@@ -63,6 +63,32 @@ class ScopedTFImportGraphDefOptions(object):
c_api.TF_DeleteImportGraphDefOptions(self.options)
+class ScopedTFImportGraphDefResults(object):
+ """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
+
+ def __init__(self, results):
+ self.results = results
+
+ def __del__(self):
+ # Note: when we're destructing the global context (i.e when the process is
+ # terminating) we can have already deleted other modules.
+ if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
+ c_api.TF_DeleteImportGraphDefResults(self.results)
+
+
+class ScopedTFFunction(object):
+ """Wrapper around TF_Function that handles deletion."""
+
+ def __init__(self, func):
+ self.func = func
+
+ def __del__(self):
+ # Note: when we're destructing the global context (i.e when the process is
+ # terminating) we can have already deleted other modules.
+ if c_api is not None and c_api.TF_DeleteFunction is not None:
+ c_api.TF_DeleteFunction(self.func)
+
+
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index c5caf9ebc0..9570f009a5 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -274,7 +274,7 @@ class _DefinedFunction(object):
self._create_definition_if_needed()
if self._c_func:
with c_api_util.tf_buffer() as buf:
- c_api.TF_FunctionToFunctionDef(self._c_func, buf)
+ c_api.TF_FunctionToFunctionDef(self._c_func.func, buf)
fdef = function_pb2.FunctionDef()
proto_data = c_api.TF_GetBuffer(buf)
fdef.ParseFromString(compat.as_bytes(proto_data))
@@ -397,7 +397,7 @@ class _DefinedFunction(object):
if self._out_names else [])
description = self._func.__doc__ or None
# pylint: disable=protected-access
- self._c_func = c_api.TF_GraphToFunction_wrapper(
+ c_func = c_api.TF_GraphToFunction_wrapper(
temp_graph._c_graph,
base_func_name,
self._func_name is None, # append_hash_to_fn_name
@@ -407,6 +407,7 @@ class _DefinedFunction(object):
output_names,
None, # opts
description)
+ self._c_func = c_api_util.ScopedTFFunction(c_func)
# pylint: enable=protected-access
self._set_c_attrs(kwargs_attr)
@@ -429,7 +430,7 @@ class _DefinedFunction(object):
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
- c_api.TF_FunctionSetAttrValueProto(self._c_func, compat.as_str(name),
+ c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
serialized)
def _create_hash_str(self, input_arg, output_arg, node_def):
@@ -825,7 +826,8 @@ def _from_definition(fdef, grad_func=None):
# pylint: disable=protected-access
if ops._USE_C_API:
serialized = fdef.SerializeToString()
- result._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ result._c_func = c_api_util.ScopedTFFunction(c_func)
result._extra_inputs = []
else:
result._definition = fdef
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 83d256fab6..c05396b06e 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -58,12 +58,32 @@ def _OptimizerOptions():
for cse in [False, True]:
for inline in [False, True]:
for cfold in [False, True]:
- yield config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
- optimizer_options=config_pb2.OptimizerOptions(
- opt_level=config_pb2.OptimizerOptions.L0,
- do_common_subexpression_elimination=cse,
- do_function_inlining=inline,
- do_constant_folding=cfold)))
+ cfg = config_pb2.ConfigProto(
+ graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L0,
+ do_common_subexpression_elimination=cse,
+ do_function_inlining=inline,
+ do_constant_folding=cfold)))
+ if cse:
+ cfg.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.ON)
+ else:
+ cfg.graph_options.rewrite_options.arithmetic_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ if inline:
+ cfg.graph_options.rewrite_options.function_optimization = (
+ rewriter_config_pb2.RewriterConfig.ON)
+ else:
+ cfg.graph_options.rewrite_options.function_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ if cfold:
+ cfg.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.ON)
+ else:
+ cfg.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ yield cfg
@test_util.with_c_api
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 23f529b988..3f8a8c4bef 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -487,6 +487,7 @@ def import_graph_def(graph_def,
try:
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options) # pylint: disable=protected-access
+ results = c_api_util.ScopedTFImportGraphDefResults(results)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
@@ -515,7 +516,7 @@ def import_graph_def(graph_def,
# they are likely to be due to a typo.
missing_unused_input_keys = (
c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
- results))
+ results.results))
if missing_unused_input_keys:
missing_unused_input_keys = [
compat.as_str(s) for s in missing_unused_input_keys
@@ -527,7 +528,7 @@ def import_graph_def(graph_def,
if return_elements is None:
return None
else:
- return _GatherReturnElements(return_elements, graph, results)
+ return _GatherReturnElements(return_elements, graph, results.results)
else:
g = graph
@@ -684,11 +685,10 @@ def import_graph_def(graph_def,
', '.join(x.name for x in op._input_types))))
# pylint: enable=protected-access
- if not g._is_function(op.type): # pylint: disable=protected-access
- # Execute shape inference for this op.
- # NOTE(mrry): If the graph contains a cycle, the full shape
- # information may not be available for this op's inputs.
- ops.set_shapes_for_outputs(op)
+ # Execute shape inference for this op.
+ # NOTE(mrry): If the graph contains a cycle, the full shape
+ # information may not be available for this op's inputs.
+ ops.set_shape_and_handle_data_for_outputs(op)
# For nodes with _output_shapes set, set the output shapes.
if '_output_shapes' in op.node_def.attr:
for i, output in enumerate(op.outputs):
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 2d55f98a1c..2574fa57a4 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -289,15 +289,26 @@ class Tensor(_TensorLike):
self._op = op
self._value_index = value_index
self._dtype = dtypes.as_dtype(dtype)
- self._shape_val = tensor_shape.unknown_shape()
+
+ if _USE_C_API:
+ # This will be set by set_shape_and_handle_data_for_outputs.
+ self._shape_val = None
+ else:
+ # The Python code requires all tensors start with a shape to support shape
+ # inference on imported while loops. This isn't necessary with the C API
+ # enabled because the C API provides the shapes for imported nodes.
+ # TODO(skyewm): remove when _USE_C_API is removed.
+ self._shape_val = tensor_shape.unknown_shape()
+
# List of operations that use this Tensor as input. We maintain this list
# to easily navigate a computation graph.
self._consumers = []
- # Attributes used for C++ shape inference. Not inspected, only forwarded.
- # If set, will be a HandleData object from cpp_shape_inference.proto.
- # TODO(b/74620627): remove when _USE_C_SHAPES is removed
- self._handle_data = None
+ if not _USE_C_SHAPES:
+ # Attributes used for C++ shape inference. Not inspected, only forwarded.
+ # If set, will be a HandleData object from cpp_shape_inference.proto.
+ self._handle_data = None
+
self._id = uid()
@property
@@ -371,18 +382,45 @@ class Tensor(_TensorLike):
A `TensorShape` representing the shape of this tensor.
"""
- graph = self._op._graph._c_graph # pylint: disable=protected-access
- if graph and _USE_C_SHAPES:
- num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output())
- if num_dims == -1:
- dim_list = None
+ if self._shape_val is None:
+ if _USE_C_SHAPES:
+ self._shape_val = self._c_api_shape()
else:
- dim_list = c_api.TF_GraphGetTensorShape_wrapper(
- graph, self._as_tf_output(), num_dims)
- dim_list = [None if i == -1 else i for i in dim_list]
- return tensor_shape.TensorShape(dim_list)
+ assert _USE_C_API
+ # Call set_shape_and_handle_data_for_outputs in topological order on all
+ # ops that are needed to compute self.op's shape. We do this instead of
+ # having set_shape_and_handle_data_for_outputs recursively call
+ # Operation.shape on self.op.inputs to overflowing the call stack.
+ need_shapes = self._get_input_ops_without_shapes(self.op)
+ need_shapes.sort(key=lambda op: op._id)
+ for op in need_shapes:
+ set_shape_and_handle_data_for_outputs(op)
return self._shape_val
+ def _get_input_ops_without_shapes(self, target_op):
+ """Returns ops needing shape inference to compute target_op's shape."""
+ result = []
+ stack = [self._op]
+ visited = set()
+ while stack:
+ op = stack.pop()
+ if op in visited: continue
+ result.append(op)
+ stack.extend(t.op for t in op.inputs if t._shape_val is None)
+ visited.add(op)
+ return result
+
+ def _c_api_shape(self):
+ """Returns the TensorShape of this tensor according to the C API."""
+ c_graph = self._op._graph._c_graph # pylint: disable=protected-access
+ shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
+ c_graph, self._as_tf_output())
+ if unknown_shape:
+ return tensor_shape.unknown_shape()
+ else:
+ shape_vector = [None if d == -1 else d for d in shape_vector]
+ return tensor_shape.TensorShape(shape_vector)
+
@property
def _shape(self):
logging.warning("Tensor._shape is private, use Tensor.shape "
@@ -466,8 +504,11 @@ class Tensor(_TensorLike):
ValueError: If `shape` is not compatible with the current shape of
this tensor.
"""
- if not _USE_C_SHAPES: # pylint: disable=protected-access
- self._shape_val = self._shape_val.merge_with(shape)
+ if _USE_C_SHAPES: # pylint: disable=protected-access
+ # Reset cached shape.
+ self._shape_val = None
+ else:
+ self._shape_val = self.shape.merge_with(shape)
if not self._op._graph._c_graph: return
@@ -579,6 +620,16 @@ class Tensor(_TensorLike):
# Necessary to support Python's collection membership operators
return id(self) == id(other)
+ def __copy__(self):
+ # Make sure _shape_val is computed before we copy.
+ # TODO(b/77597810): get rid of Tensor copies.
+ if self._shape_val is None:
+ set_shape_and_handle_data_for_outputs(self.op)
+ cls = self.__class__
+ result = cls.__new__(cls)
+ result.__dict__.update(self.__dict__)
+ return result
+
# NOTE(mrry): This enables the Tensor's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
# accords the Tensor class higher priority than an ndarray, or a
@@ -1932,6 +1983,13 @@ class Operation(object):
if not isinstance(tensor, Tensor):
raise TypeError("tensor must be a Tensor: %s" % tensor)
_assert_same_graph(self, tensor)
+
+ # Make sure output shapes are already computed for this op in case we create
+ # a cycle (we cannot compute shapes for cycles). Usually shapes are computed
+ # lazily upon request.
+ if not _USE_C_SHAPES:
+ set_shape_and_handle_data_for_outputs(self)
+
if self._c_op:
# Reset cached inputs.
self._inputs_val = None
@@ -2474,35 +2532,41 @@ class RegisterShape(object):
return f
-def _set_shapes_for_outputs_c_api(op):
- """set_shapes_for_outputs implementation when C API is enabled."""
- # The C API computes the shapes when the TF_Operation is created. Fetch the
- # output shapes from the C object.
+# TODO(b/74620627): remove when _USE_C_SHAPES is removed
+def _set_shape_and_handle_data_for_outputs_c_api(op):
+ """Set shapes and resource handle data using info from the C API."""
+ assert not _USE_C_SHAPES
for output in op.outputs:
- # pylint: disable=protected-access
- shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
+ output._shape_val = output._c_api_shape()
+ # Set the resource handle data for compatibility with the Python shape
+ # inference code.
+ serialized = c_api.ResourceHandleShapeAndType(
op._graph._c_graph, output._as_tf_output())
- # pylint: enable=protected-access
- if unknown_shape:
- output.set_shape(tensor_shape.unknown_shape())
- elif not shape_vector:
- output.set_shape(tensor_shape.scalar())
- else:
- shape_vector = [None if d == -1 else d for d in shape_vector]
- output.set_shape(tensor_shape.TensorShape(shape_vector))
-
- serialized = c_api.ResourceHandleShapeAndType(op._graph._c_graph,
- output._as_tf_output())
if serialized:
output._handle_data = (
- cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
- compat.as_bytes(serialized)))
+ cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
+ .FromString(compat.as_bytes(serialized)))
else:
output._handle_data = None
-# TODO(skyewm): remove this when _USE_C_API flag is removed.
-def _set_shapes_for_outputs(op):
- """set_shapes_for_outputs implementation when C API is disabled."""
+
+# TODO(b/74620627): remove when _USE_C_SHAPES is removed
+def set_shape_and_handle_data_for_outputs(op):
+ """Set the shapes and resource handle data for op's outputs.
+
+ When _USE_C_API = True, this is lazily called when a tensor's shape is first
+ requested. Usually this should work automatically, but some edge cases may
+ require manaully calling this first to make sure Tensor._shape_val and
+ Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a
+ Tensor).
+ """
+ if _USE_C_SHAPES: return
+
+ if op.graph._is_function(op.type):
+ for output in op.outputs:
+ output._shape_val = tensor_shape.unknown_shape()
+ return
+
try:
shape_func = _shape_registry.lookup(op.type)
except LookupError:
@@ -2521,8 +2585,10 @@ def _set_shapes_for_outputs(op):
shapes = shapes_dict["shapes"]
handle_datas = shapes_dict["handle_data"]
for output, handle_data in zip(op.outputs, handle_datas):
+ # Don't override any existing handle data that may have been manually set.
# pylint: disable=protected-access
- output._handle_data = handle_data
+ if output._handle_data is None:
+ output._handle_data = handle_data
# pylint: enable=protected-access
if len(op.outputs) != len(shapes):
@@ -2530,15 +2596,8 @@ def _set_shapes_for_outputs(op):
"Shape function for op %s returned %d shapes but expected %d %s %s" %
(op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
for output, s in zip(op.outputs, shapes):
- output.set_shape(s)
-
-
-def set_shapes_for_outputs(op):
- """Set the shapes for op's outputs."""
- if op._c_op and _USE_C_SHAPES: # pylint: disable=protected-access
- return _set_shapes_for_outputs_c_api(op)
- else:
- return _set_shapes_for_outputs(op)
+ output._shape_val = tensor_shape.unknown_shape()
+ output._shape_val = output._shape_val.merge_with(s)
class OpStats(object):
@@ -3216,9 +3275,11 @@ class Graph(object):
# as this will be unnecessary.
if not function._c_func:
serialized = function.definition.SerializeToString()
- function._c_func = c_api.TF_FunctionImportFunctionDef(serialized)
- gradient = function._grad_func._c_func if function._grad_func else None
- c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient)
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ function._c_func = c_api_util.ScopedTFFunction(c_func)
+ gradient = (function._grad_func._c_func.func if function._grad_func
+ else None)
+ c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
else:
# If there is already a function with the same name, raise an error
# if bodies are different. Else, do nothing. The C API version above
@@ -3329,18 +3390,14 @@ class Graph(object):
original_op=self._default_original_op,
op_def=op_def)
- # TODO(vrv): Instead of eagerly filling in shape property for every op,
- # only populate the shape when requested.
+ # Note: shapes are lazily computed with the C API enabled.
#
# TODO(skyewm): unlike in the original Python implementation, the C API
# always computes shape information (even for function calls, which the
# original Python shape inference code doesn't handle). Deprecate the
# compute_shapes argument.
- #
- # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
- # is removed
- if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
- set_shapes_for_outputs(ret)
+ if not _USE_C_API and compute_shapes:
+ set_shape_and_handle_data_for_outputs(ret)
self._create_op_helper(ret, compute_shapes=compute_shapes,
compute_device=compute_device)
@@ -3482,18 +3539,17 @@ class Graph(object):
for c_op in c_api_util.new_tf_operations(self)
]
+ # pylint: disable=protected-access
for op in new_ops:
# Operations created by the C API always retrieve shapes from the C API so
# we preserve the shapes of ops created in import_graph_def (from the
# "_output_shapes" attr of the imported NodeDef).
- # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
- # is removed.
- _set_shapes_for_outputs_c_api(op)
+ if not _USE_C_SHAPES:
+ _set_shape_and_handle_data_for_outputs_c_api(op)
new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
- # pylint: disable=protected-access
op._add_control_inputs(new_control_inputs)
op._control_flow_post_processing()
- # pylint: enable=protected-access
+ # pylint: enable=protected-access
return new_ops
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 984bcecdfe..64b0fa6c00 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -22,7 +22,6 @@ import six
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
-from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import compat
@@ -828,7 +827,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
Returns:
A `TensorShape` based on the constant value of the given `tensor`.
"""
- if context.executing_eagerly():
+ if isinstance(tensor, ops.EagerTensor):
return tensor_shape.as_shape(
[dim if dim != -1 else None for dim in tensor.numpy()])
diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
index ad96b53a45..12775fccec 100644
--- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
@@ -84,11 +84,13 @@ from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import DepthwiseConv2D
from tensorflow.python.keras._impl.keras.layers import Dropout
from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.layers import Reshape
+from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
@@ -116,195 +118,6 @@ def preprocess_input(x):
return imagenet_utils.preprocess_input(x, mode='tf')
-class DepthwiseConv2D(Conv2D):
- """Depthwise separable 2D convolution.
-
- Depthwise Separable convolutions consists in performing
- just the first step in a depthwise spatial convolution
- (which acts on each input channel separately).
- The `depth_multiplier` argument controls how many
- output channels are generated per input channel in the depthwise step.
-
- Arguments:
- kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
- Can be a single integer to specify the same value for
- all spatial dimensions.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: one of `'valid'` or `'same'` (case-insensitive).
- depth_multiplier: The number of depthwise convolution output channels
- for each input channel.
- The total number of depthwise convolution output
- channels will be equal to `filters_in * depth_multiplier`.
- data_format: A string,
- one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, height, width, channels)` while `channels_first`
- corresponds to inputs with shape
- `(batch, channels, height, width)`.
- It defaults to the `image_data_format` value found in your
- Keras config file at `~/.keras/keras.json`.
- If you never set it, then it will be 'channels_last'.
- activation: Activation function to use.
- If you don't specify anything, no activation is applied
- (ie. 'linear' activation: `a(x) = x`).
- use_bias: Boolean, whether the layer uses a bias vector.
- depthwise_initializer: Initializer for the depthwise kernel matrix.
- bias_initializer: Initializer for the bias vector.
- depthwise_regularizer: Regularizer function applied to
- the depthwise kernel matrix.
- bias_regularizer: Regularizer function applied to the bias vector.
- activity_regularizer: Regularizer function applied to
- the output of the layer (its 'activation')..
- depthwise_constraint: Constraint function applied to
- the depthwise kernel matrix.
- bias_constraint: Constraint function applied to the bias vector.
-
- Input shape:
- 4D tensor with shape:
- `[batch, channels, rows, cols]` if data_format='channels_first'
- or 4D tensor with shape:
- `[batch, rows, cols, channels]` if data_format='channels_last'.
-
- Output shape:
- 4D tensor with shape:
- `[batch, filters, new_rows, new_cols]` if data_format='channels_first'
- or 4D tensor with shape:
- `[batch, new_rows, new_cols, filters]` if data_format='channels_last'.
- `rows` and `cols` values might have changed due to padding.
- """
-
- def __init__(self,
- kernel_size,
- strides=(1, 1),
- padding='valid',
- depth_multiplier=1,
- data_format=None,
- activation=None,
- use_bias=True,
- depthwise_initializer='glorot_uniform',
- bias_initializer='zeros',
- depthwise_regularizer=None,
- bias_regularizer=None,
- activity_regularizer=None,
- depthwise_constraint=None,
- bias_constraint=None,
- **kwargs):
- super(DepthwiseConv2D, self).__init__(
- filters=None,
- kernel_size=kernel_size,
- strides=strides,
- padding=padding,
- data_format=data_format,
- activation=activation,
- use_bias=use_bias,
- bias_regularizer=bias_regularizer,
- activity_regularizer=activity_regularizer,
- bias_constraint=bias_constraint,
- **kwargs)
- self.depth_multiplier = depth_multiplier
- self.depthwise_initializer = initializers.get(depthwise_initializer)
- self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
- self.depthwise_constraint = constraints.get(depthwise_constraint)
- self.bias_initializer = initializers.get(bias_initializer)
-
- @shape_type_conversion
- def build(self, input_shape):
- if len(input_shape) < 4:
- raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
- 'Received input shape:', str(input_shape))
- if self.data_format == 'channels_first':
- channel_axis = 1
- else:
- channel_axis = 3
- if input_shape[channel_axis] is None:
- raise ValueError('The channel dimension of the inputs to '
- '`DepthwiseConv2D` '
- 'should be defined. Found `None`.')
- input_dim = int(input_shape[channel_axis])
- depthwise_kernel_shape = (self.kernel_size[0], self.kernel_size[1],
- input_dim, self.depth_multiplier)
-
- self.depthwise_kernel = self.add_weight(
- shape=depthwise_kernel_shape,
- initializer=self.depthwise_initializer,
- name='depthwise_kernel',
- regularizer=self.depthwise_regularizer,
- constraint=self.depthwise_constraint)
-
- if self.use_bias:
- self.bias = self.add_weight(
- shape=(input_dim * self.depth_multiplier,),
- initializer=self.bias_initializer,
- name='bias',
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
- else:
- self.bias = None
- # Set input spec.
- self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
- self.built = True
-
- def call(self, inputs, training=None):
- outputs = K.depthwise_conv2d(
- inputs,
- self.depthwise_kernel,
- strides=self.strides,
- padding=self.padding,
- dilation_rate=self.dilation_rate,
- data_format=self.data_format)
-
- if self.bias:
- outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
-
- if self.activation is not None:
- return self.activation(outputs)
-
- return outputs
-
- @shape_type_conversion
- def compute_output_shape(self, input_shape):
- if self.data_format == 'channels_first':
- rows = input_shape[2]
- cols = input_shape[3]
- out_filters = input_shape[1] * self.depth_multiplier
- elif self.data_format == 'channels_last':
- rows = input_shape[1]
- cols = input_shape[2]
- out_filters = input_shape[3] * self.depth_multiplier
-
- rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
- self.padding, self.strides[0])
- cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
- self.padding, self.strides[1])
-
- if self.data_format == 'channels_first':
- return (input_shape[0], out_filters, rows, cols)
- elif self.data_format == 'channels_last':
- return (input_shape[0], rows, cols, out_filters)
-
- def get_config(self):
- config = super(DepthwiseConv2D, self).get_config()
- config.pop('filters')
- config.pop('kernel_initializer')
- config.pop('kernel_regularizer')
- config.pop('kernel_constraint')
- config['depth_multiplier'] = self.depth_multiplier
- config['depthwise_initializer'] = initializers.serialize(
- self.depthwise_initializer)
- config['depthwise_regularizer'] = regularizers.serialize(
- self.depthwise_regularizer)
- config['depthwise_constraint'] = constraints.serialize(
- self.depthwise_constraint)
- return config
-
-
@tf_export('keras.applications.MobileNet',
'keras.applications.mobilenet.MobileNet')
def MobileNet(input_shape=None,
@@ -318,18 +131,11 @@ def MobileNet(input_shape=None,
classes=1000):
"""Instantiates the MobileNet architecture.
- Note that only TensorFlow is supported for now,
- therefore it only works with the data format
- `image_data_format='channels_last'` in your Keras config
- at `~/.keras/keras.json`.
-
To load a MobileNet model via `load_model`, import the custom
- objects `relu6` and `DepthwiseConv2D` and pass them to the
- `custom_objects` parameter.
+ objects `relu6` and pass them to the `custom_objects` parameter.
E.g.
model = load_model('mobilenet.h5', custom_objects={
- 'relu6': mobilenet.relu6,
- 'DepthwiseConv2D': mobilenet.DepthwiseConv2D})
+ 'relu6': mobilenet.relu6})
Arguments:
input_shape: optional shape tuple, only to be specified
@@ -383,11 +189,6 @@ def MobileNet(input_shape=None,
backend that does not support separable convolutions.
"""
- if K.backend() != 'tensorflow':
- raise RuntimeError('Only TensorFlow backend is currently supported, '
- 'as other backends do not support '
- 'depthwise convolution.')
-
if not (weights in {'imagenet', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `imagenet` '
@@ -522,7 +323,7 @@ def MobileNet(input_shape=None,
# load weights
if weights == 'imagenet':
if K.image_data_format() == 'channels_first':
- raise ValueError('Weights for "channels_last" format '
+ raise ValueError('Weights for "channels_first" format '
'are not available.')
if alpha == 1.0:
alpha_text = '1_0'
@@ -598,14 +399,14 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
filters = int(filters * alpha)
+ x = ZeroPadding2D(padding=(1, 1), name='conv1_pad')(inputs)
x = Conv2D(
filters,
kernel,
- padding='same',
+ padding='valid',
use_bias=False,
strides=strides,
- name='conv1')(
- inputs)
+ name='conv1')(x)
x = BatchNormalization(axis=channel_axis, name='conv1_bn')(x)
return Activation(relu6, name='conv1_relu')(x)
@@ -665,15 +466,14 @@ def _depthwise_conv_block(inputs,
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
pointwise_conv_filters = int(pointwise_conv_filters * alpha)
-
+ x = ZeroPadding2D(padding=(1, 1), name='conv_pad_%d' % block_id)(inputs)
x = DepthwiseConv2D( # pylint: disable=not-callable
(3, 3),
- padding='same',
+ padding='valid',
depth_multiplier=depth_multiplier,
strides=strides,
use_bias=False,
- name='conv_dw_%d' % block_id)(
- inputs)
+ name='conv_dw_%d' % block_id)(x)
x = BatchNormalization(axis=channel_axis, name='conv_dw_%d_bn' % block_id)(x)
x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
index 46c0e63557..f8c6aff4f2 100644
--- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
@@ -45,6 +45,7 @@ from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras._impl.keras.layers import Input
from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import ZeroPadding2D
from tensorflow.python.keras._impl.keras.models import Model
from tensorflow.python.keras._impl.keras.utils import layer_utils
from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
@@ -236,9 +237,9 @@ def ResNet50(include_top=True,
else:
bn_axis = 1
+ x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = Conv2D(
- 64, (7, 7), strides=(2, 2), padding='same', name='conv1')(
- img_input)
+ 64, (7, 7), strides=(2, 2), padding='valid', name='conv1')(x)
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
index 162ae6c28f..7cdebc6aa4 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
@@ -27,6 +27,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
# imports for backwards namespace compatibility
# pylint: disable=unused-import
from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D
@@ -1024,6 +1025,200 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer):
return dict(list(base_config.items()) + list(config.items()))
+@tf_export('keras.layers.DepthwiseConv2D')
+class DepthwiseConv2D(Conv2D):
+ """Depthwise separable 2D convolution.
+
+ Depthwise Separable convolutions consists in performing
+ just the first step in a depthwise spatial convolution
+ (which acts on each input channel separately).
+ The `depth_multiplier` argument controls how many
+ output channels are generated per input channel in the depthwise step.
+
+ Arguments:
+ kernel_size: An integer or tuple/list of 2 integers, specifying the
+ width and height of the 2D convolution window.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ strides: An integer or tuple/list of 2 integers,
+ specifying the strides of the convolution along the width and height.
+ Can be a single integer to specify the same value for
+ all spatial dimensions.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: one of `'valid'` or `'same'` (case-insensitive).
+ depth_multiplier: The number of depthwise convolution output channels
+ for each input channel.
+ The total number of depthwise convolution output
+ channels will be equal to `filters_in * depth_multiplier`.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, channels, height, width)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be 'channels_last'.
+ activation: Activation function to use.
+ If you don't specify anything, no activation is applied
+ (ie. 'linear' activation: `a(x) = x`).
+ use_bias: Boolean, whether the layer uses a bias vector.
+ depthwise_initializer: Initializer for the depthwise kernel matrix.
+ bias_initializer: Initializer for the bias vector.
+ depthwise_regularizer: Regularizer function applied to
+ the depthwise kernel matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
+ activity_regularizer: Regularizer function applied to
+ the output of the layer (its 'activation').
+ depthwise_constraint: Constraint function applied to
+ the depthwise kernel matrix.
+ bias_constraint: Constraint function applied to the bias vector.
+
+ Input shape:
+ 4D tensor with shape:
+ `[batch, channels, rows, cols]` if data_format='channels_first'
+ or 4D tensor with shape:
+ `[batch, rows, cols, channels]` if data_format='channels_last'.
+
+ Output shape:
+ 4D tensor with shape:
+ `[batch, filters, new_rows, new_cols]` if data_format='channels_first'
+ or 4D tensor with shape:
+ `[batch, new_rows, new_cols, filters]` if data_format='channels_last'.
+ `rows` and `cols` values might have changed due to padding.
+ """
+
+ def __init__(self,
+ kernel_size,
+ strides=(1, 1),
+ padding='valid',
+ depth_multiplier=1,
+ data_format=None,
+ activation=None,
+ use_bias=True,
+ depthwise_initializer='glorot_uniform',
+ bias_initializer='zeros',
+ depthwise_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ depthwise_constraint=None,
+ bias_constraint=None,
+ **kwargs):
+ super(DepthwiseConv2D, self).__init__(
+ filters=None,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation=activation,
+ use_bias=use_bias,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ bias_constraint=bias_constraint,
+ **kwargs)
+ self.depth_multiplier = depth_multiplier
+ self.depthwise_initializer = initializers.get(depthwise_initializer)
+ self.depthwise_regularizer = regularizers.get(depthwise_regularizer)
+ self.depthwise_constraint = constraints.get(depthwise_constraint)
+ self.bias_initializer = initializers.get(bias_initializer)
+
+ def build(self, input_shape):
+ if len(input_shape) < 4:
+ raise ValueError('Inputs to `DepthwiseConv2D` should have rank 4. '
+ 'Received input shape:', str(input_shape))
+ if self.data_format == 'channels_first':
+ channel_axis = 1
+ else:
+ channel_axis = 3
+ if input_shape[channel_axis] is None:
+ raise ValueError('The channel dimension of the inputs to '
+ '`DepthwiseConv2D` '
+ 'should be defined. Found `None`.')
+ input_dim = int(input_shape[channel_axis])
+ depthwise_kernel_shape = (self.kernel_size[0],
+ self.kernel_size[1],
+ input_dim,
+ self.depth_multiplier)
+
+ self.depthwise_kernel = self.add_weight(
+ shape=depthwise_kernel_shape,
+ initializer=self.depthwise_initializer,
+ name='depthwise_kernel',
+ regularizer=self.depthwise_regularizer,
+ constraint=self.depthwise_constraint)
+
+ if self.use_bias:
+ self.bias = self.add_weight(shape=(input_dim * self.depth_multiplier,),
+ initializer=self.bias_initializer,
+ name='bias',
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint)
+ else:
+ self.bias = None
+ # Set input spec.
+ self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
+ self.built = True
+
+ def call(self, inputs, training=None):
+ outputs = K.depthwise_conv2d(
+ inputs,
+ self.depthwise_kernel,
+ strides=self.strides,
+ padding=self.padding,
+ dilation_rate=self.dilation_rate,
+ data_format=self.data_format)
+
+ if self.bias:
+ outputs = K.bias_add(
+ outputs,
+ self.bias,
+ data_format=self.data_format)
+
+ if self.activation is not None:
+ return self.activation(outputs)
+
+ return outputs
+
+ @shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ if self.data_format == 'channels_first':
+ rows = input_shape[2]
+ cols = input_shape[3]
+ out_filters = input_shape[1] * self.depth_multiplier
+ elif self.data_format == 'channels_last':
+ rows = input_shape[1]
+ cols = input_shape[2]
+ out_filters = input_shape[3] * self.depth_multiplier
+
+ rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
+ self.padding,
+ self.strides[0])
+ cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
+ self.padding,
+ self.strides[1])
+ if self.data_format == 'channels_first':
+ return (input_shape[0], out_filters, rows, cols)
+ elif self.data_format == 'channels_last':
+ return (input_shape[0], rows, cols, out_filters)
+
+ def get_config(self):
+ config = super(DepthwiseConv2D, self).get_config()
+ config.pop('filters')
+ config.pop('kernel_initializer')
+ config.pop('kernel_regularizer')
+ config.pop('kernel_constraint')
+ config['depth_multiplier'] = self.depth_multiplier
+ config['depthwise_initializer'] = initializers.serialize(
+ self.depthwise_initializer)
+ config['depthwise_regularizer'] = regularizers.serialize(
+ self.depthwise_regularizer)
+ config['depthwise_constraint'] = constraints.serialize(
+ self.depthwise_constraint)
+ return config
+
+
@tf_export('keras.layers.UpSampling1D')
class UpSampling1D(Layer):
"""Upsampling layer for 1D inputs.
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
index b78962d66a..6b2a1d98fe 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+# pylint: disable=protected-access
"""Convolutional-recurrent layers.
"""
from __future__ import absolute_import
@@ -26,181 +27,456 @@ from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
-from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent
+from tensorflow.python.keras._impl.keras.layers.recurrent import _generate_dropout_mask
+from tensorflow.python.keras._impl.keras.layers.recurrent import RNN
from tensorflow.python.keras._impl.keras.utils import conv_utils
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.keras._impl.keras.utils import generic_utils
from tensorflow.python.util.tf_export import tf_export
-class ConvRecurrent2D(Recurrent):
- """Abstract base class for convolutional recurrent layers.
-
- Do not use in a model -- it's not a functional layer!
+class ConvRNN2D(RNN):
+ """Base class for convolutional-recurrent layers.
Arguments:
- filters: Integer, the dimensionality of the output space
- (i.e. the number of output filters in the convolution).
- kernel_size: An integer or tuple/list of n integers, specifying the
- dimensions of the convolution window.
- strides: An integer or tuple/list of n integers,
- specifying the strides of the convolution.
- Specifying any stride value != 1 is incompatible with specifying
- any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
- data_format: A string,
- one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, time, ..., channels)`
- while `channels_first` corresponds to
- inputs with shape `(batch, time, channels, ...)`.
- It defaults to the `image_data_format` value found in your
- Keras config file at `~/.keras/keras.json`.
- If you never set it, then it will be "channels_last".
- dilation_rate: An integer or tuple/list of n integers, specifying
- the dilation rate to use for dilated convolution.
- Currently, specifying any `dilation_rate` value != 1 is
- incompatible with specifying any `strides` value != 1.
- return_sequences: Boolean. Whether to return the last output
- in the output sequence, or the full sequence.
- go_backwards: Boolean (default False).
- If True, rocess the input sequence backwards.
- stateful: Boolean (default False). If True, the last state
- for each sample at index i in a batch will be used as initial
- state for the sample of index i in the following batch.
+ cell: A RNN cell instance. A RNN cell is a class that has:
+ - a `call(input_at_t, states_at_t)` method, returning
+ `(output_at_t, states_at_t_plus_1)`. The call method of the
+ cell can also take the optional argument `constants`, see
+ section "Note on passing external constants" below.
+ - a `state_size` attribute. This can be a single integer
+ (single state) in which case it is
+ the number of channels of the recurrent state
+ (which should be the same as the number of channels of the cell
+ output). This can also be a list/tuple of integers
+ (one size per state). In this case, the first entry
+ (`state_size[0]`) should be the same as
+ the size of the cell output.
+ return_sequences: Boolean. Whether to return the last output.
+ in the output sequence, or the full sequence.
+ return_state: Boolean. Whether to return the last state
+ in addition to the output.
+ go_backwards: Boolean (default False).
+ If True, process the input sequence backwards and return the
+ reversed sequence.
+ stateful: Boolean (default False). If True, the last state
+ for each sample at index i in a batch will be used as initial
+ state for the sample of index i in the following batch.
+ input_shape: Use this argument to specify the shape of the
+ input when this layer is the first one in a model.
Input shape:
- 5D tensor with shape `(num_samples, timesteps, channels, rows, cols)`.
+ 5D tensor with shape:
+ `(samples, timesteps, channels, rows, cols)`
+ if data_format='channels_first' or 5D tensor with shape:
+ `(samples, timesteps, rows, cols, channels)`
+ if data_format='channels_last'.
Output shape:
- - if `return_sequences`: 5D tensor with shape
- `(num_samples, timesteps, channels, rows, cols)`.
- - else, 4D tensor with shape `(num_samples, channels, rows, cols)`.
-
- # Masking
- This layer supports masking for input data with a variable number
- of timesteps. To introduce masks to your data,
- use an `Embedding` layer with the `mask_zero` parameter
- set to `True`.
- **Note:** for the time being, masking is only supported with Theano.
-
- # Note on using statefulness in RNNs
- You can set RNN layers to be 'stateful', which means that the states
- computed for the samples in one batch will be reused as initial states
- for the samples in the next batch.
- This assumes a one-to-one mapping between
- samples in different successive batches.
-
- To enable statefulness:
- - specify `stateful=True` in the layer constructor.
- - specify a fixed batch size for your model, by passing
- a `batch_input_size=(...)` to the first layer in your model.
- This is the expected shape of your inputs *including the batch
- size*.
- It should be a tuple of integers, e.g. `(32, 10, 100)`.
-
- To reset the states of your model, call `.reset_states()` on either
- a specific layer, or on your entire model.
+ - if `return_state`: a list of tensors. The first tensor is
+ the output. The remaining tensors are the last states,
+ each 5D tensor with shape:
+ `(samples, timesteps, filters, new_rows, new_cols)`
+ if data_format='channels_first'
+ or 5D tensor with shape:
+ `(samples, timesteps, new_rows, new_cols, filters)`
+ if data_format='channels_last'.
+ `rows` and `cols` values might have changed due to padding.
+ - if `return_sequences`: 5D tensor with shape:
+ `(samples, timesteps, filters, new_rows, new_cols)`
+ if data_format='channels_first'
+ or 5D tensor with shape:
+ `(samples, timesteps, new_rows, new_cols, filters)`
+ if data_format='channels_last'.
+ - else, 4D tensor with shape:
+ `(samples, filters, new_rows, new_cols)`
+ if data_format='channels_first'
+ or 4D tensor with shape:
+ `(samples, new_rows, new_cols, filters)`
+ if data_format='channels_last'.
+
+ Masking:
+ This layer supports masking for input data with a variable number
+ of timesteps. To introduce masks to your data,
+ use an Embedding layer with the `mask_zero` parameter
+ set to `True`.
+
+ Note on using statefulness in RNNs:
+ You can set RNN layers to be 'stateful', which means that the states
+ computed for the samples in one batch will be reused as initial states
+ for the samples in the next batch. This assumes a one-to-one mapping
+ between samples in different successive batches.
+ To enable statefulness:
+ - specify `stateful=True` in the layer constructor.
+ - specify a fixed batch size for your model, by passing
+ - if sequential model:
+ `batch_input_shape=(...)` to the first layer in your model.
+ - if functional model with 1 or more Input layers:
+ `batch_shape=(...)` to all the first layers in your model.
+ This is the expected shape of your inputs
+ *including the batch size*.
+ It should be a tuple of integers,
+ e.g. `(32, 10, 100, 100, 32)`.
+ Note that the number of rows and columns should be specified
+ too.
+ - specify `shuffle=False` when calling fit().
+ To reset the states of your model, call `.reset_states()` on either
+ a specific layer, or on your entire model.
+
+ Note on specifying the initial state of RNNs:
+ You can specify the initial state of RNN layers symbolically by
+ calling them with the keyword argument `initial_state`. The value of
+ `initial_state` should be a tensor or list of tensors representing
+ the initial state of the RNN layer.
+ You can specify the initial state of RNN layers numerically by
+ calling `reset_states` with the keyword argument `states`. The value of
+ `states` should be a numpy array or list of numpy arrays representing
+ the initial state of the RNN layer.
+
+ Note on passing external constants to RNNs:
+ You can pass "external" constants to the cell using the `constants`
+ keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
+ requires that the `cell.call` method accepts the same keyword argument
+ `constants`. Such constants can be used to condition the cell
+ transformation on additional static inputs (not changing over time),
+ a.k.a. an attention mechanism.
"""
def __init__(self,
- filters,
- kernel_size,
- strides=(1, 1),
- padding='valid',
- data_format=None,
- dilation_rate=(1, 1),
+ cell,
return_sequences=False,
+ return_state=False,
go_backwards=False,
stateful=False,
+ unroll=False,
**kwargs):
- super(ConvRecurrent2D, self).__init__(**kwargs)
- self.filters = filters
- self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
- self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
- self.padding = conv_utils.normalize_padding(padding)
- self.data_format = conv_utils.normalize_data_format(data_format)
- self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
- 'dilation_rate')
- self.return_sequences = return_sequences
- self.go_backwards = go_backwards
- self.stateful = stateful
+ if unroll:
+ raise TypeError('Unrolling isn\'t possible with '
+ 'convolutional RNNs.')
+ if isinstance(cell, (list, tuple)):
+ # The StackedConvRNN2DCells isn't implemented yet.
+ raise TypeError('It is not possible at the moment to'
+ 'stack convolutional cells.')
+ super(ConvRNN2D, self).__init__(cell,
+ return_sequences,
+ return_state,
+ go_backwards,
+ stateful,
+ unroll,
+ **kwargs)
self.input_spec = [InputSpec(ndim=5)]
- self.state_spec = None
+ self.states = None
@shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- if self.data_format == 'channels_first':
+
+ cell = self.cell
+ if cell.data_format == 'channels_first':
rows = input_shape[3]
cols = input_shape[4]
- elif self.data_format == 'channels_last':
+ elif cell.data_format == 'channels_last':
rows = input_shape[2]
cols = input_shape[3]
- rows = conv_utils.conv_output_length(
- rows,
- self.kernel_size[0],
- padding=self.padding,
- stride=self.strides[0],
- dilation=self.dilation_rate[0])
- cols = conv_utils.conv_output_length(
- cols,
- self.kernel_size[1],
- padding=self.padding,
- stride=self.strides[1],
- dilation=self.dilation_rate[1])
+ rows = conv_utils.conv_output_length(rows,
+ cell.kernel_size[0],
+ padding=cell.padding,
+ stride=cell.strides[0],
+ dilation=cell.dilation_rate[0])
+ cols = conv_utils.conv_output_length(cols,
+ cell.kernel_size[1],
+ padding=cell.padding,
+ stride=cell.strides[1],
+ dilation=cell.dilation_rate[1])
+
+ if cell.data_format == 'channels_first':
+ output_shape = input_shape[:2] + (cell.filters, rows, cols)
+ elif cell.data_format == 'channels_last':
+ output_shape = input_shape[:2] + (rows, cols, cell.filters)
+
+ if not self.return_sequences:
+ output_shape = output_shape[:1] + output_shape[2:]
+
+ if self.return_state:
+ output_shape = [output_shape]
+ if cell.data_format == 'channels_first':
+ output_shape += [(input_shape[0], cell.filters, rows, cols)
+ for _ in range(2)]
+ elif cell.data_format == 'channels_last':
+ output_shape += [(input_shape[0], rows, cols, cell.filters)
+ for _ in range(2)]
+ return output_shape
+
+ @shape_type_conversion
+ def build(self, input_shape):
+ # Note input_shape will be list of shapes of initial states and
+ # constants if these are passed in __call__.
+ if self._num_constants is not None:
+ constants_shape = input_shape[-self._num_constants:]
+ else:
+ constants_shape = None
+
+ if isinstance(input_shape, list):
+ input_shape = input_shape[0]
+
+ batch_size = input_shape[0] if self.stateful else None
+ self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:5])
+
+ # allow cell (if layer) to build before we set or validate state_spec
+ if isinstance(self.cell, Layer):
+ step_input_shape = (input_shape[0],) + input_shape[2:]
+ if constants_shape is not None:
+ self.cell.build([step_input_shape] + constants_shape)
+ else:
+ self.cell.build(step_input_shape)
+
+ # set or validate state_spec
+ if hasattr(self.cell.state_size, '__len__'):
+ state_size = list(self.cell.state_size)
+ else:
+ state_size = [self.cell.state_size]
+
+ if self.state_spec is not None:
+ # initial_state was passed in call, check compatibility
+ if self.cell.data_format == 'channels_first':
+ ch_dim = 1
+ elif self.cell.data_format == 'channels_last':
+ ch_dim = 3
+ if [spec.shape[ch_dim] for spec in self.state_spec] != state_size:
+ raise ValueError(
+ 'An initial_state was passed that is not compatible with '
+ '`cell.state_size`. Received `state_spec`={}; '
+ 'However `cell.state_size` is '
+ '{}'.format([spec.shape for spec in self.state_spec],
+ self.cell.state_size))
+ else:
+ if self.cell.data_format == 'channels_first':
+ self.state_spec = [InputSpec(shape=(None, dim, None, None))
+ for dim in state_size]
+ elif self.cell.data_format == 'channels_last':
+ self.state_spec = [InputSpec(shape=(None, None, None, dim))
+ for dim in state_size]
+ if self.stateful:
+ self.reset_states()
+ self.built = True
+
+ def get_initial_state(self, inputs):
+ # (samples, timesteps, rows, cols, filters)
+ initial_state = K.zeros_like(inputs)
+ # (samples, rows, cols, filters)
+ initial_state = K.sum(initial_state, axis=1)
+ shape = list(self.cell.kernel_shape)
+ shape[-1] = self.cell.filters
+ initial_state = self.cell.input_conv(initial_state,
+ K.zeros(tuple(shape)),
+ padding=self.cell.padding)
+
+ if hasattr(self.cell.state_size, '__len__'):
+ return [initial_state for _ in self.cell.state_size]
+ else:
+ return [initial_state]
+
+ def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
+ inputs, initial_state, constants = self._standardize_args(
+ inputs, initial_state, constants)
+
+ if initial_state is None and constants is None:
+ return super(ConvRNN2D, self).__call__(inputs, **kwargs)
+
+ # If any of `initial_state` or `constants` are specified and are Keras
+ # tensors, then add them to the inputs and temporarily modify the
+ # input_spec to include them.
+
+ additional_inputs = []
+ additional_specs = []
+ if initial_state is not None:
+ kwargs['initial_state'] = initial_state
+ additional_inputs += initial_state
+ self.state_spec = []
+ for state in initial_state:
+ shape = K.int_shape(state)
+ self.state_spec.append(InputSpec(shape=shape))
+
+ additional_specs += self.state_spec
+ if constants is not None:
+ kwargs['constants'] = constants
+ additional_inputs += constants
+ self.constants_spec = [InputSpec(shape=K.int_shape(constant))
+ for constant in constants]
+ self._num_constants = len(constants)
+ additional_specs += self.constants_spec
+ # at this point additional_inputs cannot be empty
+ for tensor in additional_inputs:
+ if K.is_keras_tensor(tensor) != K.is_keras_tensor(additional_inputs[0]):
+ raise ValueError('The initial state or constants of an RNN'
+ ' layer cannot be specified with a mix of'
+ ' Keras tensors and non-Keras tensors')
+
+ if K.is_keras_tensor(additional_inputs[0]):
+ # Compute the full input spec, including state and constants
+ full_input = [inputs] + additional_inputs
+ full_input_spec = self.input_spec + additional_specs
+ # Perform the call with temporarily replaced input_spec
+ original_input_spec = self.input_spec
+ self.input_spec = full_input_spec
+ output = super(ConvRNN2D, self).__call__(full_input, **kwargs)
+ self.input_spec = original_input_spec
+ return output
+ else:
+ return super(ConvRNN2D, self).__call__(inputs, **kwargs)
+
+ def call(self,
+ inputs,
+ mask=None,
+ training=None,
+ initial_state=None,
+ constants=None):
+ # note that the .build() method of subclasses MUST define
+ # self.input_spec and self.state_spec with complete input shapes.
+ if isinstance(inputs, list):
+ inputs = inputs[0]
+ if initial_state is not None:
+ pass
+ elif self.stateful:
+ initial_state = self.states
+ else:
+ initial_state = self.get_initial_state(inputs)
+
+ if isinstance(mask, list):
+ mask = mask[0]
+
+ if len(initial_state) != len(self.states):
+ raise ValueError('Layer has ' + str(len(self.states)) +
+ ' states but was passed ' +
+ str(len(initial_state)) +
+ ' initial states.')
+ timesteps = K.int_shape(inputs)[1]
+
+ kwargs = {}
+ if generic_utils.has_arg(self.cell.call, 'training'):
+ kwargs['training'] = training
+
+ if constants:
+ if not generic_utils.has_arg(self.cell.call, 'constants'):
+ raise ValueError('RNN cell does not support constants')
+
+ def step(inputs, states):
+ constants = states[-self._num_constants:]
+ states = states[:-self._num_constants]
+ return self.cell.call(inputs, states, constants=constants,
+ **kwargs)
+ else:
+ def step(inputs, states):
+ return self.cell.call(inputs, states, **kwargs)
+
+ last_output, outputs, states = K.rnn(step,
+ inputs,
+ initial_state,
+ constants=constants,
+ go_backwards=self.go_backwards,
+ mask=mask,
+ input_length=timesteps)
+ if self.stateful:
+ updates = []
+ for i in range(len(states)):
+ updates.append(K.update(self.states[i], states[i]))
+ self.add_update(updates, inputs=True)
+
if self.return_sequences:
- if self.data_format == 'channels_first':
- output_shape = (input_shape[0], input_shape[1], self.filters, rows,
- cols)
- elif self.data_format == 'channels_last':
- output_shape = (input_shape[0], input_shape[1], rows, cols,
- self.filters)
+ output = outputs
else:
- if self.data_format == 'channels_first':
- output_shape = (input_shape[0], self.filters, rows, cols)
- elif self.data_format == 'channels_last':
- output_shape = (input_shape[0], rows, cols, self.filters)
+ output = last_output
+
+ # Properly set learning phase
+ if getattr(last_output, '_uses_learning_phase', False):
+ output._uses_learning_phase = True
if self.return_state:
- if self.data_format == 'channels_first':
- output_shape = [output_shape] + [
- (input_shape[0], self.filters, rows, cols) for _ in range(2)
- ]
- elif self.data_format == 'channels_last':
- output_shape = [output_shape] + [
- (input_shape[0], rows, cols, self.filters) for _ in range(2)
- ]
+ if not isinstance(states, (list, tuple)):
+ states = [states]
+ else:
+ states = list(states)
+ return [output] + states
+ else:
+ return output
- return output_shape
+ def reset_states(self, states=None):
+ if not self.stateful:
+ raise AttributeError('Layer must be stateful.')
+ input_shape = self.input_spec[0].shape
+ state_shape = self.compute_output_shape(input_shape)
+ if self.return_state:
+ state_shape = state_shape[0]
+ if self.return_sequences:
+ state_shape = state_shape[:1].concatenate(state_shape[2:])
+ if None in state_shape:
+ raise ValueError('If a RNN is stateful, it needs to know '
+ 'its batch size. Specify the batch size '
+ 'of your input tensors: \n'
+ '- If using a Sequential model, '
+ 'specify the batch size by passing '
+ 'a `batch_input_shape` '
+ 'argument to your first layer.\n'
+ '- If using the functional API, specify '
+ 'the time dimension by passing a '
+ '`batch_shape` argument to your Input layer.\n'
+ 'The same thing goes for the number of rows and '
+ 'columns.')
- def get_config(self):
- config = {
- 'filters': self.filters,
- 'kernel_size': self.kernel_size,
- 'strides': self.strides,
- 'padding': self.padding,
- 'data_format': self.data_format,
- 'dilation_rate': self.dilation_rate,
- 'return_sequences': self.return_sequences,
- 'go_backwards': self.go_backwards,
- 'stateful': self.stateful
- }
- base_config = super(ConvRecurrent2D, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
+ # helper function
+ def get_tuple_shape(nb_channels):
+ result = list(state_shape)
+ if self.cell.data_format == 'channels_first':
+ result[1] = nb_channels
+ elif self.cell.data_format == 'channels_last':
+ result[3] = nb_channels
+ else:
+ raise KeyError
+ return tuple(result)
+ # initialize state if None
+ if self.states[0] is None:
+ if hasattr(self.cell.state_size, '__len__'):
+ self.states = [K.zeros(get_tuple_shape(dim))
+ for dim in self.cell.state_size]
+ else:
+ self.states = [K.zeros(get_tuple_shape(self.cell.state_size))]
+ elif states is None:
+ if hasattr(self.cell.state_size, '__len__'):
+ for state, dim in zip(self.states, self.cell.state_size):
+ K.set_value(state, np.zeros(get_tuple_shape(dim)))
+ else:
+ K.set_value(self.states[0],
+ np.zeros(get_tuple_shape(self.cell.state_size)))
+ else:
+ if not isinstance(states, (list, tuple)):
+ states = [states]
+ if len(states) != len(self.states):
+ raise ValueError('Layer ' + self.name + ' expects ' +
+ str(len(self.states)) + ' states, ' +
+ 'but it received ' + str(len(states)) +
+ ' state values. Input received: ' + str(states))
+ for index, (value, state) in enumerate(zip(states, self.states)):
+ if hasattr(self.cell.state_size, '__len__'):
+ dim = self.cell.state_size[index]
+ else:
+ dim = self.cell.state_size
+ if value.shape != get_tuple_shape(dim):
+ raise ValueError('State ' + str(index) +
+ ' is incompatible with layer ' +
+ self.name + ': expected shape=' +
+ str(get_tuple_shape(dim)) +
+ ', found shape=' + str(value.shape))
+ # TODO(anjalisridhar): consider batch calls to `set_value`.
+ K.set_value(state, value)
-@tf_export('keras.layers.ConvLSTM2D')
-class ConvLSTM2D(ConvRecurrent2D):
- """Convolutional LSTM.
- It is similar to an LSTM layer, but the input transformations
- and recurrent transformations are both convolutional.
+class ConvLSTM2DCell(Layer):
+ """Cell class for the ConvLSTM2D layer.
- Arguments:
+ # Arguments
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of n integers, specifying the
@@ -212,11 +488,6 @@ class ConvLSTM2D(ConvRecurrent2D):
padding: One of `"valid"` or `"same"` (case-insensitive).
data_format: A string,
one of `channels_last` (default) or `channels_first`.
- The ordering of the dimensions in the inputs.
- `channels_last` corresponds to inputs with shape
- `(batch, time, ..., channels)`
- while `channels_first` corresponds to
- inputs with shape `(batch, time, channels, ...)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
@@ -231,71 +502,32 @@ class ConvLSTM2D(ConvRecurrent2D):
for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs..
+ used for the linear transformation of the inputs.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state..
+ used for the linear transformation of the recurrent state.
bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Use in combination with `bias_initializer="zeros"`.
- This is recommended in [Jozefowicz et
- al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
+ This is recommended in [Jozefowicz et al.]
+ (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
- activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
- return_sequences: Boolean. Whether to return the last output
- in the output sequence, or the full sequence.
- go_backwards: Boolean (default False).
- If True, rocess the input sequence backwards.
- stateful: Boolean (default False). If True, the last state
- for each sample at index i in a batch will be used as initial
- state for the sample of index i in the following batch.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
-
- Input shape:
- - if data_format='channels_first'
- 5D tensor with shape:
- `(samples,time, channels, rows, cols)`
- - if data_format='channels_last'
- 5D tensor with shape:
- `(samples,time, rows, cols, channels)`
-
- Output shape:
- - if `return_sequences`
- - if data_format='channels_first'
- 5D tensor with shape:
- `(samples, time, filters, output_row, output_col)`
- - if data_format='channels_last'
- 5D tensor with shape:
- `(samples, time, output_row, output_col, filters)`
- - else
- - if data_format ='channels_first'
- 4D tensor with shape:
- `(samples, filters, output_row, output_col)`
- - if data_format='channels_last'
- 4D tensor with shape:
- `(samples, output_row, output_col, filters)`
- where o_row and o_col depend on the shape of the filter and
- the padding
-
- Raises:
- ValueError: in case of invalid constructor arguments.
-
"""
def __init__(self,
@@ -315,27 +547,20 @@ class ConvLSTM2D(ConvRecurrent2D):
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
- activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
- return_sequences=False,
- go_backwards=False,
- stateful=False,
dropout=0.,
recurrent_dropout=0.,
**kwargs):
- super(ConvLSTM2D, self).__init__(
- filters,
- kernel_size,
- strides=strides,
- padding=padding,
- data_format=data_format,
- dilation_rate=dilation_rate,
- return_sequences=return_sequences,
- go_backwards=go_backwards,
- stateful=stateful,
- **kwargs)
+ super(ConvLSTM2DCell, self).__init__(**kwargs)
+ self.filters = filters
+ self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size')
+ self.strides = conv_utils.normalize_tuple(strides, 2, 'strides')
+ self.padding = conv_utils.normalize_padding(padding)
+ self.data_format = conv_utils.normalize_data_format(data_format)
+ self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2,
+ 'dilation_rate')
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
@@ -348,7 +573,6 @@ class ConvLSTM2D(ConvRecurrent2D):
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
- self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
@@ -356,45 +580,29 @@ class ConvLSTM2D(ConvRecurrent2D):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
- self.state_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
+ self.state_size = (self.filters, self.filters)
+ self._dropout_mask = None
+ self._recurrent_dropout_mask = None
- @shape_type_conversion
def build(self, input_shape):
- if isinstance(input_shape, list):
- input_shape = input_shape[0]
- batch_size = input_shape[0] if self.stateful else None
- self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_shape[2:])
- if self.stateful:
- self.reset_states()
- else:
- # initial states: 2 all-zero tensor of shape (filters)
- self.states = [None, None]
if self.data_format == 'channels_first':
- channel_axis = 2
+ channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
- state_shape = [None] * 4
- state_shape[channel_axis] = input_dim
- state_shape = tuple(state_shape)
- self.state_spec = [
- InputSpec(shape=state_shape),
- InputSpec(shape=state_shape)
- ]
kernel_shape = self.kernel_size + (input_dim, self.filters * 4)
self.kernel_shape = kernel_shape
recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
- self.kernel = self.add_weight(
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- name='kernel',
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
+ self.kernel = self.add_weight(shape=kernel_shape,
+ initializer=self.kernel_initializer,
+ name='kernel',
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
shape=recurrent_kernel_shape,
initializer=self.recurrent_initializer,
@@ -402,25 +610,24 @@ class ConvLSTM2D(ConvRecurrent2D):
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
- self.bias = self.add_weight(
- shape=(self.filters * 4,),
- initializer=self.bias_initializer,
- name='bias',
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
+ self.bias = self.add_weight(shape=(self.filters * 4,),
+ initializer=self.bias_initializer,
+ name='bias',
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint)
if self.unit_forget_bias:
bias_value = np.zeros((self.filters * 4,))
- bias_value[self.filters:self.filters * 2] = 1.
+ bias_value[self.filters: self.filters * 2] = 1.
K.set_value(self.bias, bias_value)
else:
self.bias = None
self.kernel_i = self.kernel[:, :, :, :self.filters]
self.recurrent_kernel_i = self.recurrent_kernel[:, :, :, :self.filters]
- self.kernel_f = self.kernel[:, :, :, self.filters:self.filters * 2]
+ self.kernel_f = self.kernel[:, :, :, self.filters: self.filters * 2]
self.recurrent_kernel_f = self.recurrent_kernel[:, :, :, self.filters:
self.filters * 2]
- self.kernel_c = self.kernel[:, :, :, self.filters * 2:self.filters * 3]
+ self.kernel_c = self.kernel[:, :, :, self.filters * 2: self.filters * 3]
self.recurrent_kernel_c = self.recurrent_kernel[:, :, :, self.filters * 2:
self.filters * 3]
self.kernel_o = self.kernel[:, :, :, self.filters * 3:]
@@ -428,8 +635,8 @@ class ConvLSTM2D(ConvRecurrent2D):
if self.use_bias:
self.bias_i = self.bias[:self.filters]
- self.bias_f = self.bias[self.filters:self.filters * 2]
- self.bias_c = self.bias[self.filters * 2:self.filters * 3]
+ self.bias_f = self.bias[self.filters: self.filters * 2]
+ self.bias_c = self.bias[self.filters * 2: self.filters * 3]
self.bias_o = self.bias[self.filters * 3:]
else:
self.bias_i = None
@@ -438,166 +645,419 @@ class ConvLSTM2D(ConvRecurrent2D):
self.bias_o = None
self.built = True
- def get_initial_state(self, inputs):
- # (samples, timesteps, rows, cols, filters)
- initial_state = array_ops.zeros_like(inputs)
- # (samples, rows, cols, filters)
- initial_state = math_ops.reduce_sum(initial_state, axis=1)
- shape = list(self.kernel_shape)
- shape[-1] = self.filters
- initial_state = self.input_conv(
- initial_state, K.zeros(tuple(shape)), padding=self.padding)
-
- initial_states = [initial_state for _ in range(2)]
- return initial_states
+ def call(self, inputs, states, training=None):
+ if 0 < self.dropout < 1 and self._dropout_mask is None:
+ self._dropout_mask = _generate_dropout_mask(
+ K.ones_like(inputs),
+ self.dropout,
+ training=training,
+ count=4)
+ if (0 < self.recurrent_dropout < 1 and
+ self._recurrent_dropout_mask is None):
+ self._recurrent_dropout_mask = _generate_dropout_mask(
+ K.ones_like(states[1]),
+ self.recurrent_dropout,
+ training=training,
+ count=4)
- def reset_states(self):
- if not self.stateful:
- raise RuntimeError('Layer must be stateful.')
- input_shape = self.input_spec[0].shape
+ # dropout matrices for input units
+ dp_mask = self._dropout_mask
+ # dropout matrices for recurrent units
+ rec_dp_mask = self._recurrent_dropout_mask
- if not input_shape[0]:
- raise ValueError('If a RNN is stateful, a complete '
- 'input_shape must be provided '
- '(including batch size). '
- 'Got input shape: ' + str(input_shape))
+ h_tm1 = states[0] # previous memory state
+ c_tm1 = states[1] # previous carry state
- if self.return_state:
- output_shape = tuple(self.compute_output_shape(input_shape)[0].as_list())
- else:
- output_shape = tuple(self.compute_output_shape(input_shape).as_list())
- if self.return_sequences:
- output_shape = (input_shape[0],) + output_shape[2:]
+ if 0 < self.dropout < 1.:
+ inputs_i = inputs * dp_mask[0]
+ inputs_f = inputs * dp_mask[1]
+ inputs_c = inputs * dp_mask[2]
+ inputs_o = inputs * dp_mask[3]
else:
- output_shape = (input_shape[0],) + output_shape[1:]
+ inputs_i = inputs
+ inputs_f = inputs
+ inputs_c = inputs
+ inputs_o = inputs
- if hasattr(self, 'states'):
- K.set_value(self.states[0],
- np.zeros(output_shape))
- K.set_value(self.states[1],
- np.zeros(output_shape))
+ if 0 < self.recurrent_dropout < 1.:
+ h_tm1_i = h_tm1 * rec_dp_mask[0]
+ h_tm1_f = h_tm1 * rec_dp_mask[1]
+ h_tm1_c = h_tm1 * rec_dp_mask[2]
+ h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
- self.states = [
- K.zeros(output_shape),
- K.zeros(output_shape)
- ]
-
- def get_constants(self, inputs, training=None):
- constants = []
- if self.implementation == 0 and 0 < self.dropout < 1:
- ones = array_ops.zeros_like(inputs)
- ones = math_ops.reduce_sum(ones, axis=1)
- ones += 1
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- dp_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(4)
- ]
- constants.append(dp_mask)
- else:
- constants.append([K.cast_to_floatx(1.) for _ in range(4)])
-
- if 0 < self.recurrent_dropout < 1:
- shape = list(self.kernel_shape)
- shape[-1] = self.filters
- ones = array_ops.zeros_like(inputs)
- ones = math_ops.reduce_sum(ones, axis=1)
- ones = self.input_conv(ones, K.zeros(shape), padding=self.padding)
- ones += 1.
-
- def dropped_inputs(): # pylint: disable=function-redefined
- return K.dropout(ones, self.recurrent_dropout)
-
- rec_dp_mask = [
- K.in_train_phase(dropped_inputs, ones, training=training)
- for _ in range(4)
- ]
- constants.append(rec_dp_mask)
- else:
- constants.append([K.cast_to_floatx(1.) for _ in range(4)])
- return constants
+ h_tm1_i = h_tm1
+ h_tm1_f = h_tm1
+ h_tm1_c = h_tm1
+ h_tm1_o = h_tm1
+
+ x_i = self.input_conv(inputs_i, self.kernel_i, self.bias_i,
+ padding=self.padding)
+ x_f = self.input_conv(inputs_f, self.kernel_f, self.bias_f,
+ padding=self.padding)
+ x_c = self.input_conv(inputs_c, self.kernel_c, self.bias_c,
+ padding=self.padding)
+ x_o = self.input_conv(inputs_o, self.kernel_o, self.bias_o,
+ padding=self.padding)
+ h_i = self.recurrent_conv(h_tm1_i,
+ self.recurrent_kernel_i)
+ h_f = self.recurrent_conv(h_tm1_f,
+ self.recurrent_kernel_f)
+ h_c = self.recurrent_conv(h_tm1_c,
+ self.recurrent_kernel_c)
+ h_o = self.recurrent_conv(h_tm1_o,
+ self.recurrent_kernel_o)
+
+ i = self.recurrent_activation(x_i + h_i)
+ f = self.recurrent_activation(x_f + h_f)
+ c = f * c_tm1 + i * self.activation(x_c + h_c)
+ o = self.recurrent_activation(x_o + h_o)
+ h = o * self.activation(c)
+
+ if 0 < self.dropout + self.recurrent_dropout:
+ if training is None:
+ h._uses_learning_phase = True
+
+ return h, [h, c]
def input_conv(self, x, w, b=None, padding='valid'):
- conv_out = K.conv2d(
- x,
- w,
- strides=self.strides,
- padding=padding,
- data_format=self.data_format,
- dilation_rate=self.dilation_rate)
+ conv_out = K.conv2d(x, w, strides=self.strides,
+ padding=padding,
+ data_format=self.data_format,
+ dilation_rate=self.dilation_rate)
if b is not None:
- conv_out = K.bias_add(conv_out, b, data_format=self.data_format)
+ conv_out = K.bias_add(conv_out, b,
+ data_format=self.data_format)
return conv_out
def recurrent_conv(self, x, w):
- conv_out = K.conv2d(
- x, w, strides=(1, 1), padding='same', data_format=self.data_format)
+ conv_out = K.conv2d(x, w, strides=(1, 1),
+ padding='same',
+ data_format=self.data_format)
return conv_out
- def step(self, inputs, states):
- assert len(states) == 4
- h_tm1 = states[0]
- c_tm1 = states[1]
- dp_mask = states[2]
- rec_dp_mask = states[3]
-
- x_i = self.input_conv(
- inputs * dp_mask[0], self.kernel_i, self.bias_i, padding=self.padding)
- x_f = self.input_conv(
- inputs * dp_mask[1], self.kernel_f, self.bias_f, padding=self.padding)
- x_c = self.input_conv(
- inputs * dp_mask[2], self.kernel_c, self.bias_c, padding=self.padding)
- x_o = self.input_conv(
- inputs * dp_mask[3], self.kernel_o, self.bias_o, padding=self.padding)
- h_i = self.recurrent_conv(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i)
- h_f = self.recurrent_conv(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f)
- h_c = self.recurrent_conv(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c)
- h_o = self.recurrent_conv(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o)
+ def get_config(self):
+ config = {'filters': self.filters,
+ 'kernel_size': self.kernel_size,
+ 'strides': self.strides,
+ 'padding': self.padding,
+ 'data_format': self.data_format,
+ 'dilation_rate': self.dilation_rate,
+ 'activation': activations.serialize(self.activation),
+ 'recurrent_activation': activations.serialize(
+ self.recurrent_activation),
+ 'use_bias': self.use_bias,
+ 'kernel_initializer': initializers.serialize(
+ self.kernel_initializer),
+ 'recurrent_initializer': initializers.serialize(
+ self.recurrent_initializer),
+ 'bias_initializer': initializers.serialize(self.bias_initializer),
+ 'unit_forget_bias': self.unit_forget_bias,
+ 'kernel_regularizer': regularizers.serialize(
+ self.kernel_regularizer),
+ 'recurrent_regularizer': regularizers.serialize(
+ self.recurrent_regularizer),
+ 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'kernel_constraint': constraints.serialize(
+ self.kernel_constraint),
+ 'recurrent_constraint': constraints.serialize(
+ self.recurrent_constraint),
+ 'bias_constraint': constraints.serialize(self.bias_constraint),
+ 'dropout': self.dropout,
+ 'recurrent_dropout': self.recurrent_dropout}
+ base_config = super(ConvLSTM2DCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
- i = self.recurrent_activation(x_i + h_i)
- f = self.recurrent_activation(x_f + h_f)
- c = f * c_tm1 + i * self.activation(x_c + h_c)
- o = self.recurrent_activation(x_o + h_o)
- h = o * self.activation(c)
- return h, [h, c]
+@tf_export('keras.layers.ConvLSTM2D')
+class ConvLSTM2D(ConvRNN2D):
+ """Convolutional LSTM.
+
+ It is similar to an LSTM layer, but the input transformations
+ and recurrent transformations are both convolutional.
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space
+ (i.e. the number output of filters in the convolution).
+ kernel_size: An integer or tuple/list of n integers, specifying the
+ dimensions of the convolution window.
+ strides: An integer or tuple/list of n integers,
+ specifying the strides of the convolution.
+ Specifying any stride value != 1 is incompatible with specifying
+ any `dilation_rate` value != 1.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, time, ..., channels)`
+ while `channels_first` corresponds to
+ inputs with shape `(batch, time, channels, ...)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be "channels_last".
+ dilation_rate: An integer or tuple/list of n integers, specifying
+ the dilation rate to use for dilated convolution.
+ Currently, specifying any `dilation_rate` value != 1 is
+ incompatible with specifying any `strides` value != 1.
+ activation: Activation function to use.
+ If you don't specify anything, no activation is applied
+ (ie. "linear" activation: `a(x) = x`).
+ recurrent_activation: Activation function to use
+ for the recurrent step.
+ use_bias: Boolean, whether the layer uses a bias vector.
+ kernel_initializer: Initializer for the `kernel` weights matrix,
+ used for the linear transformation of the inputs.
+ recurrent_initializer: Initializer for the `recurrent_kernel`
+ weights matrix,
+ used for the linear transformation of the recurrent state.
+ bias_initializer: Initializer for the bias vector.
+ unit_forget_bias: Boolean.
+ If True, add 1 to the bias of the forget gate at initialization.
+ Use in combination with `bias_initializer="zeros"`.
+ This is recommended in [Jozefowicz et al.]
+ (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
+ kernel_regularizer: Regularizer function applied to
+ the `kernel` weights matrix.
+ recurrent_regularizer: Regularizer function applied to
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
+ activity_regularizer: Regularizer function applied to.
+ kernel_constraint: Constraint function applied to
+ the `kernel` weights matrix.
+ recurrent_constraint: Constraint function applied to
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
+ return_sequences: Boolean. Whether to return the last output
+ in the output sequence, or the full sequence.
+ go_backwards: Boolean (default False).
+ If True, process the input sequence backwards.
+ stateful: Boolean (default False). If True, the last state
+ for each sample at index i in a batch will be used as initial
+ state for the sample of index i in the following batch.
+ dropout: Float between 0 and 1.
+ Fraction of the units to drop for
+ the linear transformation of the inputs.
+ recurrent_dropout: Float between 0 and 1.
+ Fraction of the units to drop for
+ the linear transformation of the recurrent state.
+
+ Input shape:
+ - if data_format='channels_first'
+ 5D tensor with shape:
+ `(samples,time, channels, rows, cols)`
+ - if data_format='channels_last'
+ 5D tensor with shape:
+ `(samples,time, rows, cols, channels)`
+
+ Output shape:
+ - if `return_sequences`
+ - if data_format='channels_first'
+ 5D tensor with shape:
+ `(samples, time, filters, output_row, output_col)`
+ - if data_format='channels_last'
+ 5D tensor with shape:
+ `(samples, time, output_row, output_col, filters)`
+ - else
+ - if data_format ='channels_first'
+ 4D tensor with shape:
+ `(samples, filters, output_row, output_col)`
+ - if data_format='channels_last'
+ 4D tensor with shape:
+ `(samples, output_row, output_col, filters)`
+ where o_row and o_col depend on the shape of the filter and
+ the padding
+
+ Raises:
+ ValueError: in case of invalid constructor arguments.
+
+ References:
+ - [Convolutional LSTM Network: A Machine Learning Approach for
+ Precipitation Nowcasting](http://arxiv.org/abs/1506.04214v1)
+ The current implementation does not include the feedback loop on the
+ cells output.
+
+ """
+
+ def __init__(self,
+ filters,
+ kernel_size,
+ strides=(1, 1),
+ padding='valid',
+ data_format=None,
+ dilation_rate=(1, 1),
+ activation='tanh',
+ recurrent_activation='hard_sigmoid',
+ use_bias=True,
+ kernel_initializer='glorot_uniform',
+ recurrent_initializer='orthogonal',
+ bias_initializer='zeros',
+ unit_forget_bias=True,
+ kernel_regularizer=None,
+ recurrent_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ kernel_constraint=None,
+ recurrent_constraint=None,
+ bias_constraint=None,
+ return_sequences=False,
+ go_backwards=False,
+ stateful=False,
+ dropout=0.,
+ recurrent_dropout=0.,
+ **kwargs):
+ cell = ConvLSTM2DCell(filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ dilation_rate=dilation_rate,
+ activation=activation,
+ recurrent_activation=recurrent_activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ recurrent_initializer=recurrent_initializer,
+ bias_initializer=bias_initializer,
+ unit_forget_bias=unit_forget_bias,
+ kernel_regularizer=kernel_regularizer,
+ recurrent_regularizer=recurrent_regularizer,
+ bias_regularizer=bias_regularizer,
+ kernel_constraint=kernel_constraint,
+ recurrent_constraint=recurrent_constraint,
+ bias_constraint=bias_constraint,
+ dropout=dropout,
+ recurrent_dropout=recurrent_dropout)
+ super(ConvLSTM2D, self).__init__(cell,
+ return_sequences=return_sequences,
+ go_backwards=go_backwards,
+ stateful=stateful,
+ **kwargs)
+ self.activity_regularizer = regularizers.get(activity_regularizer)
+
+ def call(self, inputs, mask=None, training=None, initial_state=None):
+ return super(ConvLSTM2D, self).call(inputs,
+ mask=mask,
+ training=training,
+ initial_state=initial_state)
+
+ @property
+ def filters(self):
+ return self.cell.filters
+
+ @property
+ def kernel_size(self):
+ return self.cell.kernel_size
+
+ @property
+ def strides(self):
+ return self.cell.strides
+
+ @property
+ def padding(self):
+ return self.cell.padding
+
+ @property
+ def data_format(self):
+ return self.cell.data_format
+
+ @property
+ def dilation_rate(self):
+ return self.cell.dilation_rate
+
+ @property
+ def activation(self):
+ return self.cell.activation
+
+ @property
+ def recurrent_activation(self):
+ return self.cell.recurrent_activation
+
+ @property
+ def use_bias(self):
+ return self.cell.use_bias
+
+ @property
+ def kernel_initializer(self):
+ return self.cell.kernel_initializer
+
+ @property
+ def recurrent_initializer(self):
+ return self.cell.recurrent_initializer
+
+ @property
+ def bias_initializer(self):
+ return self.cell.bias_initializer
+
+ @property
+ def unit_forget_bias(self):
+ return self.cell.unit_forget_bias
+
+ @property
+ def kernel_regularizer(self):
+ return self.cell.kernel_regularizer
+
+ @property
+ def recurrent_regularizer(self):
+ return self.cell.recurrent_regularizer
+
+ @property
+ def bias_regularizer(self):
+ return self.cell.bias_regularizer
+
+ @property
+ def kernel_constraint(self):
+ return self.cell.kernel_constraint
+
+ @property
+ def recurrent_constraint(self):
+ return self.cell.recurrent_constraint
+
+ @property
+ def bias_constraint(self):
+ return self.cell.bias_constraint
+
+ @property
+ def dropout(self):
+ return self.cell.dropout
+
+ @property
+ def recurrent_dropout(self):
+ return self.cell.recurrent_dropout
def get_config(self):
- config = {
- 'activation':
- activations.serialize(self.activation),
- 'recurrent_activation':
- activations.serialize(self.recurrent_activation),
- 'use_bias':
- self.use_bias,
- 'kernel_initializer':
- initializers.serialize(self.kernel_initializer),
- 'recurrent_initializer':
- initializers.serialize(self.recurrent_initializer),
- 'bias_initializer':
- initializers.serialize(self.bias_initializer),
- 'unit_forget_bias':
- self.unit_forget_bias,
- 'kernel_regularizer':
- regularizers.serialize(self.kernel_regularizer),
- 'recurrent_regularizer':
- regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer':
- regularizers.serialize(self.bias_regularizer),
- 'activity_regularizer':
- regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint':
- constraints.serialize(self.kernel_constraint),
- 'recurrent_constraint':
- constraints.serialize(self.recurrent_constraint),
- 'bias_constraint':
- constraints.serialize(self.bias_constraint),
- 'dropout':
- self.dropout,
- 'recurrent_dropout':
- self.recurrent_dropout
- }
+ config = {'filters': self.filters,
+ 'kernel_size': self.kernel_size,
+ 'strides': self.strides,
+ 'padding': self.padding,
+ 'data_format': self.data_format,
+ 'dilation_rate': self.dilation_rate,
+ 'activation': activations.serialize(self.activation),
+ 'recurrent_activation': activations.serialize(
+ self.recurrent_activation),
+ 'use_bias': self.use_bias,
+ 'kernel_initializer': initializers.serialize(
+ self.kernel_initializer),
+ 'recurrent_initializer': initializers.serialize(
+ self.recurrent_initializer),
+ 'bias_initializer': initializers.serialize(self.bias_initializer),
+ 'unit_forget_bias': self.unit_forget_bias,
+ 'kernel_regularizer': regularizers.serialize(
+ self.kernel_regularizer),
+ 'recurrent_regularizer': regularizers.serialize(
+ self.recurrent_regularizer),
+ 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'activity_regularizer': regularizers.serialize(
+ self.activity_regularizer),
+ 'kernel_constraint': constraints.serialize(
+ self.kernel_constraint),
+ 'recurrent_constraint': constraints.serialize(
+ self.recurrent_constraint),
+ 'bias_constraint': constraints.serialize(self.bias_constraint),
+ 'dropout': self.dropout,
+ 'recurrent_dropout': self.recurrent_dropout}
base_config = super(ConvLSTM2D, self).get_config()
+ del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
+
+ @classmethod
+ def from_config(cls, config):
+ return cls(**config)
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
index 60137bdd72..9e768b4e95 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
@@ -64,6 +64,7 @@ class ConvLSTMTest(test.TestCase):
self.assertEqual(len(states), 2)
model = keras.models.Model(x, states[0])
state = model.predict(inputs)
+
self.assertAllClose(
keras.backend.eval(layer.states[0]), state, atol=1e-4)
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
index f4a134b96c..12b4267675 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
@@ -961,5 +961,43 @@ class CroppingTest(test.TestCase):
keras.layers.Cropping3D(cropping=None)
+class DepthwiseConv2DTest(test.TestCase):
+
+ def _run_test(self, kwargs, arg, values):
+ num_samples = 2
+ stack_size = 3
+ num_row = 7
+ num_col = 6
+
+ test_kwargs = copy.copy(kwargs)
+ for value in values:
+ test_kwargs[arg] = value
+ with self.test_session(use_gpu=True):
+ testing_utils.layer_test(
+ keras.layers.DepthwiseConv2D,
+ kwargs=test_kwargs,
+ input_shape=(num_samples, num_row, num_col, stack_size))
+
+ def test_depthwise_conv2d(self):
+ kwargs = {'kernel_size': (3, 3)}
+
+ self._run_test(kwargs, 'padding', ['valid', 'same'])
+ self._run_test(kwargs, 'strides', [(2, 2)])
+ if test.is_gpu_available(cuda_only=True):
+ self._run_test(kwargs, 'data_format', ['channels_first'])
+ self._run_test(kwargs, 'depth_multiplier', [1, 2])
+
+ kwargs = {'kernel_size': 3,
+ 'padding': 'valid',
+ 'data_format': 'channels_first',
+ 'activation': None,
+ 'depthwise_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'activity_regularizer': 'l2',
+ 'depthwise_constraint': 'unit_norm',
+ 'strides': (2, 2),
+ }
+ self._run_test(kwargs, 'depth_multiplier', [1])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 7f9f77c296..f53db987ff 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -251,7 +251,7 @@ class RNN(Layer):
It is also possible for `cell` to be a list of RNN cell instances,
in which cases the cells get stacked on after the other in the RNN,
implementing an efficient stacked RNN.
- return_sequences: Boolean. Whether to return the last output.
+ return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
@@ -797,10 +797,10 @@ class RNN(Layer):
@property
def losses(self):
- losses = []
+ layer_losses = super(RNN, self).losses
if isinstance(self.cell, Layer):
- losses += self.cell.losses
- return losses + self._losses
+ return self.cell.losses + layer_losses
+ return layer_losses
@property
def updates(self):
@@ -1017,7 +1017,7 @@ class SimpleRNN(RNN):
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
- return_sequences: Boolean. Whether to return the last output.
+ return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
@@ -1237,6 +1237,9 @@ class GRUCell(Layer):
batch them into fewer, larger operations. These modes will
have different performance profiles on different hardware and
for different applications.
+ reset_after: GRU convention (whether to apply reset gate after or
+ before matrix multiplication). False = "before" (default),
+ True = "after" (CuDNN compatible).
"""
def __init__(self,
@@ -1256,6 +1259,7 @@ class GRUCell(Layer):
dropout=0.,
recurrent_dropout=0.,
implementation=1,
+ reset_after=False,
**kwargs):
super(GRUCell, self).__init__(**kwargs)
self.units = units
@@ -1278,6 +1282,7 @@ class GRUCell(Layer):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.implementation = implementation
+ self.reset_after = reset_after
self.state_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
@@ -1299,12 +1304,25 @@ class GRUCell(Layer):
constraint=self.recurrent_constraint)
if self.use_bias:
- self.bias = self.add_weight(
- shape=(self.units * 3,),
- name='bias',
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
+ if not self.reset_after:
+ bias_shape = (3 * self.units,)
+ else:
+ # separate biases for input and recurrent kernels
+ # Note: the shape is intentionally different from CuDNNGRU biases
+ # `(2 * 3 * self.units,)`, so that we can distinguish the classes
+ # when loading and converting saved weights.
+ bias_shape = (2, 3 * self.units)
+ self.bias = self.add_weight(shape=bias_shape,
+ name='bias',
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint)
+ if not self.reset_after:
+ self.input_bias, self.recurrent_bias = self.bias, None
+ else:
+ self.input_bias = K.flatten(self.bias[0])
+ self.recurrent_bias = K.flatten(self.bias[1])
+
else:
self.bias = None
self.built = True
@@ -1340,13 +1358,15 @@ class GRUCell(Layer):
inputs_z = inputs
inputs_r = inputs
inputs_h = inputs
+
x_z = K.dot(inputs_z, self.kernel[:, :self.units])
x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:])
+
if self.use_bias:
- x_z = K.bias_add(x_z, self.bias[:self.units])
- x_r = K.bias_add(x_r, self.bias[self.units:self.units * 2])
- x_h = K.bias_add(x_h, self.bias[self.units * 2:])
+ x_z = K.bias_add(x_z, self.input_bias[:self.units])
+ x_r = K.bias_add(x_r, self.input_bias[self.units: self.units * 2])
+ x_h = K.bias_add(x_h, self.input_bias[self.units * 2:])
if 0. < self.recurrent_dropout < 1.:
h_tm1_z = h_tm1 * rec_dp_mask[0]
@@ -1356,42 +1376,70 @@ class GRUCell(Layer):
h_tm1_z = h_tm1
h_tm1_r = h_tm1
h_tm1_h = h_tm1
- z = self.recurrent_activation(
- x_z + K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]))
- r = self.recurrent_activation(
- x_r + K.dot(h_tm1_r, self.recurrent_kernel[:, self.units:
- self.units * 2]))
-
- hh = self.activation(x_h + K.dot(r * h_tm1_h,
- self.recurrent_kernel[:,
- self.units * 2:]))
+
+ recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])
+ recurrent_r = K.dot(h_tm1_r,
+ self.recurrent_kernel[:, self.units:self.units * 2])
+ if self.reset_after and self.use_bias:
+ recurrent_z = K.bias_add(recurrent_z, self.recurrent_bias[:self.units])
+ recurrent_r = K.bias_add(recurrent_r,
+ self.recurrent_bias[self.units:
+ self.units * 2])
+
+ z = self.recurrent_activation(x_z + recurrent_z)
+ r = self.recurrent_activation(x_r + recurrent_r)
+
+ # reset gate applied after/before matrix multiplication
+ if self.reset_after:
+ recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
+ if self.use_bias:
+ recurrent_h = K.bias_add(recurrent_h,
+ self.recurrent_bias[self.units * 2:])
+ recurrent_h = r * recurrent_h
+ else:
+ recurrent_h = K.dot(r * h_tm1_h,
+ self.recurrent_kernel[:, self.units * 2:])
+
+ hh = self.activation(x_h + recurrent_h)
else:
if 0. < self.dropout < 1.:
inputs *= dp_mask[0]
+
+ # inputs projected by all gate matrices at once
matrix_x = K.dot(inputs, self.kernel)
if self.use_bias:
- matrix_x = K.bias_add(matrix_x, self.bias)
+ # biases: bias_z_i, bias_r_i, bias_h_i
+ matrix_x = K.bias_add(matrix_x, self.input_bias)
+
+ x_z = matrix_x[:, :self.units]
+ x_r = matrix_x[:, self.units: 2 * self.units]
+ x_h = matrix_x[:, 2 * self.units:]
+
if 0. < self.recurrent_dropout < 1.:
h_tm1 *= rec_dp_mask[0]
matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
- x_z = matrix_x[:, :self.units]
- x_r = matrix_x[:, self.units:2 * self.units]
recurrent_z = matrix_inner[:, :self.units]
recurrent_r = matrix_inner[:, self.units:2 * self.units]
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
- x_h = matrix_x[:, 2 * self.units:]
- recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
+ if self.reset_after:
+ recurrent_h = r * matrix_inner[:, 2 * self.units:]
+ else:
+ recurrent_h = K.dot(r * h_tm1,
+ self.recurrent_kernel[:, 2 * self.units:])
+
hh = self.activation(x_h + recurrent_h)
+ # previous and candidate state mixed by update gate
h = z * h_tm1 + (1 - z) * hh
if 0 < self.dropout + self.recurrent_dropout:
if training is None and not context.executing_eagerly():
# This would be harmless to set in eager mode, but eager tensors
# disallow setting arbitrary attributes.
h._uses_learning_phase = True
+
return h, [h]
def get_config(self):
@@ -1415,7 +1463,8 @@ class GRUCell(Layer):
'bias_constraint': constraints.serialize(self.bias_constraint),
'dropout': self.dropout,
'recurrent_dropout': self.recurrent_dropout,
- 'implementation': self.implementation
+ 'implementation': self.implementation,
+ 'reset_after': self.reset_after
}
base_config = super(GRUCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -1423,9 +1472,16 @@ class GRUCell(Layer):
@tf_export('keras.layers.GRU')
class GRU(RNN):
- """Gated Recurrent Unit - Cho et al.
+ """Gated Recurrent Unit - Cho et al. 2014.
- 2014.
+ There are two variants. The default one is based on 1406.1078v3 and
+ has reset gate applied to hidden state before matrix multiplication. The
+ other one is based on original 1406.1078v1 and has the order reversed.
+
+ The second variant is compatible with CuDNNGRU (GPU-only) and allows
+ inference on CPU. Thus it has separate biases for `kernel` and
+ `recurrent_kernel`. Use `'reset_after'=True` and
+ `recurrent_activation='sigmoid'`.
Arguments:
units: Positive integer, dimensionality of the output space.
@@ -1469,7 +1525,7 @@ class GRU(RNN):
batch them into fewer, larger operations. These modes will
have different performance profiles on different hardware and
for different applications.
- return_sequences: Boolean. Whether to return the last output.
+ return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
@@ -1485,6 +1541,9 @@ class GRU(RNN):
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
+ reset_after: GRU convention (whether to apply reset gate after or
+ before matrix multiplication). False = "before" (default),
+ True = "after" (CuDNN compatible).
"""
@@ -1511,6 +1570,7 @@ class GRU(RNN):
go_backwards=False,
stateful=False,
unroll=False,
+ reset_after=False,
**kwargs):
if implementation == 0:
logging.warning('`implementation=0` has been deprecated, '
@@ -1532,7 +1592,8 @@ class GRU(RNN):
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
- implementation=implementation)
+ implementation=implementation,
+ reset_after=reset_after)
super(GRU, self).__init__(
cell,
return_sequences=return_sequences,
@@ -1613,6 +1674,10 @@ class GRU(RNN):
def implementation(self):
return self.cell.implementation
+ @property
+ def reset_after(self):
+ return self.cell.reset_after
+
def get_config(self):
config = {
'units':
@@ -1648,7 +1713,9 @@ class GRU(RNN):
'recurrent_dropout':
self.recurrent_dropout,
'implementation':
- self.implementation
+ self.implementation,
+ 'reset_after':
+ self.reset_after
}
base_config = super(GRU, self).get_config()
del base_config['cell']
@@ -1929,7 +1996,7 @@ class LSTMCell(Layer):
@tf_export('keras.layers.LSTM')
class LSTM(RNN):
- """Long-Short Term Memory layer - Hochreiter 1997.
+ """Long Short-Term Memory layer - Hochreiter 1997.
Arguments:
units: Positive integer, dimensionality of the output space.
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
index fb743b617f..641b563a25 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
@@ -232,6 +232,7 @@ class RNNTest(test.TestCase):
cell = RNNCellWithConstants(32)
layer = keras.layers.RNN(cell)
y = layer(x, constants=c)
+
model = keras.models.Model([x, c], y)
model.compile(optimizer='rmsprop', loss='mse')
model.train_on_batch(
@@ -280,6 +281,20 @@ class RNNTest(test.TestCase):
)
with self.test_session():
+ # Test GRUCell reset_after property.
+ x = keras.Input((None, 5))
+ c = keras.Input((3,))
+ cells = [keras.layers.recurrent.GRUCell(32, reset_after=True)]
+ layer = keras.layers.recurrent.RNN(cells)
+ y = layer(x, constants=c)
+ model = keras.models.Model([x, c], y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ [np.zeros((6, 5, 5)), np.zeros((6, 3))],
+ np.zeros((6, 32))
+ )
+
+ with self.test_session():
# Test stacked RNN serialization
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
@@ -541,6 +556,5 @@ class RNNTest(test.TestCase):
[tuple(o.as_list()) for o in output_shape],
expected_output_shape)
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index 84ee5040dc..b45cafed31 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -49,6 +49,7 @@ from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution
from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose
from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution1D
from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D
+from tensorflow.python.keras._impl.keras.layers.convolutional import DepthwiseConv2D
# Image processing layers.
from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 68d446602e..fa26e07c85 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1566,6 +1566,16 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
# pylint: enable=invalid-name
+def _constant_if_small(value, shape, dtype, name):
+ try:
+ if np.prod(shape) < 1000:
+ return constant(value, shape=shape, dtype=dtype, name=name)
+ except TypeError:
+ # Happens when shape is a Tensor, list with Tensor elements, etc.
+ pass
+ return None
+
+
@tf_export("zeros")
def zeros(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to zero.
@@ -1596,8 +1606,15 @@ def zeros(shape, dtype=dtypes.float32, name=None):
zero = ""
else:
zero = 0
+
if not isinstance(shape, ops.Tensor):
try:
+ # Create a constant if it won't be very big. Otherwise create a fill op
+ # to prevent serialized GraphDefs from becoming too large.
+ output = _constant_if_small(zero, shape, dtype, name)
+ if output is not None:
+ return output
+
# Go through tensor shapes to get int64-if-needed semantics
shape = constant_op._tensor_shape_tensor_conversion_function(
tensor_shape.TensorShape(shape))
@@ -1729,6 +1746,12 @@ def ones(shape, dtype=dtypes.float32, name=None):
one = True if dtype == dtypes.bool else 1
if not isinstance(shape, ops.Tensor):
try:
+ # Create a constant if it won't be very big. Otherwise create a fill op
+ # to prevent serialized GraphDefs from becoming too large.
+ output = _constant_if_small(one, shape, dtype, name)
+ if output is not None:
+ return output
+
# Go through tensor shapes to get int64-if-needed semantics
shape = constant_op._tensor_shape_tensor_conversion_function(
tensor_shape.TensorShape(shape))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index b460ce5b95..01d670ea2d 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1402,10 +1402,11 @@ def reduce_sum(input_tensor,
keep_dims: Deprecated alias for `keepdims`.
Returns:
- The reduced tensor.
+ The reduced tensor, of the same dtype as the input_tensor.
@compatibility(numpy)
- Equivalent to np.sum
+ Equivalent to np.sum appart the fact that numpy upcast uint8 and int32 to
+ int64 while tensorflow returns the same dtype as the input.
@end_compatibility
"""
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 07e25e540c..508ba9bfee 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -72,7 +72,12 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
# know the shape and dtype of the variable pointed to by a handle. Since
# shape inference doesn't run in eager mode we copy this data here for when
# the handle is captured by an eager mode function.
- handle._handle_data = h._handle_data # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ if h._handle_data is None:
+ ops.set_shape_and_handle_data_for_outputs(h.op)
+ handle._handle_data = h._handle_data
+ # pylint: enable=protected-access
+
# Clean up our reference cycles to avoid making the garbage collector run.
# pylint: disable=protected-access
# OrderedDict, constructed on Graph creation, makes a simple reference loop
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 7acb8eeb1a..5ee55301df 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -120,9 +120,9 @@ limitations under the License.
}
%typemap(out) (TFE_Context*) {
- if ($1 == nullptr) {
- SWIG_fail;
- } else {
+ // When the TFE_Context* returned is a nullptr, we expect the status is not
+ // OK. This will raise an error (happens in another typemap).
+ if ($1 != nullptr) {
$result = PyCapsule_New($1, nullptr, TFE_DeleteContextCapsule);
}
}
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index aae757b99a..094a9e886b 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -859,6 +859,7 @@ class ProfilerHook(session_run_hook.SessionRunHook):
showing the sizes and lifetimes of tensors.
"""
self._output_file = os.path.join(output_dir, "timeline-{}.json")
+ self._file_writer = SummaryWriterCache.get(output_dir)
self._show_dataflow = show_dataflow
self._show_memory = show_memory
self._timer = SecondOrStepTimer(
@@ -889,6 +890,8 @@ class ProfilerHook(session_run_hook.SessionRunHook):
self._save(global_step,
self._output_file.format(global_step),
run_values.run_metadata.step_stats)
+ self._file_writer.add_run_metadata(run_values.run_metadata,
+ "step_%d" % global_step)
self._next_step = global_step + 1
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index 2547661e52..f39a5261a9 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -1274,6 +1274,19 @@ class ProfilerHookTest(test.TestCase):
sess.run(self.train_op) # Saved.
self.assertEqual(3, self._count_timeline_files())
+ def test_run_metadata_saves_in_first_step(self):
+ writer_cache.FileWriterCache.clear()
+ fake_summary_writer.FakeSummaryWriter.install()
+ fake_writer = writer_cache.FileWriterCache.get(self.output_dir)
+ with self.graph.as_default():
+ hook = basic_session_run_hooks.ProfilerHook(
+ save_secs=2, output_dir=self.output_dir)
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
+ sess.run(self.train_op) # Saved.
+ self.assertEqual(
+ list(fake_writer._added_run_metadata.keys()), ['step_1'])
+ fake_summary_writer.FakeSummaryWriter.uninstall()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 16e200d64d..c6b2dcdf98 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -1226,13 +1226,16 @@ _default_tower_mode = _DefaultTowerThreadMode()
# So here we catch any attempts to deserialize variables
# when using distribution strategies.
# pylint: disable=protected-access
+_original_from_proto = resource_variable_ops._from_proto_fn
+
+
def _from_proto_fn(v, import_scope=None):
if has_distribution_strategy():
raise NotImplementedError(
"Deserialization of variables is not yet supported when using"
"distributed strategies.")
else:
- resource_variable_ops._from_proto_fn(v, import_scope=import_scope)
+ return _original_from_proto(v, import_scope=import_scope)
resource_variable_ops._from_proto_fn = _from_proto_fn
# pylint: enable=protected-access
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index c563f8f931..1c550dbb13 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2076,12 +2076,6 @@ bool CUDABlas::DoBlasGemvWithProfilingImpl(
const DeviceMemory<T> &a, int lda, const DeviceMemory<T> &x, int incx,
const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result) {
- struct TimerDeleter {
- void operator()(CUDATimer *t) {
- t->Destroy();
- delete t;
- }
- };
std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
@@ -2114,12 +2108,6 @@ bool CUDABlas::DoBlasGemmWithProfilingImpl(
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result) {
- struct TimerDeleter {
- void operator()(CUDATimer *t) {
- t->Destroy();
- delete t;
- }
- };
std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
@@ -2188,12 +2176,6 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
return false;
}
- struct TimerDeleter {
- void operator()(CUDATimer *t) {
- t->Destroy();
- delete t;
- }
- };
std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index f408c06f46..3fd9275289 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -297,6 +297,9 @@ CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
namespace {
+// Forward declaration.
+cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
+
cudnnHandle_t ToHandle(void* opaque_handle) {
return static_cast<cudnnHandle_t>(opaque_handle);
}
@@ -381,6 +384,23 @@ port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
}
return port::Status::OK();
}
+
+cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) {
+ if (algorithm.is_default()) {
+ return CUDNN_RNN_ALGO_STANDARD;
+ } else {
+ cudnnRNNAlgo_t algo = static_cast<cudnnRNNAlgo_t>(algorithm.algo_id());
+ switch (algo) {
+ case CUDNN_RNN_ALGO_STANDARD:
+ case CUDNN_RNN_ALGO_PERSIST_STATIC:
+ case CUDNN_RNN_ALGO_PERSIST_DYNAMIC:
+ return algo;
+ default:
+ LOG(FATAL) << "Unsupported Cudnn RNN algorithm: "
+ << algorithm.algo_id();
+ }
+ }
+}
#endif
port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
@@ -1124,6 +1144,8 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
cudnnRNNInputMode_t input_mode,
cudnnDirectionMode_t direction_mode,
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
+ cudnnDataType_t compute_type,
+ const dnn::AlgorithmConfig& algorithm_config,
float dropout, uint64 seed,
ScratchAllocator* state_allocator)
: parent_(parent),
@@ -1134,7 +1156,9 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
input_mode_(input_mode),
direction_mode_(direction_mode),
rnn_mode_(rnn_mode),
- data_type_(data_type) {
+ data_type_(data_type),
+ compute_type_(compute_type),
+ algorithm_config_(algorithm_config) {
// Create the dropout handle.
cudnn_dropout_desc_.reset(new CudnnDropoutDescriptor(
parent, cudnn_handle, dropout, seed, state_allocator));
@@ -1148,18 +1172,20 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
#if CUDNN_VERSION >= 6000
// TODO: allow the user to choose an algorithm.
- cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
+ cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config_.algorithm());
status = wrap::cudnnSetRNNDescriptor_v6(
parent, cudnn_handle, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
input_mode /*inputMode*/, direction_mode /*direction*/,
- rnn_mode /*mode*/, rnn_algo /*algo*/, data_type /*dataType*/);
+ rnn_mode /*mode*/, rnn_algo /*algo*/, compute_type /*dataType*/);
#else
+ CHECK(algorithm_config_.is_default())
+ << "Non-default algorithm not supported for CUDA version < 6.0";
status = wrap::cudnnSetRNNDescriptor(
parent, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
input_mode /*inputMode*/, direction_mode /*direction*/,
- rnn_mode /*mode*/, data_type /*dataType*/);
+ rnn_mode /*mode*/, compute_type /*dataType*/);
#endif
CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
@@ -1170,9 +1196,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
SetFailure(cudnn_params_desc_->Status());
return;
}
- if (data_type == CUDNN_DATA_HALF) {
- set_use_tensor_op_math(true);
- }
+ set_use_tensor_op_math(algorithm_config_.algorithm().tensor_ops_enabled());
}
~CudnnRnnDescriptor() override {
if (rnn_desc_) {
@@ -1206,6 +1230,10 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
cudnnDataType_t data_type() const { return data_type_; }
+ cudnnDataType_t compute_type() const { return compute_type_; }
+ const dnn::AlgorithmConfig& algorithm_config() const {
+ return algorithm_config_;
+ }
int64 ParamsSizeInBytes() const override {
return cudnn_params_desc_->params_size_in_bytes();
}
@@ -1236,6 +1264,8 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
cudnnDirectionMode_t direction_mode_;
cudnnRNNMode_t rnn_mode_;
cudnnDataType_t data_type_;
+ cudnnDataType_t compute_type_;
+ dnn::AlgorithmConfig algorithm_config_;
std::unique_ptr<CudnnDropoutDescriptor> cudnn_dropout_desc_;
std::unique_ptr<CudnnRnnParamsDescriptor> cudnn_params_desc_;
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
@@ -1608,7 +1638,8 @@ bool CudnnSupport::DoRnnForwardImpl(
const CudnnRnnStateTensorDescriptor& output_c_desc,
DeviceMemory<T>* output_c_data, bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
// extract model parameters
RnnModelDims model_dims;
bool res = ExtractAndCheckRnnForward(
@@ -1665,9 +1696,24 @@ bool CudnnSupport::DoRnnForwardImpl(
}
}
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
+ const bool is_profiling = output_profile_result != nullptr;
+ if (is_profiling) {
+ timer.reset(new CUDATimer(parent_));
+ if (!timer->Init()) {
+ return false;
+ }
+ // The start and stop of the timer should be as close to the Cudnn call as
+ // possible. It is still possible for other threads to issue workload on
+ // to this stream. So it could take multiple profiling measurements.
+ if (!timer->Start(AsCUDAStream(stream))) {
+ return false;
+ }
+ }
// make the forward call
+ cudnnStatus_t status;
if (!is_training) {
- cudnnStatus_t status = wrap::cudnnRNNForwardInference(
+ status = wrap::cudnnRNNForwardInference(
parent_, ToHandle(dnn_handle_) /*handle*/,
rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
@@ -1679,13 +1725,8 @@ bool CudnnSupport::DoRnnForwardImpl(
output_c_desc.handle() /*cyDesc*/, output_c_data->opaque() /*cy*/,
workspace.opaque() /*workspace*/,
workspace.size() /*workSpaceSizeInBytes*/);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Failed to call cudnnRNNForwardInference: "
- << ToString(status);
- return false;
- }
} else {
- cudnnStatus_t status = wrap::cudnnRNNForwardTraining(
+ status = wrap::cudnnRNNForwardTraining(
parent_, ToHandle(dnn_handle_) /*handle*/,
rnn_desc.handle() /*rnnDesc*/, model_dims.seq_length /*seqLength*/,
input_desc.handles() /*xDesc*/, input_data.opaque() /*x*/,
@@ -1699,8 +1740,24 @@ bool CudnnSupport::DoRnnForwardImpl(
workspace.size() /*workSpaceSizeInBytes*/,
reserve_space.opaque() /*reserveSpace*/,
reserve_space.size() /*reserveSpaceSizeInBytes*/);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Failed to call cudnnRNNForwardTraining"
+ }
+ if (is_profiling) {
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return false;
+ }
+ if (status == CUDNN_STATUS_SUCCESS) {
+ auto algo_desc = rnn_desc.algorithm_config().algorithm();
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
+ }
+ }
+ if (status != CUDNN_STATUS_SUCCESS) {
+ // Silently return when we are profiling.
+ if (!is_profiling) {
+ LOG(ERROR) << "Failed to call "
+ << (is_training ? "cudnnRNNForwardTraining "
+ : "cudnnRNNForwardInference ")
<< ToString(status);
return false;
}
@@ -1732,7 +1789,8 @@ bool CudnnSupport::DoRnnBackwardImpl(
DeviceMemory<T>* input_c_backprop_data,
DeviceMemory<T>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
// extract model parameters
RnnModelDims model_dims;
bool res = ExtractAndCheckRnnForward(
@@ -1761,6 +1819,20 @@ bool CudnnSupport::DoRnnBackwardImpl(
return false;
}
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
+ const bool is_profiling = output_profile_result != nullptr;
+ if (is_profiling) {
+ timer.reset(new CUDATimer(parent_));
+ if (!timer->Init()) {
+ return false;
+ }
+ // The start and stop of the timer should be as close to the Cudnn call as
+ // possible. It is still possible for other threads to issue workload on
+ // to this stream. So it could take multiple profiling measurements.
+ if (!timer->Start(AsCUDAStream(stream))) {
+ return false;
+ }
+ }
// make the backward data call
cudnnStatus_t status = wrap::cudnnRNNBackwardData(
parent_, ToHandle(dnn_handle_) /*handle*/, rnn_desc.handle() /*rnnDesc*/,
@@ -1781,7 +1853,11 @@ bool CudnnSupport::DoRnnBackwardImpl(
workspace.size() /*workSpaceSizeInBytes*/,
reserve_space_data->opaque() /*reserveSpace*/,
reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
+
if (status != CUDNN_STATUS_SUCCESS) {
+ if (is_profiling) {
+ timer->Stop(AsCUDAStream(stream));
+ }
LOG(ERROR) << "Failed to call cudnnRNNBackwardData: " << ToString(status);
return false;
}
@@ -1803,11 +1879,23 @@ bool CudnnSupport::DoRnnBackwardImpl(
reserve_space_data->opaque() /*reserveSpace*/,
reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
if (status != CUDNN_STATUS_SUCCESS) {
+ if (is_profiling) {
+ timer->Stop(AsCUDAStream(stream));
+ }
LOG(ERROR) << "Failed to call cudnnRNNBackwardWeights: "
<< ToString(status);
return false;
}
}
+ if (is_profiling) {
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return false;
+ }
+ auto algo_desc = rnn_desc.algorithm_config().algorithm();
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
+ }
return true;
}
@@ -1819,15 +1907,17 @@ CudnnSupport::createRnnDescriptor(int num_layers, int hidden_size,
int input_size, dnn::RnnInputMode input_mode,
dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode,
- dnn::DataType data_type, float dropout,
- uint64 seed,
+ dnn::DataType data_type,
+ const dnn::AlgorithmConfig& algorithm_config,
+ float dropout, uint64 seed,
ScratchAllocator* state_allocator) {
#if CUDNN_VERSION >= 5000
mutex_lock lock{dnn_handle_mutex_};
std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor(
parent_, ToHandle(dnn_handle_), num_layers, hidden_size, input_size,
ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode),
- ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type), dropout, seed,
+ ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type),
+ GetRnnComputeType(data_type), algorithm_config, dropout, seed,
state_allocator));
if (!rnn_desc->ok()) {
return rnn_desc->Status();
@@ -1904,7 +1994,8 @@ bool CudnnSupport::DoRnnForward(
const dnn::RnnStateTensorDescriptor& output_c_desc,
DeviceMemory<Eigen::half>* output_c_data, bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
@@ -1925,7 +2016,8 @@ bool CudnnSupport::DoRnnForward(
stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator);
+ output_c_data, is_training, reserve_space_allocator, workspace_allocator,
+ output_profile_result);
#else
return false;
#endif // CUDNN_VERSION
@@ -1946,7 +2038,8 @@ bool CudnnSupport::DoRnnForward(
const dnn::RnnStateTensorDescriptor& output_c_desc,
DeviceMemory<float>* output_c_data, bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
@@ -1967,7 +2060,8 @@ bool CudnnSupport::DoRnnForward(
stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator);
+ output_c_data, is_training, reserve_space_allocator, workspace_allocator,
+ output_profile_result);
#else
return false;
#endif // CUDNN_VERSION
@@ -1989,7 +2083,8 @@ bool CudnnSupport::DoRnnForward(
const dnn::RnnStateTensorDescriptor& output_c_desc,
DeviceMemory<double>* output_c_data, bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
@@ -2010,7 +2105,8 @@ bool CudnnSupport::DoRnnForward(
stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator);
+ output_c_data, is_training, reserve_space_allocator, workspace_allocator,
+ output_profile_result);
#else
return false;
#endif // CUDNN_VERSION
@@ -2039,7 +2135,8 @@ bool CudnnSupport::DoRnnBackward(
DeviceMemory<Eigen::half>* input_c_backprop_data,
DeviceMemory<Eigen::half>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
@@ -2063,7 +2160,7 @@ bool CudnnSupport::DoRnnBackward(
output_c_data, output_backprop_data, output_h_backprop_data,
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator);
+ workspace_allocator, output_profile_result);
#else
return false;
#endif // CUDNN_VERSION
@@ -2091,7 +2188,8 @@ bool CudnnSupport::DoRnnBackward(
DeviceMemory<float>* input_c_backprop_data,
DeviceMemory<float>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
@@ -2115,7 +2213,7 @@ bool CudnnSupport::DoRnnBackward(
output_c_data, output_backprop_data, output_h_backprop_data,
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator);
+ workspace_allocator, output_profile_result);
#else
return false;
#endif // CUDNN_VERSION
@@ -2144,7 +2242,8 @@ bool CudnnSupport::DoRnnBackward(
DeviceMemory<double>* input_c_backprop_data,
DeviceMemory<double>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
#if CUDNN_VERSION >= 5000
const CudnnRnnDescriptor& cudnn_rnn_desc =
static_cast<const CudnnRnnDescriptor&>(rnn_desc);
@@ -2168,7 +2267,7 @@ bool CudnnSupport::DoRnnBackward(
output_c_data, output_backprop_data, output_h_backprop_data,
output_c_backprop_data, input_backprop_data, input_h_backprop_data,
input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator);
+ workspace_allocator, output_profile_result);
#else
return false;
#endif // CUDNN_VERSION
@@ -2363,6 +2462,33 @@ cudnnDataType_t GetConvComputeType<double>() {
return CUDNN_DATA_DOUBLE;
}
+// A helper struct to decide whether to use FP32 as the internal compute type
+// for rnn when the input data type is FP16. By default it is turned on,
+// users can explicitly disable them (choose to use FP16 as the internal compute
+// type) through an env-var "TF_FP16_RNN_USE_FP32_COMPUTE=0".
+struct RnnDoFP32ComputationFP16Input {
+ static constexpr const char* kName = "TF_FP16_RNN_USE_FP32_COMPUTE";
+ static constexpr bool kDefaultFlag = true;
+};
+
+// A helper function to return the internal compute type for
+// RNNs in cudnn.
+cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
+ switch (data_type) {
+ case dnn::DataType::kFloat:
+ return CUDNN_DATA_FLOAT;
+ case dnn::DataType::kDouble:
+ return CUDNN_DATA_DOUBLE;
+ case dnn::DataType::kHalf:
+ if (CudnnEnvVar<RnnDoFP32ComputationFP16Input>::IsEnabled()) {
+ return CUDNN_DATA_FLOAT;
+ } else {
+ return CUDNN_DATA_HALF;
+ }
+ default:
+ LOG(FATAL) << "Invalid RNN data type: " << static_cast<int>(data_type);
+ }
+}
} // namespace
template <class T>
@@ -2742,6 +2868,30 @@ bool CudnnSupport::GetConvolveAlgorithms(
return true;
}
+bool CudnnSupport::GetRnnAlgorithms(
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index> algo_types = {
+ // clang-format off
+#if CUDNN_VERSION >= 6000
+ CUDNN_RNN_ALGO_STANDARD,
+ CUDNN_RNN_ALGO_PERSIST_STATIC,
+ CUDNN_RNN_ALGO_PERSIST_DYNAMIC,
+#endif
+ // clang-format on
+ };
+
+ out_algorithms->clear();
+ for (auto i : algo_types) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/false});
+#if CUDNN_VERSION >= 7100
+ if (RnnTensorOpMathEnabled()) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/true});
+ }
+#endif
+ }
+ return true;
+}
+
bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 48d56f71e3..e40ba9b012 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -50,8 +50,9 @@ class CudnnSupport : public dnn::DnnSupport {
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
int num_layers, int hidden_size, int input_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
- dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout,
- uint64 seed, ScratchAllocator* state_allocator) override;
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) override;
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
@@ -77,7 +78,8 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::RnnStateTensorDescriptor& output_c_desc,
DeviceMemory<Eigen::half>* output_c_data, bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) override;
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) override;
bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
const dnn::RnnSequenceTensorDescriptor& input_desc,
@@ -94,7 +96,8 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::RnnStateTensorDescriptor& output_c_desc,
DeviceMemory<float>* output_c_data, bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) override;
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) override;
bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
const dnn::RnnSequenceTensorDescriptor& input_desc,
@@ -111,7 +114,8 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::RnnStateTensorDescriptor& output_c_desc,
DeviceMemory<double>* output_c_data, bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) override;
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) override;
bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
const dnn::RnnSequenceTensorDescriptor& input_desc,
@@ -135,7 +139,8 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<Eigen::half>* input_c_backprop_data,
DeviceMemory<Eigen::half>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) override;
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) override;
bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
const dnn::RnnSequenceTensorDescriptor& input_desc,
@@ -159,7 +164,8 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<float>* input_c_backprop_data,
DeviceMemory<float>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) override;
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) override;
bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
const dnn::RnnSequenceTensorDescriptor& input_desc,
@@ -183,12 +189,16 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<double>* input_c_backprop_data,
DeviceMemory<double>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) override;
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) override;
bool GetConvolveAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
+ bool GetRnnAlgorithms(
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
+
bool GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
@@ -746,7 +756,8 @@ class CudnnSupport : public dnn::DnnSupport {
const CudnnRnnStateTensorDescriptor& output_c_desc,
DeviceMemory<T>* output_c_data, bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator);
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result);
template <class T>
bool DoRnnBackwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc,
@@ -771,7 +782,8 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<T>* input_c_backprop_data,
DeviceMemory<T>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator);
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result);
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
};
diff --git a/tensorflow/stream_executor/cuda/cuda_timer.h b/tensorflow/stream_executor/cuda/cuda_timer.h
index 4a2714dc1f..2abc55ec94 100644
--- a/tensorflow/stream_executor/cuda/cuda_timer.h
+++ b/tensorflow/stream_executor/cuda/cuda_timer.h
@@ -77,6 +77,13 @@ class CUDATimer : public internal::TimerInterface {
// executing in a stream.
};
+struct TimerDeleter {
+ void operator()(CUDATimer *t) {
+ t->Destroy();
+ delete t;
+ }
+};
+
} // namespace cuda
} // namespace gputools
} // namespace perftools
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 44144a0613..0a3c4bcf50 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -28,6 +28,10 @@ bool DnnSupport::GetConvolveAlgorithms(
return false;
}
+bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
+ return false;
+}
+
bool DnnSupport::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<AlgorithmDesc>* out_algorithms) {
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index b41536e638..43cfd313c1 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -1195,6 +1195,9 @@ class DnnSupport {
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<AlgorithmDesc>* out_algorithms);
+ // Returns a list of supported rnn algorithms.
+ virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);
+
// Version of DoConvolve that uses pre-quantized 8 bit coefficients.
// coefficient_scales specifies the scaling of each column of coefficients:
// original float coefficient[row * num_columns + column] =
@@ -2001,6 +2004,7 @@ class DnnSupport {
dnn::RnnInputMode input_mode,
dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ const dnn::AlgorithmConfig& algorithm_config,
float dropout, uint64 seed,
ScratchAllocator* state_allocator) {
return port::Status{port::error::UNIMPLEMENTED,
@@ -2076,7 +2080,8 @@ class DnnSupport {
DeviceMemory<Eigen::half>* output_c_data,
bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
return false;
}
@@ -2096,7 +2101,8 @@ class DnnSupport {
DeviceMemory<float>* output_c_data,
bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
return false;
}
@@ -2116,7 +2122,8 @@ class DnnSupport {
DeviceMemory<double>* output_c_data,
bool is_training,
ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
return false;
}
// Enqueue a backward operation of the RNN model onto the stream.
@@ -2183,7 +2190,8 @@ class DnnSupport {
DeviceMemory<Eigen::half>* input_c_backprop_data,
DeviceMemory<Eigen::half>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
return false;
}
@@ -2210,7 +2218,8 @@ class DnnSupport {
DeviceMemory<float>* input_c_backprop_data,
DeviceMemory<float>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
return false;
}
@@ -2237,7 +2246,8 @@ class DnnSupport {
DeviceMemory<double>* input_c_backprop_data,
DeviceMemory<double>* params_backprop_data,
DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator) {
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result) {
return false;
}
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 1e3afde268..fe498507a8 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -4795,7 +4795,8 @@ Stream &Stream::ThenRnnForward(
const dnn::RnnStateTensorDescriptor &output_c_desc,
DeviceMemory<Eigen::half> *output_c_data, bool is_training,
ScratchAllocator *reserve_space_allocator,
- ScratchAllocator *workspace_allocator) {
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result) {
// TODO(zhengxq): add VLOG PARAM calls.
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
@@ -4803,7 +4804,8 @@ Stream &Stream::ThenRnnForward(
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
input_c_desc, input_c_data, params, output_desc, output_data,
output_h_desc, output_h_data, output_c_desc, output_c_data,
- is_training, reserve_space_allocator, workspace_allocator));
+ is_training, reserve_space_allocator, workspace_allocator,
+ output_profile_result));
} else {
SetError();
LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
@@ -4827,7 +4829,8 @@ Stream &Stream::ThenRnnForward(
const dnn::RnnStateTensorDescriptor &output_c_desc,
DeviceMemory<float> *output_c_data, bool is_training,
ScratchAllocator *reserve_space_allocator,
- ScratchAllocator *workspace_allocator) {
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result) {
// TODO(zhengxq): add VLOG PARAM calls.
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
@@ -4835,7 +4838,8 @@ Stream &Stream::ThenRnnForward(
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
input_c_desc, input_c_data, params, output_desc, output_data,
output_h_desc, output_h_data, output_c_desc, output_c_data,
- is_training, reserve_space_allocator, workspace_allocator));
+ is_training, reserve_space_allocator, workspace_allocator,
+ output_profile_result));
} else {
SetError();
LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
@@ -4860,7 +4864,8 @@ Stream &Stream::ThenRnnForward(
const dnn::RnnStateTensorDescriptor &output_c_desc,
DeviceMemory<double> *output_c_data, bool is_training,
ScratchAllocator *reserve_space_allocator,
- ScratchAllocator *workspace_allocator) {
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result) {
// TODO(zhengxq): add VLOG PARAM calls.
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
@@ -4868,7 +4873,8 @@ Stream &Stream::ThenRnnForward(
this, rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
input_c_desc, input_c_data, params, output_desc, output_data,
output_h_desc, output_h_data, output_c_desc, output_c_data,
- is_training, reserve_space_allocator, workspace_allocator));
+ is_training, reserve_space_allocator, workspace_allocator,
+ output_profile_result));
} else {
SetError();
LOG(WARNING) << "Attempting to call ThenRnnForward without DNN support";
@@ -4900,7 +4906,8 @@ Stream &Stream::ThenRnnBackward(
DeviceMemory<Eigen::half> *input_c_backprop_data,
DeviceMemory<Eigen::half> *params_backprop_data,
DeviceMemory<uint8> *reserve_space_data,
- ScratchAllocator *workspace_allocator) {
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result) {
// TODO(zhengxq): add VLOG PARAM calls.
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
@@ -4910,7 +4917,8 @@ Stream &Stream::ThenRnnBackward(
output_h_desc, output_h_data, output_c_desc, output_c_data,
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
- params_backprop_data, reserve_space_data, workspace_allocator));
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result));
} else {
SetError();
LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
@@ -4941,7 +4949,8 @@ Stream &Stream::ThenRnnBackward(
DeviceMemory<float> *input_c_backprop_data,
DeviceMemory<float> *params_backprop_data,
DeviceMemory<uint8> *reserve_space_data,
- ScratchAllocator *workspace_allocator) {
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result) {
// TODO(zhengxq): add VLOG PARAM calls.
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
@@ -4951,7 +4960,8 @@ Stream &Stream::ThenRnnBackward(
output_h_desc, output_h_data, output_c_desc, output_c_data,
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
- params_backprop_data, reserve_space_data, workspace_allocator));
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result));
} else {
SetError();
LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
@@ -4983,7 +4993,8 @@ Stream &Stream::ThenRnnBackward(
DeviceMemory<double> *input_c_backprop_data,
DeviceMemory<double> *params_backprop_data,
DeviceMemory<uint8> *reserve_space_data,
- ScratchAllocator *workspace_allocator) {
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result) {
// TODO(zhengxq): add VLOG PARAM calls.
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
@@ -4993,7 +5004,8 @@ Stream &Stream::ThenRnnBackward(
output_h_desc, output_h_data, output_c_desc, output_c_data,
output_backprop_data, output_h_backprop_data, output_c_backprop_data,
input_backprop_data, input_h_backprop_data, input_c_backprop_data,
- params_backprop_data, reserve_space_data, workspace_allocator));
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result));
} else {
SetError();
LOG(WARNING) << "Attempting to call ThenRnnBackward without DNN support";
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index d7d1131569..4af426001f 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1802,7 +1802,8 @@ class Stream {
DeviceMemory<Eigen::half> *output_c_data,
bool is_training,
ScratchAllocator *reserve_space_allocator,
- ScratchAllocator *workspace_allocator);
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result);
Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
@@ -1819,7 +1820,8 @@ class Stream {
const dnn::RnnStateTensorDescriptor &output_c_desc,
DeviceMemory<float> *output_c_data, bool is_training,
ScratchAllocator *reserve_space_allocator,
- ScratchAllocator *workspace_allocator);
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result);
Stream &ThenRnnForward(const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
@@ -1836,7 +1838,8 @@ class Stream {
const dnn::RnnStateTensorDescriptor &output_c_desc,
DeviceMemory<double> *output_c_data, bool is_training,
ScratchAllocator *reserve_space_allocator,
- ScratchAllocator *workspace_allocator);
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result);
// Enqueue a backward operation of the RNN model onto the stream.
// See DnnSupport::DoRnnBackward for more details.
@@ -1863,7 +1866,8 @@ class Stream {
DeviceMemory<Eigen::half> *input_c_backprop_data,
DeviceMemory<Eigen::half> *params_backprop_data,
DeviceMemory<uint8> *reserve_space_data,
- ScratchAllocator *workspace_allocator);
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result);
Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
@@ -1887,7 +1891,8 @@ class Stream {
DeviceMemory<float> *input_c_backprop_data,
DeviceMemory<float> *params_backprop_data,
DeviceMemory<uint8> *reserve_space_data,
- ScratchAllocator *workspace_allocator);
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result);
Stream &ThenRnnBackward(const dnn::RnnDescriptor &rnn_desc,
const dnn::RnnSequenceTensorDescriptor &input_desc,
@@ -1911,7 +1916,8 @@ class Stream {
DeviceMemory<double> *input_c_backprop_data,
DeviceMemory<double> *params_backprop_data,
DeviceMemory<uint8> *reserve_space_data,
- ScratchAllocator *workspace_allocator);
+ ScratchAllocator *workspace_allocator,
+ dnn::ProfileResult *output_profile_result);
// Enqueue onto the stream a operation that transforms a tensor.
// See DnnSupport::DoTransformTensor for more details.
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index afca1c2e59..f55fa68402 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -305,6 +305,15 @@ bool StreamExecutor::GetConvolveAlgorithms(
cc_minor, out_algorithms);
}
+bool StreamExecutor::GetRnnAlgorithms(
+ std::vector<dnn::AlgorithmDesc> *out_algorithms) {
+ dnn::DnnSupport *dnn_support = AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
+ return dnn_support->GetRnnAlgorithms(out_algorithms);
+}
+
bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
std::vector<dnn::AlgorithmDesc> *out_algorithms) {
@@ -344,7 +353,8 @@ port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
int num_layers, int hidden_size, int input_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
- dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout, uint64 seed,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
ScratchAllocator *state_allocator) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
@@ -353,7 +363,7 @@ StreamExecutor::createRnnDescriptor(
}
return dnn_support->createRnnDescriptor(
num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
- data_type, dropout, seed, state_allocator);
+ data_type, algorithm_config, dropout, seed, state_allocator);
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index a2a77218cb..69d0374d73 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -349,10 +349,14 @@ class StreamExecutor {
// platform that underlies this interface.
bool SupportsDnn() const;
- // Get the list of supported algorithms for the forward convolution opeartion.
+ // Returns the list of supported algorithms for the forward convolution
+ // operation.
bool GetConvolveAlgorithms(bool with_winograd_nonfused,
std::vector<dnn::AlgorithmDesc> *out_algorithms);
+ // Returns the list of supported algorithms for rnn operation.
+ bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
+
// Get the list of supported algorithms for the backward convolution on data.
bool GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
@@ -372,8 +376,9 @@ class StreamExecutor {
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
int num_layers, int hidden_size, int input_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
- dnn::RnnMode rnn_mode, dnn::DataType data_type, float dropout,
- uint64 seed, ScratchAllocator *state_allocator);
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
+ ScratchAllocator *state_allocator);
// Create a RNN sequence descriptor that specifies either the input or output
// sequence. The caller retains the ownership of the returned descriptor.
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index 6a7da1aef8..a535f18170 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -1,21 +1,53 @@
path: "tensorflow.keras.layers.ConvLSTM2D"
tf_class {
is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional_recurrent.ConvLSTM2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional_recurrent.ConvRecurrent2D\'>"
- is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional_recurrent.ConvRNN2D\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
+ name: "activation"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
+ name: "bias_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "bias_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "bias_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "data_format"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dilation_rate"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dropout"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
+ name: "filters"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "graph"
mtype: "<type \'property\'>"
}
@@ -36,6 +68,22 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "kernel_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "kernel_size"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "losses"
mtype: "<type \'property\'>"
}
@@ -68,10 +116,42 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "padding"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_activation"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_constraint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_dropout"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_initializer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "recurrent_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "scope_name"
mtype: "<type \'property\'>"
}
member {
+ name: "states"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "strides"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -80,10 +160,18 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "unit_forget_bias"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
+ name: "use_bias"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "variables"
mtype: "<type \'property\'>"
}
@@ -144,10 +232,6 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_constants"
- argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "get_initial_state"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
@@ -188,27 +272,11 @@ tf_class {
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "input_conv"
- argspec: "args=[\'self\', \'x\', \'w\', \'b\', \'padding\'], varargs=None, keywords=None, defaults=[\'None\', \'valid\'], "
- }
- member_method {
- name: "preprocess_input"
- argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "recurrent_conv"
- argspec: "args=[\'self\', \'x\', \'w\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
- member_method {
- name: "step"
- argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
- }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
new file mode 100644
index 0000000000..b38716aa2c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
@@ -0,0 +1,187 @@
+path: "tensorflow.keras.layers.DepthwiseConv2D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.DepthwiseConv2D\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.convolutional.Conv2D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional.Conv2D\'>"
+ is_instance: "<class \'tensorflow.python.layers.convolutional._Conv\'>"
+ is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.base_layer.Layer\'>"
+ is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.training.checkpointable.CheckpointableBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "scope_name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'kernel_size\', \'strides\', \'padding\', \'depth_multiplier\', \'data_format\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'1\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 1fd3febad2..4274b8d425 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -91,7 +91,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'reset_after\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index f5f41d879d..8d9f06083c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -123,6 +123,10 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
+ name: "reset_after"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "scope_name"
mtype: "<type \'property\'>"
}
@@ -160,7 +164,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
+ argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'reset_after\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
index 088c8e88e2..affc9bd09b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
@@ -117,6 +117,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "DepthwiseConv2D"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Dot"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index 438c5d52f6..5e9ae497e1 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -42,6 +42,14 @@ source "tensorflow/tools/ci_build/windows/bazel/common_env.sh" \
source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \
|| { echo "Failed to source bazel_test_lib.sh" >&2; exit 1; }
+skip_test=0
+
+for ARG in "$@"; do
+ if [[ "$ARG" == --skip_test ]]; then
+ skip_test=1
+ fi
+done
+
run_configure_for_cpu_build
# --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
@@ -49,6 +57,10 @@ run_configure_for_cpu_build
BUILD_OPTS="--define=override_eigen_strong_inline=true"
bazel build -c opt $BUILD_OPTS tensorflow/tools/pip_package:build_pip_package || exit $?
+if [[ "$skip_test" == 1 ]]; then
+ exit 0
+fi
+
# Create a python test directory to avoid package name conflict
PY_TEST_DIR="py_test_dir"
create_python_test_dir "${PY_TEST_DIR}"
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 26523bb020..018a395063 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -221,11 +221,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "png_archive",
urls = [
- "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.2.53.tar.gz",
- "https://github.com/glennrp/libpng/archive/v1.2.53.tar.gz",
+ "https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
+ "https://github.com/glennrp/libpng/archive/v1.6.34.tar.gz",
],
- sha256 = "716c59c7dfc808a4c368f8ada526932be72b2fcea11dd85dc9d88b1df1dfe9c2",
- strip_prefix = "libpng-1.2.53",
+ sha256 = "e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef",
+ strip_prefix = "libpng-1.6.34",
build_file = clean_dep("//third_party:png.BUILD"),
)
diff --git a/third_party/png.BUILD b/third_party/png.BUILD
index 6a7ad719aa..76ab32d69c 100644
--- a/third_party/png.BUILD
+++ b/third_party/png.BUILD
@@ -9,15 +9,20 @@ cc_library(
name = "png",
srcs = [
"png.c",
+ "pngdebug.h",
"pngerror.c",
"pngget.c",
+ "pnginfo.h",
+ "pnglibconf.h",
"pngmem.c",
"pngpread.c",
+ "pngpriv.h",
"pngread.c",
"pngrio.c",
"pngrtran.c",
"pngrutil.c",
"pngset.c",
+ "pngstruct.h",
"pngtrans.c",
"pngwio.c",
"pngwrite.c",
@@ -33,3 +38,10 @@ cc_library(
visibility = ["//visibility:public"],
deps = ["@zlib_archive//:zlib"],
)
+
+genrule(
+ name = "snappy_stubs_public_h",
+ srcs = ["scripts/pnglibconf.h.prebuilt"],
+ outs = ["pnglibconf.h"],
+ cmd = "sed -e 's/PNG_ZLIB_VERNUM 0/PNG_ZLIB_VERNUM 0x12b0/' $< >$@",
+)