aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Frank Chen <frankchn@google.com>2017-10-06 11:44:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-06 11:59:53 -0700
commit5eaefbabce16bffeeb4b19cee9890b1aeccabb09 (patch)
treef207e219167af0c4e9676a3ee51d629f4d85d828
parent3110185270e93e0b6a3e82be9199febed1239602 (diff)
Merge changes from github.
END_PUBLIC --- Commit ee0fdc296 authored by Gunhan Gulsoy<gunan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add noasan tag to estimator_test PiperOrigin-RevId: 171075499 --- Commit a02116882 authored by Justin Lebar<jlebar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA:CPU] Put the HLO name in IR values that hold the HLO's value. PiperOrigin-RevId: 171075449 --- Commit 89aaac4bc authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Allow Layer.add_update() in Eager mode. PiperOrigin-RevId: 171070861 --- Commit 840dcae57 authored by Amit Patankar<amitpatankar@google.com> Committed by gunan<gunan@google.com>: Updating the install sources file with a supported configs table (#13450) * Updating the install sources file with a supported configs page. * Implementing Gunan's suggestions. * Adding GCC string to Linux compiler. * Updating the bazel/cmake column. --- Commit 89df2e336 authored by Igor Saprykin<isaprykin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add the 'is_the_final_export' signal to Exporters. Use them in training. When the training ends, the final export is performed via `Exporter.export()` call. That final export is going to have is_the_final_export parameter being set to true. If `TrainSpec.max_steps` is `None`, then "when training ends" is undefined. We are going to train forever. In that case, `is_the_final_export` is going to be always False. I added a note about it. PiperOrigin-RevId: 171070760 --- Commit 4486b4f69 authored by Akshay Agrawal<akshayka@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make graph_callable compatible with functions that do not return anything PiperOrigin-RevId: 171067061 --- Commit 39565c0cb authored by Martin Wicke<wicke@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Publish train_and_evaluate and associated classes. PiperOrigin-RevId: 171066379 --- Commit 3b4477000 authored by Saurabh Saxena<srbs@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make VariantTensorData::tensors_size() const. PiperOrigin-RevId: 171063397 --- Commit 53cc63a2d authored by Dhananjay Nakrani<dhananjayn@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [part 1] Add support for int32 & int64 in RandomPoissonOp. This computes int32/int64-precision poisson samples with double precision intermediate calculations (same as it's done for `half`) respectively. part 2 will switch over python calls to new op once forward compatibility period has passed. PiperOrigin-RevId: 171058336 --- Commit 70fc9bf9b authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Java: Add support for loading op libraries dynamically. This change adds the equivalent of tf.load_op_library in Python to Java. (https://github.com/tensorflow/tensorflow/commit/5c7f9e316d8c7735308a217310350d416d7498cc was required to make this possible) Though, TensorFlow.loadLibrary() is likely to fail on Windows as symbols required by custom op libraries (those exported by the tensorflow_framework library) are not exported by the monolithic JNI library yet. This should help with #10454 and #13476 PiperOrigin-RevId: 171054707 --- Commit e7c53698e authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Internal cleanup PiperOrigin-RevId: 171053770 --- Commit cc8ee6c0f authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fast path for tf.conj when it should be pass-through. PiperOrigin-RevId: 171053662 --- Commit c41dbc3c1 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Adding TF Boosted trees regression example on boston dataset, minor fix for mnist example. PiperOrigin-RevId: 171052367 --- Commit d66e77f7c authored by Mustafa Ispir<ispir@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Added get variable utils to tf.estimator.Estimator. PiperOrigin-RevId: 171052121 --- Commit 083bd5dde authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Java: Add support for loading op libraries dynamically. This change adds the equivalent of tf.load_op_library in Python to Java. (https://github.com/tensorflow/tensorflow/commit/5c7f9e316d8c7735308a217310350d416d7498cc was required to make this possible) Though, TensorFlow.loadLibrary() is likely to fail on Windows as symbols required by custom op libraries (those exported by the tensorflow_framework library) are not exported by the monolithic JNI library yet. This should help with #10454 and #13476 PiperOrigin-RevId: 171054707 --- Commit 2fe6cf285 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Internal cleanup PiperOrigin-RevId: 171053770 --- Commit 15155493b authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fast path for tf.conj when it should be pass-through. PiperOrigin-RevId: 171053662 --- Commit 6c954d0b3 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Adding TF Boosted trees regression example on boston dataset, minor fix for mnist example. PiperOrigin-RevId: 171052367 --- Commit ad69076eb authored by Mustafa Ispir<ispir@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Added get variable utils to tf.estimator.Estimator. PiperOrigin-RevId: 171052121 --- Commit 3cf41b2ed authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Test save/restore variable from graph_callable. PiperOrigin-RevId: 171051237 --- Commit cf17ec96e authored by Yangzihao Wang<yangzihao@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add V2 versions of output window size computation functions for convolution. These V2 versions take arbitrary dilation rates. In preparation for the support of native cudnn dilated convolution. PiperOrigin-RevId: 171048878 --- Commit 491584ff4 authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: eager: Always run dataset iterator operations on CPU. It has no kernels for other devices. With an explicit "tf.device()" before invoking the kernel we ensure that Iterator.next() functions even when placed inside a: with tf.device("/device:GPU:0") PiperOrigin-RevId: 171048558 --- Commit 3b354016e authored by Igor Saprykin<isaprykin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Rename SavedModelExporter to LatestExporter. PiperOrigin-RevId: 171048345 --- Commit 943c6d7af authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: errors out if the evaluator has task id > 0. PiperOrigin-RevId: 171047652 --- Commit 8c9ef4466 authored by Mark Heffernan<meheff@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Expand set of 64-bit type tests in LocalClientExecuteTest.ShapeBufferToLiteralConversion64bit and factor out into their own test. PiperOrigin-RevId: 171043047 --- Commit cc521eb06 authored by Benoit Steiner<bsteiner@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Place all the nodes created by the trivial_test_graph_input_yielder PiperOrigin-RevId: 171045878 --- Commit 9b9301240 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA:CPU] Factor out parallel task assignment from cpu parallelization prep (no functional changes). PiperOrigin-RevId: 171045137 --- Commit 558d878d9 authored by Allen Lavoie<allenl@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: TFTS: Move normalization to the base class, start using it for state space models Preivously, state space models adjusted their priors based on the data (e.g. setting initial variances to match sample variance) but did not normalize the data itself. When the data has a rather extreme scale, this runs into precision issues. After this CL, state space models will first normalize, then use adjusted statistics on top of that normalization to estimate initial observation/transition noise. Also fixes an issue where start-of-series statistics were incorrect for the first batch (which only shows up with large input scales). PiperOrigin-RevId: 171044863 --- Commit 266f77156 authored by Mark Heffernan<meheff@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Expand set of 64-bit type tests in LocalClientExecuteTest.ShapeBufferToLiteralConversion64bit and factor out into their own test. PiperOrigin-RevId: 171043047 --- Commit c9915d1a2 authored by Shanqing Cai<cais@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [tf-signal] Fix pip tests by including test_util in signal_py PiperOrigin-RevId: 171042732 --- Commit f8550f4e9 authored by Mark Heffernan<meheff@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Expand set of 64-bit type tests in LocalClientExecuteTest.ShapeBufferToLiteralConversion64bit and factor out into their own test. PiperOrigin-RevId: 171043047 --- Commit 87dc532cd authored by Shanqing Cai<cais@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [tf-signal] Fix pip tests by including test_util in signal_py PiperOrigin-RevId: 171042732 --- Commit 0578dd65e authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add more debugging output for XLA send/recv. PiperOrigin-RevId: 171041978 --- Commit 23992bb09 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Several minor documentation fixes. PiperOrigin-RevId: 171038610 --- Commit af14ed3f3 authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Some docstring twists and argument validations. PiperOrigin-RevId: 171037949 --- Commit 6b90a65f6 authored by Mark Heffernan<meheff@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Remove "hybrid" HloModuleConfig option. The option was used to generate executables which only generated the array values of tuple-shaped outputs, not the tuple index tables.. With cl/170133015, ShapedBuffers which hold the computation output now have materialized tuples with these index tables so this option is no longer desired or necessary. No functional change. Just cleanup. PiperOrigin-RevId: 171035738 --- Commit 41a0264ab authored by Mustafa Ispir<ispir@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Added utilities to make global step reading deterministic. Used them in Estimator. Enabled/Fixed some tests. PiperOrigin-RevId: 171035291 --- Commit 9d7843c0a authored by Skye Wanderman-Milne<skyewm@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add optional unused_input_map_keys output param to ImportGraphDef This is a more general feature than that in the Python importer, which raises an exception if the input map contains unused names. PiperOrigin-RevId: 171029316 --- Commit 4f10a6597 authored by Mark Heffernan<meheff@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add vlogging of HloModule before and after fusion. PiperOrigin-RevId: 171029054 --- Commit 9e658545a authored by Reed Wanderman-Milne<reedwm@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Document what dtype tf.image.resize_images returns. For consistency, tf.image.resize_images now will always return a float32 when method != ResizeMethod.NEAREST_NEIGHBOR. Before, it returned the same dtype as its input if it could be determined statically that the height and width would not be changed. PiperOrigin-RevId: 171028825 --- Commit 4d70239f0 authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Replace the contrib FC with core FC in canned Estimator docstring. PiperOrigin-RevId: 171027602 --- Commit 6a1b867ff authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Adds the docstring with details for tf.estimator.train_and_evaluate PiperOrigin-RevId: 171027527 --- Commit 7209c1602 authored by Peter Hawkins<phawkins@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [TF:XLA] Mark IdentityN as CompilationOnly(). PiperOrigin-RevId: 171025171 --- Commit 8e22eb874 authored by FAIJUL<md.faijul.amin@intel.com> Committed by Benoit Steiner<benoitsteiner@users.noreply.github.com>: Eigen BiasAdd and BiasAddGrad Fix for NCHW Format. (#13158) --- Commit 7db7a890c authored by Jingyue Wu<jingyue@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [Grappler] Move InferOutputShapes to GraphProperties. So it can be used by other optimizers. No functional changes. PiperOrigin-RevId: 171010106 --- Commit 2114fd51e authored by Peter Hawkins<phawkins@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [TF:XLA] Improve numerical stability of SoftPlus. PiperOrigin-RevId: 171003559 --- Commit 727d6270f authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix race condition in TensorForest tree traversal. PiperOrigin-RevId: 170990425 --- Commit d016cb020 authored by Suharsh Sivakumar<suharshs@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix c++ gradients issue where multiple dependent outputs result in incorrect answer. The issue is that we incorrectly calculate the pending num_expected_backprops for outputs nodes when one output transitively depends on another. this is because we use output nodes as an indicator of when we need to end our traversal. Instead we should only use output nodes that don't transitively get consumed by other output nodes as end indicators for our traversal. This change implements that fix. Fixes #13190 PiperOrigin-RevId: 170971937 --- Commit 5405f3bd7 authored by gunan<gunan@google.com> Committed by Frank Chen<frankchn@gmail.com>: Fix tf-signal tests on pip packages. (#13483) --- Commit f9f037c1c authored by Eugene Brevdo<ebrevdo@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Bugfix to LSTMBlockCell and friends: clipping is off by default. * Rename broken API argu clip_cell boolean to cell_clip value. * Make default no clipping. PiperOrigin-RevId: 170960975 --- Commit bfaaefa9e authored by Frank Chen<frankchn@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update APIs for TPU Cluster Resolver to remove the custom API definition and instead use a standard definition file stored in GCS. PiperOrigin-RevId: 170960877 --- Commit c31c118a3 authored by Ian Langmore<langmore@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Extend tf.contrib.bijector API to handle some non-injective transforms. AbsoluteValue Bijector added to contrib/distributions/bijectors/ TransformedDistribution udpated to handle some non-injective transforms. PiperOrigin-RevId: 170960054 --- Commit 664dd0859 authored by Frank Chen<frankchn@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Disable cluster_function_library_runtime_test on Mac OS as it is currently failing with an Unimplemented error PiperOrigin-RevId: 170958505 --- Commit 6af7ab97a authored by Mahmoud Abuzaina<mahmoud.abuzaina@intel.com> Committed by gunan<gunan@google.com>: MKL-DNN open source integration. (#13135) * MKL-DNN conv and build integration * Adding new files that were mistakenly missing from the PR * Minor change in the pip package build file * Added missing #include * Fixed a linking failure when running the bazel test * Fixing BUILD file format * Using -fopenmp for building mkl_dnn only when running on linux * Fixing build rule attribute value * Removing unnecessary deps from mkl test rule * Removed deps on mkl-dnn when not building with --config=mkl --- Commit 93fa1af76 authored by Akshay Agrawal<akshayka@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make graph_callable, defun tf_decorators PiperOrigin-RevId: 170948777 --- Commit b39525785 authored by Mustafa Ispir<ispir@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Added comment re:behavior of listener in case of multiple saver hooks. PiperOrigin-RevId: 170946536 --- Commit de14fcbb6 authored by Igor Saprykin<isaprykin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Support evaluation in `_TrainingExecutor.run_master()`. This CL aims to address the following TODO: # TODO(b/66720832): Once listener API is added into Estimator.train, the # eval and export process should be wrapped as a listener and passed to # _start_distributed_training. The expected behavior should be # 1. The export is invoked after each intermediate evaluation. # 2. The evaluation and export should be invoked correctly at the end of # training. This should be fine if the listener works as intended (it will # send the `after_save` signal for the final ckpt saving). 1. is achieved as follows: a. saving_evaluators are added to the CheckpointSaverHook's listeners inside the Estimator. b. MonitoredSession calls after_run() of CheckpointSaverHook, which in turn calls after_save on the listeners. 2. is achieved in a similar way, but when MonitoredSession calls .end() on CheckpointSaverHook. PiperOrigin-RevId: 170945961 --- Commit d4ea993ca authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Removes unnecessary eager-mode call to convert_to_tensor in record_gradient. PiperOrigin-RevId: 170944265 --- Commit add6d2d03 authored by RJ Ryan<rjryan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [tf-signal] Use tf.spectral.dct in mfccs_from_log_mel_spectrograms instead of a private implementation. PiperOrigin-RevId: 170943986 --- Commit b959da92f authored by Jiri Simsa<jsimsa@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fixing CPU implementation of parallel_stack for tensors with non-zero rank. PiperOrigin-RevId: 170942814 --- Commit 4cf61262a authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Improve TFGAN documentation. PiperOrigin-RevId: 170940188 --- Commit 0068086b9 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Introduce `tf.data` namespace. PiperOrigin-RevId: 170939033 --- Commit 0c8dbc1fd authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: matmul uses shape_tuple internally PiperOrigin-RevId: 170938790 --- Commit ad37fa81f authored by Igor Saprykin<isaprykin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Refactor ExportStrategies into Exporters. This design eliminates some indirection. Instead of combining an `export_fn` with `make_export_strategy` call to arrive at an ExportStrategy that is going to call the supplied `export_fn` inside its `export` call with Exporters one just defines the `export` call in an Exporter. PiperOrigin-RevId: 170936640 --- Commit b925f8553 authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fast-path for EagerTensorBase.dtype PiperOrigin-RevId: 170933005 --- Commit 08e266d9b authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Pass activity_regularizer to __init__ instead of using the (now deprecated) property setter. PiperOrigin-RevId: 170932807 --- Commit b002c8b7d authored by Jingyue Wu<jingyue@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [Grappler] Fold chains of reshapes. Reshape(Reshape(input, shape1), shape2) is equivalent to Reshape(input, shape2). PiperOrigin-RevId: 170932278 --- Commit 075d1d13b authored by horance<horance@aliyun.com> Committed by Frank Chen<frankchn@gmail.com>: remove warning for forward decl (#13459) --- Commit 931609fcf authored by Ryohei Kuroki<ryohei.kuroki@gmail.com> Committed by Frank Chen<frankchn@gmail.com>: Remove unnecessary specification for default kernel name (#13465) --- Commit 94463f521 authored by Akshay Agrawal<akshayka@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Preserve target function signature in custom_gradient decorator PiperOrigin-RevId: 170931715 --- Commit 681056636 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Internal change to simplify prediction ops. - it no longer returns predictions_no_dropout, which is mostly for debugging purpose. - as a consequence, MultipleAdditiveTrees::Predict() doesn't return prediction_no_dropout, and it accept trees_to_include indexes intead of trees_to_drop indexes. PiperOrigin-RevId: 170926422 --- Commit d6e963b82 authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: SYCL: Fix build breakage introduced in https://github.com/tensorflow/tensorflow/commit/f0e8c545e0196b8b48ce0ad0f116df97d980d1f1 Fixes #13350 PiperOrigin-RevId: 170923862 --- Commit 5123f2971 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Internal cleanup. PiperOrigin-RevId: 170922297 --- Commit d0c76cd18 authored by Igor Saprykin<isaprykin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Handle the absence of a fresh eval checkpoint in `run_local`. It is ~unexpected condition for an eval checkpoint to not be available after a train call to the estimator. There is a corner case when it is possible, but that's going to be resolved soon. This case is handled for continuous (distributed) evaluation differently. Instead of erroring out, we skip evaluation runs. That behavior is captured in the `test_skip_evaluation_due_to_ckpt` test. PiperOrigin-RevId: 170919925 --- Commit 435b31b9f authored by Gunhan Gulsoy<gunan@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 170892257 PiperOrigin-RevId: 171321707
-rw-r--r--README.md6
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc51
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h4
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java23
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py40
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py2
-rw-r--r--tensorflow/contrib/deprecated/__init__.py2
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc10
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py6
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc57
-rw-r--r--tensorflow/contrib/memory_stats/__init__.py2
-rw-r--r--tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc22
-rw-r--r--tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc4
-rw-r--r--tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py22
-rw-r--r--tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py5
-rw-r--r--tensorflow/contrib/resampler/kernels/resampler_ops.cc2
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py10
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py2
-rw-r--r--tensorflow/contrib/signal/BUILD1
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py5
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py45
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD48
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py7
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py375
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py267
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model_utils.py319
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py236
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py3
-rw-r--r--tensorflow/core/BUILD22
-rw-r--r--tensorflow/core/graph/mkl_graph_util.h128
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc2
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc2
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc2
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc2
-rw-r--r--tensorflow/core/kernels/BUILD34
-rw-r--r--tensorflow/core/kernels/bias_op.cc159
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc55
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc53
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc109
-rw-r--r--tensorflow/core/kernels/conv_ops.cc51
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc51
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc19
-rw-r--r--tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc45
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc181
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc190
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc213
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h308
-rw-r--r--tensorflow/core/kernels/mkl_cwise_ops_common.cc2
-rw-r--r--tensorflow/core/lib/strings/numbers.cc2
-rw-r--r--tensorflow/core/ops/dataset_ops.cc3
-rw-r--r--tensorflow/core/ops/nn_ops.cc84
-rw-r--r--tensorflow/core/ops/nn_ops_test.cc49
-rw-r--r--tensorflow/core/ops/parsing_ops.cc2
-rw-r--r--tensorflow/core/util/mkl_util.h401
-rw-r--r--tensorflow/docs_src/install/install_sources.md38
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java8
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py2
-rw-r--r--tensorflow/go/example_inception_inference_test.go2
-rw-r--r--tensorflow/go/tensor.go48
-rw-r--r--tensorflow/go/tensor_test.go10
-rw-r--r--tensorflow/java/src/gen/perl/tftypes-runall.pl2
-rw-r--r--tensorflow/java/src/gen/perl/tftypes.pl102
-rw-r--r--tensorflow/java/src/gen/resources/Tensors.java.tmpl31
-rw-r--r--tensorflow/java/src/gen/resources/tftypes.csv42
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/DataType.java39
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java7
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Input.java4
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java9
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Operand.java12
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Operation.java18
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java14
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Output.java12
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java5
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Session.java34
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java241
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensors.java447
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java79
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Operands.java8
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java34
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java21
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/package-info.java16
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java1
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java25
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/OperationTest.java19
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SessionTest.java41
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TensorTest.java99
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java24
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java7
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java128
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java22
-rw-r--r--tensorflow/python/debug/lib/debug_graphs.py4
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology_test.py2
-rw-r--r--tensorflow/python/kernel_tests/conv2d_transpose_test.py14
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py11
-rw-r--r--tensorflow/python/kernel_tests/summary_tensor_op_test.py2
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/parsing_ops.py39
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc90
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h12
-rw-r--r--tensorflow/stream_executor/dnn.cc12
-rw-r--r--tensorflow/stream_executor/dnn.h12
-rw-r--r--tensorflow/stream_executor/platform.h2
-rw-r--r--tensorflow/stream_executor/stream.h2
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc22
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h9
-rw-r--r--tensorflow/tensorflow.bzl35
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_golang.sh2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu4
-rw-r--r--tensorflow/tools/docker/jupyter_notebook_config.py1
-rw-r--r--tensorflow/tools/docs/parser.py4
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc9
-rw-r--r--tensorflow/workspace.bzl17
-rw-r--r--third_party/gpus/cuda_configure.bzl2
-rw-r--r--third_party/mkl_dnn/BUILD1
-rw-r--r--third_party/mkl_dnn/mkldnn.BUILD25
122 files changed, 4102 insertions, 1655 deletions
diff --git a/README.md b/README.md
index 4cc53096e0..6339c57c95 100644
--- a/README.md
+++ b/README.md
@@ -48,9 +48,9 @@ GPU packages on all platforms will arrive soon!
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Windows CPU-only: [Python 3.5 64-bit](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
-* Windows GPU: Coming soon!
-* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
+* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
+* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
+* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
#### *Try your first TensorFlow program*
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 1b5dd558dd..27c5da08c1 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -52,6 +52,11 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
bool retry_on_failure) override;
Status Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) override;
+ // Register an Tensor (input or resource variable) with the allocator. If
+ // the operation returns an alias to one of its inputs, then the allocator
+ // needs to be able to handle it.
+ Status RegisterArgument(const Tensor* t);
+
// Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
// interpreted as having data type 'dtype' and shape 'shape'.
Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype,
@@ -103,6 +108,14 @@ xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
return gpu::DeviceMemoryBase(data, size);
}
+Status XlaAllocator::RegisterArgument(const Tensor* t) {
+ void* data =
+ reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data()));
+ TF_RET_CHECK(data != nullptr);
+ tensors_[data] = *t;
+ return Status::OK();
+}
+
Status XlaAllocator::Deallocate(int device_ordinal,
gpu::DeviceMemoryBase* mem) {
if (mem->opaque() != nullptr) {
@@ -284,6 +297,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
shape, client->platform(), client->default_device_ordinal(), dmem)
.ConsumeValueOrDie();
arg_ptrs[i] = arg_buffers[i].get();
+
+ OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t));
}
// Make the final parameter point at local_runtime_context.
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 89145a9038..7dd242425c 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -256,9 +256,9 @@ tensorflow::Status ConvolutionThunk::Convolve(
algorithm_config.algorithm_no_scratch().algo_id());
}
-std::vector<AlgorithmDesc::Index> ConvolutionThunk::GetAlgorithms(
+std::vector<AlgorithmDesc> ConvolutionThunk::GetAlgorithms(
se::StreamExecutor* stream_exec) const {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
// TODO(yangzihao): Currently disable the use of winograd nonfused in XLA
// by default. Should send in conv parameters and enable it when
// ShouldIncludeWinogradNonfusedAlgo() returns true.
@@ -297,32 +297,27 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
se::dnn::ProfileResult best_result;
se::dnn::ProfileResult best_result_without_scratch;
- std::vector<AlgorithmDesc::Index> algorithms =
- GetAlgorithms(stream->parent());
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- AlgorithmDesc algorithm(algo_index, use_tensor_ops);
- ConvolveScratchAllocator scratch_allocator(
- buffer_allocations.device_ordinal(),
- buffer_allocations.memory_allocator());
- se::dnn::ProfileResult profile_result;
- bool launch_ok =
- Convolve(input_descriptor, input_data, filter_descriptor,
- filter_data, output_descriptor, output_data,
- convolution_descriptor,
- se::dnn::AlgorithmConfig(algorithm, algorithm), stream,
- &scratch_allocator, &profile_result)
- .ok();
- if (launch_ok && profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalAllocatedBytes() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_without_scratch.elapsed_time_in_ms()) {
- best_result_without_scratch = profile_result;
- }
+ std::vector<AlgorithmDesc> algorithms = GetAlgorithms(stream->parent());
+ for (auto algorithm : algorithms) {
+ ConvolveScratchAllocator scratch_allocator(
+ buffer_allocations.device_ordinal(),
+ buffer_allocations.memory_allocator());
+ se::dnn::ProfileResult profile_result;
+ bool launch_ok =
+ Convolve(input_descriptor, input_data, filter_descriptor, filter_data,
+ output_descriptor, output_data, convolution_descriptor,
+ se::dnn::AlgorithmConfig(algorithm, algorithm), stream,
+ &scratch_allocator, &profile_result)
+ .ok();
+ if (launch_ok && profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalAllocatedBytes() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_without_scratch.elapsed_time_in_ms()) {
+ best_result_without_scratch = profile_result;
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 509719c1fe..13432301b2 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -115,9 +115,7 @@ class ConvolutionThunk : public Thunk {
perftools::gputools::dnn::ProfileResult* profile_result);
// Returns the convolve algorithms that can be used for this ConvolutionThunk.
- // TODO(nluehr) GetAlgorithms should return AlgorithmDesc including both
- // tensor-op and non-tensor-op variants.
- std::vector<perftools::gputools::dnn::AlgorithmDesc::Index> GetAlgorithms(
+ std::vector<perftools::gputools::dnn::AlgorithmDesc> GetAlgorithms(
perftools::gputools::StreamExecutor* stream_exec) const;
// Fastest cuDNN convolution algorithm for this thunk learned from
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 395dd6c5d2..80e03f2036 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -31,12 +31,13 @@ import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
-import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
+import org.tensorflow.Tensors;
+import org.tensorflow.types.UInt8;
/**
* Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
@@ -328,7 +329,7 @@ public class TensorFlowInferenceInterface {
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, byte[] src, long... dims) {
- addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src)));
+ addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
}
/**
@@ -337,7 +338,7 @@ public class TensorFlowInferenceInterface {
* a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[] src) {
- addFeed(inputName, Tensor.create(src));
+ addFeed(inputName, Tensors.create(src));
}
/**
@@ -346,7 +347,7 @@ public class TensorFlowInferenceInterface {
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[][] src) {
- addFeed(inputName, Tensor.create(src));
+ addFeed(inputName, Tensors.create(src));
}
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
@@ -403,7 +404,7 @@ public class TensorFlowInferenceInterface {
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, ByteBuffer src, long... dims) {
- addFeed(inputName, Tensor.create(DataType.UINT8, dims, src));
+ addFeed(inputName, Tensor.create(UInt8.class, dims, src));
}
/**
@@ -544,7 +545,7 @@ public class TensorFlowInferenceInterface {
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}
- private void addFeed(String inputName, Tensor t) {
+ private void addFeed(String inputName, Tensor<?> t) {
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
TensorId tid = TensorId.parse(inputName);
runner.feed(tid.name, tid.outputIndex, t);
@@ -578,7 +579,7 @@ public class TensorFlowInferenceInterface {
}
}
- private Tensor getTensor(String outputName) {
+ private Tensor<?> getTensor(String outputName) {
int i = 0;
for (String n : fetchNames) {
if (n.equals(outputName)) {
@@ -591,7 +592,7 @@ public class TensorFlowInferenceInterface {
}
private void closeFeeds() {
- for (Tensor t : feedTensors) {
+ for (Tensor<?> t : feedTensors) {
t.close();
}
feedTensors.clear();
@@ -599,7 +600,7 @@ public class TensorFlowInferenceInterface {
}
private void closeFetches() {
- for (Tensor t : fetchTensors) {
+ for (Tensor<?> t : fetchTensors) {
t.close();
}
fetchTensors.clear();
@@ -614,9 +615,9 @@ public class TensorFlowInferenceInterface {
// State reset on every call to run.
private Session.Runner runner;
private List<String> feedNames = new ArrayList<String>();
- private List<Tensor> feedTensors = new ArrayList<Tensor>();
+ private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
private List<String> fetchNames = new ArrayList<String>();
- private List<Tensor> fetchTensors = new ArrayList<Tensor>();
+ private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
// Mutable state.
private RunStats runStats;
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
index dad3b4e10d..c329c6d4f7 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
@@ -36,7 +36,7 @@ class WeightedQuantilesSummary {
struct SummaryEntry {
SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
const WeightType& max) {
- // Explicitely initialize all of memory (including padding from memory
+ // Explicitly initialize all of memory (including padding from memory
// alignment) to allow the struct to be msan-resistant "plain old data".
//
// POD = http://en.cppreference.com/w/cpp/concept/PODType
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 813c64d141..91f100e0f0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -253,6 +253,46 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testDenseToSparseBatchDatasetWithUnknownShape(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ iterator = (dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x, x], x)).dense_to_sparse_batch(
+ 4, [5, -1]).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual(
+ [[i, j, z] for i, c in enumerate(components[start:start+4])
+ for j in range(c) for z in range(c)], results.indices)
+ self.assertAllEqual(
+ [c for c in components[start:start+4]
+ for _ in range(c) for _ in range(c)],
+ results.values)
+ self.assertAllEqual(
+ [min(4, len(components) - start),
+ 5,
+ np.max(components[start:start+4])],
+ results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithInvalidShape(self):
+ input_tensor = array_ops.constant([[1]])
+ iterator = (dataset_ops.Dataset.from_tensors(input_tensor)
+ .dense_to_sparse_batch(4, [-2]).make_initializable_iterator())
+ init_op = iterator.initializer
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Dimension -2 must be >= -1"):
+ sess.run(init_op)
+
def testDenseToSparseBatchDatasetShapeErrors(self):
input_tensor = array_ops.placeholder(dtypes.int32)
iterator = (dataset_ops.Dataset.from_tensors(input_tensor).apply(
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index ff89c47a2e..b74dcd3be2 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -653,7 +653,7 @@ class Dataset(dataset_ops.Dataset):
```python
# Preprocess 4 files concurrently, and interleave blocks of 16 records from
# each file.
- filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ..."]
+ filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
dataset = (Dataset.from_tensor_slices(filenames)
.interleave(lambda x:
TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
diff --git a/tensorflow/contrib/deprecated/__init__.py b/tensorflow/contrib/deprecated/__init__.py
index bfea8445a7..7aff045de3 100644
--- a/tensorflow/contrib/deprecated/__init__.py
+++ b/tensorflow/contrib/deprecated/__init__.py
@@ -91,7 +91,7 @@ from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long
+# pylint: disable=unused-import
from tensorflow.python.ops.logging_ops import audio_summary
from tensorflow.python.ops.logging_ops import histogram_summary
from tensorflow.python.ops.logging_ops import image_summary
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index 888f5c38a2..b417a70b6e 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -208,7 +208,15 @@ string GetTempFilename(const string& extension) {
}
struct stat statbuf;
if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) {
- return io::JoinPath(dir, StrCat("tmp_file_", getpid(), ".", extension));
+ string tmp_filepath =
+ io::JoinPath(dir, StrCat("tmp_file_XXXXXX", ".", extension));
+ int fd = mkstemps(&tmp_filepath[0], extension.length() + 1);
+ if (fd < 0) {
+ LOG(FATAL) << "Failed to create temp file.";
+ } else {
+ close(fd);
+ return tmp_filepath;
+ }
}
}
LOG(FATAL) << "No temp directory found.";
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index e595e4d90b..92a2a4ff2d 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -78,9 +78,9 @@ def reduce_sum_n(tensors, name=None):
return math_ops.add_n(tensors, name=name_scope)
@deprecated(None,
- "Please switch to tf.confusion_matrix.remove_squeezable_dimensions. Note "
- "that order of the inputs and ouputs of labels and predictions have also "
- "been switched.")
+ 'Please switch to tf.confusion_matrix.remove_squeezable_dimensions.'
+ 'Note that order of the inputs and outputs of labels and '
+ 'predictions have also been switched.')
def remove_squeezable_dimensions(predictions, labels, name=None):
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 9275d5a22b..256f200868 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -493,42 +493,37 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
dnn::AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
fused_conv_parameters, &algorithm_config)) {
- std::vector<dnn::AlgorithmDesc::Index> algorithms;
+ std::vector<dnn::AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(
fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(),
&algorithms));
dnn::ProfileResult best_result;
dnn::ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- dnn::AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- dnn::ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenFusedConvolveWithAlgorithm(
- conv_input_desc, conv_input_ptr, conv_input_scale,
- filter_desc, filter_ptr, conv_desc, side_input_ptr,
- side_input_scale, bias_desc, bias_ptr,
- dnn::ActivationMode::kRelu, output_desc, &output_ptr,
- &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ dnn::ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenFusedConvolveWithAlgorithm(
+ conv_input_desc, conv_input_ptr, conv_input_scale,
+ filter_desc, filter_ptr, conv_desc, side_input_ptr,
+ side_input_scale, bias_desc, bias_ptr,
+ dnn::ActivationMode::kRelu, output_desc, &output_ptr,
+ &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
+ &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/contrib/memory_stats/__init__.py b/tensorflow/contrib/memory_stats/__init__.py
index a2b2b65692..a32302c854 100644
--- a/tensorflow/contrib/memory_stats/__init__.py
+++ b/tensorflow/contrib/memory_stats/__init__.py
@@ -14,10 +14,12 @@
# ==============================================================================
"""Ops for memory statistics.
+@@BytesInUse
@@BytesLimit
@@MaxBytesInUse
"""
+from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import BytesInUse
from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import BytesLimit
from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import MaxBytesInUse
diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc
index 3b88535dce..7e2e96e160 100644
--- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc
+++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc
@@ -40,6 +40,28 @@ class MemoryStatsOp : public OpKernel {
const AllocatorStats& allocator_stats) const = 0;
};
+// Op that measures current memory in bytes.
+class BytesInUseOp : public MemoryStatsOp {
+ public:
+ explicit BytesInUseOp(OpKernelConstruction* context)
+ : MemoryStatsOp(context) {}
+
+ private:
+ int64 ExtractAllocatorStats(
+ const AllocatorStats& allocator_stats) const override {
+ return allocator_stats.bytes_in_use;
+ }
+};
+
+// Register this op on GPU only, see comment for MaxBytesInUse for reason
+REGISTER_KERNEL_BUILDER(Name("BytesInUse").Device(DEVICE_GPU).HostMemory("out"),
+ BytesInUseOp);
+
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("BytesInUse").Device(DEVICE_SYCL).HostMemory("out"), MaxBytesInUseOp);
+#endif // TENSORFLOW_USE_SYCL
+
// Op that measures the total memory (in bytes) of a device.
class BytesLimitOp : public MemoryStatsOp {
public:
diff --git a/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc
index 08859c8613..42020cf7f6 100644
--- a/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc
+++ b/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc
@@ -17,6 +17,10 @@ limitations under the License.
namespace tensorflow {
+REGISTER_OP("BytesInUse")
+ .Output("out: int64")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("BytesLimit")
.Output("out: int64")
.SetIsStateful()
diff --git a/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py b/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
index ec25c032f0..d1b430b803 100644
--- a/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
+++ b/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.memory_stats.python.ops import memory_stats_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
@@ -64,10 +65,29 @@ class MemoryStatsOpsTest(test_util.TensorFlowTestCase):
d = math_ops.matmul(c, b)
sess.run(d)
- max_bytes_in_use = sess.run(memory_stats_ops.MaxBytesInUse())
+ max_bytes_in_use_op = memory_stats_ops.MaxBytesInUse()
+ max_bytes_in_use = sess.run(max_bytes_in_use_op)
self.assertGreaterEqual(max_bytes_in_use, matrix_size_in_bytes * 3)
self.assertLess(max_bytes_in_use, matrix_size_in_bytes * 4)
+ # run chain with 2 ops, make sure BytesInUse captures intermediate
+ # memory usage
+ a = random_ops.random_uniform(matrix_shape, dtype=dtype)
+ with ops.control_dependencies([a]):
+ bytes_in_use_op = memory_stats_ops.BytesInUse()
+ with ops.control_dependencies([bytes_in_use_op]):
+ b = random_ops.random_uniform(matrix_shape, dtype=dtype)
+
+ _, bytes_in_use, max_bytes_in_use = sess.run([a, bytes_in_use_op,
+ max_bytes_in_use_op])
+
+ # intermediate result allocates 1 matrix, max usage is at least 2
+ self.assertGreaterEqual(bytes_in_use, matrix_size_in_bytes * 1)
+ self.assertLess(bytes_in_use, matrix_size_in_bytes * 2)
+
+ # max usage is still 3 because it reflects maxium from previous .run call
+ self.assertGreaterEqual(max_bytes_in_use, matrix_size_in_bytes * 3)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py b/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py
index d35c6583ed..c0f7788c1c 100644
--- a/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py
+++ b/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py
@@ -26,6 +26,11 @@ _memory_stats_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_memory_stats_ops.so"))
+def BytesInUse():
+ """Generates an op that computes the current memory of a device."""
+ return gen_memory_stats_ops.bytes_in_use()
+
+
def BytesLimit():
"""Generates an op that measures the total memory (in bytes) of a device."""
return gen_memory_stats_ops.bytes_limit()
diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.cc b/tensorflow/contrib/resampler/kernels/resampler_ops.cc
index afc8bcd446..7d9ef14cef 100644
--- a/tensorflow/contrib/resampler/kernels/resampler_ops.cc
+++ b/tensorflow/contrib/resampler/kernels/resampler_ops.cc
@@ -122,7 +122,7 @@ struct Resampler2DFunctor<CPUDevice, T>{
};
// Rough estimate of work for each batch entry.
// From third_party/tensorflow/core/util/work_sharder.cc we gather that an
- // estimate of the cost of each work unit is needed to correclty shard the
+ // estimate of the cost of each work unit is needed to correctly shard the
// workload. Shard assumes each cost unit is 1ns, minimum cost per shard
// being 10us.
const int64 cost = static_cast<int64>(num_sampling_points) *
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 1b0327d62b..6702a89d22 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -525,7 +525,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
self._state_tuple_type = collections.namedtuple(
"GridLSTMStateTuple", state_names.strip(","))
self._state_size = self._state_tuple_type(
- *([num_units, num_units] * self._total_blocks))
+ *([num_units, num_units] * self._total_blocks))
else:
self._state_tuple_type = None
self._state_size = num_units * self._total_blocks * 2
@@ -2082,9 +2082,11 @@ def _conv(args,
shape_length = len(shapes[0])
for shape in shapes:
if len(shape) not in [3,4,5]:
- raise ValueError("Conv Linear expects 3D, 4D or 5D arguments: %s" % str(shapes))
+ raise ValueError("Conv Linear expects 3D, 4D "
+ "or 5D arguments: %s" % str(shapes))
if len(shape) != len(shapes[0]):
- raise ValueError("Conv Linear expects all args to be of same Dimensiton: %s" % str(shapes))
+ raise ValueError("Conv Linear expects all args "
+ "to be of same Dimension: %s" % str(shapes))
else:
total_arg_size_depth += shape[-1]
dtype = [a.dtype for a in args][0]
@@ -2102,7 +2104,7 @@ def _conv(args,
# Now the computation.
kernel = vs.get_variable(
- "kernel",
+ "kernel",
filter_size + [total_arg_size_depth, num_features],
dtype=dtype)
if len(args) == 1:
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index 64e00c21c7..b55d90cbab 100644
--- a/tensorflow/contrib/seq2seq/python/ops/helper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -309,7 +309,7 @@ class ScheduledEmbeddingTrainingHelper(TrainingHelper):
gen_array_ops.fill([self.batch_size], -1))
def next_inputs(self, time, outputs, state, sample_ids, name=None):
- with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
+ with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index 43f24474ed..2204b684ac 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -5,6 +5,7 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+load("//tensorflow:tensorflow.bzl", "py_test") # @unused
py_library(
name = "signal_py",
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index f9449095be..094568389c 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -135,7 +135,10 @@ class BoundingBox(ItemHandler):
"""
sides = []
for key in self._full_keys:
- side = array_ops.expand_dims(keys_to_tensors[key].values, 0)
+ side = keys_to_tensors[key]
+ if isinstance(side, sparse_tensor.SparseTensor):
+ side = side.values
+ side = array_ops.expand_dims(side, 0)
sides.append(side)
bounding_box = array_ops.concat(sides, 0)
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index 96606b9c0e..60d1eba07f 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -692,7 +692,7 @@ class TFExampleDecoderTest(test.TestCase):
else:
self.assertAllClose(image, decoded_image, atol=0)
- def testDecodeExampleWithBoundingBox(self):
+ def testDecodeExampleWithBoundingBoxSparse(self):
num_bboxes = 10
np_ymin = np.random.rand(num_bboxes, 1)
np_xmin = np.random.rand(num_bboxes, 1)
@@ -731,6 +731,49 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllClose(np_bboxes, bboxes)
+ def testDecodeExampleWithBoundingBoxDense(self):
+ num_bboxes = 10
+ np_ymin = np.random.rand(num_bboxes, 1)
+ np_xmin = np.random.rand(num_bboxes, 1)
+ np_ymax = np.random.rand(num_bboxes, 1)
+ np_xmax = np.random.rand(num_bboxes, 1)
+ np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
+
+ example = example_pb2.Example(features=feature_pb2.Features(feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
+ serialized_example = example.SerializeToString()
+
+ with self.test_session():
+ serialized_example = array_ops.reshape(serialized_example, shape=[])
+
+ keys_to_features = {
+ 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ }
+
+ items_to_handlers = {
+ 'object/bbox':
+ tfexample_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
+ 'image/object/bbox/'),
+ }
+
+ decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
+ items_to_handlers)
+ [tf_bboxes] = decoder.decode(serialized_example, ['object/bbox'])
+ bboxes = tf_bboxes.eval()
+
+ self.assertAllClose(np_bboxes, bboxes)
+
def testDecodeExampleWithRepeatedImages(self):
image_shape = (2, 3, 3)
image_format = 'png'
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index 2c4bed5db1..da583a2ba0 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -42,6 +42,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":feature_keys",
+ ":head",
":input_pipeline",
":model_utils",
"//tensorflow/python:util",
@@ -78,8 +79,8 @@ py_library(
deps = [
":ar_model",
":feature_keys",
+ ":head",
":math_utils",
- ":model_utils",
":state_management",
"//tensorflow/contrib/timeseries/python/timeseries/state_space_models:filtering_postprocessor",
"//tensorflow/contrib/timeseries/python/timeseries/state_space_models:state_space_model",
@@ -123,9 +124,9 @@ py_test(
)
py_library(
- name = "model_utils",
+ name = "head",
srcs = [
- "model_utils.py",
+ "head.py",
],
srcs_version = "PY2AND3",
deps = [
@@ -149,9 +150,9 @@ py_library(
)
py_test(
- name = "model_utils_test",
+ name = "head_test",
srcs = [
- "model_utils_test.py",
+ "head_test.py",
],
srcs_version = "PY2AND3",
tags = [
@@ -159,8 +160,8 @@ py_test(
],
deps = [
":feature_keys",
+ ":head",
":model",
- ":model_utils",
":state_management",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -175,6 +176,41 @@ py_test(
)
py_library(
+ name = "model_utils",
+ srcs = [
+ "model_utils.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":feature_keys",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:variable_scope",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "model_utils_test",
+ srcs = [
+ "model_utils_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip_gpu", # b/63391119
+ ],
+ deps = [
+ ":model_utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_library(
name = "state_management",
srcs = [
"state_management.py",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 267a5f88da..ff140efd48 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -374,7 +374,7 @@ class ARModel(model.TimeSeriesModel):
original_values = values
# Extra shape checking for the window size (above that in
- # model_utils.make_model_fn).
+ # `head.create_estimator_spec`).
expected_times_shape = [None, self.window_size]
if not times.get_shape().is_compatible_with(expected_times_shape):
raise ValueError(
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 4025a8f014..3738dfa154 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -20,8 +20,8 @@ from __future__ import print_function
from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
+from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
from tensorflow.contrib.timeseries.python.timeseries import math_utils
-from tensorflow.contrib.timeseries.python.timeseries import model_utils
from tensorflow.contrib.timeseries.python.timeseries import state_management
from tensorflow.contrib.timeseries.python.timeseries.state_space_models import state_space_model
from tensorflow.contrib.timeseries.python.timeseries.state_space_models import structural_ensemble
@@ -59,9 +59,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
if optimizer is None:
optimizer = train.AdamOptimizer(0.02)
self._model = model
- model_fn = model_utils.make_model_fn(
+ ts_regression_head = ts_head_lib.time_series_regression_head(
model, state_manager, optimizer,
input_statistics_generator=input_statistics_generator)
+ model_fn = ts_regression_head.create_estimator_spec
super(TimeSeriesRegressor, self).__init__(
model_fn=model_fn,
model_dir=model_dir,
@@ -132,7 +133,7 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
with ops.Graph().as_default():
self._model.initialize_graph()
model_start_state = self._model.get_start_state()
- for prefixed_state_name, state_tensor in model_utils.state_to_dictionary(
+ for prefixed_state_name, state_tensor in ts_head_lib.state_to_dictionary(
model_start_state).items():
state_shape_with_batch = tensor_shape.TensorShape(
(default_batch_size,)).concatenate(state_tensor.get_shape())
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
new file mode 100644
index 0000000000..5896fc2a20
--- /dev/null
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -0,0 +1,375 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Timeseries head."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.contrib.framework.python.ops import variables
+from tensorflow.contrib.layers.python.layers import optimizers
+
+from tensorflow.contrib.timeseries.python.timeseries import feature_keys
+
+from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.export import export_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import nest
+
+
+def time_series_regression_head(model,
+ state_manager,
+ optimizer,
+ input_statistics_generator=None):
+ """Creates a `_Head` for time series regression.
+
+ Args:
+ model: A model for time series regression.
+ state_manager: A state manager.
+ optimizer: An optimizer.
+ input_statistics_generator: A input statistics generator.
+
+ Returns:
+ An instance of `_Head` for time series regression.
+ """
+ return _TimeSeriesRegressionHead(model, state_manager, optimizer,
+ input_statistics_generator)
+
+
+class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access
+ """See `time_series_regression_head`."""
+
+ def __init__(self,
+ model,
+ state_manager,
+ optimizer,
+ input_statistics_generator=None,
+ name=None):
+ self.model = model
+ self.state_manager = state_manager
+ self.optimizer = optimizer
+ self.input_statistics_generator = input_statistics_generator
+ self._name = name
+
+ def _train_ops(self, features):
+ """Add training ops to the graph."""
+ with variable_scope.variable_scope("model"):
+ model_outputs = self.state_manager.define_loss(
+ self.model, features, estimator_lib.ModeKeys.TRAIN)
+
+ train_op = optimizers.optimize_loss(
+ model_outputs.loss,
+ global_step=variables.get_global_step(),
+ optimizer=self.optimizer,
+ # Learning rate is set in the Optimizer object
+ learning_rate=None)
+ return estimator_lib.EstimatorSpec(
+ loss=model_outputs.loss,
+ mode=estimator_lib.ModeKeys.TRAIN,
+ train_op=train_op)
+
+ # TODO(terrytangyuan): suffix summary and metrics keys by `"/" + name`
+ @property
+ def name(self):
+ return self._name
+
+ # TODO(terrytangyuan): unused for now. Need to decouple
+ # `state_manager.define_loss` to satisfy the extendable return signature of
+ # `_Head.create_loss`.
+ def create_loss(self, features, mode, logits, labels):
+ """See `_Head`."""
+ return None
+
+ # TODO(terrytangyuan): check label dimension
+ @property
+ def logits_dimension(self):
+ return None
+
+ def _evaluate_ops(self, features):
+ """Add ops for evaluation (aka filtering) to the graph."""
+ with variable_scope.variable_scope("model"):
+ model_outputs = self.state_manager.define_loss(
+ self.model, features, estimator_lib.ModeKeys.EVAL)
+ metrics = {}
+ # Just output in-sample predictions for the last chunk seen
+ for prediction_key, prediction_value in model_outputs.predictions.items():
+ metrics[prediction_key] = _identity_metric_single(prediction_key,
+ prediction_value)
+ metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single(
+ feature_keys.FilteringResults.TIMES, model_outputs.prediction_times)
+ metrics[feature_keys.FilteringResults.STATE_TUPLE] = (
+ _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,
+ model_outputs.end_state))
+ return estimator_lib.EstimatorSpec(
+ loss=model_outputs.loss,
+ mode=estimator_lib.ModeKeys.EVAL,
+ eval_metric_ops=metrics,
+ predictions={})
+
+ def _predict_ops(self, features):
+ """Add ops for prediction to the graph."""
+ with variable_scope.variable_scope("model"):
+ prediction = self.model.predict(features=features)
+ prediction[feature_keys.PredictionResults.TIMES] = features[
+ feature_keys.PredictionFeatures.TIMES]
+ return estimator_lib.EstimatorSpec(
+ predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT)
+
+ def _serving_ops(self, features):
+ """Add ops for serving to the graph."""
+ with variable_scope.variable_scope("model"):
+ prediction_outputs = self.model.predict(features=features)
+ with variable_scope.variable_scope("model", reuse=True):
+ filtering_outputs = self.state_manager.define_loss(
+ self.model, features, estimator_lib.ModeKeys.EVAL)
+
+ return estimator_lib.EstimatorSpec(
+ mode=estimator_lib.ModeKeys.PREDICT,
+ export_outputs={
+ feature_keys.SavedModelLabels.PREDICT:
+ export_lib.PredictOutput(prediction_outputs),
+ feature_keys.SavedModelLabels.FILTER:
+ export_lib.PredictOutput(
+ state_to_dictionary(filtering_outputs.end_state))
+ },
+ # Likely unused, but it is necessary to return `predictions` to satisfy
+ # the Estimator's error checking.
+ predictions={})
+
+ def _convert_feature_to_tensor(self, name, value):
+ """Casts features to the correct dtype based on their name."""
+ if name in [
+ feature_keys.TrainEvalFeatures.TIMES,
+ feature_keys.PredictionFeatures.TIMES
+ ]:
+ return math_ops.cast(value, dtypes.int64)
+ if name == feature_keys.TrainEvalFeatures.VALUES:
+ return math_ops.cast(value, self.model.dtype)
+ if name == feature_keys.PredictionFeatures.STATE_TUPLE:
+ return value # Correct dtypes are model-dependent
+ return ops.convert_to_tensor(value)
+
+ def _gather_state(self, features):
+ """Returns `features` with state packed, indicates if packing was done."""
+ prefixed_state_re = re.compile(r"^" + feature_keys.State.STATE_PREFIX +
+ r"_(\d+)$")
+ numbered_state = []
+ for key, tensor in features.items():
+ search_result = prefixed_state_re.search(key)
+ if search_result:
+ numbered_state.append((int(search_result.group(1)), key, tensor))
+ if not numbered_state:
+ return features, False
+ features = features.copy()
+ for _, key, _ in numbered_state:
+ del features[key]
+ numbered_state.sort(key=lambda number, *_: number)
+ features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as(
+ structure=self.model.get_start_state(),
+ flat_sequence=[tensor for _, _, tensor in numbered_state])
+ return features, True
+
+ def create_estimator_spec(self, features, mode, labels=None):
+ """Performs basic error checking and returns an EstimatorSpec."""
+ with ops.name_scope("head"):
+ if labels:
+ raise ValueError(
+ "The model received a `labels` dictionary, which is "
+ "not supported. Pass '{}' and '{}' as "
+ "features.".format(feature_keys.TrainEvalFeatures.TIMES,
+ feature_keys.TrainEvalFeatures.VALUES))
+ del labels
+ features = {
+ name: self._convert_feature_to_tensor(name=name, value=value)
+ for name, value in features.items()
+ }
+ if self.input_statistics_generator is not None:
+ input_statistics = self.input_statistics_generator.initialize_graph(
+ features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN))
+ else:
+ input_statistics = None
+ self.model.initialize_graph(input_statistics=input_statistics)
+
+ # _gather_state requires the model to have its graph initialized (so it
+ # has access to the structure of the model's state)
+ features, passed_flat_state = self._gather_state(features)
+ if (mode == estimator_lib.ModeKeys.TRAIN or
+ mode == estimator_lib.ModeKeys.EVAL):
+ _check_train_eval_features(features, self.model)
+ elif mode == estimator_lib.ModeKeys.PREDICT:
+ _check_predict_features(features)
+ else:
+ raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode))
+
+ self.state_manager.initialize_graph(
+ model=self.model, input_statistics=input_statistics)
+
+ if mode == estimator_lib.ModeKeys.TRAIN:
+ return self._train_ops(features)
+ elif mode == estimator_lib.ModeKeys.EVAL:
+ return self._evaluate_ops(features)
+ elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state:
+ return self._predict_ops(features)
+ elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state:
+ # The mode is PREDICT, but we're actually in export_savedmodel for
+ # serving. We want to return two graphs: one for filtering (state + data
+ # -> state) and one for predicting (state -> prediction).
+ return self._serving_ops(features)
+
+
+def _check_feature_shapes_compatible_with(features,
+ compatible_with_name,
+ compatible_with_value,
+ ignore=None):
+ """Checks all features are compatible with the given time-like feature."""
+ if ignore is None:
+ ignore = set()
+ for name, value in features.items():
+ if name in ignore:
+ continue
+ feature_shape = value.get_shape()
+ if feature_shape.ndims is None:
+ continue
+ if feature_shape.ndims < 2:
+ raise ValueError(
+ ("Features must have shape (batch dimension, window size, ...) "
+ "(got rank {} for feature '{}')").format(feature_shape.ndims, name))
+ if not feature_shape[:2].is_compatible_with(
+ compatible_with_value.get_shape()):
+ raise ValueError(
+ ("Features must have shape (batch dimension, window size, ...) "
+ "where batch dimension and window size match the "
+ "'{times_feature}' feature (got shape {feature_shape} for "
+ "feature '{feature_name}' but shape {times_shape} for feature "
+ "'{times_feature}')").format(
+ times_feature=compatible_with_name,
+ feature_shape=feature_shape,
+ feature_name=name,
+ times_shape=compatible_with_value.get_shape()))
+
+
+def _check_predict_features(features):
+ """Raises errors if features are not suitable for prediction."""
+ if feature_keys.PredictionFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.TIMES))
+ if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE))
+ times_feature = features[feature_keys.PredictionFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
+ times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.PredictionFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes
+ ]))
+
+
+def _check_train_eval_features(features, model):
+ """Raise errors if features are not suitable for training/evaluation."""
+ if feature_keys.TrainEvalFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for training/evaluation.".format(
+ feature_keys.TrainEvalFeatures.TIMES))
+ if feature_keys.TrainEvalFeatures.VALUES not in features:
+ raise ValueError("Expected a '{}' feature for training/evaluation.".format(
+ feature_keys.TrainEvalFeatures.VALUES))
+ times_feature = features[feature_keys.TrainEvalFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES,
+ times_feature.get_shape()))
+ values_feature = features[feature_keys.TrainEvalFeatures.VALUES]
+ if not values_feature.get_shape().is_compatible_with(
+ [None, None, model.num_features]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size, {num_features}) "
+ "for feature '{feature_name}', since the model was configured "
+ "with num_features={num_features} (got shape {got_shape})").format(
+ num_features=model.num_features,
+ feature_name=feature_keys.TrainEvalFeatures.VALUES,
+ got_shape=times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.TrainEvalFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ feature_keys.State.STATE_TUPLE # Model-dependent shapes
+ ]))
+
+
+def _identity_metric_single(name, input_tensor):
+ """A metric which takes on its last updated value.
+
+ This keeps evaluation metrics in sync with one another, since update ops are
+ run separately from their result Tensors. Simply returning (input_tensor,
+ no_op) as a metric with a value but no update means that a metric will come
+ from a different batch of data than metrics which cache values in a Variable
+ (e.g. the default loss metric).
+
+ Args:
+ name: A name for the metric.
+ input_tensor: Any Tensor.
+ Returns:
+ A tuple of (value, update_op).
+ """
+ metric_variable = variable_scope.variable(
+ name="{}_identity_metric".format(name),
+ initial_value=array_ops.zeros([], dtype=input_tensor.dtype),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ validate_shape=False)
+ update_op = state_ops.assign(
+ metric_variable, input_tensor, validate_shape=False)
+ # This shape will be correct once the first update runs (but may be
+ # incomplete, so is not helpful for initializing the variable).
+ metric_variable.set_shape(input_tensor.get_shape())
+ return (metric_variable.value(), update_op)
+
+
+def _identity_metric_nested(name, input_tensors):
+ """Create identity metrics for a nested tuple of Tensors."""
+ update_ops = []
+ value_tensors = []
+ for tensor_number, tensor in enumerate(nest.flatten(input_tensors)):
+ value_tensor, update_op = _identity_metric_single(
+ name="{}_{}".format(name, tensor_number), input_tensor=tensor)
+ update_ops.append(update_op)
+ value_tensors.append(value_tensor)
+ return (nest.pack_sequence_as(input_tensors, value_tensors),
+ control_flow_ops.group(*update_ops))
+
+
+def state_to_dictionary(state_tuple):
+ """Flatten model state into a dictionary with string keys."""
+ flattened = {}
+ for state_number, state_value in enumerate(nest.flatten(state_tuple)):
+ prefixed_state_name = "{}_{:02d}".format(feature_keys.State.STATE_PREFIX,
+ state_number)
+ flattened[prefixed_state_name] = state_value
+ return flattened
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
new file mode 100644
index 0000000000..3415061cfd
--- /dev/null
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -0,0 +1,267 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for head."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.timeseries.python.timeseries import feature_keys
+from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
+from tensorflow.contrib.timeseries.python.timeseries import model
+from tensorflow.contrib.timeseries.python.timeseries import state_management
+
+from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import coordinator as coordinator_lib
+from tensorflow.python.training import queue_runner_impl
+from tensorflow.python.training import training as train
+
+
+class HeadTest(test.TestCase):
+
+ def test_labels_provided_error(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,
+ estimator_lib.ModeKeys.PREDICT]:
+ with self.assertRaisesRegexp(ValueError, "labels"):
+ model_fn(features={}, labels={"a": "b"}, mode=mode)
+
+ def test_unknown_mode(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"):
+ model_fn(features={}, labels={}, mode="Not a mode")
+
+
+class _TickerModel(object):
+ num_features = 1
+ dtype = dtypes.float32
+
+ def initialize_graph(self, input_statistics):
+ pass
+
+ def define_loss(self, features, mode):
+ del mode # unused
+ return model.ModelOutputs(
+ loss=features["ticker"],
+ end_state=(features["ticker"], features["ticker"]),
+ prediction_times=array_ops.zeros(()),
+ predictions={"ticker": features["ticker"]})
+
+
+class EvaluationMetricsTests(test.TestCase):
+
+ def test_metrics_consistent(self):
+ # Tests that the identity metrics used to report in-sample predictions match
+ # the behavior of standard metrics.
+ g = ops.Graph()
+ with g.as_default():
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES:
+ array_ops.zeros((1, 1)),
+ feature_keys.TrainEvalFeatures.VALUES:
+ array_ops.zeros((1, 1, 1)),
+ "ticker":
+ array_ops.reshape(
+ math_ops.cast(
+ variables.Variable(
+ name="ticker",
+ initial_value=0,
+ dtype=dtypes.int64,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ .count_up_to(10),
+ dtype=dtypes.float32), (1, 1, 1))
+ }
+ model_fn = ts_head_lib.time_series_regression_head(
+ model=_TickerModel(),
+ state_manager=state_management.PassthroughStateManager(),
+ optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec
+ outputs = model_fn(
+ features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL)
+ metric_update_ops = [
+ metric[1] for metric in outputs.eval_metric_ops.values()]
+ loss_mean, loss_update = metrics.mean(outputs.loss)
+ metric_update_ops.append(loss_update)
+ with self.test_session() as sess:
+ coordinator = coordinator_lib.Coordinator()
+ queue_runner_impl.start_queue_runners(sess, coord=coordinator)
+ variables.local_variables_initializer().run()
+ sess.run(metric_update_ops)
+ loss_evaled, metric_evaled, nested_metric_evaled = sess.run(
+ (loss_mean, outputs.eval_metric_ops["ticker"][0],
+ outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][
+ 0][0]))
+ # The custom model_utils metrics for in-sample predictions should be in
+ # sync with the Estimator's mean metric for model loss.
+ self.assertAllClose(0., loss_evaled)
+ self.assertAllClose((((0.,),),), metric_evaled)
+ self.assertAllClose((((0.,),),), nested_metric_evaled)
+ coordinator.request_stop()
+ coordinator.join()
+
+
+class _StubModel(object):
+ num_features = 3
+ dtype = dtypes.float64
+
+ def initialize_graph(self, input_statistics):
+ del input_statistics # unused
+
+
+def _stub_model_fn():
+ return ts_head_lib.time_series_regression_head(
+ model=_StubModel(),
+ state_manager=state_management.PassthroughStateManager(),
+ optimizer=train.AdamOptimizer(0.001)).create_estimator_spec
+
+
+class TrainEvalFeatureCheckingTests(test.TestCase):
+
+ def test_no_time_feature(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
+ feature_keys.TrainEvalFeatures.TIMES)):
+ model_fn(
+ features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]},
+ labels=None,
+ mode=mode)
+
+ def test_no_value_feature(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
+ feature_keys.TrainEvalFeatures.VALUES)):
+ model_fn(
+ features={feature_keys.TrainEvalFeatures.TIMES: [[1]]},
+ labels=None,
+ mode=mode)
+
+ def test_bad_time_rank(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(ValueError,
+ "Expected shape.*for feature '{}'".format(
+ feature_keys.TrainEvalFeatures.TIMES)):
+ model_fn(
+ features={
+ feature_keys.TrainEvalFeatures.TIMES: [[[1]]],
+ feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
+ },
+ labels=None,
+ mode=mode)
+
+ def test_bad_value_rank(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(ValueError,
+ "Expected shape.*for feature '{}'".format(
+ feature_keys.TrainEvalFeatures.VALUES)):
+ model_fn(
+ features={
+ feature_keys.TrainEvalFeatures.TIMES: [[1]],
+ feature_keys.TrainEvalFeatures.VALUES: [[1.]]
+ },
+ labels=None,
+ mode=mode)
+
+ def test_bad_value_num_features(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(
+ ValueError, "Expected shape.*, 3.*for feature '{}'".format(
+ feature_keys.TrainEvalFeatures.VALUES)):
+ model_fn(
+ features={
+ feature_keys.TrainEvalFeatures.TIMES: [[1]],
+ feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
+ },
+ labels=None,
+ mode=mode)
+
+ def test_bad_exogenous_shape(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Features must have shape.*for feature 'exogenous'"):
+ model_fn(
+ features={
+ feature_keys.TrainEvalFeatures.TIMES: [[1]],
+ feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]],
+ "exogenous": [[1], [2]]
+ },
+ labels=None,
+ mode=mode)
+
+
+class PredictFeatureCheckingTests(test.TestCase):
+
+ def test_no_time_feature(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
+ feature_keys.PredictionFeatures.TIMES)):
+ model_fn(
+ features={
+ feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.)
+ },
+ labels=None,
+ mode=estimator_lib.ModeKeys.PREDICT)
+
+ def test_no_start_state_feature(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE)):
+ model_fn(
+ features={feature_keys.PredictionFeatures.TIMES: [[1]]},
+ labels=None,
+ mode=estimator_lib.ModeKeys.PREDICT)
+
+ def test_bad_time_rank(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(ValueError,
+ "Expected shape.*for feature '{}'".format(
+ feature_keys.PredictionFeatures.TIMES)):
+ model_fn(
+ features={
+ feature_keys.PredictionFeatures.TIMES: 1,
+ feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.))
+ },
+ labels=None,
+ mode=estimator_lib.ModeKeys.PREDICT)
+
+ def test_bad_exogenous_shape(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Features must have shape.*for feature 'exogenous'"):
+ model_fn(
+ features={
+ feature_keys.PredictionFeatures.TIMES: [[1]],
+ feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)),
+ "exogenous": 1.
+ },
+ labels=None,
+ mode=estimator_lib.ModeKeys.PREDICT)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils.py
index addcdb0575..b5d7cb376b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils.py
@@ -18,334 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import re
-
import numpy
-from tensorflow.contrib.framework.python.ops import variables
-from tensorflow.contrib.layers.python.layers import optimizers
-
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
-from tensorflow.python.estimator import estimator_lib
-from tensorflow.python.estimator.export import export_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-
-
-def _check_feature_shapes_compatible_with(
- features, compatible_with_name, compatible_with_value, ignore=None):
- """Checks all features are compatible with the given time-like feature."""
- if ignore is None:
- ignore = set()
- for name, value in features.items():
- if name in ignore:
- continue
- feature_shape = value.get_shape()
- if feature_shape.ndims is None:
- continue
- if feature_shape.ndims < 2:
- raise ValueError(
- ("Features must have shape (batch dimension, window size, ...) "
- "(got rank {} for feature '{}')").format(
- feature_shape.ndims, name))
- if not feature_shape[:2].is_compatible_with(
- compatible_with_value.get_shape()):
- raise ValueError(
- ("Features must have shape (batch dimension, window size, ...) "
- "where batch dimension and window size match the "
- "'{times_feature}' feature (got shape {feature_shape} for "
- "feature '{feature_name}' but shape {times_shape} for feature "
- "'{times_feature}')").format(
- times_feature=compatible_with_name,
- feature_shape=feature_shape,
- feature_name=name,
- times_shape=compatible_with_value.get_shape()))
-
-
-def _check_predict_features(features):
- """Raises errors if features are not suitable for prediction."""
- if feature_keys.PredictionFeatures.TIMES not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.TIMES))
- if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.STATE_TUPLE))
- times_feature = features[feature_keys.PredictionFeatures.TIMES]
- if not times_feature.get_shape().is_compatible_with([None, None]):
- raise ValueError(
- ("Expected shape (batch dimension, window size) for feature '{}' "
- "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
- times_feature.get_shape()))
- _check_feature_shapes_compatible_with(
- features=features,
- compatible_with_name=feature_keys.PredictionFeatures.TIMES,
- compatible_with_value=times_feature,
- ignore=set([
- feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes
- ]))
-
-
-def _check_train_eval_features(features, model):
- """Raise errors if features are not suitable for training/evaluation."""
- if feature_keys.TrainEvalFeatures.TIMES not in features:
- raise ValueError("Expected a '{}' feature for training/evaluation.".format(
- feature_keys.TrainEvalFeatures.TIMES))
- if feature_keys.TrainEvalFeatures.VALUES not in features:
- raise ValueError("Expected a '{}' feature for training/evaluation.".format(
- feature_keys.TrainEvalFeatures.VALUES))
- times_feature = features[feature_keys.TrainEvalFeatures.TIMES]
- if not times_feature.get_shape().is_compatible_with([None, None]):
- raise ValueError(
- ("Expected shape (batch dimension, window size) for feature '{}' "
- "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES,
- times_feature.get_shape()))
- values_feature = features[feature_keys.TrainEvalFeatures.VALUES]
- if not values_feature.get_shape().is_compatible_with(
- [None, None, model.num_features]):
- raise ValueError(
- ("Expected shape (batch dimension, window size, {num_features}) "
- "for feature '{feature_name}', since the model was configured "
- "with num_features={num_features} (got shape {got_shape})").format(
- num_features=model.num_features,
- feature_name=feature_keys.TrainEvalFeatures.VALUES,
- got_shape=times_feature.get_shape()))
- _check_feature_shapes_compatible_with(
- features=features,
- compatible_with_name=feature_keys.TrainEvalFeatures.TIMES,
- compatible_with_value=times_feature,
- ignore=set([
- feature_keys.State.STATE_TUPLE # Model-dependent shapes
- ]))
-
-
-def _identity_metric_single(name, input_tensor):
- """A metric which takes on its last updated value.
-
- This keeps evaluation metrics in sync with one another, since update ops are
- run separately from their result Tensors. Simply returning (input_tensor,
- no_op) as a metric with a value but no update means that a metric will come
- from a different batch of data than metrics which cache values in a Variable
- (e.g. the default loss metric).
-
- Args:
- name: A name for the metric.
- input_tensor: Any Tensor.
- Returns:
- A tuple of (value, update_op).
- """
- metric_variable = variable_scope.variable(
- name="{}_identity_metric".format(name),
- initial_value=array_ops.zeros([], dtype=input_tensor.dtype),
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- validate_shape=False)
- update_op = state_ops.assign(metric_variable, input_tensor,
- validate_shape=False)
- # This shape will be correct once the first update runs (but may be
- # incomplete, so is not helpful for initializing the variable).
- metric_variable.set_shape(input_tensor.get_shape())
- return (metric_variable.value(), update_op)
-
-
-def _identity_metric_nested(name, input_tensors):
- """Create identity metrics for a nested tuple of Tensors."""
- update_ops = []
- value_tensors = []
- for tensor_number, tensor in enumerate(nest.flatten(input_tensors)):
- value_tensor, update_op = _identity_metric_single(
- name="{}_{}".format(name, tensor_number),
- input_tensor=tensor)
- update_ops.append(update_op)
- value_tensors.append(value_tensor)
- return (nest.pack_sequence_as(input_tensors, value_tensors),
- control_flow_ops.group(*update_ops))
-
-
-def state_to_dictionary(state_tuple):
- """Flatten model state into a dictionary with string keys."""
- flattened = {}
- for state_number, state_value in enumerate(nest.flatten(state_tuple)):
- prefixed_state_name = "{}_{:02d}".format(feature_keys.State.STATE_PREFIX,
- state_number)
- flattened[prefixed_state_name] = state_value
- return flattened
-
-
-def make_model_fn(
- model, state_manager, optimizer, input_statistics_generator=None):
- """Returns a model function suitable for use with a tf.estimator.
-
- Args:
- model: The object (inheriting from Model) to create a function for.
- state_manager: A state manager to wrap the model with (or
- PassthroughStateManager if no state needs to be managed).
- optimizer: An instance of `tf.train.Optimizer` to use for training.
- input_statistics_generator: An InputStatisticsFromMiniBatch object from
- math_utils.py, used for collecting statistics about input data during
- training.
- Returns:
- The model function, suitable for passing to a tf.estimator.Estimator.
- """
-
- def _convert_feature_to_tensor(name, value):
- """Casts features to the correct dtype based on their name."""
- if name in [
- feature_keys.TrainEvalFeatures.TIMES,
- feature_keys.PredictionFeatures.TIMES
- ]:
- return math_ops.cast(value, dtypes.int64)
- if name == feature_keys.TrainEvalFeatures.VALUES:
- return math_ops.cast(value, model.dtype)
- if name == feature_keys.PredictionFeatures.STATE_TUPLE:
- return value # Correct dtypes are model-dependent
- return ops.convert_to_tensor(value)
-
- def _gather_state(features):
- """Returns `features` with state packed, indicates if packing was done."""
- prefixed_state_re = re.compile(r"^" + feature_keys.State.STATE_PREFIX +
- r"_(\d+)$")
- numbered_state = []
- for key, tensor in features.items():
- search_result = prefixed_state_re.search(key)
- if search_result:
- numbered_state.append((int(search_result.group(1)), key, tensor))
- if not numbered_state:
- return features, False
- features = features.copy()
- for _, key, _ in numbered_state:
- del features[key]
- numbered_state.sort(key=lambda number, *_: number)
- features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as(
- structure=model.get_start_state(),
- flat_sequence=[tensor for _, _, tensor in numbered_state])
- return features, True
-
- def _train(features):
- """Add training ops to the graph."""
- with variable_scope.variable_scope("model"):
- model_outputs = state_manager.define_loss(model, features,
- estimator_lib.ModeKeys.TRAIN)
- train_op = optimizers.optimize_loss(
- model_outputs.loss,
- global_step=variables.get_global_step(),
- optimizer=optimizer,
- # Learning rate is set in the Optimizer object
- learning_rate=None)
- return estimator_lib.EstimatorSpec(
- loss=model_outputs.loss,
- mode=estimator_lib.ModeKeys.TRAIN,
- train_op=train_op)
-
- def _evaluate(features):
- """Add ops for evaluation (aka filtering) to the graph."""
- with variable_scope.variable_scope("model"):
- model_outputs = state_manager.define_loss(model, features,
- estimator_lib.ModeKeys.EVAL)
- metrics = {}
- # Just output in-sample predictions for the last chunk seen
- for prediction_key, prediction_value in model_outputs.predictions.items():
- metrics[prediction_key] = _identity_metric_single(prediction_key,
- prediction_value)
- metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single(
- feature_keys.FilteringResults.TIMES, model_outputs.prediction_times)
- metrics[feature_keys.FilteringResults.STATE_TUPLE] = (
- _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,
- model_outputs.end_state))
- return estimator_lib.EstimatorSpec(
- loss=model_outputs.loss,
- mode=estimator_lib.ModeKeys.EVAL,
- eval_metric_ops=metrics,
- predictions={})
-
- def _predict(features):
- """Add ops for prediction to the graph."""
- with variable_scope.variable_scope("model"):
- prediction = model.predict(features=features)
- prediction[feature_keys.PredictionResults.TIMES] = features[
- feature_keys.PredictionFeatures.TIMES]
- return estimator_lib.EstimatorSpec(
- predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT)
-
- def _serving(features):
- with variable_scope.variable_scope("model"):
- prediction_outputs = model.predict(features=features)
- with variable_scope.variable_scope("model", reuse=True):
- filtering_outputs = state_manager.define_loss(model, features,
- estimator_lib.ModeKeys.EVAL)
- return estimator_lib.EstimatorSpec(
- mode=estimator_lib.ModeKeys.PREDICT,
- export_outputs={
- feature_keys.SavedModelLabels.PREDICT:
- export_lib.PredictOutput(prediction_outputs),
- feature_keys.SavedModelLabels.FILTER:
- export_lib.PredictOutput(
- state_to_dictionary(filtering_outputs.end_state))
- },
- # Likely unused, but it is necessary to return `predictions` to satisfy
- # the Estimator's error checking.
- predictions={})
-
- def _model_fn(features, labels, mode):
- """Given a time series in `features`, define a loss for `mode`.
-
- Args:
- features: A dictionary, the output of a chunker (typically with keys
- feature_keys.TrainEvalFeatures.TIMES and
- feature_keys.TrainEvalFeatures.VALUES).
- labels: Not used; included for compatibility with tf.learn.
- mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
- Returns:
- A tuple of predictions, a loss Tensor, and a train op.
- Raises:
- ValueError: If the model makes predictions which do not have static shape
- information.
- """
- if labels:
- raise ValueError("The model received a `labels` dictionary, which is not"
- " supported. Pass '{}' and '{}' as features.".format(
- feature_keys.TrainEvalFeatures.TIMES,
- feature_keys.TrainEvalFeatures.VALUES))
- del labels
- features = {name: _convert_feature_to_tensor(name=name, value=value)
- for name, value in features.items()}
- if input_statistics_generator is not None:
- input_statistics = input_statistics_generator.initialize_graph(
- features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN))
- else:
- input_statistics = None
- model.initialize_graph(input_statistics=input_statistics)
- # _gather_state requires the model to have its graph initialized (so it has
- # access to the structure of the model's state)
- features, passed_flat_state = _gather_state(features)
- if (mode == estimator_lib.ModeKeys.TRAIN
- or mode == estimator_lib.ModeKeys.EVAL):
- _check_train_eval_features(features, model)
- elif mode == estimator_lib.ModeKeys.PREDICT:
- _check_predict_features(features)
- else:
- raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode))
- state_manager.initialize_graph(
- model=model, input_statistics=input_statistics)
- if mode == estimator_lib.ModeKeys.TRAIN:
- return _train(features)
- elif mode == estimator_lib.ModeKeys.EVAL:
- return _evaluate(features)
- elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state:
- return _predict(features)
- elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state:
- # The mode is PREDICT, but we're actually in export_savedmodel for
- # serving. We want to return two graphs: one for filtering (state + data
- # -> state) and one for predicting (state -> prediction).
- return _serving(features)
- return _model_fn
# TODO(agarwal): Remove and replace with functionality from tf.slim
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
index 2998689554..cfd31cc70d 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
@@ -18,22 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.timeseries.python.timeseries import feature_keys
-from tensorflow.contrib.timeseries.python.timeseries import model
from tensorflow.contrib.timeseries.python.timeseries import model_utils
-from tensorflow.contrib.timeseries.python.timeseries import state_management
-from tensorflow.python.estimator import estimator_lib
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import metrics
-from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import coordinator as coordinator_lib
-from tensorflow.python.training import queue_runner_impl
-from tensorflow.python.training import training as train
class ModelUtilsTest(test.TestCase):
@@ -46,230 +34,6 @@ class ModelUtilsTest(test.TestCase):
self.assertEqual(5, getter(parameter))
self.assertEqual(4, getter(overridden_parameter))
- def test_labels_provided_error(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,
- estimator_lib.ModeKeys.PREDICT]:
- with self.assertRaisesRegexp(ValueError, "labels"):
- model_fn(features={}, labels={"a": "b"}, mode=mode)
-
- def test_unknown_mode(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"):
- model_fn(features={}, labels={}, mode="Not a mode")
-
-
-class _TickerModel(object):
- num_features = 1
- dtype = dtypes.float32
-
- def initialize_graph(self, input_statistics):
- pass
-
- def define_loss(self, features, mode):
- del mode # unused
- return model.ModelOutputs(
- loss=features["ticker"],
- end_state=(features["ticker"], features["ticker"]),
- prediction_times=array_ops.zeros(()),
- predictions={"ticker": features["ticker"]})
-
-
-class EvaluationMetricsTests(test.TestCase):
-
- def test_metrics_consistent(self):
- # Tests that the identity metrics used to report in-sample predictions match
- # the behavior of standard metrics.
- g = ops.Graph()
- with g.as_default():
- features = {
- feature_keys.TrainEvalFeatures.TIMES:
- array_ops.zeros((1, 1)),
- feature_keys.TrainEvalFeatures.VALUES:
- array_ops.zeros((1, 1, 1)),
- "ticker":
- array_ops.reshape(
- math_ops.cast(
- variables.Variable(
- name="ticker",
- initial_value=0,
- dtype=dtypes.int64,
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
- .count_up_to(10),
- dtype=dtypes.float32), (1, 1, 1))
- }
- model_fn = model_utils.make_model_fn(
- model=_TickerModel(),
- state_manager=state_management.PassthroughStateManager(),
- optimizer=train.GradientDescentOptimizer(0.001))
- outputs = model_fn(
- features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL)
- metric_update_ops = [
- metric[1] for metric in outputs.eval_metric_ops.values()]
- loss_mean, loss_update = metrics.mean(outputs.loss)
- metric_update_ops.append(loss_update)
- with self.test_session() as sess:
- coordinator = coordinator_lib.Coordinator()
- queue_runner_impl.start_queue_runners(sess, coord=coordinator)
- variables.local_variables_initializer().run()
- sess.run(metric_update_ops)
- loss_evaled, metric_evaled, nested_metric_evaled = sess.run(
- (loss_mean, outputs.eval_metric_ops["ticker"][0],
- outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][
- 0][0]))
- # The custom model_utils metrics for in-sample predictions should be in
- # sync with the Estimator's mean metric for model loss.
- self.assertAllClose(0., loss_evaled)
- self.assertAllClose((((0.,),),), metric_evaled)
- self.assertAllClose((((0.,),),), nested_metric_evaled)
- coordinator.request_stop()
- coordinator.join()
-
-
-class _StubModel(object):
- num_features = 3
- dtype = dtypes.float64
-
- def initialize_graph(self, input_statistics):
- del input_statistics # unused
-
-
-def _stub_model_fn():
- return model_utils.make_model_fn(
- model=_StubModel(),
- state_manager=state_management.PassthroughStateManager(),
- optimizer=train.AdamOptimizer(0.001))
-
-
-class TrainEvalFeatureCheckingTests(test.TestCase):
-
- def test_no_time_feature(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
- feature_keys.TrainEvalFeatures.TIMES)):
- model_fn(
- features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]},
- labels=None,
- mode=mode)
-
- def test_no_value_feature(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
- feature_keys.TrainEvalFeatures.VALUES)):
- model_fn(
- features={feature_keys.TrainEvalFeatures.TIMES: [[1]]},
- labels=None,
- mode=mode)
-
- def test_bad_time_rank(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(ValueError,
- "Expected shape.*for feature '{}'".format(
- feature_keys.TrainEvalFeatures.TIMES)):
- model_fn(
- features={
- feature_keys.TrainEvalFeatures.TIMES: [[[1]]],
- feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
- },
- labels=None,
- mode=mode)
-
- def test_bad_value_rank(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(ValueError,
- "Expected shape.*for feature '{}'".format(
- feature_keys.TrainEvalFeatures.VALUES)):
- model_fn(
- features={
- feature_keys.TrainEvalFeatures.TIMES: [[1]],
- feature_keys.TrainEvalFeatures.VALUES: [[1.]]
- },
- labels=None,
- mode=mode)
-
- def test_bad_value_num_features(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(
- ValueError, "Expected shape.*, 3.*for feature '{}'".format(
- feature_keys.TrainEvalFeatures.VALUES)):
- model_fn(
- features={
- feature_keys.TrainEvalFeatures.TIMES: [[1]],
- feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
- },
- labels=None,
- mode=mode)
-
- def test_bad_exogenous_shape(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(
- ValueError,
- "Features must have shape.*for feature 'exogenous'"):
- model_fn(
- features={
- feature_keys.TrainEvalFeatures.TIMES: [[1]],
- feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]],
- "exogenous": [[1], [2]]
- },
- labels=None,
- mode=mode)
-
-
-class PredictFeatureCheckingTests(test.TestCase):
-
- def test_no_time_feature(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
- feature_keys.PredictionFeatures.TIMES)):
- model_fn(
- features={
- feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.)
- },
- labels=None,
- mode=estimator_lib.ModeKeys.PREDICT)
-
- def test_no_start_state_feature(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
- feature_keys.PredictionFeatures.STATE_TUPLE)):
- model_fn(
- features={feature_keys.PredictionFeatures.TIMES: [[1]]},
- labels=None,
- mode=estimator_lib.ModeKeys.PREDICT)
-
- def test_bad_time_rank(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(ValueError,
- "Expected shape.*for feature '{}'".format(
- feature_keys.PredictionFeatures.TIMES)):
- model_fn(
- features={
- feature_keys.PredictionFeatures.TIMES: 1,
- feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.))
- },
- labels=None,
- mode=estimator_lib.ModeKeys.PREDICT)
-
- def test_bad_exogenous_shape(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(
- ValueError,
- "Features must have shape.*for feature 'exogenous'"):
- model_fn(
- features={
- feature_keys.PredictionFeatures.TIMES: [[1]],
- feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)),
- "exogenous": 1.
- },
- labels=None,
- mode=estimator_lib.ModeKeys.PREDICT)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py
index 16e29f5e68..97f6d36a87 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py
@@ -23,6 +23,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.timeseries.python.timeseries import feature_keys as _feature_keys
+from tensorflow.contrib.timeseries.python.timeseries import head as _head
from tensorflow.contrib.timeseries.python.timeseries import input_pipeline as _input_pipeline
from tensorflow.contrib.timeseries.python.timeseries import model_utils as _model_utils
@@ -34,7 +35,7 @@ def _colate_features_to_feeds_and_fetches(continue_from, signature, features,
"""Uses a saved model signature to construct feed and fetch dictionaries."""
if _feature_keys.FilteringResults.STATE_TUPLE in continue_from:
# We're continuing from an evaluation, so we need to unpack/flatten state.
- state_values = _model_utils.state_to_dictionary(
+ state_values = _head.state_to_dictionary(
continue_from[_feature_keys.FilteringResults.STATE_TUPLE])
else:
state_values = continue_from
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index eb66d8e329..f3e43dd552 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1773,6 +1773,7 @@ tf_cuda_library(
) + if_mkl(
[
"//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn//:mkl_dnn",
],
),
alwayslink = 1,
@@ -1933,7 +1934,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/visitable_allocator.h",
"graph/gradients.h",
"graph/quantize_training.h",
-]
+] + if_mkl(["graph/mkl_graph_util.h"])
tf_cuda_library(
name = "core_cpu_impl",
@@ -2034,7 +2035,10 @@ tf_cuda_library(
"//third_party/eigen3",
"//tensorflow/core/kernels:required",
] + if_mkl(
- ["//third_party/mkl:intel_binary_blob"],
+ [
+ "//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn//:mkl_dnn",
+ ],
) + tf_additional_core_deps() + if_static([":core_cpu_impl"]),
alwayslink = 1,
)
@@ -2670,7 +2674,7 @@ tf_cc_test_mkl(
"graph/mkl_layout_pass_test.cc",
"graph/mkl_tfconversion_pass_test.cc",
],
- linkstatic = tf_kernel_tests_linkstatic(),
+ linkstatic = 1,
deps = [
":core",
":core_cpu",
@@ -2688,18 +2692,6 @@ tf_cc_test_mkl(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/core/kernels:mkl_aggregate_ops",
- "//tensorflow/core/kernels:mkl_concat_op",
- "//tensorflow/core/kernels:mkl_conv_op",
- "//tensorflow/core/kernels:mkl_cwise_ops_common",
- "//tensorflow/core/kernels:mkl_fused_batch_norm_op",
- "//tensorflow/core/kernels:mkl_identity_op",
- "//tensorflow/core/kernels:mkl_input_conversion_op",
- "//tensorflow/core/kernels:mkl_lrn_op",
- "//tensorflow/core/kernels:mkl_pooling_ops",
- "//tensorflow/core/kernels:mkl_relu_op",
- "//tensorflow/core/kernels:mkl_reshape_op",
- "//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:ops_util",
"//third_party/eigen3",
],
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
new file mode 100644
index 0000000000..cb32d64334
--- /dev/null
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
+#define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
+#ifdef INTEL_MKL
+
+#include <string>
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+// Since our ops are going to produce and also consume N addition tensors
+// (Mkl) for N Tensorflow tensors, we can have following different
+// orderings among these 2N tensors.
+//
+// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
+// consume A_m, B_m, and C_m additionally.
+//
+// INTERLEAVED: in this case 2N tensors are interleaved. So for above
+// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
+//
+// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
+// by N Mkl tensors. So for above example, the ordering looks
+// like: A, B, C, A_m, B_m, C_m
+//
+// Following APIs map index of original Tensorflow tensors to their
+// appropriate position based on selected ordering. For contiguous ordering,
+// we need to know the total number of tensors (parameter total).
+//
+typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
+// NOTE: Currently, we use contiguous ordering. If you change this, then you
+// would need to change Mkl op definitions in nn_ops.cc.
+static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
+
+// Get index of MetaData tensor from index 'n' of Data tensor.
+inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ // For interleaved ordering, Mkl tensor follows immediately after
+ // Tensorflow tensor.
+ return n + 1;
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
+ return n + total_tensors / 2;
+ }
+}
+
+int inline GetTensorDataIndex(int n, int total_tensors) {
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ return 2 * n; // index corresponding to nth input/output tensor
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ return n;
+ }
+}
+
+int inline GetTensorMetaDataIndex(int n, int total_tensors) {
+ // Get index for TensorData first and then use mapping function
+ // to get TensorMetaData index from TensorData index.
+ int tidx = GetTensorDataIndex(n, total_tensors);
+ return DataIndexToMetaDataIndex(tidx, total_tensors);
+}
+
+namespace mkl_op_registry {
+static const char* kMklOpLabel = "MklOp";
+static const char* kMklOpLabelPattern = "label='MklOp'";
+
+// Get the name of Mkl op from original TensorFlow op
+// We prefix 'Mkl' to the original op to get Mkl op.
+inline string GetMklOpName(const string& name) {
+ // Prefix that we add to Tensorflow op name to construct Mkl op name.
+ const char* const kMklOpPrefix = "_Mkl";
+ return string(kMklOpPrefix) + name;
+}
+
+// Check whether opname with type T is registered as MKL-compliant.
+//
+// @input: name of the op
+// @input: T datatype to be used for checking op
+// @return: true if opname is registered as Mkl op; false otherwise
+static inline bool IsMklOp(const std::string& op_name, DataType T) {
+ string kernel = KernelsRegisteredForOp(op_name);
+ bool result =
+ kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
+ if (result) {
+ VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
+ }
+ return result;
+}
+
+// Check whether opname with type T is registered as MKL-compliant and
+// is element-wise.
+//
+// @input: name of the op
+// @input: T datatype to be used for checking op
+// @return: true if opname is registered as element-wise Mkl op;
+// false otherwise
+static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
+ if (!IsMklOp(op_name, T)) {
+ return false;
+ }
+
+ bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
+ 0 == op_name.compare(GetMklOpName("Sub")) ||
+ 0 == op_name.compare(GetMklOpName("Mul")) ||
+ 0 == op_name.compare(GetMklOpName("Maximum")) ||
+ 0 == op_name.compare(GetMklOpName("SquaredDifference")));
+
+ VLOG(1) << "mkl_op_registry::" << op_name
+ << " is elementwise MKL op: " << result;
+ return result;
+}
+} // namespace mkl_op_registry
+} // namespace tensorflow
+#endif // INTEL_MKL
+#endif // TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 90377e54c7..f87a94a76a 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -37,8 +37,8 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
-#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 6a41e3965a..a2b2f6530d 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include "tensorflow/core/graph/mkl_layout_pass.h"
-#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm>
#include <string>
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 3f8b0e86d0..fe4588389e 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -33,8 +33,8 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
-#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index b01818f746..bbdbe78bbd 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
-#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm>
#include <string>
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 36fbf6b023..bdc6faefbc 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -820,6 +820,7 @@ tf_kernel_library(
hdrs = ["transpose_op.h"],
deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn//:mkl_dnn",
]),
)
@@ -2596,6 +2597,7 @@ tf_kernel_library(
"//conditions:default": [],
}) + if_mkl([
"//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn//:mkl_dnn",
]) + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]),
@@ -5501,8 +5503,10 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
+ ] + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
@@ -5516,8 +5520,10 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
+ ] + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
@@ -5566,16 +5572,19 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
+ ] + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
name = "mkl_fused_batch_norm_op",
srcs = ["mkl_fused_batch_norm_op.cc"],
- deps = NN_DEPS + [
+ deps = NN_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
@@ -5589,9 +5598,10 @@ tf_mkl_kernel_library(
tf_mkl_kernel_library(
name = "mkl_concat_op",
prefix = "mkl_concat_op",
- deps = ARRAY_DEPS + [
+ deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
@@ -5605,17 +5615,19 @@ tf_mkl_kernel_library(
tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
- deps = ARRAY_DEPS + [
+ deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
name = "mkl_lrn_op",
prefix = "mkl_lrn_op",
- deps = NN_DEPS + [
+ deps = NN_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 1bdfafb89b..368993c827 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -39,6 +39,48 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
+namespace {
+
+void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
+ int32* batch, int32* height, int32* width,
+ int32* channel) {
+ *batch = 1;
+ *width = 1;
+ *height = 1;
+ *channel = 1;
+ if (data_format == FORMAT_NHWC) {
+ int32 channel_dim = value_tensor.dims() - 1;
+ *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
+ for (int32 i = 0; i < channel_dim; i++) {
+ *batch *= static_cast<int32>(value_tensor.dim_size(i));
+ }
+ } else if (data_format == FORMAT_NCHW) {
+ int32 channel_dim = value_tensor.dims() - 3;
+ int32 height_dim = value_tensor.dims() - 2;
+ int32 width_dim = value_tensor.dims() - 1;
+ *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
+ *height = static_cast<int32>(value_tensor.dim_size(height_dim));
+ *width = static_cast<int32>(value_tensor.dim_size(width_dim));
+ for (int32 i = 0; i < channel_dim; i++) {
+ *batch *= static_cast<int32>(value_tensor.dim_size(i));
+ }
+ }
+}
+
+template <class T>
+struct AccumulatorType {
+ typedef T type;
+};
+
+// float is faster on the CPU than half, and also more precise,
+// so use float for the temporary accumulators.
+template <>
+struct AccumulatorType<Eigen::half> {
+ typedef float type;
+};
+
+} // namespace
+
template <typename Device, typename T>
class BiasOp : public BinaryOp<T> {
public:
@@ -50,9 +92,6 @@ class BiasOp : public BinaryOp<T> {
} else {
data_format_ = FORMAT_NHWC;
}
- OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
- errors::InvalidArgument(context->device()->name() +
- " BiasOp only supports NHWC."));
}
void Compute(OpKernelContext* context) override {
@@ -65,9 +104,21 @@ class BiasOp : public BinaryOp<T> {
OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
errors::InvalidArgument("Biases must be 1D: ",
bias.shape().DebugString()));
- const auto last_dim = input.shape().dims() - 1;
+
+ // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
+ size_t channel_dim;
+ if (data_format_ == FORMAT_NCHW) {
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument(
+ "NCHW format supports only 4D input tensor."));
+ channel_dim = 1;
+ } else {
+ channel_dim = input.shape().dims() - 1; // End of code by intel_tf.
+ }
+
OP_REQUIRES(
- context, bias.shape().dim_size(0) == input.shape().dim_size(last_dim),
+ context,
+ bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
errors::InvalidArgument(
"Must provide as many biases as the last dimension "
"of the input tensor: ",
@@ -78,6 +129,19 @@ class BiasOp : public BinaryOp<T> {
{0}, 0, input.shape(), &output));
if (input.NumElements() == 0) return;
+ // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
+ if (data_format_ == FORMAT_NCHW) {
+ int32 batch, height, width, channel;
+ GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
+ Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
+ Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
+ const Device& d = context->eigen_device<Device>();
+ output->tensor<T, 4>().device(d) =
+ input.tensor<T, 4>() +
+ bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
+ return;
+ } // End of code by intel_tf.
+
switch (input.shape().dims()) {
case 2:
Compute<2>(context, input, bias, output);
@@ -137,48 +201,6 @@ REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL
-namespace {
-
-void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
- int32* batch, int32* height, int32* width,
- int32* channel) {
- *batch = 1;
- *width = 1;
- *height = 1;
- *channel = 1;
- if (data_format == FORMAT_NHWC) {
- int32 channel_dim = value_tensor.dims() - 1;
- *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
- for (int32 i = 0; i < channel_dim; i++) {
- *batch *= static_cast<int32>(value_tensor.dim_size(i));
- }
- } else if (data_format == FORMAT_NCHW) {
- int32 channel_dim = value_tensor.dims() - 3;
- int32 height_dim = value_tensor.dims() - 2;
- int32 width_dim = value_tensor.dims() - 1;
- *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
- *height = static_cast<int32>(value_tensor.dim_size(height_dim));
- *width = static_cast<int32>(value_tensor.dim_size(width_dim));
- for (int32 i = 0; i < channel_dim; i++) {
- *batch *= static_cast<int32>(value_tensor.dim_size(i));
- }
- }
-}
-
-template <class T>
-struct AccumulatorType {
- typedef T type;
-};
-
-// float is faster on the CPU than half, and also more precise,
-// so use float for the temporary accumulators.
-template <>
-struct AccumulatorType<Eigen::half> {
- typedef float type;
-};
-
-} // namespace
-
template <typename Device, typename T>
class BiasGradOp : public OpKernel {
public:
@@ -190,9 +212,6 @@ class BiasGradOp : public OpKernel {
} else {
data_format_ = FORMAT_NHWC;
}
- OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
- errors::InvalidArgument(context->device()->name() +
- " BiasGradOp only supports NHWC."));
}
void Compute(OpKernelContext* context) override {
@@ -222,18 +241,40 @@ class BiasGradOp : public OpKernel {
// Eigen often crashes by design on empty tensors, but setZero is safe
output->template flat<T>().setZero();
} else {
- Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
+ // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
+ if (data_format_ == FORMAT_NCHW) {
+ OP_REQUIRES(context, output_backprop.dims() == 4,
+ errors::InvalidArgument(
+ "NCHW format supports only 4D input/output tensor."));
+ Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
+#ifdef EIGEN_HAS_INDEX_LIST
+ using idx0 = Eigen::type2index<0>;
+ using idx2 = Eigen::type2index<2>;
+ using idx3 = Eigen::type2index<3>;
+ Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
+#else
+ Eigen::array<int, 3> reduction_axes = {0, 2, 3};
+#endif
+ output->template flat<T>().device(context->eigen_device<Device>()) =
+ output_backprop.flat<T>()
+ .template cast<typename AccumulatorType<T>::type>()
+ .reshape(four_dims)
+ .sum(reduction_axes)
+ .template cast<T>(); // End of code by intel_tf.
+ } else {
+ Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
#ifdef EIGEN_HAS_INDEX_LIST
- Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
+ Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
#else
- Eigen::array<int, 1> reduction_axis = {0};
+ Eigen::array<int, 1> reduction_axis = {0};
#endif
- output->template flat<T>().device(context->eigen_device<Device>()) =
- output_backprop.flat<T>()
- .template cast<typename AccumulatorType<T>::type>()
- .reshape(two_dims)
- .sum(reduction_axis)
- .template cast<T>();
+ output->template flat<T>().device(context->eigen_device<Device>()) =
+ output_backprop.flat<T>()
+ .template cast<typename AccumulatorType<T>::type>()
+ .reshape(two_dims)
+ .sum(reduction_axis)
+ .template cast<T>();
+ }
}
}
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 641077ca65..5e09963d2d 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -816,40 +816,35 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(
- ConvolveBackwardFilterScratchSize, ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardFilterWithAlgorithm(
- input_desc, input_ptr, output_desc, out_backprop_ptr,
- conv_desc, filter_desc, &filter_backprop_ptr,
- &scratch_allocator, AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
+ ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardFilterWithAlgorithm(
+ input_desc, input_ptr, output_desc, out_backprop_ptr,
+ conv_desc, filter_desc, &filter_backprop_ptr,
+ &scratch_allocator, AlgorithmConfig(profile_algorithm),
+ &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 0732bf4046..0b2d01afa9 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -870,39 +870,34 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
- ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardDataWithAlgorithm(
- filter_desc, filter_ptr, output_desc, out_backprop_ptr,
- conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
+ ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr,
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 8ad56053a8..21f5cb1716 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -654,40 +654,34 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(
- ConvolveBackwardDataScratchSize, context);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardDataWithAlgorithm(
- filter_desc, filter_ptr, output_desc, out_backprop_ptr,
- conv_desc, input_desc, &in_backprop_ptr,
- &scratch_allocator, AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
+ context);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr,
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
@@ -1026,40 +1020,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(
- ConvolveBackwardFilterScratchSize, context);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardFilterWithAlgorithm(
- input_desc, input_ptr, output_desc, out_backprop_ptr,
- conv_desc, filter_desc, &filter_backprop_ptr,
- &scratch_allocator, AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(
+ ConvolveBackwardFilterScratchSize, context);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardFilterWithAlgorithm(
+ input_desc, input_ptr, output_desc, out_backprop_ptr,
+ conv_desc, filter_desc, &filter_backprop_ptr,
+ &scratch_allocator, AlgorithmConfig(profile_algorithm),
+ &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index dc03eeb658..bb67113fb0 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -662,38 +662,33 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveWithAlgorithm(
- input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
- output_desc, &output_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ output_desc, &output_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 72758f707a..8a89d564de 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -390,38 +390,33 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveWithAlgorithm(
- input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
- output_desc, &output_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ output_desc, &output_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 42ea23553b..5e48ae9766 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -36,8 +36,8 @@ class DecodeCSVOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_quote_delim", &use_quote_delim_));
OP_REQUIRES(ctx, delim.size() == 1,
errors::InvalidArgument("field_delim should be only 1 char"));
-
delim_ = delim[0];
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("na_value", &na_value_));
}
void Compute(OpKernelContext* ctx) override {
@@ -79,9 +79,9 @@ class DecodeCSVOp : public OpKernel {
const DataType& dtype = out_type_[f];
switch (dtype) {
case DT_INT32: {
- // If this field is empty, check if default is given:
+ // If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
- if (fields[f].empty()) {
+ if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
@@ -99,9 +99,9 @@ class DecodeCSVOp : public OpKernel {
break;
}
case DT_INT64: {
- // If this field is empty, check if default is given:
+ // If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
- if (fields[f].empty()) {
+ if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
@@ -119,9 +119,9 @@ class DecodeCSVOp : public OpKernel {
break;
}
case DT_FLOAT: {
- // If this field is empty, check if default is given:
+ // If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
- if (fields[f].empty()) {
+ if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
@@ -138,9 +138,9 @@ class DecodeCSVOp : public OpKernel {
break;
}
case DT_STRING: {
- // If this field is empty, check if default is given:
+ // If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
- if (fields[f].empty()) {
+ if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
@@ -165,6 +165,7 @@ class DecodeCSVOp : public OpKernel {
std::vector<DataType> out_type_;
char delim_;
bool use_quote_delim_;
+ string na_value_;
void ExtractFields(OpKernelContext* ctx, StringPiece input,
std::vector<string>* result) {
diff --git a/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
index 25a6813d59..0174c8dfc8 100644
--- a/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
@@ -49,10 +49,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->input("row_shape", &row_shape_t));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(row_shape_t->shape()),
errors::InvalidArgument("row_shape must be a vector"));
- TensorShape row_shape;
- for (size_t i = 0; i < row_shape_t->dim_size(0); ++i) {
- row_shape.AddDim(row_shape_t->vec<int64>()(i));
- }
+ PartialTensorShape row_shape;
+ OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
+ row_shape_t->vec<int64>().data(),
+ row_shape_t->NumElements(), &row_shape));
*output = nullptr;
@@ -78,7 +78,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
template <class T>
class Dataset : public DatasetBase {
public:
- Dataset(int64 batch_size, const TensorShape& row_shape,
+ Dataset(int64 batch_size, const PartialTensorShape& row_shape,
const DatasetBase* input)
: batch_size_(batch_size), row_shape_(row_shape), input_(input) {
input_->Ref();
@@ -129,9 +129,22 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
int64 total_elements = 0;
batch_elements.reserve(
DatasetIterator<Dataset<T>>::dataset()->batch_size_);
- const TensorShape& row_shape =
+ const PartialTensorShape& row_shape =
DatasetIterator<Dataset<T>>::dataset()->row_shape_;
const int row_ndims = row_shape.dims();
+
+ // Determine the size of the output tensors:
+ // * dense_shape will be [`row_shape + 1`].
+ Tensor dense_shape(cpu_allocator(), DT_INT64, {row_ndims + 1});
+ auto dense_shape_vec = dense_shape.vec<int64>();
+ for (size_t i = 0; i < row_ndims; ++i) {
+ if (row_shape.dim_size(i) == -1) {
+ dense_shape_vec(i + 1) = 0;
+ } else {
+ dense_shape_vec(i + 1) = row_shape.dim_size(i);
+ }
+ }
+
{
mutex_lock l(mu_);
*end_of_sequence = false;
@@ -156,9 +169,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
") that is incompatible with the row shape (",
row_shape.DebugString(), ").");
}
- for (int i = 0; i < row_ndims; ++i) {
- if (batch_element_tuple[0].shape().dim_size(i) >
- row_shape.dim_size(i)) {
+ for (int j = 0; j < row_ndims; ++j) {
+ // Take the maximum in the dimension if -1 is given.
+ if (row_shape.dim_size(j) == -1) {
+ dense_shape_vec(j + 1) =
+ std::max(batch_element_tuple[0].dim_size(j),
+ dense_shape_vec(j + 1));
+ } else if (batch_element_tuple[0].dim_size(j) >
+ row_shape.dim_size(j)) {
return errors::DataLoss(
"Input element had shape (",
batch_element_tuple[0].shape().DebugString(),
@@ -175,20 +193,16 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- // Determine the size of the output tensors:
// * indices will be [`total_elements`, `row_shape + 1`].
// * values will be [`total_elements`].
- // * dense_shape will be [`row_shape + 1`].
Tensor indices(cpu_allocator(), DT_INT64,
{total_elements, row_ndims + 1});
Tensor values(
cpu_allocator(),
DatasetIterator<Dataset<T>>::dataset()->output_dtypes()[1],
{total_elements});
- Tensor dense_shape(cpu_allocator(), DT_INT64, {row_ndims + 1});
auto indices_matrix = indices.matrix<int64>();
auto values_flat = values.flat<T>();
- auto dense_shape_vec = dense_shape.vec<int64>();
int64 current_position_in_values = 0;
for (int64 i = 0; i < batch_elements.size(); ++i) {
@@ -220,9 +234,6 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
}
dense_shape_vec(0) = batch_elements.size();
- for (size_t i = 0; i < row_ndims; ++i) {
- dense_shape_vec(i + 1) = row_shape.dim_size(i);
- }
out_tensors->push_back(std::move(indices));
out_tensors->push_back(std::move(values));
@@ -239,7 +250,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
};
const int64 batch_size_;
- const TensorShape row_shape_;
+ const PartialTensorShape row_shape_;
const DatasetBase* const input_;
std::vector<PartialTensorShape> output_shapes_;
};
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index f81a448e51..9080bf7be8 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -41,10 +42,24 @@ limitations under the License.
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+
+using mkldnn::prop_kind;
+using mkldnn::stream;
+
+using mkldnn::convolution_backward_weights;
+using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
+
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifndef INTEL_MKL_DNN
+
template <typename Device, class T>
class MklConv2DCustomBackpropFilterOp : public OpKernel {
public:
@@ -411,6 +426,172 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
TensorFormat data_format_;
};
+#else
+
+template <typename Device, class T>
+class MklConv2DCustomBackpropFilterOp : public OpKernel {
+ public:
+ explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ int stride_n = GetTensorDim(strides_, data_format_, 'N');
+ int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ context, (stride_n == 1 && stride_c == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ auto cpu_engine = engine(engine::cpu, 0);
+
+ MklDnnData<T> input(&cpu_engine);
+ MklDnnData<T> outbackprop(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Input tensors
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ const Tensor& filter_tensor = MklGetInput(context, 1);
+ const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
+
+ // Generate input shapes.
+ TensorShape filter_shape;
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(filter_tensor.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
+ filter_tensor.dims()));
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_tensor.vec<int32>(), &filter_shape));
+ TensorShape input_shape = input_tensor.shape();
+ TensorShape obp_shape = obp_tensor.shape();
+
+ // By default, all dims are in MKL order. Only dims in TF order
+ // are those with prefix tf_order.
+ memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
+ memory::dims padding_l, padding_r, strides, fwd_output_dims;
+ memory::dims fwd_output_dims_tf_order;
+
+ // Get forward convolution parameters.
+ MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims,
+ &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
+ &padding_r);
+ if (!context->status().ok()) return;
+
+ // Create Convolution forward descriptor since Convolution backward
+ // API needs it. For that, we first need to create input, filter
+ // and output memory descriptors.
+ auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
+ auto fwd_src_md =
+ memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_filter_md =
+ memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio);
+ auto fwd_out_md =
+ memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md,
+ fwd_out_md, strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+ auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
+
+ // Allocate output tensor and shape
+ // TODO(nhasabni): Update this when support for MKL layout is added.
+ // Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
+ TensorShape tf_output_shape(filter_shape);
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ Tensor* output_tensor = nullptr;
+ AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
+ mkl_output_mkl_shape);
+
+ // Create memory for user data.
+ // Describe how the inputs and outputs of Convolution look like. Also
+ // specify buffers containing actual input and output data.
+ // Although input shape required is in MKL-DNN order, the layout is
+ // Tensorflow's layout (NHWC or NCHW depending on data format).
+ input.SetUsrMem(fwd_input_dims, mkl_data_format, &input_tensor);
+ // Outbackprop shape is NHWC or NCHW depending on data format. Since
+ // GetInputSizeInMklOrder function returns size in that order we just use
+ // use that function directly.
+ conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
+ if (!context->status().ok()) return;
+ outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
+ // Although output shape required is in MKL-DNN order,
+ // layout is Tensorflow's filter layout (HWIO)
+ // Shape of output of Conv2DBackpropInput is same as shape of filter.
+ memory::dims bwd_output_dims = fwd_filter_dims;
+ output.SetUsrMem(bwd_output_dims, memory::format::hwio, output_tensor);
+
+ // Create memory descriptors for convolution data w/ no specified format.
+ input.SetOpMemDesc(fwd_input_dims, memory::format::any);
+ outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
+ output.SetOpMemDesc(bwd_output_dims, memory::format::any);
+
+ // Create convolution backward weights primitive.
+ auto bwd_desc = convolution_backward_weights::desc(
+ convolution_direct, input.GetOpMemDesc(), output.GetOpMemDesc(),
+ outbackprop.GetOpMemDesc(), strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+
+ auto bwd_pd = convolution_backward_weights::primitive_desc(
+ bwd_desc, cpu_engine, fwd_pd);
+
+ PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output);
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ // Prepare and execute net - checks for input and output reorders.
+ void PrepareAndExecutePrimitive(
+ const convolution_backward_weights::primitive_desc& conv_pd,
+ MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output) {
+ // Create reorders between user layout and MKL layout if it is needed and
+ // add it to the net before convolution.
+ std::vector<primitive> net;
+ input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net);
+ obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
+
+ // Memory for output of convolution. Since we may need reorder on the
+ // output side, we will prepare reorder primitive in case output
+ // reorder to user memory is required.
+ bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
+ conv_pd.diff_weights_primitive_desc());
+
+ net.push_back(convolution_backward_weights(
+ conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem()));
+
+ // Insert reorder primitive in the net for output reorder if reorder is
+ // required.
+ if (output_reorder_required) {
+ output->InsertReorderToUserMem(&net);
+ }
+
+ // Handle output reorder
+ stream(stream::kind::eager).submit(net).wait();
+ }
+};
+#endif
+
#define REGISTER_MKL_FILTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 00884d0981..4b6bf92e42 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -23,6 +23,8 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <vector>
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -30,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -40,13 +43,24 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
+
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+
+using mkldnn::prop_kind;
+using mkldnn::stream;
+
+using mkldnn::convolution_backward_data;
+using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
+#endif
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifndef INTEL_MKL_DNN
+
template <typename Device, class T>
class MklConv2DCustomBackpropInputOp : public OpKernel {
public:
@@ -345,6 +359,178 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format;
};
+#else
+
+template <typename Device, class T>
+class MklConv2DCustomBackpropInputOp : public OpKernel {
+ public:
+ ~MklConv2DCustomBackpropInputOp() {}
+ explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format_str;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
+ OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ int stride_n = GetTensorDim(strides_, data_format_, 'N');
+ int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ context, (stride_n == 1 && stride_c == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ auto cpu_engine = engine(engine::cpu, 0);
+
+ MklDnnData<T> filter(&cpu_engine);
+ MklDnnData<T> outbackprop(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Input tensors
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ const Tensor& filter_tensor = MklGetInput(context, 1);
+ const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
+
+ // Generate input shape.
+ TensorShape input_shape;
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(input_tensor.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
+ input_tensor.dims()));
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ input_tensor.vec<int32>(), &input_shape));
+ TensorShape filter_shape = filter_tensor.shape();
+ TensorShape obp_shape = obp_tensor.shape();
+
+ // By default, all dims are in MKL order. Only dims in TF order
+ // are those with prefix tf_order.
+ memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
+ memory::dims padding_l, padding_r, strides, fwd_output_dims;
+ memory::dims fwd_output_dims_tf_order;
+
+ // Get forward convolution parameters.
+ MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims,
+ &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
+ &padding_r);
+ if (!context->status().ok()) return;
+
+ // Create Convolution forward descriptor since Convolution backward
+ // API needs it. For that, we first need to create input, filter
+ // and output memory descriptors.
+ auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
+ auto fwd_src_md =
+ memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_filter_md =
+ memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio);
+ auto fwd_out_md =
+ memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md,
+ fwd_out_md, strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+ auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
+
+ // Allocate output tensor and shape
+ // TODO(nhasabni): Update this when support for MKL layout is added.
+ // Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
+ TensorShape tf_output_shape(input_shape);
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ Tensor* output_tensor = nullptr;
+ AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
+ mkl_output_mkl_shape);
+
+ // Create memory for user data.
+ // Describe how the inputs and outputs of Convolution look like. Also
+ // specify buffers containing actual input and output data.
+ // Although input shape required is in MKL-DNN order, the layout is
+ // Tensorflow's layout (NHWC or NCHW depending on data format).
+ // Although filter shape (filter_dims) required is in MKL-DNN order,
+ // the layout is Tensorflow's layout (HWIO).
+ // Shape of Conv2DBackpropInput's filter is same as that of Conv2D filter.
+ filter.SetUsrMem(fwd_filter_dims, memory::format::hwio, &filter_tensor);
+ // Outbackprop shape is NHWC or NCHW depending on data format. Since
+ // GetInputSizeInMklOrder function returns size in that order we just use
+ // use that function directly.
+ conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
+ if (!context->status().ok()) return;
+ outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
+ // Although output shape required is in MKL-DNN order,
+ // layout is Tensorflow's layout (NHWC or NCHW depending on data format).
+ // Shape of output of Conv2DBackpropInput is same as shape of 'input'
+ // of Conv2D.
+ memory::dims bwd_output_dims = fwd_input_dims;
+ output.SetUsrMem(bwd_output_dims, mkl_data_format, output_tensor);
+
+ // Create memory descriptors for convolution data w/ no specified format.
+ filter.SetOpMemDesc(fwd_filter_dims, memory::format::any);
+ outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
+ output.SetOpMemDesc(bwd_output_dims, memory::format::any);
+
+ // Create convolution backward data primitive.
+ auto bwd_desc = convolution_backward_data::desc(
+ convolution_direct, output.GetOpMemDesc(), filter.GetOpMemDesc(),
+ outbackprop.GetOpMemDesc(), strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+
+ auto bwd_pd = convolution_backward_data::primitive_desc(
+ bwd_desc, cpu_engine, fwd_pd);
+
+ PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output);
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ // Prepare and execute net - checks for input and output reorders.
+ void PrepareAndExecutePrimitive(
+ const convolution_backward_data::primitive_desc& conv_pd,
+ MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) {
+ // Create reorders between user layout and MKL layout if it is needed and
+ // add it to the net before convolution.
+ std::vector<primitive> net;
+ filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net);
+ obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
+
+ // Memory for output of convolution. Since we may need reorder on the
+ // output side, we will prepare reorder primitive in case output
+ // reorder to user memory is required.
+ bool output_reorder_required =
+ output->PrepareReorderToUserMemIfReq(conv_pd.diff_src_primitive_desc());
+
+ net.push_back(convolution_backward_data(
+ conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem()));
+
+ // Insert reorder primitive in the net for output reorder if reorder is
+ // required.
+ if (output_reorder_required) {
+ output->InsertReorderToUserMem(&net);
+ }
+
+ // Handle output reorder
+ stream(stream::kind::eager).submit(net).wait();
+ }
+};
+
+#endif // INTEL_MKL_DNN
+
#define REGISTER_MKL_CPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 7f1555d325..57661e8b10 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include <string.h>
#include <map>
+#include <string>
#include <vector>
+
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -40,10 +43,23 @@ limitations under the License.
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+
+using mkldnn::prop_kind;
+using mkldnn::stream;
+
+using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+// For now, MKL-ML is default. So making MKL-DNN not a default choice.
+#ifndef INTEL_MKL_DNN
+
template <typename Device, typename T, bool biasEnabled>
class MklConv2DOp : public OpKernel {
public:
@@ -461,6 +477,203 @@ class MklConv2DOp : public OpKernel {
TensorFormat data_format_;
};
+#else
+
+template <typename Device, typename T, bool biasEnabled>
+class MklConv2DOp : public OpKernel {
+ public:
+ ~MklConv2DOp() {}
+
+ explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+
+ const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
+ const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ context, stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ auto cpu_engine = engine(engine::cpu, 0);
+
+ // Input tensors
+ size_t src_idx = 0, filter_idx = 1;
+ const Tensor& src_tensor = MklGetInput(context, src_idx);
+ const Tensor& filter_tensor = MklGetInput(context, filter_idx);
+
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> filter(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ memory::dims src_dims, filter_dims, padding_l, padding_r, strides;
+ memory::dims output_dims_tf_order, output_dims_mkl_order;
+
+ // Get shapes of input tensors in MKL-DNN order
+ MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ src_tensor.shape(), filter_tensor.shape(), &src_dims, &filter_dims,
+ &strides, &output_dims_tf_order, &output_dims_mkl_order, &padding_l,
+ &padding_r);
+ if (!context->status().ok()) return;
+
+ // Check for corner case - if there is nothing to compute, return.
+ TensorShape tf_output_shape(
+ {output_dims_tf_order[0], output_dims_tf_order[1],
+ output_dims_tf_order[2], output_dims_tf_order[3]});
+ Tensor* output_tensor = nullptr;
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
+ mkl_output_mkl_shape);
+
+ // Forward filter in TF format from input at index 1 to output at index 1.
+ ForwardTfTensorInToOut(context, 1, 1);
+
+ if (tf_output_shape.num_elements() == 0) {
+ // TODO(jbobba): Verify correctness here
+ // Need semantics for Null MKL tensor
+ return;
+ }
+
+ // Corner case to handle 0 batch size.
+ if (output_dims_tf_order[0] == 0) {
+ // Nothing to do, allocate output tensor and return
+ // TODO(nhasabni): remove this code later once serialization
+ // in MKL-DNN is supported.
+ AllocateOutputSetMklShape(context, 0, &output_tensor,
+ src_tensor.shape(), mkl_output_mkl_shape);
+ return;
+ } else {
+ // Otherwise regular output tensor allocation
+ // Allocate output tensor.
+ }
+ CHECK_NOTNULL(output_tensor);
+
+ // Create memory for user data.
+ // Describe how the inputs and outputs of Convolution look like. Also
+ // specify buffers containing actual input and output data.
+ // Although input shape (src_dims) required is in MKL-DNN order,
+ // the layout is Tensorflow's layout (NHWC or NCHW depending on data
+ // format).
+ src.SetUsrMem(src_dims, TFDataFormatToMklDnnDataFormat(data_format_),
+ const_cast<void*>(
+ static_cast<const void*>(src_tensor.flat<T>().data())));
+ // Although filter shape (filter_dims) required is in MKL-DNN order,
+ // the layout is Tensorflow's layout (HWIO).
+ filter.SetUsrMem(filter_dims, memory::format::hwio,
+ const_cast<void*>(static_cast<const void*>(
+ filter_tensor.flat<T>().data())));
+ // Although output shape (output_dims) required is in MKL-DNN order,
+ // layout is Tensorflow's layout (NHWC or NCHW depending on data format).
+ output.SetUsrMem(output_dims_mkl_order,
+ TFDataFormatToMklDnnDataFormat(data_format_),
+ output_tensor->flat<T>().data());
+
+ // Create memory descriptors for convolution data w/ no specified format.
+ src.SetOpMemDesc(src_dims, memory::format::any);
+ filter.SetOpMemDesc(filter_dims, memory::format::any);
+ output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
+
+ // If bias is enabled, then do the same steps as above for bias.
+ if (biasEnabled) {
+ MklDnnData<T> bias(&cpu_engine);
+ memory::dims bias_size;
+ conv_utl.GetBiasSizeInMklOrder(2 /* bias idx */, &bias_size);
+ const Tensor& bias_tensor = MklGetInput(context, 2);
+ bias.SetUsrMem(bias_size, memory::format::x,
+ const_cast<void*>(static_cast<const void*>(
+ bias_tensor.flat<T>().data())));
+ bias.SetOpMemDesc(bias_size, memory::format::any);
+
+ // Create convolution primitive with Bias.
+ auto conv_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
+ filter.GetOpMemDesc(), bias.GetOpMemDesc(), output.GetOpMemDesc(),
+ strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
+
+ auto conv_prim_desc =
+ convolution_forward::primitive_desc(conv_desc, cpu_engine);
+ PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output);
+ } else {
+ // Create convolution primitive without Bias.
+ auto conv_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
+ filter.GetOpMemDesc(), output.GetOpMemDesc(), strides, padding_l,
+ padding_r, TFPaddingToMklDnnPadding(padding_));
+
+ auto conv_prim_desc =
+ convolution_forward::primitive_desc(conv_desc, cpu_engine);
+ PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output);
+ }
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + std::string(e.message) + ", in file " +
+ std::string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ // Prepare and execute net - checks for input and output reorders.
+ void PrepareAndExecuteNet(
+ const convolution_forward::primitive_desc& conv_prim_desc,
+ MklDnnData<T>* src, MklDnnData<T>* filter, MklDnnData<T>* bias,
+ MklDnnData<T>* output) {
+ // Create reorders between user layout and MKL layout if it is needed and
+ // add it to the net before convolution.
+ std::vector<primitive> net;
+ src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net);
+ filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), &net);
+
+ // Memory for output of convolution. Since we may need reorder on the
+ // output side, we will prepare reorder primitive in case output
+ // reorder to user memory is required.
+ bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
+ conv_prim_desc.dst_primitive_desc());
+
+ // Create convolution primitive and add it to net.
+ if (bias) {
+ CHECK_EQ(biasEnabled, true);
+ net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
+ filter->GetOpMem(), bias->GetOpMem(),
+ output->GetOpMem()));
+ } else {
+ CHECK_EQ(biasEnabled, false);
+ net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
+ filter->GetOpMem(),
+ output->GetOpMem()));
+ }
+
+ // Insert reorder primitive in the net for output reorder if reorder is
+ // required.
+ if (output_reorder_required) {
+ output->InsertReorderToUserMem(&net);
+ }
+
+ // Handle output reorder
+ stream(stream::kind::eager).submit(net).wait();
+ }
+};
+
+#endif
+
#define REGISTER_MKL_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
new file mode 100644
index 0000000000..e29af19ca9
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -0,0 +1,308 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#include "tensorflow/core/util/mkl_util.h"
+
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+#endif
+
+namespace tensorflow {
+
+#ifdef INTEL_MKL_DNN
+
+class MklDnnConvUtil {
+ protected:
+ OpKernelContext *context_; // We don't own this.
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ public:
+ MklDnnConvUtil(OpKernelContext *context, const std::vector<int32> &strides,
+ Padding pad, TensorFormat fm)
+ : context_(context), strides_(strides), padding_(pad), data_format_(fm) {}
+
+ virtual ~MklDnnConvUtil() { context_ = nullptr; }
+
+ // Calculate Convolution strides
+ virtual inline void GetStridesInMklOrder(memory::dims *strides) {
+ // For now we take the stride from the second and third dimensions only
+ // (we do not support striding on the batch or depth dimension).
+ CHECK_NOTNULL(strides);
+ int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ *strides = {stride_rows, stride_cols};
+ }
+
+ // Calculate Convolution input size in MKL-DNN order. MKL-DNN
+ // requires input in NCHW format. Function does not return anything.
+ // But errors arising from sanity checks are returned in context's
+ // status.
+ virtual inline void GetInputSizeInMklOrder(const TensorShape &input_shape,
+ memory::dims *input_dims) {
+#define CHECK_BOUNDS(val, err_msg) \
+ do { \
+ OP_REQUIRES(context_, \
+ FastBoundsCheck(val, std::numeric_limits<int>::max()), \
+ errors::InvalidArgument(err_msg)); \
+ } while (0)
+
+ CHECK_NOTNULL(input_dims);
+
+ // Input channel
+ int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
+ int input_depth = static_cast<int>(input_depth_raw);
+
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // Input batch
+ int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
+ CHECK_BOUNDS(input_batch_raw, "Input batch too large");
+ int input_batch = static_cast<int>(input_batch_raw);
+
+#undef CHECK_BOUNDS
+
+ // MKL-DNN always requires input in NCHW format.
+ *input_dims = {input_batch, input_depth, input_rows, input_cols};
+ }
+
+ // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
+ // requires filter in OIHW format. Function does not return anything.
+ // But errors arising from sanity checks are returned in context's
+ // status.
+ //
+ // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
+ // requires filter in OIHW format. Function does not return anything.
+ // But errors arising from sanity checks are returned in context's
+ // status. This function differs from GetConvFilterSizeInMklOrder in
+ // parameter for input - it accepts src_shape since Convolution Backward
+ // Input gets shape of input tensor rather than actual tensor (Convolution
+ // forward gets actual tensor as input).
+ //
+ // TODO(nhasabni): Add similar function for input and filter in MklShape.
+ virtual inline void GetFilterSizeInMklOrder(const TensorShape &input_shape,
+ const TensorShape &filter_shape,
+ memory::dims *filter_dims) {
+ CHECK_NOTNULL(filter_dims);
+
+ OP_REQUIRES(context_, filter_shape.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter_shape.DebugString()));
+
+ for (int i = 0; i < 3; i++) {
+ OP_REQUIRES(context_,
+ FastBoundsCheck(filter_shape.dim_size(i),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ int input_depth = GetTensorDim(input_shape, data_format_, 'C');
+
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ", input_depth,
+ " vs ", filter_shape.dim_size(2)));
+
+ // TF filter is always in (rows, cols, in_depth, out_depth) order.
+ int filter_rows = static_cast<int>(filter_shape.dim_size(0));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(1));
+ int in_depth = static_cast<int>(filter_shape.dim_size(2));
+ int out_depth = static_cast<int>(filter_shape.dim_size(3));
+
+ // MKL-DNN always needs filter in OIHW format.
+ // OIHW = (out_depth, in_depth, rows, cols)
+ *filter_dims = {out_depth, in_depth, filter_rows, filter_cols};
+ }
+
+ // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
+ // requires filter in OIHW format. Function does not return anything.
+ // But errors arising from sanity checks are returned in context's
+ // status.
+ virtual inline void GetFilterSizeInMklOrder(size_t src_index,
+ size_t filter_index,
+ memory::dims *filter_dims) {
+ CHECK_NOTNULL(filter_dims);
+ const Tensor &input = MklGetInput(context_, src_index);
+ const Tensor &filter = MklGetInput(context_, filter_index);
+ GetFilterSizeInMklOrder(input.shape(), filter.shape(), filter_dims);
+ }
+
+ // Calculate Bias size for 2D Convolution. Function does not return
+ // anything, but sets error in context status.
+ virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
+ memory::dims *bias_dims) {
+ const Tensor &bias = MklGetInput(context_, bias_index);
+ OP_REQUIRES(context_, bias.dims() == 1,
+ errors::InvalidArgument("bias must be 1-dimensional: ",
+ bias.shape().DebugString()));
+
+ *bias_dims = {static_cast<int>(bias.dim_size(0))};
+ }
+
+ // Function to calculate output and padding size for 2D convolution.
+ //
+ // Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
+ // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
+ // NHWC or NCHW format depending on data format. Function also calculates
+ // left, right, top and bottom pads. Function does not return any status -
+ // status is returned via context status.
+ //
+ // TODO(nhasabni): Add similar function for input and filter in MklShape.
+ virtual inline void GetOutputAndPadSizeInMklOrder(
+ const TensorShape &input_shape, const TensorShape &filter_shape,
+ const memory::dims &strides, memory::dims *output_dims_tf_order,
+ memory::dims *output_dims_mkl_order, memory::dims *pad_l,
+ memory::dims *pad_r) {
+ CHECK_NOTNULL(output_dims_tf_order);
+ CHECK_NOTNULL(output_dims_mkl_order);
+ CHECK_NOTNULL(pad_l);
+ CHECK_NOTNULL(pad_r);
+
+ int input_rows = GetTensorDim(input_shape, data_format_, 'H');
+ int input_cols = GetTensorDim(input_shape, data_format_, 'W');
+
+ // The first dimension for filter is rows/height.
+ int filter_rows = filter_shape.dim_size(0);
+ // The second dimension for filter is cols/width.
+ int filter_cols = filter_shape.dim_size(1);
+
+ // Stride is vector of 2 elements: {s_r, s_c}
+ int stride_rows = strides[0];
+ int stride_cols = strides[1];
+
+ // Output batch is same as input batch.
+ int out_batch = GetTensorDim(input_shape, data_format_, 'N');
+ // Output depth is same as last dimension for filter.
+ int out_depth = filter_shape.dim_size(3);
+
+ int64 out_rows = 0, out_cols = 0;
+ int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
+
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_rows, filter_rows, stride_rows, padding_,
+ &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_cols, filter_cols, stride_cols, padding_,
+ &out_cols, &pad_left, &pad_right));
+
+ // Tensorflow output is in data_format order. (NHWC or NCHW)
+ TensorShape out_shape =
+ ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
+ *output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
+
+ // MKL-DNN always needs output in NCHW format.
+ *output_dims_mkl_order = {out_batch, out_depth, static_cast<int>(out_rows),
+ static_cast<int>(out_cols)};
+
+ // Now handle padding. MKL-DNN uses asymetric padding.
+ *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ }
+
+ // Calculate output and pad size of forward Convolution operator.
+ // See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
+ //
+ // Function does not return anything, but sets error in context status.
+ inline void GetOutputAndPadSizeInMklOrder(
+ size_t src_index, size_t filter_index, const memory::dims &strides,
+ memory::dims *output_dims_tf_order, memory::dims *output_dims_mkl_order,
+ memory::dims *pad_l, memory::dims *pad_r) {
+ CHECK_NOTNULL(output_dims_tf_order);
+ CHECK_NOTNULL(output_dims_mkl_order);
+ CHECK_NOTNULL(pad_l);
+ CHECK_NOTNULL(pad_r);
+
+ const Tensor &input = MklGetInput(context_, src_index);
+ const Tensor &filter = MklGetInput(context_, filter_index);
+
+ OP_REQUIRES(context_, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+
+ GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(), strides,
+ output_dims_tf_order, output_dims_mkl_order,
+ pad_l, pad_r);
+ }
+
+ // Wrapper function to calculate input, filter, and output sizes of
+ // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
+ // Function also calculates output shape in Tensorflow order. Additionally, it
+ // also calculates strides and paddings for 2D Convolution.
+ //
+ // Function does not return anything, but sets error in context status.
+ inline void GetConvFwdSizesInMklOrder(
+ const TensorShape &input_shape, const TensorShape &filter_shape,
+ memory::dims *input_dims, memory::dims *filter_dims,
+ memory::dims *strides, memory::dims *output_dims_tf_order,
+ memory::dims *output_dims_mkl_order, memory::dims *pad_l,
+ memory::dims *pad_r) {
+ CHECK_NOTNULL(input_dims);
+ CHECK_NOTNULL(filter_dims);
+ CHECK_NOTNULL(strides);
+ CHECK_NOTNULL(output_dims_tf_order);
+ CHECK_NOTNULL(output_dims_mkl_order);
+ CHECK_NOTNULL(pad_l);
+ CHECK_NOTNULL(pad_r);
+
+ GetInputSizeInMklOrder(input_shape, input_dims);
+ if (!context_->status().ok()) return;
+ GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims);
+ if (!context_->status().ok()) return;
+ GetStridesInMklOrder(strides);
+ GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides,
+ output_dims_tf_order, output_dims_mkl_order,
+ pad_l, pad_r);
+ if (!context_->status().ok()) return;
+ }
+};
+
+#endif // INTEL_MKL_DNN
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
index 7fc633c254..c065724e0d 100644
--- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
@@ -48,7 +48,7 @@ class MklBinaryOp : public BinaryOp<Device, Functor> {
auto out = context->mutable_output(0);
VLOG(1) << "Shapes (output): " << out->shape().DebugString();
- // Pass input shape through to ouput shape
+ // Pass input shape through to output shape
ForwardMklMetaDataInToOut(context, 0, 0);
out = context->mutable_output(0);
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index 3c85737702..302a6967e3 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -340,7 +340,7 @@ char* FloatToBuffer(float value, char* buffer) {
float parsed_value;
if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) {
snprintf_result =
- snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 2, value);
+ snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 3, value);
// Should never overflow; see above.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index df189af1b8..c0e84c8bb0 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -383,7 +383,8 @@ input_dataset: A handle to an input dataset. Must have a single component.
batch_size: A scalar representing the number of elements to accumulate in a
batch.
row_shape: A vector representing the dense shape of each row in the produced
- SparseTensor.
+ SparseTensor. The shape may be partially specified, using `-1` to indicate
+ that a particular dimension should use the maximum size of all batch elements.
)doc");
REGISTER_OP("RangeDataset")
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 3dc16ac457..b34dc1a008 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -29,22 +29,6 @@ using shape_inference::ShapeHandle;
namespace {
-// A shape function that uses the tensor value at <input_idx> as a shape for
-// output 0. If the tensor value is not available, it uses a shape with <ndims>
-// unknown dims.
-Status InputTensorShapeOrUnknown(InferenceContext* c, int input_idx,
- int ndims) {
- ShapeHandle out;
- const Tensor* input = c->input_tensor(input_idx);
- if (input == nullptr) {
- out = c->UnknownShapeOfRank(ndims);
- } else {
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(input_idx, &out));
- }
- c->set_output(0, out);
- return Status::OK();
-}
-
Status FractionalPoolShapeFn(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
@@ -119,11 +103,11 @@ REGISTER_OP("AvgPoolGrad")
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {half, float, double}")
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes gradients of the average pooling function.
@@ -583,11 +567,11 @@ REGISTER_OP("Conv2DBackpropInput")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes the gradients of convolution with respect to the input.
@@ -625,11 +609,11 @@ REGISTER_OP("Conv2DBackpropFilter")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes the gradients of convolution with respect to the filter.
@@ -882,11 +866,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes the gradients of depthwise convolution with respect to the input.
@@ -924,11 +908,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes the gradients of depthwise convolution with respect to the filter.
@@ -2870,7 +2854,11 @@ REGISTER_OP("_MklConv2DBackpropFilter")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
@@ -2911,7 +2899,11 @@ REGISTER_OP("_MklConv2DBackpropInput")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
@@ -3034,7 +3026,11 @@ REGISTER_OP("_MklAvgPoolGrad")
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {float, half, double}")
.SetShapeFn([](InferenceContext* c) {
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc
index 51e4f8bffe..4628b725f8 100644
--- a/tensorflow/core/ops/nn_ops_test.cc
+++ b/tensorflow/core/ops/nn_ops_test.cc
@@ -81,55 +81,6 @@ TEST(NNOpsTest, TopKV2_ShapeFn) {
op, "[1,2,3,4];[]");
}
-TEST(NNOpsTest, InputTensorShapeOrUnknown2D_ShapeFn) {
- typedef std::pair<const char*, int> NameAndInputIndex;
- for (const auto& p :
- {NameAndInputIndex("AvgPoolGrad", 0),
- NameAndInputIndex("Conv2DBackpropInput", 0),
- NameAndInputIndex("Conv2DBackpropFilter", 1),
- NameAndInputIndex("DepthwiseConv2dNativeBackpropInput", 0),
- NameAndInputIndex("DepthwiseConv2dNativeBackpropFilter", 1)}) {
- ShapeInferenceTestOp op(p.first);
- op.input_tensors.resize(2);
-
- // Conv and Depthwise conv have three inputs.
- string extra_shapes = (op.name == "AvgPoolGrad" ? "" : ";?");
-
- // When the input tensor is not known, the output is 4 unknown dims.
- INFER_OK(op, "?;?" + extra_shapes, "[?,?,?,?]");
- INFER_OK(op, "[4];?" + extra_shapes, "[?,?,?,?]");
-
- // When input tensor is known, its values determine output shape.
- std::vector<int32> shape{1, 2, 3, 4};
- Tensor shape_t = test::AsTensor<int32>(shape);
- op.input_tensors[p.second] = &shape_t;
- INFER_OK(op, "[4];?" + extra_shapes, "[1,2,3,4]");
- }
-}
-
-TEST(NNOpsTest, InputTensorShapeOrUnknown3D_ShapeFn) {
- typedef std::pair<const char*, int> NameAndInputIndex;
- for (const auto& p : {NameAndInputIndex("AvgPool3DGrad", 0),
- NameAndInputIndex("Conv3DBackpropInputV2", 0),
- NameAndInputIndex("Conv3DBackpropFilterV2", 1)}) {
- ShapeInferenceTestOp op(p.first);
- op.input_tensors.resize(2);
-
- // Conv3D has an extra shape.
- string extra_shapes = (op.name == "AvgPool3DGrad" ? "" : ";?");
-
- // When the input tensor is not known, the output is 4 unknown dims.
- INFER_OK(op, "?;?" + extra_shapes, "[?,?,?,?,?]");
- INFER_OK(op, "[5];?" + extra_shapes, "[?,?,?,?,?]");
-
- // When input tensor is known, its values determine output shape.
- std::vector<int32> shape{1, 2, 3, 4, 5};
- Tensor shape_t = test::AsTensor<int32>(shape);
- op.input_tensors[p.second] = &shape_t;
- INFER_OK(op, "[5];?" + extra_shapes, "[1,2,3,4,5]");
- }
-}
-
TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) {
ShapeInferenceTestOp op("BatchNormWithGlobalNormalization");
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index f23ff083af..b44ea2e080 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -332,6 +332,7 @@ REGISTER_OP("DecodeCSV")
.Attr("OUT_TYPE: list({float,int32,int64,string})")
.Attr("field_delim: string = ','")
.Attr("use_quote_delim: bool = true")
+ .Attr("na_value: string = ''")
.SetShapeFn([](InferenceContext* c) {
// Validate the record_defaults inputs.
for (int i = 1; i < c->num_inputs(); ++i) {
@@ -362,6 +363,7 @@ field_delim: char delimiter to separate fields in a record.
use_quote_delim: If false, treats double quotation marks as regular
characters inside of the string fields (ignoring RFC 4180, Section 2,
Bullet 5).
+na_value: Additional string to recognize as NA/NaN.
output: Each tensor will have the same shape as records.
)doc");
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index f4bec9524a..1bfa4f83a3 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -26,13 +26,19 @@ limitations under the License.
#include "mkl_trans.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+#endif
// The file contains a number of utility classes and functions used by MKL
// enabled kernels
@@ -219,19 +225,18 @@ class MklShape {
// Location from start of buffer where isMklTensor_ is serialized
#define DIMS_OFFSET \
(IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_
-#define SIZES_OFFSET(dims) \
- (DIMS_OFFSET + \
- sizeof(size_t)) // Location of sizes. Note dim is not used here, left here
- // to make macros consistent.
+// Location of sizes. Note dim is not used here, left here
+// to make macros consistent.
+#define SIZES_OFFSET(dims) (DIMS_OFFSET + sizeof(size_t))
#define STRIDES_OFFSET(dims) \
(SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides
#define MKL_LAYOUT_OFFSET(dims) \
(STRIDES_OFFSET(dims) + dims * sizeof(size_t)) // Location of mklLayout_
#define TF_LAYOUT_OFFSET(dims) \
(MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_
+// Location of tf_to_mkl_dim_map_
#define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
- (TF_LAYOUT_OFFSET(dims) + \
- SIZE_OF_MKL_DNN_BUF) // Location of tf_to_mkl_dim_map_
+ (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
// TODO(agramesh1) make sure to create a const to share with rewrite pass
// for min size of MKL metadata tensor.
@@ -342,58 +347,6 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
return output_tensor;
}
-// Since our ops are going to produce and also consume N addition tensors
-// (Mkl) for N Tensorflow tensors, we can have following different
-// orderings among these 2N tensors.
-//
-// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
-// consume A_m, B_m, and C_m additionally.
-//
-// INTERLEAVED: in this case 2N tensors are interleaved. So for above
-// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
-//
-// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
-// by N Mkl tensors. So for above example, the ordering looks
-// like: A, B, C, A_m, B_m, C_m
-//
-// Following APIs map index of original Tensorflow tensors to their appropriate
-// position based on selected ordering. For contiguous ordering, we need to know
-// the total number of tensors (parameter total).
-//
-typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
-// NOTE: Currently, we use contiguous ordering. If you change this, then you
-// would need to change Mkl op definitions in nn_ops.cc.
-static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
-
-// Get index of MetaData tensor from index 'n' of Data tensor.
-inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
- if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
- // For interleaved ordering, Mkl tensor follows immediately after
- // Tensorflow tensor.
- return n + 1;
- } else {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
- return n + total_tensors / 2;
- }
-}
-
-int inline GetTensorDataIndex(int n, int total_tensors) {
- if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
- return 2 * n; // index corresponding to nth input/output tensor
- } else {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- return n;
- }
-}
-
-int inline GetTensorMetaDataIndex(int n, int total_tensors) {
- // Get index for TensorData first and then use mapping function
- // to get TensorMetaData index from TensorData index.
- int tidx = GetTensorDataIndex(n, total_tensors);
- return DataIndexToMetaDataIndex(tidx, total_tensors);
-}
-
// Get the MKL shape from the second string tensor
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
mklshape->DeSerializeMklShape(
@@ -480,6 +433,13 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
*buf_out = static_cast<void*>(tensor_out->flat<float>().data());
}
+template <typename T>
+inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
+ TensorShape tf_shape) {
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
+ tf_shape, tensor_out));
+}
+
inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
const size_t* sizes) {
// MKL requires strides in NCHW
@@ -743,56 +703,299 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
}
}
-namespace mkl_op_registry {
-static const char* kMklOpLabel = "MklOp";
-static const char* kMklOpLabelPattern = "label='MklOp'";
+// -------------------------------------------------------------------
+
+#ifdef INTEL_MKL_DNN
+
+using mkldnn::engine;
+using mkldnn::memory;
+using mkldnn::padding_kind;
+using mkldnn::primitive;
+using mkldnn::reorder;
+
+/// Return MKL-DNN data type (memory::data_type) for input type T
+///
+/// @input None
+/// @return memory::data_type corresponding to type T
+template <typename T>
+static memory::data_type MklDnnType();
+
+/// Instantiation for float type. Add similar instantiations for other
+/// type if needed.
+template <>
+memory::data_type MklDnnType<float>() {
+ return memory::data_type::f32;
+}
+
+/// Map TensorFlow's data format into MKL-DNN data format
+///
+/// @input: TensorFlow data format
+/// @return: memory::format corresponding to TensorFlow data format;
+/// Fails with an error if invalid data format.
+inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
+ if (format == FORMAT_NHWC)
+ return memory::format::nhwc;
+ else if (format == FORMAT_NCHW)
+ return memory::format::nchw;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
+ // Return to get rid of compiler warning
+ return memory::format::format_undef;
+}
-// Get the name of Mkl op from original TensorFlow op
-// We prefix 'Mkl' to the original op to get Mkl op.
-inline string GetMklOpName(const string& name) {
- // Prefix that we add to Tensorflow op name to construct Mkl op name.
- const char* const kMklOpPrefix = "_Mkl";
- return string(kMklOpPrefix) + name;
+/// Map TensorShape object into memory::dims required by MKL-DNN
+///
+/// This function will simply map input TensorShape into MKL-DNN dims
+/// naively. So it will preserve the order of dimensions. E.g., if
+/// input tensor is in NHWC format, then dims will be in NHWC format
+/// also.
+///
+/// @input TensorShape object in shape
+/// @return memory::dims corresponding to TensorShape
+inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
+ memory::dims dims(shape.dims());
+ for (unsigned int d = 0; d < shape.dims(); ++d) {
+ dims[d] = shape.dim_size(d);
+ }
+ return dims;
}
-// Check whether opname with type T is registered as MKL-compliant.
-//
-// @input: name of the op
-// @input: T datatype to be used for checking op
-// @return: true if opname is registered as Mkl op; false otherwise
-static inline bool IsMklOp(const std::string& op_name, DataType T) {
- string kernel = KernelsRegisteredForOp(op_name);
- bool result =
- kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
- if (result) {
- VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
- }
- return result;
+/// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
+///
+/// This function is a specific one than above function. It will map input
+/// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
+/// order of dimensions. E.g., if input tensor is in NHWC format, then dims
+/// will be in NCHW format, and not in NHWC format.
+///
+/// @input TensorShape object in shape
+/// @return memory::dims in MKL-DNN required NCHW format
+inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
+ TensorFormat format) {
+ // Check validity of format.
+ CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
+ memory::format::format_undef);
+
+ int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
+ int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
+ int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
+ int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
+
+ // MKL-DNN requires dimensions in NCHW format.
+ return memory::dims({n, c, h, w});
}
-// Check whether opname with type T is registered as MKL-compliant and
-// is element-wise.
-//
-// @input: name of the op
-// @input: T datatype to be used for checking op
-// @return: true if opname is registered as element-wise Mkl op; false otherwise
-static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
- if (!IsMklOp(op_name, T)) {
+inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
+ // MKL-DNN only supports zero padding.
+ return padding_kind::zero;
+}
+
+/*
+ * Class to represent all the resources corresponding to a tensor in TensorFlow
+ * that are required to execute an operation (such as Convolution).
+ */
+template <typename T>
+class MklDnnData {
+ private:
+ /// MKL-DNN memory primitive for input user memory
+ memory* user_memory_;
+
+ /// MKL-DNN memory primitive in case input or output reorder is needed.
+ memory* reorder_memory_;
+
+ /// Operations memory descriptor
+ memory::desc* op_md_;
+
+ /// CPU engine on which operation will be executed
+ const engine* cpu_engine_;
+
+ public:
+ explicit MklDnnData(const engine* e)
+ : user_memory_(nullptr),
+ reorder_memory_(nullptr),
+ op_md_(nullptr),
+ cpu_engine_(e) {}
+
+ ~MklDnnData() {
+ cpu_engine_ = nullptr; // We don't own this.
+ delete (user_memory_);
+ delete (reorder_memory_);
+ delete (op_md_);
+ }
+
+ void* GetTensorBuffer(const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ return const_cast<void*>(
+ static_cast<const void*>(tensor->flat<T>().data()));
+ }
+
+ /// Set user memory primitive using specified dimensions, memory format and
+ /// data_buffer. Function automatically uses element data type by using
+ /// input type T used for creating call object.
+ ///
+ /// In a nutshell, function allows user to describe the input tensor to
+ /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
+ /// memory format HWIO, and the buffer that contains actual values is
+ /// pointed by data_buffer.
+ void SetUsrMem(memory::dims dim, memory::format fm, void* data_buffer) {
+ CHECK_NOTNULL(data_buffer);
+ CHECK_NOTNULL(cpu_engine_);
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ user_memory_ =
+ new memory(memory::primitive_desc(
+ memory::desc(dim, MklDnnType<T>(), fm), *cpu_engine_),
+ data_buffer);
+ }
+
+ void SetUsrMem(memory::dims dim, memory::format fm, const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ SetUsrMem(dim, fm, GetTensorBuffer(tensor));
+ }
+
+ /// A version of function to set user memory primitive that accepts memory
+ /// descriptor directly, instead of accepting dimensions and format. This
+ /// function is more generic that the one above, but the function above is
+ /// sufficient in most cases.
+ void SetUsrMem(memory::desc md, void* data_buffer) {
+ CHECK_NOTNULL(data_buffer);
+ CHECK_NOTNULL(cpu_engine_);
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ user_memory_ =
+ new memory(memory::primitive_desc(md, *cpu_engine_), data_buffer);
+ }
+
+ /// A version of SetUsrMem with memory descriptor and tensor
+ void SetUsrMem(memory::desc md, const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ SetUsrMem(md, GetTensorBuffer(tensor));
+ }
+
+ /// A version of function to set user memory primitive that accepts primitive
+ /// descriptor directly, instead of accepting dimensions and format. This
+ /// function is more generic that the one above, but the function above is
+ /// sufficient in most cases.
+ void SetUsrMem(memory::primitive_desc pd, void* data_buffer) {
+ CHECK_NOTNULL(data_buffer);
+ CHECK_NOTNULL(cpu_engine_);
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ user_memory_ = new memory(pd, data_buffer);
+ }
+
+ /// A version of SetUsrMem with primitive descriptor and tensor
+ void SetUsrMem(memory::primitive_desc pd, const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ SetUsrMem(pd, GetTensorBuffer(tensor));
+ }
+
+ /// Get function for user memory primitive.
+ const memory* GetUsrMem() const { return user_memory_; }
+
+ /// Get function for primitive descriptor of user memory primitive.
+ const memory::primitive_desc GetUsrMemPrimDesc() const {
+ CHECK_NOTNULL(user_memory_);
+ return user_memory_->get_primitive_desc();
+ }
+
+ /// Get function for descriptor of user memory.
+ memory::desc GetUsrMemDesc() {
+ // This is ugly. Why MKL-DNN does not provide desc() method of const type??
+ const memory::primitive_desc pd = GetUsrMemPrimDesc();
+ return const_cast<memory::primitive_desc*>(&pd)->desc();
+ }
+
+ /// Get function for data buffer of user memory primitive.
+ void* GetUsrMemDataHandle() const {
+ CHECK_NOTNULL(user_memory_);
+ return user_memory_->get_data_handle();
+ }
+
+ /// Get the memory primitive for input and output of an op. If inputs
+ /// to an op require reorders, then this function returns memory primitive
+ /// for reorder. Otherwise, it will return memory primitive for user memory.
+ ///
+ /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
+ /// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
+ /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
+ /// primitive for F), then we need I_r and F_r to perform Conv2D.
+ const memory& GetOpMem() const {
+ return reorder_memory_ ? *reorder_memory_ : *user_memory_;
+ }
+
+ /// Set memory descriptor of an operation in terms of dimensions and memory
+ /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
+ /// but memory::format would be mkldnn::any because we want MKL-DNN to choose
+ /// best layout/format for given input dimensions.
+ void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
+ }
+
+ /// Get function for memory descriptor for an operation
+ const memory::desc& GetOpMemDesc() const { return *op_md_; }
+
+ /// Function to handle input reordering
+ ///
+ /// Check if we need to reorder this input of an operation.
+ /// Return true and allocate reorder memory primitive if reorder is needed.
+ /// Otherwise, return false and do not allocate reorder memory primitive.
+ ///
+ /// To check if reorder is needed, this function compares memory primitive
+ /// descriptor of an operation (op_pd) for the given input with the
+ /// user-specified memory primitive descriptor.
+ ///
+ /// @input: op_pd - memory primitive descriptor of the given input of an
+ /// operation
+ /// @input: net - net to which to add reorder primitive in case it is needed.
+ /// @return: true in case reorder of input is needed; false, otherwise.
+ bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ std::vector<primitive>* net) {
+ CHECK_NOTNULL(net);
+ CHECK_NOTNULL(user_memory_);
+ if (op_pd != user_memory_->get_primitive_desc()) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ reorder_memory_ = new memory(op_pd);
+ net->push_back(reorder(*user_memory_, *reorder_memory_));
+ return true;
+ }
return false;
}
- bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
- 0 == op_name.compare(GetMklOpName("Sub")) ||
- 0 == op_name.compare(GetMklOpName("Mul")) ||
- 0 == op_name.compare(GetMklOpName("Maximum")) ||
- 0 == op_name.compare(GetMklOpName("SquaredDifference")));
+ /// Function to handle output reorder
+ ///
+ /// This function performs very similar functionality as input reordering
+ /// function above. The only difference is that this function does not add
+ /// reorder primitive to the net. The reason for this is: the reorder
+ /// primitive for output needs to be added to the list only after operation
+ /// has executed. But we need to prepare a temporary buffer in case output
+ /// reorder is needed. And this temporary buffer will hold the output of
+ /// an operation before it is fed to reorder primitive.
+ ///
+ /// @input memory primitive descriptor for the given output of an operation
+ /// @return: true in case reorder of output is needed; false, otherwise.
+ bool PrepareReorderToUserMemIfReq(const memory::primitive_desc& op_pd) {
+ CHECK_NOTNULL(user_memory_);
+ if (op_pd != user_memory_->get_primitive_desc()) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ reorder_memory_ = new memory(op_pd);
+ return true;
+ }
+ return false;
+ }
- VLOG(1) << "mkl_op_registry::" << op_name
- << " is elementwise MKL op: " << result;
- return result;
-}
+ /// Function to actually insert reorder primitive in the net
+ ///
+ /// This function completes remaining part of output reordering. It inserts
+ /// a reordering primitive from the temporary buffer that holds the output
+ /// to the user-specified output buffer.
+ ///
+ /// @input: net - net to which to add reorder primitive
+ void InsertReorderToUserMem(std::vector<primitive>* net) {
+ CHECK_NOTNULL(net);
+ CHECK_NOTNULL(user_memory_);
+ CHECK_NOTNULL(reorder_memory_);
+ net->push_back(reorder(*reorder_memory_, *user_memory_));
+ }
+};
-} // namespace mkl_op_registry
+#endif // INTEL_MKL_DNN
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index d8925d3909..e6a4088656 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -429,3 +429,41 @@ Stack Overflow and specify the `tensorflow` tag.
<pre>ImportError: cannot import name pywrap_tensorflow</pre></td>
</tr>
</table>
+
+## Tested source configurations
+**Linux**
+<table>
+<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.3.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>6</td><td>8</td></tr>
+<tr><td>tensorflow-1.2.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.2.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>5.1</td><td>8</td></tr>
+<tr><td>tensorflow-1.1.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.1.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
+<tr><td>tensorflow-1.0.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.0.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
+</table>
+
+**Mac**
+<table>
+<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>ttensorflow-1.2.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>ttensorflow-1.1.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>ttensorflow_gpu-1.1.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
+<tr><td>ttensorflow-1.0.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>ttensorflow_gpu-1.0.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
+</table>
+
+**Windows**
+<table>
+<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.3.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>6</td><td>8</td></tr>
+<tr><td>tensorflow-1.2.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.2.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
+<tr><td>tensorflow-1.1.0</td><td>CPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.1.0</td><td>GPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
+<tr><td>tensorflow-1.0.0</td><td>CPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.0.0</td><td>GPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
+</table>
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
index eb4dc69d63..184df1bdb4 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
@@ -37,6 +37,7 @@ import android.content.pm.PackageManager;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaRecorder;
+import android.os.Build;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
@@ -151,12 +152,15 @@ public class SpeechActivity extends Activity {
// Start the recording and recognition threads.
requestMicrophonePermission();
+ startRecording();
startRecognition();
}
private void requestMicrophonePermission() {
- requestPermissions(
- new String[] {android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ requestPermissions(
+ new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
+ }
}
@Override
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index 6d98c7b85d..1fa2b14869 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -89,7 +89,7 @@ def build_dataset(words, n_words):
# Filling 4 global variables:
# data - list of codes (integers from 0 to vocabulary_size-1).
# This is the original text but words are replaced by their codes
-# count - map of words(strings) to count of occurences
+# count - map of words(strings) to count of occurrences
# dictionary - map of words(strings) to their codes(integers)
# reverse_dictionary - maps codes(integers) to words(strings)
data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
diff --git a/tensorflow/go/example_inception_inference_test.go b/tensorflow/go/example_inception_inference_test.go
index 2162fbe484..f84a588899 100644
--- a/tensorflow/go/example_inception_inference_test.go
+++ b/tensorflow/go/example_inception_inference_test.go
@@ -28,8 +28,8 @@ import (
"os"
"path/filepath"
- "github.com/tensorflow/tensorflow/tensorflow/go/op"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
+ "github.com/tensorflow/tensorflow/tensorflow/go/op"
)
func Example() {
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index a534a0d659..e8fa21a62b 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -92,7 +92,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
raw := tensorData(t.c)
buf := bytes.NewBuffer(raw[:0:len(raw)])
if dataType != String {
- if err := encodeTensor(buf, val); err != nil {
+ if err := encodeTensor(buf, val, shape); err != nil {
return nil, err
}
if uintptr(buf.Len()) != nbytes {
@@ -100,7 +100,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
}
} else {
e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()}
- if err := e.encode(reflect.ValueOf(value)); err != nil {
+ if err := e.encode(reflect.ValueOf(value), shape); err != nil {
return nil, err
}
if int64(buf.Len()) != nflattened*8 {
@@ -236,17 +236,11 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro
typ := val.Type()
for typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice {
shape = append(shape, int64(val.Len()))
- // If slice elements are slices, verify that all of them have the same size.
- // Go's type system makes that guarantee for arrays.
if val.Len() > 0 {
- if val.Type().Elem().Kind() == reflect.Slice {
- expected := val.Index(0).Len()
- for i := 1; i < val.Len(); i++ {
- if val.Index(i).Len() != expected {
- return shape, dt, fmt.Errorf("mismatched slice lengths: %d and %d", val.Index(i).Len(), expected)
- }
- }
- }
+ // In order to check tensor structure properly in general case we need to iterate over all slices of the tensor to check sizes match
+ // Since we already going to iterate over all elements in encodeTensor() let's
+ // 1) do the actual check in encodeTensor() to save some cpu cycles here
+ // 2) assume the shape is represented by lengths of elements with zero index in each dimension
val = val.Index(0)
}
typ = typ.Elem()
@@ -302,7 +296,7 @@ func byteSizeOfEncodedStrings(val interface{}) uintptr {
// encodeTensor writes v to the specified buffer using the format specified in
// c_api.h. Use stringEncoder for String tensors.
-func encodeTensor(w *bytes.Buffer, v reflect.Value) error {
+func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
switch v.Kind() {
case reflect.Bool:
b := byte(0)
@@ -318,19 +312,18 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value) error {
}
case reflect.Array, reflect.Slice:
- // If slice elements are slices, verify that all of them have the same size.
+ // If current dimension is a slice, verify that it has the expected size
// Go's type system makes that guarantee for arrays.
- if v.Len() > 0 && v.Type().Elem().Kind() == reflect.Slice {
- expected := v.Index(0).Len()
- for i := 1; i < v.Len(); i++ {
- if v.Index(i).Len() != expected {
- return fmt.Errorf("mismatched slice lengths: %d and %d", v.Index(i).Len(), expected)
- }
+ if v.Kind() == reflect.Slice {
+ expected := int(shape[0])
+ if v.Len() != expected {
+ return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
}
}
+ subShape := shape[1:]
for i := 0; i < v.Len(); i++ {
- err := encodeTensor(w, v.Index(i))
+ err := encodeTensor(w, v.Index(i), subShape)
if err != nil {
return err
}
@@ -379,7 +372,7 @@ type stringEncoder struct {
status *status
}
-func (e *stringEncoder) encode(v reflect.Value) error {
+func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
if v.Kind() == reflect.String {
if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil {
return err
@@ -395,8 +388,17 @@ func (e *stringEncoder) encode(v reflect.Value) error {
C.free(unsafe.Pointer(src))
return e.status.Err()
}
+
+ if v.Kind() == reflect.Slice {
+ expected := int(shape[0])
+ if v.Len() != expected {
+ return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
+ }
+ }
+
+ subShape := shape[1:]
for i := 0; i < v.Len(); i++ {
- if err := e.encode(v.Index(i)); err != nil {
+ if err := e.encode(v.Index(i), subShape); err != nil {
return err
}
}
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go
index 2fc7553f87..35bd2fd9a5 100644
--- a/tensorflow/go/tensor_test.go
+++ b/tensorflow/go/tensor_test.go
@@ -42,6 +42,10 @@ func TestNewTensor(t *testing.T) {
{[]int64{2}, []bool{true, false}},
{[]int64{1}, []float64{1}},
{[]int64{1}, [1]float64{1}},
+ {[]int64{1, 1}, [1][1]float64{{1}}},
+ {[]int64{1, 1, 1}, [1][1][]float64{{{1}}}},
+ {[]int64{1, 1, 2}, [1][][2]float64{{{1, 2}}}},
+ {[]int64{1, 1, 1, 1}, [1][][1][]float64{{{{1}}}}},
{[]int64{2}, []string{"string", "slice"}},
{[]int64{2}, [2]string{"string", "array"}},
{[]int64{3, 2}, [][]float64{{1, 2}, {3, 4}, {5, 6}}},
@@ -74,6 +78,12 @@ func TestNewTensor(t *testing.T) {
[]uint64{5},
// Mismatched dimensions
[][]float32{{1, 2, 3}, {4}},
+ // Mismatched dimensions. Should return "mismatched slice lengths" error instead of "BUG"
+ [][][]float32{{{1, 2}, {3, 4}}, {{1}, {3}}},
+ // Mismatched dimensions. Should return error instead of valid tensor
+ [][][]float32{{{1, 2}, {3, 4}}, {{1}, {3}}, {{1, 2, 3}, {2, 3, 4}}},
+ // Mismatched dimensions for strings
+ [][]string{{"abc"}, {"abcd", "abcd"}},
}
for _, test := range tests {
diff --git a/tensorflow/java/src/gen/perl/tftypes-runall.pl b/tensorflow/java/src/gen/perl/tftypes-runall.pl
index 258c1ff836..a451ce92aa 100644
--- a/tensorflow/java/src/gen/perl/tftypes-runall.pl
+++ b/tensorflow/java/src/gen/perl/tftypes-runall.pl
@@ -37,4 +37,4 @@ sub locchk {
&locchk("$rsrc/tftypes.csv");
system("perl $dir/tftypes.pl -t $rsrc/tftypes.csv $pkg/types");
-# system("perl $dir/tftypes.pl -c $rsrc/tftypes.csv $rsrc/Tensors.java.tmpl > $pkg/op/Tensors.java");
+system("perl $dir/tftypes.pl -c $rsrc/tftypes.csv $rsrc/Tensors.java.tmpl > $pkg/Tensors.java");
diff --git a/tensorflow/java/src/gen/perl/tftypes.pl b/tensorflow/java/src/gen/perl/tftypes.pl
index 86867335cb..115723ac8a 100644
--- a/tensorflow/java/src/gen/perl/tftypes.pl
+++ b/tensorflow/java/src/gen/perl/tftypes.pl
@@ -75,15 +75,23 @@ open (TYPEDESC, $typedesc);
my @info = ([]);
+sub trim {
+ (my $ret) = @_;
+ $ret =~ s/^\s*//g;
+ $ret =~ s/\s*$//g;
+ return $ret;
+}
+
while (<TYPEDESC>) {
chomp;
my $line = $_;
if ($line =~ m/^TF type/) { next }
$line =~ s/\r$//;
- (my $name, my $jtype, my $creat, my $default, my $desc) =
- split /,/, $line, 5;
- $desc =~ s/^ *//g;
- $desc =~ s/ *$//g;
+ my @items = split /,/, $line, 6;
+ for (my $i = 0; $i <= $#items; $i++) {
+ $items[$i] = trim $items[$i];
+ }
+ my $jtype = $items[2];
$jtypecount{$jtype}++;
if ($jtypecount{$jtype} > 1) {
# currently allowing Java types to stand for more than one TF type, but
@@ -92,63 +100,85 @@ while (<TYPEDESC>) {
# exit 1
}
- push @info, [$name, $jtype, $creat, $default, $desc];
+ push @info, \@items;
+}
+
+sub article {
+ (my $s) = @_;
+ if (substr($s, 0, 1) =~ m/^[aeoiu8]$/i) {
+ return "an $s"
+ } else {
+ return "a $s"
+ }
}
for (my $i = 1; $i <= $#info; $i++) {
- (my $name, my $jtype, my $creat, my $default, my $desc) =
+ (my $name, my $builtin, my $jtype, my $creat, my $default, my $desc) =
@{$info[$i]};
- my $tfname = "TF".$name;
+ my $tfname = $name;
my $ucname = uc $name;
+ print STDERR "$name $desc\n";
+
if ($option eq '-t') {
if ($jtype eq '') { next }
+ if ($builtin eq 'y') { next }
# Generate class declarations
# print STDERR "Creating $dirname/$tfname.java\n";
open (CLASSFILE, ">$dirname/$tfname.java") || die "Can't open $tfname.java";
- print CLASSFILE $copyright;
- print CLASSFILE "// GENERATED FILE. To update, edit tftypes.pl instead.\n\n";
-
- my $fulldesc = $desc;
- if (substr($desc, 0, 1) =~ m/^[aeoiu8]$/i) {
- $fulldesc = "an $desc"
- } else {
- $fulldesc = "a $desc"
- }
- print CLASSFILE "package org.tensorflow.types;\n\n"
- ."import org.tensorflow.DataType;\n\n";
+ print CLASSFILE $copyright, "\n";
+ # print CLASSFILE "// GENERATED FILE. To update, edit tftypes.pl instead.\n\n";
+
+ my $fulldesc = article($desc);
+ print CLASSFILE "package org.tensorflow.types;\n\n";
print CLASSFILE "/** Represents $fulldesc. */\n"
- ."public class $tfname implements TFType {\n"
- ." private $tfname() {}\n"
- ." static {\n"
- ." Types.typeCodes.put($tfname.class, DataType.$ucname);\n"
- ." }\n";
- if ($default ne '') {
- print CLASSFILE
- " static {\n"
- ." Types.scalars.put($tfname.class, $default);\n"
- ." }\n";
- }
- print CLASSFILE "}\n";
+ ."public class $tfname {\n"
+ ." private $tfname() {\n"
+ ." }\n"
+ ."}\n";
close(CLASSFILE);
} elsif ($option eq '-c') {
# Generate creator declarations for Tensors.java
if ($jtype ne '' && $creat eq 'y') {
- for (my $brackets = ''; length $brackets <= 12; $brackets .= '[]') {
+ for (my $brackets = '', my $rank = 0; length $brackets <= 12; $brackets .= '[]', $rank++) {
+ my $datainfo = " * \@param data An array containing the values to put into the new tensor.\n"
+ ." * The dimensions of the new tensor will match those of the array.\n";
+ if ($rank == 0) {
+ $datainfo = " * \@param data The value to put into the new scalar tensor.\n"
+ }
+
+ my $trank = $rank;
+ if ($tfname eq 'String') {
+ $trank = $rank-1;
+ next if $trank < 0;
+
+ $datainfo = " * \@param data An array containing the data to put into the new tensor.\n"
+ ." * String elements are sequences of bytes from the last array dimension.\n";
+ }
+
+
+ my $intro = ($trank > 0)
+ ? "Creates a rank-$trank tensor of {\@code $jtype} elements."
+ : "Creates a scalar tensor containing a single {\@code $jtype} element.";
$typeinfo .=
- " public static Tensor<$tfname> create($jtype$brackets data) {\n"
- ." return Tensor.create(data, $tfname.class);\n"
- ." }\n";
+ " /**\n"
+ ." * $intro\n"
+ ." * \n"
+ .$datainfo
+ ." */\n"
+ ." public static Tensor<$tfname> create($jtype$brackets data) {\n"
+ ." return Tensor.create(data, $tfname.class);\n"
+ ." }\n\n";
}
}
- if ($text =~ m/\b$tfname\b/ || $creat eq 'y') {
+ if ($text =~ m/\b$tfname\b/ && $builtin eq 'n' && $creat eq 'y') {
$imports .= "import org.tensorflow.types.$tfname;\n";
}
}
}
if ($option ne '-t') {
- print "// GENERATED FILE. Edits to this file will be lost -- edit $tmpl instead.\n";
+# print "// GENERATED FILE. Edits to this file will be lost -- edit $tmpl instead.\n";
$text =~ s/\@TYPEINFO\@/$typeinfo/;
$text =~ s/\@IMPORTS\@/$imports/;
diff --git a/tensorflow/java/src/gen/resources/Tensors.java.tmpl b/tensorflow/java/src/gen/resources/Tensors.java.tmpl
new file mode 100644
index 0000000000..98e1588559
--- /dev/null
+++ b/tensorflow/java/src/gen/resources/Tensors.java.tmpl
@@ -0,0 +1,31 @@
+package org.tensorflow;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import org.tensorflow.Tensor;
+@IMPORTS@
+
+/**
+ * Type-safe factory methods for creating {@link Tensor} objects.
+ */
+public final class Tensors {
+ private Tensors() {}
+
+ /** Creates a scalar String tensor using the default, UTF-8 encoding.
+ *
+ * @param data The string to put into the new scalar tensor.
+ */
+ public static Tensor<String> create(String data) {
+ return Tensor.create(data.getBytes(UTF_8), String.class);
+ }
+
+ /** Creates a scalar String tensor using a specified encoding.
+ *
+ * @param charset The encoding from String to bytes.
+ * @param data The string to put into the new scalar tensor.
+ */
+ public static Tensor<String> create(String data, java.nio.charset.Charset charset) {
+ return Tensor.create(data.getBytes(charset), String.class);
+ }
+
+@TYPEINFO@}
+
diff --git a/tensorflow/java/src/gen/resources/tftypes.csv b/tensorflow/java/src/gen/resources/tftypes.csv
index 88acaafd3c..6f26230f27 100644
--- a/tensorflow/java/src/gen/resources/tftypes.csv
+++ b/tensorflow/java/src/gen/resources/tftypes.csv
@@ -1,21 +1,21 @@
-TF type,Java type,Creator?,Zero value,Description
-Float,float,y,0f,32-bit single precision floating point number
-Double,double,y,0.0,64-bit double precision floating point number
-Int32,int,y,0,32-bit signed integer
-UInt8,byte,n,(byte)0,8-bit unsigned integer
-Int16,,n,(short)0,16-bit signed integer
-Int8,,n,(byte)0,8-bit signed integer
-String,byte,n,,arbitrary sequence of bytes
-Complex64,,n,,single-precision complex number
-Int64,long,y,0L,64-bit signed integer
-Bool,boolean,y,false,boolean
-QInt8,,n,,quantized int8
-QUInt8,,n,,quantized uint8
-QInt32,,n,,quantized int32
-BFloat16,,n,,float32 truncated to 16 bits. Only for cast ops.
-QInt16,,n,,quantized int16
-QUInt16,,n,,quantized uint16
-UInt16,,n,,16-bit unsigned integer
-Complex128,,n,,double-precision complex number
-Half,,n,,
-Resource,,n,,
+TF type,Builtin,Java type,Creator?,Zero value,Description
+Float,y,float,y,0f,32-bit single precision floating point number
+Double,y,double,y,0.0,64-bit double precision floating point number
+Integer,y,int,y,0,32-bit signed integer
+UInt8,n,byte,n,(byte)0,8-bit unsigned integer
+Short,y,,n,(short)0,16-bit signed integer
+Byte,y,,n,(byte)0,8-bit signed integer
+String,y,byte,y,,arbitrary sequence of bytes
+Complex64,n,,n,,single-precision complex number
+Long,y,long,y,0L,64-bit signed integer
+Boolean,y,boolean,y,false,boolean
+QInt8,n,,n,,quantized int8
+QUInt8,n,,n,,quantized uint8
+QInt32,n,,n,,quantized int32
+BFloat16,n,,n,,float32 truncated to 16 bits. Only for cast ops.
+QInt16,n,,n,,quantized int16
+QUInt16,n,,n,,quantized uint16
+UInt16,n,,n,,16-bit unsigned integer
+Complex128,n,,n,,double-precision complex number
+Half,n,,n,,
+Resource,n,,n,,
diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
index e67e266ff7..e835101d08 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
@@ -15,7 +15,13 @@ limitations under the License.
package org.tensorflow;
-/** Type of elements in a {@link Tensor}. */
+import java.util.HashMap;
+import java.util.Map;
+import org.tensorflow.types.UInt8;
+
+/**
+ * Represents the type of elements in a {@link Tensor} as an enum.
+ */
public enum DataType {
/** 32-bit single precision floating point. */
FLOAT(1),
@@ -55,14 +61,41 @@ public enum DataType {
}
// Cached to avoid copying it
- final private static DataType[] values = values();
+ private static final DataType[] values = values();
static DataType fromC(int c) {
for (DataType t : values) {
- if (t.value == c)
+ if (t.value == c) {
return t;
+ }
}
throw new IllegalArgumentException(
"DataType " + c + " is not recognized in Java (version " + TensorFlow.version() + ")");
}
+
+ /**
+ * Returns the DataType of a Tensor whose elements have the type specified by class {@code c}.
+ *
+ * @param c The class describing the TensorFlow type of interest.
+ */
+ public static DataType fromClass(Class<?> c) {
+ DataType dtype = typeCodes.get(c);
+ if (dtype == null) {
+ throw new IllegalArgumentException(
+ c.getName() + " objects cannot be used as elements in a TensorFlow Tensor");
+ }
+ return dtype;
+ }
+
+ private static final Map<Class<?>, DataType> typeCodes = new HashMap<>();
+
+ static {
+ typeCodes.put(Float.class, DataType.FLOAT);
+ typeCodes.put(Double.class, DataType.DOUBLE);
+ typeCodes.put(Integer.class, DataType.INT32);
+ typeCodes.put(UInt8.class, DataType.UINT8);
+ typeCodes.put(Long.class, DataType.INT64);
+ typeCodes.put(Boolean.class, DataType.BOOL);
+ typeCodes.put(String.class, DataType.STRING);
+ }
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 58ad3ab193..d4fd3db5f7 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -81,8 +81,8 @@ public final class Graph implements AutoCloseable {
/**
* Iterator over all the {@link Operation}s in the graph.
*
- * The order of iteration is unspecified. Consumers of the iterator will received no notification
- * should the underlying graph change during iteration.
+ * <p>The order of iteration is unspecified. Consumers of the iterator will receive no
+ * notification should the underlying graph change during iteration.
*/
public Iterator<Operation> operations() {
return new OperationIterator(this);
@@ -245,7 +245,8 @@ public final class Graph implements AutoCloseable {
private static native long operation(long handle, String name);
- // This method returns the Operation native handle at index 0 and the new value for pos at index 1 (see TF_GraphNextOperation)
+ // This method returns the Operation native handle at index 0 and the new value for pos at index 1
+ // (see TF_GraphNextOperation)
private static native long[] nextOperation(long handle, int position);
private static native void importGraphDef(long handle, byte[] graphDef, String prefix)
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java
index 8e6685ee0f..13bc463e7d 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Input.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Input.java
@@ -34,7 +34,7 @@ package org.tensorflow;
* ops.array().concat(0, split);
* }</pre>
*/
-public interface Input {
+public interface Input<T> {
/**
* Returns the symbolic handle of a tensor.
@@ -44,5 +44,5 @@ public interface Input {
*
* @see OperationBuilder#addInput(Output)
*/
- Output asOutput();
+ Output<T> asOutput();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
index d2d019babb..2b431eebf5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
@@ -122,8 +122,7 @@ final class NativeLibrary {
}
private static String extractResource(
- InputStream resource, String resourceName, String extractToDirectory)
- throws IOException {
+ InputStream resource, String resourceName, String extractToDirectory) throws IOException {
final File dst = new File(extractToDirectory, System.mapLibraryName(resourceName));
dst.deleteOnExit();
final String dstPath = dst.toString();
@@ -184,8 +183,7 @@ final class NativeLibrary {
// compatibility.
private static File createTemporaryDirectory() {
File baseDirectory = new File(System.getProperty("java.io.tmpdir"));
- String directoryName
- = "tensorflow_native_libraries-" + System.currentTimeMillis() + "-";
+ String directoryName = "tensorflow_native_libraries-" + System.currentTimeMillis() + "-";
for (int attempt = 0; attempt < 1000; attempt++) {
File temporaryDirectory = new File(baseDirectory, directoryName + attempt);
if (temporaryDirectory.mkdir()) {
@@ -194,7 +192,8 @@ final class NativeLibrary {
}
throw new IllegalStateException(
"Could not create a temporary directory (tried to make "
- + directoryName + "*) to extract TensorFlow native libraries.");
+ + directoryName
+ + "*) to extract TensorFlow native libraries.");
}
private NativeLibrary() {}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operand.java b/tensorflow/java/src/main/java/org/tensorflow/Operand.java
index 695c4c1060..61082e83d5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Operand.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Operand.java
@@ -22,19 +22,19 @@ package org.tensorflow;
*
* <pre>{@code
* // The "decodeJpeg" operation can be used as an operand to the "cast" operation
- * Operand decodeJpeg = ops.image().decodeJpeg(...);
+ * Operand<UInt8> decodeJpeg = ops.image().decodeJpeg(...);
* ops.math().cast(decodeJpeg, DataType.FLOAT);
*
* // The output "y" of the "unique" operation can be used as an operand to the "cast" operation
- * Output y = ops.array().unique(...).y();
- * ops.math().cast(y, DataType.FLOAT);
+ * Output<Integer> y = ops.array().unique(...).y();
+ * ops.math().cast(y, Float.class);
*
* // The "split" operation can be used as operand list to the "concat" operation
- * Iterable<? extends Operand> split = ops.array().split(...);
+ * Iterable<? extends Operand<Float>> split = ops.array().split(...);
* ops.array().concat(0, split);
* }</pre>
*/
-public interface Operand {
+public interface Operand<T> {
/**
* Returns the symbolic handle of a tensor.
@@ -44,5 +44,5 @@ public interface Operand {
*
* @see OperationBuilder#addInput(Output)
*/
- Output asOutput();
+ Output<T> asOutput();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java
index ec26309fba..6b82e5780b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Operation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java
@@ -98,16 +98,26 @@ public final class Operation {
* @param length number of tensors in the list
* @return array of {@code Output}
*/
- public Output[] outputList(int idx, int length) {
- Output[] outputs = new Output[length];
+ public Output<?>[] outputList(int idx, int length) {
+ Output<?>[] outputs = new Output<?>[length];
for (int i = 0; i < length; ++i) {
outputs[i] = output(idx + i);
}
return outputs;
}
- /** Returns a symbolic handle to one of the tensors produced by this operation. */
- public Output output(int idx) {
+ /**
+ * Returns a symbolic handle to one of the tensors produced by this operation.
+ *
+ * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
+ * operation.<Integer>output(0)}
+ *
+ * @param <T> The expected element type of the tensors produced by this output.
+ * @param idx The index of the output among the outputs produced by this operation.
+ */
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public <T> Output<T> output(int idx) {
return new Output(this, idx);
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
index 15077ce439..9a1b7592b3 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
@@ -63,7 +63,6 @@ public final class OperationBuilder {
}
}
-
/**
* Returns the builder to create an operation.
*
@@ -73,7 +72,7 @@ public final class OperationBuilder {
* @param input {@link Output} supposed to be the input of the OperationBuilder.
* @return the OperationBuilder instance for chaining.
*/
- public OperationBuilder addInput(Output input) {
+ public OperationBuilder addInput(Output<?> input) {
Graph.Reference r = graph.ref();
try {
addInput(unsafeNativeHandle, input.op().getUnsafeNativeHandle(), input.index());
@@ -106,7 +105,7 @@ public final class OperationBuilder {
return this;
}
- public OperationBuilder addInputList(Output[] inputs) {
+ public OperationBuilder addInputList(Output<?>[] inputs) {
Graph.Reference r = graph.ref();
try {
long[] opHandles = new long[inputs.length];
@@ -231,7 +230,7 @@ public final class OperationBuilder {
return this;
}
- public OperationBuilder setAttr(String name, Tensor value) {
+ public OperationBuilder setAttr(String name, Tensor<?> value) {
Graph.Reference r = graph.ref();
try {
setAttrTensor(unsafeNativeHandle, name, value.getNativeHandle());
@@ -241,10 +240,10 @@ public final class OperationBuilder {
return this;
}
- public OperationBuilder setAttr(String name, Tensor[] value) {
+ public OperationBuilder setAttr(String name, Tensor<?>[] value) {
long[] handles = new long[value.length];
int idx = 0;
- for (Tensor t : value) {
+ for (Tensor<?> t : value) {
handles[idx++] = t.getNativeHandle();
}
Graph.Reference r = graph.ref();
@@ -266,7 +265,7 @@ public final class OperationBuilder {
return this;
}
- public OperationBuilder setAttr(String name, String[] value) {
+ public OperationBuilder setAttr(String name, String[] value) {
Charset utf8 = Charset.forName("UTF-8");
Object[] objects = new Object[value.length];
for (int i = 0; i < value.length; ++i) {
@@ -326,5 +325,4 @@ public final class OperationBuilder {
private static native void setAttrShape(long handle, String name, long[] shape, int numDims);
private static native void setAttrStringList(long handle, String name, Object[] value);
-
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java
index 8dff50fafb..0e17a722ff 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Output.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java
@@ -20,13 +20,13 @@ import java.util.Objects;
/**
* A symbolic handle to a tensor produced by an {@link Operation}.
*
- * <p>An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing
- * the {@link Operation} in a {@link Session}.
+ * <p>An Output<T> is a symbolic handle to a Tensor<T>. The value of the tensor is computed by
+ * executing the {@link Operation} in a {@link Session}.
*
* <p>By implementing the {@link Operand} interface, instances of this class also act as operands to
* {@link org.tensorflow.op.Op Op} instances.
*/
-public final class Output implements Operand {
+public final class Output<T> implements Operand<T> {
/** Handle to the idx-th output of the Operation {@code op}. */
public Output(Operation op, int idx) {
@@ -55,7 +55,7 @@ public final class Output implements Operand {
}
@Override
- public Output asOutput() {
+ public Output<T> asOutput() {
return this;
}
@@ -69,8 +69,8 @@ public final class Output implements Operand {
if (o == this) {
return true;
}
- if (o instanceof Output) {
- Output that = (Output) o;
+ if (o instanceof Output<?>) {
+ Output<?> that = (Output<?>) o;
return index == that.index && operation.equals(that.operation);
}
return false;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
index b4591dd869..c8b9126f03 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -27,8 +27,9 @@ package org.tensorflow;
public class SavedModelBundle implements AutoCloseable {
/**
- * Load a saved model from an export directory. The model that is being loaded should be created using
- * the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model API</a>.
+ * Load a saved model from an export directory. The model that is being loaded should be created
+ * using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
+ * API</a>.
*
* @param exportDir the directory path containing a saved model.
* @param tags the tags identifying the specific metagraphdef to load.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java
index 83a300a560..73324f23e6 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java
@@ -127,7 +127,7 @@ public final class Session implements AutoCloseable {
* {@code SignatureDef} protocol buffer messages that are included in {@link
* SavedModelBundle#metaGraphDef()}.
*/
- public Runner feed(String operation, Tensor t) {
+ public Runner feed(String operation, Tensor<?> t) {
return feed(parseOutput(operation), t);
}
@@ -138,7 +138,7 @@ public final class Session implements AutoCloseable {
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
* one {@code t} is being provided for.
*/
- public Runner feed(String operation, int index, Tensor t) {
+ public Runner feed(String operation, int index, Tensor<?> t) {
Operation op = operationByName(operation);
if (op != null) {
inputs.add(op.output(index));
@@ -151,7 +151,7 @@ public final class Session implements AutoCloseable {
* Use {@code t} instead of the Tensor referred to by executing the operation referred to by
* {@code output}.
*/
- public Runner feed(Output o, Tensor t) {
+ public Runner feed(Output<?> o, Tensor<?> t) {
inputs.add(o);
inputTensors.add(t);
return this;
@@ -186,7 +186,7 @@ public final class Session implements AutoCloseable {
}
/** Makes {@link #run()} return the Tensor referred to by {@code output}. */
- public Runner fetch(Output output) {
+ public Runner fetch(Output<?> output) {
outputs.add(output);
return this;
}
@@ -240,8 +240,11 @@ public final class Session implements AutoCloseable {
* easier for the caller to cleanup (perhaps returning something like AutoCloseableList in
* SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a
* {@code Map<Output, Tensor>}?
+ *
+ * <p>TODO(andrewmyers): It would also be good if whatever is returned here made it easier to
+ * extract output tensors in a type-safe way.
*/
- public List<Tensor> run() {
+ public List<Tensor<?>> run() {
return runHelper(false).outputs;
}
@@ -269,17 +272,17 @@ public final class Session implements AutoCloseable {
// It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
// validity of the Graph and graphRef ensures that.
int idx = 0;
- for (Tensor t : inputTensors) {
+ for (Tensor<?> t : inputTensors) {
inputTensorHandles[idx++] = t.getNativeHandle();
}
idx = 0;
- for (Output o : inputs) {
+ for (Output<?> o : inputs) {
inputOpHandles[idx] = o.op().getUnsafeNativeHandle();
inputOpIndices[idx] = o.index();
idx++;
}
idx = 0;
- for (Output o : outputs) {
+ for (Output<?> o : outputs) {
outputOpHandles[idx] = o.op().getUnsafeNativeHandle();
outputOpIndices[idx] = o.index();
idx++;
@@ -306,12 +309,12 @@ public final class Session implements AutoCloseable {
} finally {
runRef.close();
}
- List<Tensor> outputs = new ArrayList<Tensor>();
+ List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
for (long h : outputTensorHandles) {
try {
outputs.add(Tensor.fromHandle(h));
} catch (Exception e) {
- for (Tensor t : outputs) {
+ for (Tensor<?> t : outputs) {
t.close();
}
outputs.clear();
@@ -355,7 +358,8 @@ public final class Session implements AutoCloseable {
return op;
}
- private Output parseOutput(String opName) {
+ @SuppressWarnings("rawtypes")
+ private Output<?> parseOutput(String opName) {
int colon = opName.lastIndexOf(':');
if (colon == -1 || colon == opName.length() - 1) {
return new Output(operationByName(opName), 0);
@@ -369,9 +373,9 @@ public final class Session implements AutoCloseable {
}
}
- private ArrayList<Output> inputs = new ArrayList<Output>();
- private ArrayList<Tensor> inputTensors = new ArrayList<Tensor>();
- private ArrayList<Output> outputs = new ArrayList<Output>();
+ private ArrayList<Output<?>> inputs = new ArrayList<Output<?>>();
+ private ArrayList<Tensor<?>> inputTensors = new ArrayList<Tensor<?>>();
+ private ArrayList<Output<?>> outputs = new ArrayList<Output<?>>();
private ArrayList<Operation> targets = new ArrayList<Operation>();
private byte[] runOptions = null;
}
@@ -388,7 +392,7 @@ public final class Session implements AutoCloseable {
*/
public static final class Run {
/** Tensors from requested fetches. */
- public List<Tensor> outputs;
+ public List<Tensor<?>> outputs;
/**
* (Experimental): Metadata about the run.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index c5ad1ee51c..d4b753628b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -28,89 +28,117 @@ import java.util.Arrays;
import java.util.HashMap;
/**
- * A typed multi-dimensional array.
+ * A statically typed multi-dimensional array whose elements are of a type described by T.
*
* <p>Instances of a Tensor are <b>not</b> thread-safe.
*
* <p><b>WARNING:</b> Resources consumed by the Tensor object <b>must</b> be explicitly freed by
* invoking the {@link #close()} method when the object is no longer needed. For example, using a
- * try-with-resources block like:
+ * try-with-resources block:
*
* <pre>{@code
- * try(Tensor t = Tensor.create(...)) {
+ * try (Tensor t = Tensor.create(...)) {
* doSomethingWith(t);
* }
* }</pre>
*/
-public final class Tensor implements AutoCloseable {
+public final class Tensor<T> implements AutoCloseable {
/**
- * Create a Tensor from a Java object.
+ * Creates a Tensor from a Java object.
*
- * <p>A Tensor is a multi-dimensional array of elements of a limited set of types ({@link
- * DataType}). Thus, not all Java objects can be converted to a Tensor. In particular, {@code obj}
- * must be either a primitive (float, double, int, long, boolean) or a multi-dimensional array of
- * one of those primitives. For example:
+ * <p>A {@code Tensor} is a multi-dimensional array of elements of a limited set of types ({@link
+ * types}), so not all Java objects can be converted to a {@code Tensor}. In particular, the
+ * argument {@code obj} must be either a primitive (float, double, int, long, boolean, byte) or a
+ * multi-dimensional array of one of those primitives. The argument {@code type} specifies how to
+ * interpret the first argument as a TensorFlow type. For example:
*
* <pre>{@code
* // Valid: A 64-bit integer scalar.
- * Tensor s = Tensor.create(42L);
+ * Tensor<Long> s = Tensor.create(42L, Long.class);
*
* // Valid: A 3x2 matrix of floats.
* float[][] matrix = new float[3][2];
- * Tensor m = Tensor.create(matrix);
+ * Tensor<Float> m = Tensor.create(matrix, Float.class);
*
* // Invalid: Will throw an IllegalArgumentException as an arbitrary Object
* // does not fit into the TensorFlow type system.
- * Tensor o = Tensor.create(new Object());
+ * Tensor<?> o = Tensor.create(new Object())
*
* // Invalid: Will throw an IllegalArgumentException since there are
* // a differing number of elements in each row of this 2-D array.
* int[][] twoD = new int[2][];
* twoD[0] = new int[1];
* twoD[1] = new int[2];
- * Tensor x = Tensor.create(twoD);
+ * Tensor<Integer> x = Tensor.create(twoD, Integer.class);
* }</pre>
*
- * {@link DataType#STRING} typed Tensors are multi-dimensionary arrays of arbitrary byte sequences
- * and thus have {@code byte[]} and not {@code String}-valued elements. For example:
+ * {@link String}-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can
+ * be initialized from arrays of {@code byte[]} elements. For example:
*
* <pre>{@code
- * // Valid: A DataType.STRING tensor.
- * Tensor s = Tensor.create(new byte[]{1, 2, 3});
+ * // Valid: A String tensor.
+ * Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class);
*
* // Java Strings will need to be encoded into a byte-sequence.
* String mystring = "foo";
- * Tensor s = Tensor.create(mystring.getBytes("UTF-8"));
+ * Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class);
*
- * // Valid: Matrix of DataType.STRING tensors.
+ * // Valid: Matrix of String tensors.
* // Each element might have a different length.
* byte[][][] matrix = new byte[2][2][];
* matrix[0][0] = "this".getBytes("UTF-8");
* matrix[0][1] = "is".getBytes("UTF-8");
* matrix[1][0] = "a".getBytes("UTF-8");
* matrix[1][1] = "matrix".getBytes("UTF-8");
- * Tensor m = Tensor.create(matrix);
+ * Tensor<String> m = Tensor.create(matrix, String.class);
* }</pre>
*
+ * @param obj The object to convert to a Tensor<T>. Note that whether it is compatible with the
+ * type T is not checked by the type system. For type-safe creation of tensors, use {@link
+ * Tensors}.
+ * @param type The class object representing the type T.
* @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
- * system, or if obj does not disambiguate between multiple DataTypes. In that case, consider
- * using {@link #create(DataType, long[], ByteBuffer)} instead.
+ * system.
*/
- public static Tensor create(Object obj) {
+ @SuppressWarnings("unchecked")
+ public static <T> Tensor<T> create(Object obj, Class<T> type) {
+ DataType dtype = DataType.fromClass(type);
+ if (!objectCompatWithType(obj, dtype)) {
+ throw new IllegalArgumentException(
+ "DataType of object does not match T (expected "
+ + dtype
+ + ", got "
+ + dataTypeOf(obj)
+ + ")");
+ }
+ return (Tensor<T>) create(obj, dtype);
+ }
+
+ /**
+ * Creates a tensor from an object whose class is inspected to figure out what the underlying data
+ * type should be.
+ *
+ * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
+ * system.
+ */
+ public static Tensor<?> create(Object obj) {
return create(obj, dataTypeOf(obj));
}
/**
- * Create a Tensor of data type {@code dtype} from a Java object.
+ * Create a Tensor of data type {@code dtype} from a Java object. Requires the parameter {@code T}
+ * to match {@code type}, but this condition is not checked.
*
- * @param dtype the intended tensor data type. It must match the the run-time type of the object.
+ * @param obj the object supplying the tensor data.
+ * @param dtype the data type of the tensor to create. It must be compatible with the run-time
+ * type of the object.
+ * @return the new tensor
*/
- static Tensor create(Object obj, DataType dtype) {
- Tensor t = new Tensor();
- t.dtype = dtype;
+ private static Tensor<?> create(Object obj, DataType dtype) {
+ @SuppressWarnings("rawtypes")
+ Tensor<?> t = new Tensor(dtype);
t.shapeCopy = new long[numDimensions(obj, dtype)];
- assert objectCompatWithType(obj, dtype);
fillShape(obj, 0, t.shapeCopy);
if (t.dtype != DataType.STRING) {
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
@@ -125,7 +153,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Create an {@link DataType#INT32} Tensor with data from the given buffer.
+ * Create a {@link Integer} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
@@ -136,14 +164,14 @@ public final class Tensor implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor create(long[] shape, IntBuffer data) {
- Tensor t = allocateForBuffer(DataType.INT32, shape, data.remaining());
+ public static Tensor<Integer> create(long[] shape, IntBuffer data) {
+ Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape, data.remaining());
t.buffer().asIntBuffer().put(data);
return t;
}
/**
- * Create a {@link DataType#FLOAT} Tensor with data from the given buffer.
+ * Create a {@link Float} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
@@ -154,14 +182,14 @@ public final class Tensor implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor create(long[] shape, FloatBuffer data) {
- Tensor t = allocateForBuffer(DataType.FLOAT, shape, data.remaining());
+ public static Tensor<Float> create(long[] shape, FloatBuffer data) {
+ Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape, data.remaining());
t.buffer().asFloatBuffer().put(data);
return t;
}
/**
- * Create a {@link DataType#DOUBLE} Tensor with data from the given buffer.
+ * Create a {@link Double} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
@@ -172,14 +200,14 @@ public final class Tensor implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor create(long[] shape, DoubleBuffer data) {
- Tensor t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining());
+ public static Tensor<Double> create(long[] shape, DoubleBuffer data) {
+ Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining());
t.buffer().asDoubleBuffer().put(data);
return t;
}
/**
- * Create an {@link DataType#INT64} Tensor with data from the given buffer.
+ * Create an {@link Long} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
@@ -190,47 +218,87 @@ public final class Tensor implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor create(long[] shape, LongBuffer data) {
- Tensor t = allocateForBuffer(DataType.INT64, shape, data.remaining());
+ public static Tensor<Long> create(long[] shape, LongBuffer data) {
+ Tensor<Long> t = allocateForBuffer(DataType.INT64, shape, data.remaining());
t.buffer().asLongBuffer().put(data);
return t;
}
/**
- * Create a Tensor with data from the given buffer.
+ * Create a Tensor of any type with data from the given buffer.
+ *
+ * <p>Creates a Tensor with the provided shape of any type where the tensor's data has been
+ * encoded into {@code data} as per the specification of the TensorFlow <a
+ * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
+ *
+ * @param <T> the tensor element type
+ * @param type the tensor element type, represented as a class object.
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
+ * buffer
+ */
+ public static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data) {
+ @SuppressWarnings("unchecked")
+ Tensor<T> ret = (Tensor<T>) create(DataType.fromClass(type), shape, data);
+ return ret;
+ }
+
+ /**
+ * Creates a Tensor of any type with data from the given buffer.
*
* <p>Creates a Tensor with the provided shape of any type where the tensor's data has been
* encoded into {@code data} as per the specification of the TensorFlow <a
* href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
*
- * @param dataType the tensor datatype.
+ * @param <T> The tensor element type
+ * @param type the tensor element type, specified as a DataType. This must agree with T.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- public static Tensor create(DataType dataType, long[] shape, ByteBuffer data) {
+ private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) {
int nremaining = 0;
- if (dataType != DataType.STRING) {
- int elemBytes = elemByteSize(dataType);
+ if (dtype != DataType.STRING) {
+ int elemBytes = elemByteSize(dtype);
if (data.remaining() % elemBytes != 0) {
throw new IllegalArgumentException(
String.format(
"ByteBuffer with %d bytes is not compatible with a %s Tensor (%d bytes/element)",
- data.remaining(), dataType.toString(), elemBytes));
+ data.remaining(), dtype.toString(), elemBytes));
}
nremaining = data.remaining() / elemBytes;
} else {
nremaining = data.remaining();
}
- Tensor t = allocateForBuffer(dataType, shape, nremaining);
+ Tensor<?> t = allocateForBuffer(dtype, shape, nremaining);
t.buffer().put(data);
return t;
}
+ /**
+ * Returns this Tensor object with the type {@code Tensor<U>}. This method is useful when given a
+ * value of type {@code Tensor<?>}.
+ *
+ * @param type any (non-null) array of the correct type.
+ * @throws IllegalArgumentException if the actual data type of this object does not match the type
+ * {@code U}.
+ */
+ @SuppressWarnings("unchecked")
+ public <U> Tensor<U> expect(Class<U> type) {
+ DataType dt = DataType.fromClass(type);
+ if (!dt.equals(dtype)) {
+ throw new IllegalArgumentException(
+ "Cannot cast from tensor of " + dtype + " to tensor of " + dt);
+ }
+ return ((Tensor<U>) this);
+ }
+
// Helper function to allocate a Tensor for the create() methods that create a Tensor from
// a java.nio.Buffer.
- private static Tensor allocateForBuffer(DataType dataType, long[] shape, int nBuffered) {
+ // Requires: dataType matches T
+ private static <T> Tensor<T> allocateForBuffer(DataType dataType, long[] shape, int nBuffered) {
final int nflattened = numElements(shape);
int nbytes = 0;
if (dataType != DataType.STRING) {
@@ -242,8 +310,7 @@ public final class Tensor implements AutoCloseable {
// DT_STRING tensor encoded in a ByteBuffer.
nbytes = nBuffered;
}
- Tensor t = new Tensor();
- t.dtype = dataType;
+ Tensor<T> t = new Tensor<T>(dataType);
t.shapeCopy = Arrays.copyOf(shape, shape.length);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
return t;
@@ -300,7 +367,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#FLOAT} tensor.
+ * Returns the value in a scalar {@link Float} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
*/
@@ -309,7 +376,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#DOUBLE} tensor.
+ * Returns the value in a scalar {@link Double} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
*/
@@ -318,7 +385,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#INT32} tensor.
+ * Returns the value in a scalar {@link Integer} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
*/
@@ -327,7 +394,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#INT64} tensor.
+ * Returns the value in a scalar {@link Long} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
*/
@@ -336,7 +403,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#BOOL} tensor.
+ * Returns the value in a scalar {@link Boolean} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
@@ -345,7 +412,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#STRING} tensor.
+ * Returns the value in a scalar {@link String} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
@@ -377,21 +444,21 @@ public final class Tensor implements AutoCloseable {
* @throws IllegalArgumentException if the tensor is a scalar or if {@code dst} is not compatible
* with the tensor (for example, mismatched data types or shapes).
*/
- public <T> T copyTo(T dst) {
+ public <U> U copyTo(U dst) {
throwExceptionIfTypeIsIncompatible(dst);
readNDArray(nativeHandle, dst);
return dst;
}
/**
- * Write the data of a {@link DataType#INT32} tensor into the given buffer.
+ * Write the data of a {@link Integer} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
- * @throws IllegalArgumentException If the tensor datatype is not {@link DataType#INT32}
+ * @throws IllegalArgumentException If the tensor data type is not {@link Integer}
*/
public void writeTo(IntBuffer dst) {
if (dtype != DataType.INT32) {
@@ -402,14 +469,14 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Write the data of a {@link DataType#FLOAT} tensor into the given buffer.
+ * Write the data of a {@link Float} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
- * @throws IllegalArgumentException If the tensor datatype is not {@link DataType#FLOAT}
+ * @throws IllegalArgumentException If the tensor datatype is not {@link Float}
*/
public void writeTo(FloatBuffer dst) {
if (dtype != DataType.FLOAT) {
@@ -420,14 +487,14 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Write the data of a {@link DataType#DOUBLE} tensor into the given buffer.
+ * Write the data of a {@link Double} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
- * @throws IllegalArgumentException If the tensor datatype is not {@link DataType#DOUBLE}
+ * @throws IllegalArgumentException If the tensor datatype is not {@link Double}
*/
public void writeTo(DoubleBuffer dst) {
if (dtype != DataType.DOUBLE) {
@@ -438,14 +505,14 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Write the data of a {@link DataType#INT64} tensor into the given buffer.
+ * Write the data of a {@link Long} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
- * @throws IllegalArgumentException If the tensor datatype is not {@link DataType#INT64}
+ * @throws IllegalArgumentException If the tensor datatype is not {@link Long}
*/
public void writeTo(LongBuffer dst) {
if (dtype != DataType.INT64) {
@@ -480,9 +547,9 @@ public final class Tensor implements AutoCloseable {
*
* <p>Takes ownership of the handle.
*/
- static Tensor fromHandle(long handle) {
- Tensor t = new Tensor();
- t.dtype = DataType.fromC(dtype(handle));
+ static Tensor<?> fromHandle(long handle) {
+ @SuppressWarnings("rawtypes")
+ Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
t.shapeCopy = shape(handle);
t.nativeHandle = handle;
return t;
@@ -496,7 +563,9 @@ public final class Tensor implements AutoCloseable {
private DataType dtype;
private long[] shapeCopy = null;
- private Tensor() {}
+ private Tensor(DataType t) {
+ dtype = t;
+ }
private ByteBuffer buffer() {
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
@@ -564,11 +633,26 @@ public final class Tensor implements AutoCloseable {
classDataTypes.put(Boolean.class, DataType.BOOL);
}
- private static DataType dataTypeOf(Object o) {
+ /** The class for the data type to which Java object o corresponds. */
+ private static Class<?> baseObjType(Object o) {
Class<?> c = o.getClass();
while (c.isArray()) {
c = c.getComponentType();
}
+ return c;
+ }
+
+ /**
+ * The default TensorFlow data type to which Java object o corresponds. Some Java objects
+ * represent more than one TensorFlow data type; for example, 'byte' can represent both {@code
+ * uint8} and {@code string}, with the latter being the default interpretation.
+ */
+ private static DataType dataTypeOf(Object o) {
+ Class<?> c = baseObjType(o);
+ return dataTypeFromClass(c);
+ }
+
+ private static DataType dataTypeFromClass(Class<?> c) {
DataType ret = classDataTypes.get(c);
if (ret != null) {
return ret;
@@ -577,7 +661,12 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the number of dimensions of a tensor of type dtype when represented by the object o.
+ * Return the number of dimensions of the tensor that object {@code o} represents as a tensor
+ * whose datatype is {@code dtype}. Normally this is the same as the number of dimensions of o
+ * itself, but is one smaller for tensors of strings.
+ *
+ * @param o The object to inspect. It must be a valid representation of the given data type.
+ * @param dtype The expected data type of the tensor.
*/
private static int numDimensions(Object o, DataType dtype) {
int ret = numArrayDimensions(o);
@@ -624,7 +713,13 @@ public final class Tensor implements AutoCloseable {
/** Returns whether the object {@code obj} can represent a tensor with data type {@code dtype}. */
private static boolean objectCompatWithType(Object obj, DataType dtype) {
- DataType dto = dataTypeOf(obj);
+ Class<?> c = baseObjType(obj);
+ DataType dto = dataTypeFromClass(c);
+ int nd = numDimensions(obj, dto);
+ if (!c.isPrimitive() && c != String.class && nd != 0) {
+ throw new IllegalArgumentException(
+ "cannot create non-scalar Tensors from arrays of boxed values");
+ }
if (dto.equals(dtype)) {
return true;
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
new file mode 100644
index 0000000000..c828d23efc
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
@@ -0,0 +1,447 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+/** Type-safe factory methods for creating {@link org.tensorflow.Tensor} objects. */
+public final class Tensors {
+ private Tensors() {}
+
+ /**
+ * Creates a scalar String tensor using the default, UTF-8 encoding.
+ *
+ * @param data The string to put into the new scalar tensor.
+ */
+ public static Tensor<String> create(String data) {
+ return Tensor.create(data.getBytes(UTF_8), String.class);
+ }
+
+ /**
+ * Creates a scalar String tensor using a specified encoding.
+ *
+ * @param charset The encoding from String to bytes.
+ * @param data The string to put into the new scalar tensor.
+ */
+ public static Tensor<String> create(String data, java.nio.charset.Charset charset) {
+ return Tensor.create(data.getBytes(charset), String.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code float} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Float> create(float data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][][][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][][][][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][][][][][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code double} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Double> create(double data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][][][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][][][][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][][][][][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code int} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Integer> create(int data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][][][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][][][][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][][][][][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code byte} element.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][][][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][][][][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][][][][][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code long} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Long> create(long data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][][][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][][][][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][][][][][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code boolean} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Boolean> create(boolean data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][][][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][][][][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][][][][][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
index 19929188a5..489e95c310 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@@ -29,6 +29,7 @@ import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
+import org.tensorflow.types.UInt8;
/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
public class LabelImage {
@@ -61,17 +62,17 @@ public class LabelImage {
readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
- try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
+ try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
float[] labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(
- String.format(
- "BEST MATCH: %s (%.2f%% likely)",
- labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
+ String.format("BEST MATCH: %s (%.2f%% likely)",
+ labels.get(bestLabelIdx),
+ labelProbabilities[bestLabelIdx] * 100f));
}
}
- private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
+ private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
// Some constants specific to the pre-trained model at:
@@ -88,28 +89,29 @@ public class LabelImage {
// Since the graph is being constructed once per execution here, we can use a constant for the
// input image. If the graph were to be re-used for multiple input images, a placeholder would
// have been more appropriate.
- final Output input = b.constant("input", imageBytes);
- final Output output =
+ final Output<String> input = b.constant("input", imageBytes);
+ final Output<Float> output =
b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
- b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
+ b.cast(b.decodeJpeg(input, 3), Float.class),
b.constant("make_batch", 0)),
b.constant("size", new int[] {H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
- return s.runner().fetch(output.op().name()).run().get(0);
+ return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}
}
}
- private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
+ private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
- Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
+ Tensor<Float> result =
+ s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
@@ -161,48 +163,71 @@ public class LabelImage {
this.g = g;
}
- Output div(Output x, Output y) {
+ Output<Float> div(Output<Float> x, Output<Float> y) {
return binaryOp("Div", x, y);
}
- Output sub(Output x, Output y) {
+ <T> Output<T> sub(Output<T> x, Output<T> y) {
return binaryOp("Sub", x, y);
}
- Output resizeBilinear(Output images, Output size) {
- return binaryOp("ResizeBilinear", images, size);
+ <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
+ return binaryOp3("ResizeBilinear", images, size);
}
- Output expandDims(Output input, Output dim) {
- return binaryOp("ExpandDims", input, dim);
+ <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
+ return binaryOp3("ExpandDims", input, dim);
}
- Output cast(Output value, DataType dtype) {
- return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
+ <T, U> Output<U> cast(Output<T> value, Class<U> type) {
+ DataType dtype = DataType.fromClass(type);
+ return g.opBuilder("Cast", "Cast")
+ .addInput(value)
+ .setAttr("DstT", dtype)
+ .build()
+ .<U>output(0);
}
- Output decodeJpeg(Output contents, long channels) {
+ Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
- .output(0);
+ .<UInt8>output(0);
}
- Output constant(String name, Object value) {
- try (Tensor t = Tensor.create(value)) {
+ <T> Output<T> constant(String name, Object value, Class<T> type) {
+ try (Tensor<T> t = Tensor.<T>create(value, type)) {
return g.opBuilder("Const", name)
- .setAttr("dtype", t.dataType())
+ .setAttr("dtype", DataType.fromClass(type))
.setAttr("value", t)
.build()
- .output(0);
+ .<T>output(0);
}
}
+ Output<String> constant(String name, byte[] value) {
+ return this.constant(name, value, String.class);
+ }
- private Output binaryOp(String type, Output in1, Output in2) {
- return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
+ Output<Integer> constant(String name, int value) {
+ return this.constant(name, value, Integer.class);
}
+ Output<Integer> constant(String name, int[] value) {
+ return this.constant(name, value, Integer.class);
+ }
+
+ Output<Float> constant(String name, float value) {
+ return this.constant(name, value, Float.class);
+ }
+
+ private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
+ return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
+ }
+
+ private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
+ return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
+ }
private Graph g;
}
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java b/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java
index 5971103d6d..ac48da8032 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java
@@ -33,12 +33,12 @@ public final class Operands {
* @param inputs an iteration of input operands
* @return an array of outputs
*/
- public static Output[] asOutputs(Iterable<? extends Operand> inputs) {
- List<Output> outputList = new ArrayList<>();
- for (Operand input : inputs) {
+ public static Output<?>[] asOutputs(Iterable<? extends Operand<?>> inputs) {
+ List<Output<?>> outputList = new ArrayList<>();
+ for (Operand<?> input : inputs) {
outputList.add(input.asOutput());
}
- return outputList.toArray(new Output[outputList.size()]);
+ return outputList.toArray(new Output<?>[outputList.size()]);
}
// Disabled constructor
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index cd7931d3bb..725c81765a 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -31,7 +31,7 @@ import org.tensorflow.op.annotation.Operator;
/** An operator producing a constant value. */
@Operator
-public final class Constant extends PrimitiveOp implements Operand {
+public final class Constant<T> extends PrimitiveOp implements Operand<T> {
/**
* Create a constant from a Java object.
*
@@ -47,8 +47,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param object a Java object representing the constant.
* @see org.tensorflow.Tensor#create(Object) Tensor.create
*/
- public static Constant create(Scope scope, Object object) {
- try (Tensor value = Tensor.create(object)) {
+ public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
+ try (Tensor<T> value = Tensor.create(object, type)) {
return createWithTensor(scope, value);
}
}
@@ -66,8 +66,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant create(Scope scope, long[] shape, IntBuffer data) {
- try (Tensor value = Tensor.create(shape, data)) {
+ public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) {
+ try (Tensor<Integer> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -85,8 +85,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant create(Scope scope, long[] shape, FloatBuffer data) {
- try (Tensor value = Tensor.create(shape, data)) {
+ public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) {
+ try (Tensor<Float> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -104,8 +104,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant create(Scope scope, long[] shape, DoubleBuffer data) {
- try (Tensor value = Tensor.create(shape, data)) {
+ public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) {
+ try (Tensor<Double> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -123,8 +123,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant create(Scope scope, long[] shape, LongBuffer data) {
- try (Tensor value = Tensor.create(shape, data)) {
+ public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) {
+ try (Tensor<Long> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -143,14 +143,14 @@ public final class Constant extends PrimitiveOp implements Operand {
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) {
- try (Tensor value = Tensor.create(dataType, shape, data)) {
+ public static <T> Constant<T> create(Scope scope, Class<T> type, long[] shape, ByteBuffer data) {
+ try (Tensor<T> value = Tensor.create(type, shape, data)) {
return createWithTensor(scope, value);
}
}
- private static Constant createWithTensor(Scope scope, Tensor value) {
- return new Constant(
+ private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) {
+ return new Constant<T>(
scope
.graph()
.opBuilder("Const", scope.makeOpName("Const"))
@@ -160,7 +160,7 @@ public final class Constant extends PrimitiveOp implements Operand {
}
@Override
- public Output asOutput() {
+ public Output<T> asOutput() {
return output;
}
@@ -169,5 +169,5 @@ public final class Constant extends PrimitiveOp implements Operand {
output = operation.output(0);
}
- private final Output output;
+ private final Output<T> output;
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java
new file mode 100644
index 0000000000..0c751aed9f
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java
@@ -0,0 +1,21 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.types;
+
+/** Represents an 8-bit unsigned integer. */
+public class UInt8 {
+ private UInt8() {}
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
index f1410a760e..96018c5366 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
@@ -15,13 +15,15 @@ limitations under the License.
/**
* Defines classes that represent TensorFlow data types. For each possible data type
- * that can be used in a tensor, there is a corresponding class in this package that
+ * that can be used in a tensor, there is a corresponding class that
* is used to represent it. For example, the TensorFlow int32 type is represented by
- * the type TFInt32 and by the class object TFInt32.class. The former is used to
- * support compile-time checking of tensor data types and the latter is used for
- * run-time checking of data types. All such classes implement the TFType interface.
- * TensorFlow data types are also separately represented by the DataType enum, with
- * one enum value per data type. The enum representation should rarely be needed, but
- * the Types class can be used to obtain it from the class object representation.
+ * the type {@link Integer} and by the class object {@code Integer.class}. The former is used to
+ * support compile-time checking of tensor element types and the latter is used for
+ * run-time checking of element types. Classes appearing in this package, such as
+ * UInt8, represent TensorFlow data types for which there is no existing Java equivalent.
+ *
+ * <p>TensorFlow element types are also separately represented by the {@link DataType} enum, with
+ * one enum value per element type. The enum representation is not usually needed, but
+ * can be obtained using {@link DataType.fromClass}.
*/
package org.tensorflow.types;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index 4adc861bf1..c540299bdc 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
index b3bc3aaef9..6dc233987b 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
@@ -34,8 +34,8 @@ public class OperationBuilderTest {
public void failWhenMixingOperationsOnDifferentGraphs() {
try (Graph g1 = new Graph();
Graph g2 = new Graph()) {
- Output c1 = TestUtil.constant(g1, "C1", 3);
- Output c2 = TestUtil.constant(g2, "C2", 3);
+ Output<Integer> c1 = TestUtil.constant(g1, "C1", 3);
+ Output<Integer> c2 = TestUtil.constant(g2, "C2", 3);
TestUtil.addN(g1, c1, c1);
try {
TestUtil.addN(g2, c1, c2);
@@ -48,7 +48,7 @@ public class OperationBuilderTest {
@Test
public void failOnUseAfterBuild() {
try (Graph g = new Graph();
- Tensor t = Tensor.create(1)) {
+ Tensor<Integer> t = Tensors.create(1)) {
OperationBuilder b =
g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
b.build();
@@ -64,7 +64,7 @@ public class OperationBuilderTest {
public void failOnUseAfterGraphClose() {
OperationBuilder b = null;
try (Graph g = new Graph();
- Tensor t = Tensor.create(1)) {
+ Tensor<Integer> t = Tensors.create(1)) {
b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
}
try {
@@ -85,7 +85,7 @@ public class OperationBuilderTest {
// types that aren't inferred from the input arguments.
try (Graph g = new Graph()) {
// dtype, tensor attributes.
- try (Tensor t = Tensor.create(1)) {
+ try (Tensor<Integer> t = Tensors.create(1)) {
g.opBuilder("Const", "DataTypeAndTensor")
.setAttr("dtype", DataType.INT32)
.setAttr("value", t)
@@ -101,7 +101,7 @@ public class OperationBuilderTest {
assertTrue(hasNode(g, "StringAndBool"));
// int (TF "int" attributes are 64-bit signed, so a Java long).
g.opBuilder("RandomUniform", "Int")
- .addInput(TestUtil.constant(g, "RandomUniformShape", new int[]{1}))
+ .addInput(TestUtil.constant(g, "RandomUniformShape", new int[] {1}))
.setAttr("seed", 10)
.setAttr("dtype", DataType.FLOAT)
.build();
@@ -127,7 +127,7 @@ public class OperationBuilderTest {
@Test
public void setAttrShape() {
try (Graph g = new Graph()) {
- Output n =
+ Output<?> n =
g.opBuilder("Placeholder", "unknown")
.setAttr("dtype", DataType.FLOAT)
.setAttr("shape", Shape.unknown())
@@ -136,8 +136,7 @@ public class OperationBuilderTest {
assertEquals(-1, n.shape().numDimensions());
assertEquals(DataType.FLOAT, n.dataType());
- n =
- g.opBuilder("Placeholder", "batch_of_vectors")
+ n = g.opBuilder("Placeholder", "batch_of_vectors")
.setAttr("dtype", DataType.FLOAT)
.setAttr("shape", Shape.make(-1, 784))
.build()
@@ -153,13 +152,13 @@ public class OperationBuilderTest {
public void addControlInput() {
try (Graph g = new Graph();
Session s = new Session(g);
- Tensor yes = Tensor.create(true);
- Tensor no = Tensor.create(false)) {
- Output placeholder = TestUtil.placeholder(g, "boolean", DataType.BOOL);
+ Tensor<Boolean> yes = Tensors.create(true);
+ Tensor<Boolean> no = Tensors.create(false)) {
+ Output<Boolean> placeholder = TestUtil.placeholder(g, "boolean", Boolean.class);
Operation check =
g.opBuilder("Assert", "assert")
.addInput(placeholder)
- .addInputList(new Output[] {placeholder})
+ .addInputList(new Output<?>[] {placeholder})
.build();
Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build();
diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
index aade375db8..6fe3b3c327 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
@@ -24,7 +24,6 @@ import static org.junit.Assert.fail;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -104,9 +103,9 @@ public class OperationTest {
@Test
public void outputEquality() {
try (Graph g = new Graph()) {
- Output output = TestUtil.constant(g, "c", 1);
- Output output1 = output.op().output(0);
- Output output2 = g.operation("c").output(0);
+ Output<Integer> output = TestUtil.constant(g, "c", 1);
+ Output<Integer> output1 = output.op().<Integer>output(0);
+ Output<Integer> output2 = g.operation("c").<Integer>output(0);
assertEquals(output, output1);
assertEquals(output.hashCode(), output1.hashCode());
assertEquals(output, output2);
@@ -117,10 +116,10 @@ public class OperationTest {
@Test
public void outputCollection() {
try (Graph g = new Graph()) {
- Output output = TestUtil.constant(g, "c", 1);
- Output output1 = output.op().output(0);
- Output output2 = g.operation("c").output(0);
- Set<Output> ops = new HashSet<>();
+ Output<Integer> output = TestUtil.constant(g, "c", 1);
+ Output<Integer> output1 = output.op().<Integer>output(0);
+ Output<Integer> output2 = g.operation("c").<Integer>output(0);
+ Set<Output<Integer>> ops = new HashSet<>();
ops.addAll(Arrays.asList(output, output1, output2));
assertEquals(1, ops.size());
assertTrue(ops.contains(output));
@@ -132,7 +131,7 @@ public class OperationTest {
@Test
public void outputToString() {
try (Graph g = new Graph()) {
- Output output = TestUtil.constant(g, "c", new int[] {1});
+ Output<Integer> output = TestUtil.constant(g, "c", new int[] {1});
assertNotNull(output.toString());
}
}
@@ -158,7 +157,7 @@ public class OperationTest {
public void outputList() {
try (Graph g = new Graph()) {
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
- Output[] outputs = split.outputList(1, 2);
+ Output<?>[] outputs = split.outputList(1, 2);
assertNotNull(outputs);
assertEquals(2, outputs.length);
for (int i = 0; i < outputs.length; ++i) {
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
index 50bdf351e3..a86b4dd117 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
@@ -35,9 +35,9 @@ public class SessionTest {
try (Graph g = new Graph();
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
- try (Tensor x = Tensor.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor> outputs =
- new AutoCloseableList<Tensor>(s.runner().feed("X", x).fetch("Y").run())) {
+ try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
+ AutoCloseableList<Tensor<?>> outputs =
+ new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -50,11 +50,11 @@ public class SessionTest {
try (Graph g = new Graph();
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
- Output feed = g.operation("X").output(0);
- Output fetch = g.operation("Y").output(0);
- try (Tensor x = Tensor.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor> outputs =
- new AutoCloseableList<Tensor>(s.runner().feed(feed, x).fetch(fetch).run())) {
+ Output<Integer> feed = g.operation("X").output(0);
+ Output<Integer> fetch = g.operation("Y").output(0);
+ try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
+ AutoCloseableList<Tensor<?>> outputs =
+ new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -78,14 +78,21 @@ public class SessionTest {
.build()
.output(0);
// Fetch using colon separated names.
- try (Tensor fetched = s.runner().fetch("Split:1").run().get(0)) {
+ try (Tensor<Integer> fetched =
+ s.runner().fetch("Split:1").run().get(0).expect(Integer.class)) {
final int[] expected = {3, 4};
assertArrayEquals(expected, fetched.copyTo(new int[2]));
}
// Feed using colon separated names.
- try (Tensor fed = Tensor.create(new int[] {4, 3, 2, 1});
- Tensor fetched =
- s.runner().feed("Split:0", fed).feed("Split:1", fed).fetch("Add").run().get(0)) {
+ try (Tensor<Integer> fed = Tensors.create(new int[] {4, 3, 2, 1});
+ Tensor<Integer> fetched =
+ s.runner()
+ .feed("Split:0", fed)
+ .feed("Split:1", fed)
+ .fetch("Add")
+ .run()
+ .get(0)
+ .expect(Integer.class)) {
final int[] expected = {8, 6, 4, 2};
assertArrayEquals(expected, fetched.copyTo(new int[4]));
}
@@ -97,7 +104,7 @@ public class SessionTest {
try (Graph g = new Graph();
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
- try (Tensor x = Tensor.create(new int[][] {{5}, {7}})) {
+ try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}})) {
Session.Run result =
s.runner()
.feed("X", x)
@@ -105,7 +112,7 @@ public class SessionTest {
.setOptions(fullTraceRunOptions())
.runAndFetchMetadata();
// Sanity check on outputs.
- AutoCloseableList<Tensor> outputs = new AutoCloseableList<Tensor>(result.outputs);
+ AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs);
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -117,6 +124,7 @@ public class SessionTest {
assertTrue(md.toString(), md.hasStepStats());
*/
assertTrue(result.metadata.length > 0);
+ outputs.close();
}
}
}
@@ -127,11 +135,12 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.constant(g, "c1", 2718);
TestUtil.constant(g, "c2", 31415);
- AutoCloseableList<Tensor> outputs =
- new AutoCloseableList<Tensor>(s.runner().fetch("c2").fetch("c1").run());
+ AutoCloseableList<Tensor<?>> outputs =
+ new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
assertEquals(2, outputs.size());
assertEquals(31415, outputs.get(0).intValue());
assertEquals(2718, outputs.get(1).intValue());
+ outputs.close();
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
index fe46c0184c..3b027700c5 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
@@ -61,7 +61,7 @@ public class ShapeTest {
@Test
public void nodesInAGraph() {
try (Graph g = new Graph()) {
- Output n = TestUtil.placeholder(g, "feed", DataType.FLOAT);
+ Output<Float> n = TestUtil.placeholder(g, "feed", Float.class);
assertEquals(-1, n.shape().numDimensions());
n = TestUtil.constant(g, "scalar", 3);
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
index 036db04503..6538359d11 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
@@ -30,6 +30,7 @@ import java.nio.LongBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.tensorflow.types.UInt8;
/** Unit tests for {@link org.tensorflow.Tensor}. */
@RunWith(JUnit4.class)
@@ -47,7 +48,7 @@ public class TensorTest {
byte[] strings = "test".getBytes(UTF_8);
long[] strings_shape = {};
byte[] strings_; // raw TF_STRING
- try (Tensor t = Tensor.create(strings)) {
+ try (Tensor<String> t = Tensors.create(strings)) {
ByteBuffer to = ByteBuffer.allocate(t.numBytes());
t.writeTo(to);
strings_ = to.array();
@@ -55,7 +56,7 @@ public class TensorTest {
// validate creating a tensor using a byte buffer
{
- try (Tensor t = Tensor.create(DataType.BOOL, bools_shape, ByteBuffer.wrap(bools_))) {
+ try (Tensor<Boolean> t = Tensor.create(Boolean.class, bools_shape, ByteBuffer.wrap(bools_))) {
boolean[] actual = t.copyTo(new boolean[bools_.length]);
for (int i = 0; i < bools.length; ++i) {
assertEquals("" + i, bools[i], actual[i]);
@@ -63,7 +64,8 @@ public class TensorTest {
}
// note: the buffer is expected to contain raw TF_STRING (as per C API)
- try (Tensor t = Tensor.create(DataType.STRING, strings_shape, ByteBuffer.wrap(strings_))) {
+ try (Tensor<String> t =
+ Tensor.create(String.class, strings_shape, ByteBuffer.wrap(strings_))) {
assertArrayEquals(strings, t.bytesValue());
}
}
@@ -72,15 +74,15 @@ public class TensorTest {
{
ByteBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder());
buf.asDoubleBuffer().put(doubles);
- try (Tensor t = Tensor.create(DataType.DOUBLE, doubles_shape, buf)) {
+ try (Tensor<Double> t = Tensor.create(Double.class, doubles_shape, buf)) {
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
}
}
// validate shape checking
- try (Tensor t =
- Tensor.create(DataType.BOOL, new long[bools_.length * 2], ByteBuffer.wrap(bools_))) {
+ try (Tensor<Boolean> t =
+ Tensor.create(Boolean.class, new long[bools_.length * 2], ByteBuffer.wrap(bools_))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
@@ -99,7 +101,7 @@ public class TensorTest {
.asDoubleBuffer()
.put(doubles);
buf.flip();
- try (Tensor t = Tensor.create(new long[] {doubles.length}, buf)) {
+ try (Tensor<Double> t = Tensor.create(new long[] {doubles.length}, buf)) {
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
}
@@ -115,19 +117,19 @@ public class TensorTest {
// validate creating a tensor using a typed buffer
{
- try (Tensor t = Tensor.create(shape, DoubleBuffer.wrap(doubles))) {
+ try (Tensor<Double> t = Tensor.create(shape, DoubleBuffer.wrap(doubles))) {
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
}
- try (Tensor t = Tensor.create(shape, FloatBuffer.wrap(floats))) {
+ try (Tensor<Float> t = Tensor.create(shape, FloatBuffer.wrap(floats))) {
float[] actual = new float[floats.length];
assertArrayEquals(floats, t.copyTo(actual), EPSILON_F);
}
- try (Tensor t = Tensor.create(shape, IntBuffer.wrap(ints))) {
+ try (Tensor<Integer> t = Tensor.create(shape, IntBuffer.wrap(ints))) {
int[] actual = new int[ints.length];
assertArrayEquals(ints, t.copyTo(actual));
}
- try (Tensor t = Tensor.create(shape, LongBuffer.wrap(longs))) {
+ try (Tensor<Long> t = Tensor.create(shape, LongBuffer.wrap(longs))) {
long[] actual = new long[longs.length];
assertArrayEquals(longs, t.copyTo(actual));
}
@@ -135,22 +137,23 @@ public class TensorTest {
// validate shape-checking
{
- try (Tensor t = Tensor.create(new long[doubles.length + 1], DoubleBuffer.wrap(doubles))) {
+ try (Tensor<Double> t =
+ Tensor.create(new long[doubles.length + 1], DoubleBuffer.wrap(doubles))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
}
- try (Tensor t = Tensor.create(new long[floats.length + 1], FloatBuffer.wrap(floats))) {
+ try (Tensor<Float> t = Tensor.create(new long[floats.length + 1], FloatBuffer.wrap(floats))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
}
- try (Tensor t = Tensor.create(new long[ints.length + 1], IntBuffer.wrap(ints))) {
+ try (Tensor<Integer> t = Tensor.create(new long[ints.length + 1], IntBuffer.wrap(ints))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
}
- try (Tensor t = Tensor.create(new long[longs.length + 1], LongBuffer.wrap(longs))) {
+ try (Tensor<Long> t = Tensor.create(new long[longs.length + 1], LongBuffer.wrap(longs))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
@@ -166,11 +169,11 @@ public class TensorTest {
long[] longs = {1L, 2L, 3L};
boolean[] bools = {true, false, true};
- try (Tensor tints = Tensor.create(ints);
- Tensor tfloats = Tensor.create(floats);
- Tensor tdoubles = Tensor.create(doubles);
- Tensor tlongs = Tensor.create(longs);
- Tensor tbools = Tensor.create(bools)) {
+ try (Tensor<Integer> tints = Tensors.create(ints);
+ Tensor<Float> tfloats = Tensors.create(floats);
+ Tensor<Double> tdoubles = Tensors.create(doubles);
+ Tensor<Long> tlongs = Tensors.create(longs);
+ Tensor<Boolean> tbools = Tensors.create(bools)) {
// validate that any datatype is readable with ByteBuffer (content, position)
{
@@ -293,35 +296,35 @@ public class TensorTest {
@Test
public void scalars() {
- try (Tensor t = Tensor.create(2.718f)) {
+ try (Tensor<Float> t = Tensors.create(2.718f)) {
assertEquals(DataType.FLOAT, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertEquals(2.718f, t.floatValue(), EPSILON_F);
}
- try (Tensor t = Tensor.create(3.1415)) {
+ try (Tensor<Double> t = Tensors.create(3.1415)) {
assertEquals(DataType.DOUBLE, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertEquals(3.1415, t.doubleValue(), EPSILON);
}
- try (Tensor t = Tensor.create(-33)) {
+ try (Tensor<Integer> t = Tensors.create(-33)) {
assertEquals(DataType.INT32, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertEquals(-33, t.intValue());
}
- try (Tensor t = Tensor.create(8589934592L)) {
+ try (Tensor<Long> t = Tensors.create(8589934592L)) {
assertEquals(DataType.INT64, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertEquals(8589934592L, t.longValue());
}
- try (Tensor t = Tensor.create(true)) {
+ try (Tensor<Boolean> t = Tensors.create(true)) {
assertEquals(DataType.BOOL, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
@@ -329,7 +332,7 @@ public class TensorTest {
}
final byte[] bytes = {1, 2, 3, 4};
- try (Tensor t = Tensor.create(bytes)) {
+ try (Tensor<String> t = Tensors.create(bytes)) {
assertEquals(DataType.STRING, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
@@ -340,7 +343,7 @@ public class TensorTest {
@Test
public void nDimensional() {
double[] vector = {1.414, 2.718, 3.1415};
- try (Tensor t = Tensor.create(vector)) {
+ try (Tensor<Double> t = Tensors.create(vector)) {
assertEquals(DataType.DOUBLE, t.dataType());
assertEquals(1, t.numDimensions());
assertArrayEquals(new long[] {3}, t.shape());
@@ -350,7 +353,7 @@ public class TensorTest {
}
int[][] matrix = {{1, 2, 3}, {4, 5, 6}};
- try (Tensor t = Tensor.create(matrix)) {
+ try (Tensor<Integer> t = Tensors.create(matrix)) {
assertEquals(DataType.INT32, t.dataType());
assertEquals(2, t.numDimensions());
assertArrayEquals(new long[] {2, 3}, t.shape());
@@ -362,7 +365,7 @@ public class TensorTest {
long[][][] threeD = {
{{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}},
};
- try (Tensor t = Tensor.create(threeD)) {
+ try (Tensor<Long> t = Tensors.create(threeD)) {
assertEquals(DataType.INT64, t.dataType());
assertEquals(3, t.numDimensions());
assertArrayEquals(new long[] {2, 5, 1}, t.shape());
@@ -376,7 +379,7 @@ public class TensorTest {
{{{false, false, true, true}, {false, true, false, false}}},
{{{false, true, false, true}, {false, true, true, false}}},
};
- try (Tensor t = Tensor.create(fourD)) {
+ try (Tensor<Boolean> t = Tensors.create(fourD)) {
assertEquals(DataType.BOOL, t.dataType());
assertEquals(4, t.numDimensions());
assertArrayEquals(new long[] {3, 1, 2, 4}, t.shape());
@@ -394,7 +397,7 @@ public class TensorTest {
matrix[i][j] = String.format("(%d, %d) = %d", i, j, i << j).getBytes(UTF_8);
}
}
- try (Tensor t = Tensor.create(matrix)) {
+ try (Tensor<String> t = Tensors.create(matrix)) {
assertEquals(DataType.STRING, t.dataType());
assertEquals(2, t.numDimensions());
assertArrayEquals(new long[] {4, 3}, t.shape());
@@ -412,14 +415,24 @@ public class TensorTest {
@Test
public void testUInt8Tensor() {
- byte[] vector = new byte[] { 1, 2, 3, 4 };
- try (Tensor t = Tensor.create(vector, DataType.UINT8)) {
+ byte[] vector = new byte[] {1, 2, 3, 4};
+ try (Tensor<UInt8> t = Tensor.create(vector, UInt8.class)) {
assertEquals(DataType.UINT8, t.dataType());
assertEquals(1, t.numDimensions());
assertArrayEquals(new long[] {4}, t.shape());
byte[] got = t.copyTo(new byte[4]);
- assertArrayEquals(got, vector);
+ assertArrayEquals(vector, got);
+ }
+ }
+
+ @Test
+ public void testCreateFromArrayOfBoxed() {
+ Integer[] vector = new Integer[] {1, 2, 3, 4};
+ try (Tensor<Integer> t = Tensor.create(vector, Integer.class)) {
+ fail("Tensor.create() should fail because it was given an array of boxed values");
+ } catch (IllegalArgumentException e) {
+ // The expected exception
}
}
@@ -431,7 +444,7 @@ public class TensorTest {
invalid[x][y] = new int[x + y + 1];
}
}
- try (Tensor t = Tensor.create(invalid)) {
+ try (Tensor<?> t = Tensor.create(invalid)) {
fail("Tensor.create() should fail because of differing sizes in the 3rd dimension");
} catch (IllegalArgumentException e) {
// The expected exception.
@@ -440,7 +453,7 @@ public class TensorTest {
@Test
public void failCopyToOnIncompatibleDestination() {
- try (final Tensor matrix = Tensor.create(new int[][] {{1, 2}, {3, 4}})) {
+ try (final Tensor<Integer> matrix = Tensors.create(new int[][] {{1, 2}, {3, 4}})) {
try {
matrix.copyTo(new int[2]);
fail("should have failed on dimension mismatch");
@@ -466,7 +479,7 @@ public class TensorTest {
@Test
public void failCopyToOnScalar() {
- try (final Tensor scalar = Tensor.create(3)) {
+ try (final Tensor<Integer> scalar = Tensors.create(3)) {
try {
scalar.copyTo(3);
fail("copyTo should fail on scalar tensors, suggesting use of primitive accessors instead");
@@ -478,8 +491,8 @@ public class TensorTest {
@Test
public void failOnArbitraryObject() {
- try (Tensor t = Tensor.create(new Object())) {
- fail("should fail on creating a Tensor with a Java object that has not equivalent DataType");
+ try (Tensor<?> t = Tensor.create(new Object())) {
+ fail("should fail on creating a Tensor with a Java object that has no equivalent DataType");
} catch (IllegalArgumentException e) {
// The expected exception.
}
@@ -487,7 +500,7 @@ public class TensorTest {
@Test
public void failOnZeroDimension() {
- try (Tensor t = Tensor.create(new int[3][0][1])) {
+ try (Tensor<Integer> t = Tensors.create(new int[3][0][1])) {
fail("should fail on creating a Tensor where one of the dimensions is 0");
} catch (IllegalArgumentException e) {
// The expected exception.
@@ -497,7 +510,7 @@ public class TensorTest {
@Test
public void useAfterClose() {
int n = 4;
- Tensor t = Tensor.create(n);
+ Tensor<?> t = Tensor.create(n);
t.close();
try {
t.intValue();
@@ -515,8 +528,8 @@ public class TensorTest {
// An exception is made for this test, where the pitfalls of this is avoided by not calling
// close() on both Tensors.
final float[][] matrix = {{1, 2, 3}, {4, 5, 6}};
- try (Tensor src = Tensor.create(matrix)) {
- Tensor cpy = Tensor.fromHandle(src.getNativeHandle());
+ try (Tensor<Float> src = Tensors.create(matrix)) {
+ Tensor<Float> cpy = Tensor.fromHandle(src.getNativeHandle()).expect(Float.class);
assertEquals(src.dataType(), cpy.dataType());
assertEquals(src.numDimensions(), cpy.numDimensions());
assertArrayEquals(src.shape(), cpy.shape());
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index e3415a696d..c973b5a3d8 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -19,33 +19,36 @@ import java.lang.reflect.Array;
/** Static utility functions. */
public class TestUtil {
- public static Output constant(Graph g, String name, Object value) {
- try (Tensor t = Tensor.create(value)) {
+ public static <T> Output<T> constant(Graph g, String name, Object value) {
+ try (Tensor<?> t = Tensor.create(value)) {
return g.opBuilder("Const", name)
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build()
- .output(0);
+ .<T>output(0);
}
}
- public static Output placeholder(Graph g, String name, DataType dtype) {
- return g.opBuilder("Placeholder", name).setAttr("dtype", dtype).build().output(0);
+ public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) {
+ return g.opBuilder("Placeholder", name)
+ .setAttr("dtype", DataType.fromClass(type))
+ .build()
+ .<T>output(0);
}
- public static Output addN(Graph g, Output... inputs) {
+ public static Output<?> addN(Graph g, Output<?>... inputs) {
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
}
- public static Output matmul(
- Graph g, String name, Output a, Output b, boolean transposeA, boolean transposeB) {
+ public static <T> Output<T> matmul(
+ Graph g, String name, Output<T> a, Output<T> b, boolean transposeA, boolean transposeB) {
return g.opBuilder("MatMul", name)
.addInput(a)
.addInput(b)
.setAttr("transpose_a", transposeA)
.setAttr("transpose_b", transposeB)
.build()
- .output(0);
+ .<T>output(0);
}
public static Operation split(Graph g, String name, int[] values, int numSplit) {
@@ -57,7 +60,8 @@ public class TestUtil {
}
public static void transpose_A_times_X(Graph g, int[][] a) {
- matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false);
+ Output<Integer> aa = constant(g, "A", a);
+ matmul(g, "Y", aa, placeholder(g, "X", Integer.class), true, false);
}
/**
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java
index 4fdd150acc..79bfcc8354 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@@ -36,8 +36,9 @@ public class OperandsTest {
public void createOutputArrayFromOperandList() {
try (Graph g = new Graph()) {
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
- List<Output> list = Arrays.asList(split.output(0), split.output(2));
- Output[] array = Operands.asOutputs(list);
+ List<Output<Integer>> list =
+ Arrays.asList(split.<Integer>output(0), split.<Integer>output(2));
+ Output<?>[] array = Operands.asOutputs(list);
assertEquals(list.size(), array.length);
assertSame(array[0], list.get(0));
assertSame(array[1], list.get(1));
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java
index b24bf5a476..e02c38ed22 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java
@@ -36,7 +36,7 @@ public class PrimitiveOpTest {
@Test
public void equalsHashcode() {
try (Graph g = new Graph()) {
- Output array = TestUtil.constant(g, "array", new int[2]);
+ Output<Integer> array = TestUtil.constant(g, "array", new int[2]);
PrimitiveOp test1 =
new PrimitiveOp(g.opBuilder("Shape", "shape1").addInput(array).build()) {};
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
index 9256cb281d..125de73554 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
@@ -19,6 +19,8 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -26,6 +28,8 @@ import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
+import org.tensorflow.Tensors;
+import org.tensorflow.types.UInt8;
/** Unit tests for {@link org.tensorflow.Scope}. */
@RunWith(JUnit4.class)
@@ -122,13 +126,13 @@ public class ScopeTest {
public void basic() {
try (Graph g = new Graph()) {
Scope s = new Scope(g);
- Const c1 = Const.create(s, 42);
+ Const<Integer> c1 = Const.create(s, 42);
assertEquals("Const", c1.output().op().name());
- Const c2 = Const.create(s, 7);
+ Const<Integer> c2 = Const.create(s, 7);
assertEquals("Const_1", c2.output().op().name());
- Const c3 = Const.create(s.withName("four"), 4);
+ Const<Integer> c3 = Const.create(s.withName("four"), 4);
assertEquals("four", c3.output().op().name());
- Const c4 = Const.create(s.withName("four"), 4);
+ Const<Integer> c4 = Const.create(s.withName("four"), 4);
assertEquals("four_1", c4.output().op().name());
}
}
@@ -148,122 +152,164 @@ public class ScopeTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope s = new Scope(g);
- Output data = Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output();
+ Output<Integer> data =
+ Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output();
// Create a composite op with a customized name
- Variance var1 = Variance.create(s.withName("example"), data);
+ Variance<Integer> var1 = Variance.create(s.withName("example"), data, Integer.class);
assertEquals("example/variance", var1.output().op().name());
// Confirm internally added ops have the right names.
assertNotNull(g.operation("example/squared_deviation"));
assertNotNull(g.operation("example/Mean"));
- assertNotNull(g.operation("example/zero"));
+ // assertNotNull(g.operation("example/zero"));
// Same composite op with a default name
- Variance var2 = Variance.create(s, data);
+ Variance<Integer> var2 = Variance.create(s, data, Integer.class);
assertEquals("variance/variance", var2.output().op().name());
// Confirm internally added ops have the right names.
assertNotNull(g.operation("variance/squared_deviation"));
assertNotNull(g.operation("variance/Mean"));
- assertNotNull(g.operation("variance/zero"));
+ // assertNotNull(g.operation("variance/zero"));
// Verify correct results as well.
- Tensor result = sess.runner().fetch(var1.output()).run().get(0);
+ Tensor<Integer> result =
+ sess.runner().fetch(var1.output()).run().get(0).expect(Integer.class);
assertEquals(21704, result.intValue());
- result = sess.runner().fetch(var2.output()).run().get(0);
+ result = sess.runner().fetch(var2.output()).run().get(0).expect(Integer.class);
assertEquals(21704, result.intValue());
}
}
// "handwritten" sample operator classes
- private static final class Const {
- private final Output output;
+ private static final class Const<T> {
+ private final Output<T> output;
- static Const create(Scope s, Object v) {
- try (Tensor value = Tensor.create(v)) {
- return new Const(
+ static Const<Integer> create(Scope s, int v) {
+ return create(s, Tensors.create(v));
+ }
+
+ static Const<Integer> create(Scope s, int[] v) {
+ return create(s, Tensors.create(v));
+ }
+
+ static <T> Const<T> create(Scope s, Tensor<T> value) {
+ return new Const<T>(
+ s.graph()
+ .opBuilder("Const", s.makeOpName("Const"))
+ .setAttr("dtype", value.dataType())
+ .setAttr("value", value)
+ .build()
+ .<T>output(0));
+ }
+
+ static <T> Const<T> create(Scope s, Object v, Class<T> type) {
+ try (Tensor<T> value = Tensor.create(v, type)) {
+ return new Const<T>(
s.graph()
.opBuilder("Const", s.makeOpName("Const"))
.setAttr("dtype", value.dataType())
.setAttr("value", value)
.build()
- .output(0));
+ .<T>output(0));
}
}
- Const(Output o) {
+ Const(Output<T> o) {
output = o;
}
- Output output() {
+ Output<T> output() {
return output;
}
}
- private static final class Mean {
- private final Output output;
+ private static final class Mean<T> {
+ private final Output<T> output;
- static Mean create(Scope s, Output input, Output reductionIndices) {
- return new Mean(
+ static <T> Mean<T> create(Scope s, Output<T> input, Output<T> reductionIndices) {
+ return new Mean<T>(
s.graph()
.opBuilder("Mean", s.makeOpName("Mean"))
.addInput(input)
.addInput(reductionIndices)
.build()
- .output(0));
+ .<T>output(0));
}
- Mean(Output o) {
+ Mean(Output<T> o) {
output = o;
}
- Output output() {
+ Output<T> output() {
return output;
}
}
- private static final class SquaredDifference {
- private final Output output;
+ private static final class SquaredDifference<T> {
+ private final Output<T> output;
- static SquaredDifference create(Scope s, Output x, Output y) {
- return new SquaredDifference(
+ static <T> SquaredDifference<T> create(Scope s, Output<T> x, Output<T> y) {
+ return new SquaredDifference<T>(
s.graph()
.opBuilder("SquaredDifference", s.makeOpName("SquaredDifference"))
.addInput(x)
.addInput(y)
.build()
- .output(0));
+ .<T>output(0));
}
- SquaredDifference(Output o) {
+ SquaredDifference(Output<T> o) {
output = o;
}
- Output output() {
+ Output<T> output() {
return output;
}
}
- private static final class Variance {
- private final Output output;
+ /**
+ * Returns the zero value of type described by {@code c}, or null if the type (e.g., string) is
+ * not numeric and therefore has no zero value.
+ *
+ * @param c The class describing the TensorFlow type of interest.
+ */
+ public static Object zeroValue(Class<?> c) {
+ return zeros.get(c);
+ }
+
+ private static final Map<Class<?>, Object> zeros = new HashMap<>();
+
+ static {
+ zeros.put(Float.class, 0.0f);
+ zeros.put(Double.class, 0.0);
+ zeros.put(Integer.class, 0);
+ zeros.put(UInt8.class, (byte) 0);
+ zeros.put(Long.class, 0L);
+ zeros.put(Boolean.class, false);
+ zeros.put(String.class, null); // no zero value
+ }
+
+ private static final class Variance<T> {
+ private final Output<T> output;
- static Variance create(Scope base, Output x) {
+ static <T> Variance<T> create(Scope base, Output<T> x, Class<T> type) {
Scope s = base.withSubScope("variance");
- Output zero = Const.create(s.withName("zero"), new int[] {0}).output();
- Output sqdiff =
+ Output<T> zero = Const.create(base, zeroValue(type), type).output();
+ Output<T> sqdiff =
SquaredDifference.create(
s.withName("squared_deviation"), x, Mean.create(s, x, zero).output())
.output();
- return new Variance(Mean.create(s.withName("variance"), sqdiff, zero).output());
+ return new Variance<T>(Mean.create(s.withName("variance"), sqdiff, zero).output());
}
- Variance(Output o) {
+ Variance(Output<T> o) {
output = o;
}
- Output output() {
+ Output<T> output() {
return output;
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
index ec23792485..ca54214e06 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -29,7 +29,6 @@ import java.nio.LongBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
@@ -47,8 +46,9 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
+ Tensor<Integer> result = sess.runner().fetch(op.asOutput())
+ .run().get(0).expect(Integer.class);
int[] actual = new int[ints.length];
assertArrayEquals(ints, result.copyTo(actual));
}
@@ -62,8 +62,8 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
+ Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
float[] actual = new float[floats.length];
assertArrayEquals(floats, result.copyTo(actual), EPSILON);
}
@@ -77,8 +77,8 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
+ Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
}
@@ -92,8 +92,8 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs));
+ Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
long[] actual = new long[longs.length];
assertArrayEquals(longs, result.copyTo(actual));
}
@@ -123,8 +123,8 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content));
+ Tensor<String> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class);
assertArrayEquals(data, result.bytesValue());
}
}
diff --git a/tensorflow/python/debug/lib/debug_graphs.py b/tensorflow/python/debug/lib/debug_graphs.py
index 486e659158..87033d53a4 100644
--- a/tensorflow/python/debug/lib/debug_graphs.py
+++ b/tensorflow/python/debug/lib/debug_graphs.py
@@ -231,8 +231,8 @@ def _infer_device_name(graph_def):
break
if device_name is None:
logging.warn(
- "Failed to infer device name from partiton GraphDef: none of the nodes "
- "of the GraphDef has a non-empty device name.")
+ "Failed to infer device name from partition GraphDef: none of the "
+ "nodes of the GraphDef has a non-empty device name.")
return device_name
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index d7fe4bbfa1..c0a287e922 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -49,7 +49,7 @@ except ImportError:
def _fill_array(arr, seq, fillvalue=0):
"""
Recursively fills padded arr with elements from seq.
- If lenght of seq is less then arr padded length, fillvalue used.
+ If length of seq is less than arr padded length, fillvalue used.
Args:
arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
index 97bef2965c..32e692ba7c 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
@@ -200,7 +200,7 @@ class TopologyConstructionTest(test.TestCase):
with self.assertRaises(ValueError):
_ = keras.layers.Input(shape=(32,), batch_shape=(10, 32))
with self.assertRaises(ValueError):
- _ = keras.layers.Input(shape=(32,), unknwon_kwarg=None)
+ _ = keras.layers.Input(shape=(32,), unknown_kwarg=None)
self.assertListEqual(a.get_shape().as_list(), [None, 32])
a_layer, a_node_index, a_tensor_index = a._keras_history
diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
index 18184a0ee0..7d0bc54b69 100644
--- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
@@ -24,8 +24,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import device_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@@ -289,6 +293,16 @@ class Conv2DTransposeTest(test.TestCase):
self.assertAllClose(cache_values, value)
+ def testConv2DTransposeShapeInference(self):
+ # Test case for 8972
+ initializer = random_ops.truncated_normal(
+ [3, 3, 5, 1], mean=0.0, stddev=0.01, dtype=dtypes.float32)
+ x = variables.Variable(random_ops.random_normal([3, 10, 5, 1]))
+ f = variable_scope.get_variable("f", initializer=initializer)
+ f_shape = array_ops.stack([array_ops.shape(x)[0], 10, 5, 5])
+ output = nn_ops.conv2d_transpose(
+ x, f, f_shape, strides=[1, 1, 1, 1], padding="SAME")
+ self.assertEqual(output.get_shape().as_list(), [None, 10, 5, 5])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
index 3853379328..7d9e57c8e5 100644
--- a/tensorflow/python/kernel_tests/decode_csv_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -116,6 +116,17 @@ class DecodeCSVOpTest(test.TestCase):
self._test(args, expected_out)
+ def testNA(self):
+ args = {
+ "records": ["2.0,NA,aa", "NA,5,bb", "3,6,NA"],
+ "record_defaults": [[0.0], [0], [""]],
+ "na_value": "NA"
+ }
+
+ expected_out = [[2.0, 0.0, 3], [0, 5, 6], [b"aa", b"bb", b""]]
+
+ self._test(args, expected_out)
+
def testWithDefaults(self):
args = {
"records": [",1,", "0.2,3,bcd", "3.0,,"],
diff --git a/tensorflow/python/kernel_tests/summary_tensor_op_test.py b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
index 3584637865..d534aadb79 100644
--- a/tensorflow/python/kernel_tests/summary_tensor_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
@@ -154,7 +154,7 @@ class SummaryOpsTest(test.TestCase):
self.assertEqual(descr.display_name, "my name")
self.assertEqual(descr.summary_description, "my description")
- # If both SummmaryMetadata and explicit args are provided, the args win
+ # If both SummaryMetadata and explicit args are provided, the args win
overwrite = summary_ops.tensor_summary(
"simple",
const,
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 6e7122db5e..d27e867583 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -207,6 +207,7 @@ TextLineReaderV2
TFRecordReaderV2
WholeFileReaderV2
LMDBReader
+DecodeCSV
# linalg_ops
BatchCholesky
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index c5fd15bae4..ea7132791c 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -1166,3 +1166,42 @@ def _parse_single_sequence_example_raw(serialized,
feature_list_sparse_tensors + feature_list_dense_values))
return (context_output, feature_list_output)
+
+
+# Swap `name` and `na_value` for backward compatibility.
+def decode_csv(records, record_defaults, field_delim=",",
+ use_quote_delim=True, name=None, na_value=""):
+ # pylint: disable=protected-access
+ """Convert CSV records to tensors. Each column maps to one tensor.
+
+ RFC 4180 format is expected for the CSV records.
+ (https://tools.ietf.org/html/rfc4180)
+ Note that we allow leading and trailing spaces with int or float field.
+
+ Args:
+ records: A `Tensor` of type `string`.
+ Each string is a record/row in the csv and all records should have
+ the same format.
+ record_defaults: A list of `Tensor` objects with specific types.
+ Acceptable types are `float32`, `int32`, `int64`, `string`.
+ One tensor per column of the input record, with either a
+ scalar default value for that column or empty if the column is required.
+ field_delim: An optional `string`. Defaults to `","`.
+ char delimiter to separate fields in a record.
+ use_quote_delim: An optional `bool`. Defaults to `True`.
+ If false, treats double quotation marks as regular
+ characters inside of the string fields (ignoring RFC 4180, Section 2,
+ Bullet 5).
+ name: A name for the operation (optional).
+ na_value: Additional string to recognize as NA/NaN.
+
+ Returns:
+ A list of `Tensor` objects. Has the same type as `record_defaults`.
+ Each tensor will have the same shape as records.
+ """
+ # TODO(martinwicke), remove the wrapper when new Python API generator is done.
+ return gen_parsing_ops._decode_csv(
+ records=records, record_defaults=record_defaults,
+ field_delim=field_delim, use_quote_delim=use_quote_delim,
+ na_value=na_value, name=name)
+ # pylint: enable=protected-access
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index bf8380ebbd..0a1a748c40 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -562,7 +562,7 @@ static bool TensorOpMathEnabled() {
bool ret;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DISABLE_TENSOR_OP_MATH",
/*default=*/false, &ret));
- return ret;
+ return !ret;
}();
return is_enabled;
}
@@ -2474,58 +2474,73 @@ struct WinogradNonfused {
};
bool CudnnSupport::GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
- out_algorithms->assign({
- // clang-format off
- CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
- CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
- CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
- CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
- CUDNN_CONVOLUTION_FWD_ALGO_FFT,
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index> algo_types = {
+ // clang-format off
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
+ CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
+ CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT,
#if CUDNN_VERSION >= 5000
- CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
#endif
- // clang-format on
- });
+ // clang-format on
+ };
if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
- out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
+ algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
}
#if CUDNN_VERSION >= 5100
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
- out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
+ algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
}
#endif
+
+ out_algorithms->clear();
+ for (auto i : algo_types) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/false});
+ if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/true});
+ }
+ }
return true;
}
bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
- out_algorithms->assign({
- // clang-format off
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index> algo_types = {
+ // clang-format off
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
#if CUDNN_VERSION >= 5000
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
#endif
- // clang-format on
- });
+ // clang-format on
+ };
#if CUDNN_VERSION >= 5100
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
- out_algorithms->push_back(
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
+ algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
}
#endif
+
+ out_algorithms->clear();
+ for (auto i : algo_types) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/false});
+ if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/true});
+ }
+ }
return true;
}
bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
- out_algorithms->assign({
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index> algo_types = {
// clang-format off
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
@@ -2534,13 +2549,20 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
// Based on cudnn.h, the following is not implemented.
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
// clang-format on
- });
+ };
#if CUDNN_VERSION >= 5110
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
- out_algorithms->push_back(
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
+ algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
}
#endif
+
+ out_algorithms->clear();
+ for (auto i : algo_types) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/false});
+ if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/true});
+ }
+ }
return true;
}
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index beb2f7d050..8d7069a902 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -145,16 +145,16 @@ class CudnnSupport : public dnn::DnnSupport {
ScratchAllocator* workspace_allocator) override;
bool GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) override;
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
bool GetConvolveBackwardDataAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) override;
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
bool GetConvolveBackwardFilterAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) override;
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
bool DoBatchNormalizationForward(
Stream* stream, const DeviceMemory<float>& x,
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 2c40e18f5c..07fe8a85f4 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -23,20 +23,20 @@ namespace gputools {
namespace dnn {
bool DnnSupport::GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms) {
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms) {
return false;
}
bool DnnSupport::GetConvolveBackwardDataAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms) {
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms) {
return false;
}
bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms) {
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms) {
return false;
}
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 5fe523602a..624357b82f 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -1183,8 +1183,8 @@ class DnnSupport {
// Return a list of algorithms supported by the forward convolution pass.
virtual bool GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms);
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms);
// Version of DoConvolve that uses pre-quantized 8 bit coefficients.
// coefficient_scales specifies the scaling of each column of coefficients:
@@ -1263,8 +1263,8 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for
// data.
virtual bool GetConvolveBackwardDataAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms);
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms);
virtual bool DoConvolveBackwardData(
Stream* stream, const FilterDescriptor& filter_descriptor,
@@ -1312,8 +1312,8 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for
// filters.
virtual bool GetConvolveBackwardFilterAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms);
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms);
virtual bool DoConvolveBackwardFilter(
Stream* stream, const BatchDescriptor& input_descriptor,
diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h
index ed12982e30..f0a0e60e02 100644
--- a/tensorflow/stream_executor/platform.h
+++ b/tensorflow/stream_executor/platform.h
@@ -96,7 +96,7 @@ class Platform {
// each platform is required to expose an ID to ensure unique registration and
// as a target against which plugins can register.
//
- // The macro below is provided to help generate a [process-unique] identifer.
+ // The macro below is provided to help generate a [process-unique] identifier.
using Id = void*;
// Helper macro to define a plugin ID. To be used only inside plugin
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index a72ee804c1..21172d5a16 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -70,7 +70,7 @@ class BatchDescriptor;
class FilterDescriptor;
class ConvolutionDescriptor;
class ProfileResult;
-struct AlgorithmDesc;
+class AlgorithmDesc;
} // namespace dnn
class StreamExecutor;
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 199a908914..9bbfe7f04a 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -286,35 +286,41 @@ bool StreamExecutor::SupportsDnn() const {
bool StreamExecutor::GetConvolveAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms) {
+ std::vector<dnn::AlgorithmDesc> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused,
- out_algorithms);
+ int cc_major, cc_minor;
+ GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
+ return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
+ cc_minor, out_algorithms);
}
bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms) {
+ std::vector<dnn::AlgorithmDesc> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveBackwardDataAlgorithms(with_winograd_nonfused,
- out_algorithms);
+ int cc_major, cc_minor;
+ GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
+ return dnn_support->GetConvolveBackwardDataAlgorithms(
+ with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
}
bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms) {
+ std::vector<dnn::AlgorithmDesc> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
+ int cc_major, cc_minor;
+ GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
return dnn_support->GetConvolveBackwardFilterAlgorithms(
- with_winograd_nonfused, out_algorithms);
+ with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
}
bool StreamExecutor::GetBlasGemmAlgorithms(
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 98136a92a0..f354317a6e 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -343,20 +343,19 @@ class StreamExecutor {
bool SupportsDnn() const;
// Get the list of supported algorithms for the forward convolution opeartion.
- bool GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms);
+ bool GetConvolveAlgorithms(bool with_winograd_nonfused,
+ std::vector<dnn::AlgorithmDesc> *out_algorithms);
// Get the list of supported algorithms for the backward convolution on data.
bool GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms);
+ std::vector<dnn::AlgorithmDesc> *out_algorithms);
// Get the list of supported algorithms for the backward convolution on the
// filter.
bool GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms);
+ std::vector<dnn::AlgorithmDesc> *out_algorithms);
// Get the list of supported algorithms for BLAS gemm.
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index a308688790..0f074151db 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -526,6 +526,7 @@ def tf_cc_test(name,
extra_copts=[],
suffix="",
linkopts=[],
+ nocopts=None,
**kwargs):
native.cc_test(
name="%s%s" % (name, suffix),
@@ -547,6 +548,7 @@ def tf_cc_test(name,
clean_dep("//tensorflow:darwin"): 1,
"//conditions:default": 0,
}),
+ nocopts=nocopts,
**kwargs)
@@ -649,7 +651,8 @@ def tf_cc_tests(srcs,
tags=[],
size="medium",
args=None,
- linkopts=[]):
+ linkopts=[],
+ nocopts=None):
for src in srcs:
tf_cc_test(
name=src_to_test_name(src),
@@ -659,7 +662,8 @@ def tf_cc_tests(srcs,
tags=tags,
size=size,
args=args,
- linkopts=linkopts)
+ linkopts=linkopts,
+ nocopts=nocopts)
def tf_cc_test_mkl(srcs,
@@ -669,7 +673,7 @@ def tf_cc_test_mkl(srcs,
tags=[],
size="medium",
args=None):
- if_mkl(tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args))
+ if_mkl(tf_cc_tests(srcs, deps, name, linkstatic=linkstatic, tags=tags, size=size, args=args, nocopts="-fno-exceptions"))
def tf_cc_tests_gpu(srcs,
@@ -867,18 +871,33 @@ def tf_mkl_kernel_library(name,
deps=None,
alwayslink=1,
copts=tf_copts(),
+ nocopts="-fno-exceptions",
**kwargs):
+ """A rule to build MKL-based TensorFlow kernel libraries."""
+ gpu_srcs = gpu_srcs # unused argument
+ kwargs = kwargs # unused argument
+
+ if not bool(srcs):
+ srcs = []
+ if not bool(hdrs):
+ hdrs = []
+
+ if prefix:
+ srcs = srcs + native.glob(
+ [prefix + "*.cc"])
+ hdrs = hdrs + native.glob(
+ [prefix + "*.h"])
+
if_mkl(
- tf_kernel_library(
- name,
- prefix=prefix,
+ native.cc_library(
+ name=name,
srcs=srcs,
- gpu_srcs=gpu_srcs,
hdrs=hdrs,
deps=deps,
alwayslink=alwayslink,
copts=copts,
- **kwargs))
+ nocopts=nocopts
+ ))
# Bazel rules for building swig files.
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 32a86e420a..6e03f9e8fb 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -874,7 +874,7 @@ tf_module {
}
member_method {
name: "decode_csv"
- argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\'], "
+ argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\', \'na_value\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\', \'\'], "
}
member_method {
name: "decode_json_example"
diff --git a/tensorflow/tools/ci_build/install/install_golang.sh b/tensorflow/tools/ci_build/install/install_golang.sh
index 88bc2960e3..596265b069 100755
--- a/tensorflow/tools/ci_build/install/install_golang.sh
+++ b/tensorflow/tools/ci_build/install/install_golang.sh
@@ -16,7 +16,7 @@
set -ex
-GOLANG_URL="https://storage.googleapis.com/golang/go1.8.3.linux-amd64.tar.gz"
+GOLANG_URL="https://storage.googleapis.com/golang/go1.9.linux-amd64.tar.gz"
sudo mkdir -p /usr/local
wget -q -O - "${GOLANG_URL}" | sudo tar -C /usr/local -xz
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index f5364d803a..04773376e9 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -78,10 +78,12 @@ WORKDIR /tensorflow
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
-ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH /usr/local/cuda/lib64/stubs:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
ENV TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
+
RUN tensorflow/tools/ci_build/builds/configured GPU \
bazel build -c opt --config=cuda --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
tensorflow/tools/pip_package:build_pip_package && \
diff --git a/tensorflow/tools/docker/jupyter_notebook_config.py b/tensorflow/tools/docker/jupyter_notebook_config.py
index 747beb8251..0acbf6fcee 100644
--- a/tensorflow/tools/docker/jupyter_notebook_config.py
+++ b/tensorflow/tools/docker/jupyter_notebook_config.py
@@ -18,7 +18,6 @@ from IPython.lib import passwd
c.NotebookApp.ip = '*'
c.NotebookApp.port = int(os.getenv('PORT', 8888))
c.NotebookApp.open_browser = False
-c.MultiKernelManager.default_kernel_name = 'python2'
# sets a password if PASSWORD is set in the environment
if 'PASSWORD' in os.environ:
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index ca3b778c29..1015103077 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -923,7 +923,7 @@ class _ClassPageInfo(object):
"""Sets the `aliases` list.
Args:
- aliases: A list of strings. Containing all the obejct's full names.
+ aliases: A list of strings. Containing all the object's full names.
"""
assert self.aliases is None
self._aliases = aliases
@@ -1438,7 +1438,7 @@ class _PythonBuiltin(object):
class _PythonFile(object):
"""This class indicates that the object is defined in a regular python file.
- This can be used for the `defined_in` slot of the `PageInfo` obejcts.
+ This can be used for the `defined_in` slot of the `PageInfo` objects.
"""
def __init__(self, path, parser_config):
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
index 81f85e0009..6f0b4f47de 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
@@ -93,13 +93,15 @@ TEST(CreateProtoDebugStringLibTest, ValidSimpleTypes) {
proto.set_optional_int64(std::numeric_limits<protobuf_int64>::max());
proto.set_optional_uint32(std::numeric_limits<uint32>::max());
proto.set_optional_uint64(std::numeric_limits<uint64>::max());
- proto.set_optional_float(std::numeric_limits<float>::max());
+ // TODO(b/67475677): Re-enable after resolving float precision issue
+ // proto.set_optional_float(std::numeric_limits<float>::max());
proto.set_optional_double(std::numeric_limits<double>::max());
EXPECT_TEXT_TRANSFORMS_MATCH();
// Least positive numeric values.
proto.Clear();
- proto.set_optional_float(std::numeric_limits<float>::min());
+ // TODO(b/67475677): Re-enable after resolving float precision issue
+ // proto.set_optional_float(std::numeric_limits<float>::min());
proto.set_optional_double(std::numeric_limits<double>::min());
EXPECT_TEXT_TRANSFORMS_MATCH();
@@ -107,7 +109,8 @@ TEST(CreateProtoDebugStringLibTest, ValidSimpleTypes) {
proto.Clear();
proto.set_optional_int32(std::numeric_limits<int32>::lowest());
proto.set_optional_int64(std::numeric_limits<protobuf_int64>::lowest());
- proto.set_optional_float(std::numeric_limits<float>::lowest());
+ // TODO(b/67475677): Re-enable after resolving float precision issue
+ // proto.set_optional_float(std::numeric_limits<float>::lowest());
proto.set_optional_double(std::numeric_limits<double>::lowest());
EXPECT_TEXT_TRANSFORMS_MATCH();
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index b226184261..de0084613b 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -171,6 +171,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
"and will be removed in the future.")
native.new_http_archive(
+ name = "mkl_dnn",
+ urls = [
+ "https://github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz",
+ "http://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz",
+ ],
+ sha256 = "0d529ad4c49dc799e6df07c2b88b115d0668735da15fb3b3862d28d33fa68165",
+ strip_prefix = "mkl-dnn-b01e3a55a07be62172e713bcd2644c5176360212",
+ build_file = str(Label("//third_party/mkl_dnn:mkldnn.BUILD")),
+ )
+
+ native.new_http_archive(
name = "eigen_archive",
urls = [
"https://bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz",
@@ -373,10 +384,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
patched_http_archive(
name = "protobuf_archive",
urls = [
- "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz",
+ "http://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz",
],
- sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93",
- strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66",
+ sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a",
+ strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9",
# TODO: remove patching when tensorflow stops linking same protos into
# multiple shared libraries loaded in runtime by python.
# This patch fixes a runtime crash when tensorflow is compiled
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index baa6e01bca..31a4bfabf6 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -117,7 +117,7 @@ def get_cxx_inc_directories(repository_ctx, cc):
includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
- includes_cpp_set = set(includes_cpp)
+ includes_cpp_set = depset(includes_cpp)
return includes_cpp + [inc for inc in includes_c
if inc not in includes_cpp_set]
diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD
new file mode 100644
index 0000000000..5b01f6e3e4
--- /dev/null
+++ b/third_party/mkl_dnn/BUILD
@@ -0,0 +1 @@
+licenses(["notice"])
diff --git a/third_party/mkl_dnn/mkldnn.BUILD b/third_party/mkl_dnn/mkldnn.BUILD
new file mode 100644
index 0000000000..58bb7a6a5d
--- /dev/null
+++ b/third_party/mkl_dnn/mkldnn.BUILD
@@ -0,0 +1,25 @@
+exports_files(["LICENSE"])
+
+cc_library(
+ name = "mkl_dnn",
+ srcs = glob([
+ "src/common/*.cpp",
+ "src/cpu/*.cpp",
+ ]),
+ hdrs = glob(["include/*"]),
+ copts = ["-fexceptions"] + select({
+ "@org_tensorflow//tensorflow:linux_x86_64": [
+ "-fopenmp",
+ ],
+ "//conditions:default": [],
+ }),
+ includes = [
+ "include",
+ "src",
+ "src/common",
+ "src/cpu",
+ "src/cpu/xbyak",
+ ],
+ nocopts = "-fno-exceptions",
+ visibility = ["//visibility:public"],
+)