aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md6
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc51
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h4
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java23
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py40
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py2
-rw-r--r--tensorflow/contrib/deprecated/__init__.py2
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc10
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py6
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc57
-rw-r--r--tensorflow/contrib/memory_stats/__init__.py2
-rw-r--r--tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc22
-rw-r--r--tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc4
-rw-r--r--tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py22
-rw-r--r--tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py5
-rw-r--r--tensorflow/contrib/resampler/kernels/resampler_ops.cc2
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py10
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/helper.py2
-rw-r--r--tensorflow/contrib/signal/BUILD1
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py5
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py45
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD48
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py7
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py375
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py267
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model_utils.py319
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py236
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py3
-rw-r--r--tensorflow/core/BUILD22
-rw-r--r--tensorflow/core/graph/mkl_graph_util.h128
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc2
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc2
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc2
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc2
-rw-r--r--tensorflow/core/kernels/BUILD34
-rw-r--r--tensorflow/core/kernels/bias_op.cc159
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc55
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc53
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc109
-rw-r--r--tensorflow/core/kernels/conv_ops.cc51
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc51
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc19
-rw-r--r--tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc45
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc181
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc190
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc213
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h308
-rw-r--r--tensorflow/core/kernels/mkl_cwise_ops_common.cc2
-rw-r--r--tensorflow/core/lib/strings/numbers.cc2
-rw-r--r--tensorflow/core/ops/dataset_ops.cc3
-rw-r--r--tensorflow/core/ops/nn_ops.cc84
-rw-r--r--tensorflow/core/ops/nn_ops_test.cc49
-rw-r--r--tensorflow/core/ops/parsing_ops.cc2
-rw-r--r--tensorflow/core/util/mkl_util.h401
-rw-r--r--tensorflow/docs_src/install/install_sources.md38
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java8
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py2
-rw-r--r--tensorflow/go/example_inception_inference_test.go2
-rw-r--r--tensorflow/go/tensor.go48
-rw-r--r--tensorflow/go/tensor_test.go10
-rw-r--r--tensorflow/java/src/gen/perl/tftypes-runall.pl2
-rw-r--r--tensorflow/java/src/gen/perl/tftypes.pl102
-rw-r--r--tensorflow/java/src/gen/resources/Tensors.java.tmpl31
-rw-r--r--tensorflow/java/src/gen/resources/tftypes.csv42
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/DataType.java39
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java7
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Input.java4
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java9
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Operand.java12
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Operation.java18
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java14
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Output.java12
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java5
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Session.java34
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java241
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensors.java447
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java79
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Operands.java8
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java34
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java21
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/package-info.java16
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java1
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java25
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/OperationTest.java19
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SessionTest.java41
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TensorTest.java99
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java24
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java7
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java128
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java22
-rw-r--r--tensorflow/python/debug/lib/debug_graphs.py4
-rw-r--r--tensorflow/python/estimator/inputs/queues/feeding_functions.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology_test.py2
-rw-r--r--tensorflow/python/kernel_tests/conv2d_transpose_test.py14
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py11
-rw-r--r--tensorflow/python/kernel_tests/summary_tensor_op_test.py2
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/parsing_ops.py39
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc90
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h12
-rw-r--r--tensorflow/stream_executor/dnn.cc12
-rw-r--r--tensorflow/stream_executor/dnn.h12
-rw-r--r--tensorflow/stream_executor/platform.h2
-rw-r--r--tensorflow/stream_executor/stream.h2
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc22
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h9
-rw-r--r--tensorflow/tensorflow.bzl35
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_golang.sh2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu4
-rw-r--r--tensorflow/tools/docker/jupyter_notebook_config.py1
-rw-r--r--tensorflow/tools/docs/parser.py4
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc9
-rw-r--r--tensorflow/workspace.bzl17
-rw-r--r--third_party/gpus/cuda_configure.bzl2
-rw-r--r--third_party/mkl_dnn/BUILD1
-rw-r--r--third_party/mkl_dnn/mkldnn.BUILD25
122 files changed, 4102 insertions, 1655 deletions
diff --git a/README.md b/README.md
index 4cc53096e0..6339c57c95 100644
--- a/README.md
+++ b/README.md
@@ -48,9 +48,9 @@ GPU packages on all platforms will arrive soon!
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/))
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Windows CPU-only: [Python 3.5 64-bit](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
-* Windows GPU: Coming soon!
-* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
+* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/))
+* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/))
+* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
#### *Try your first TensorFlow program*
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 1b5dd558dd..27c5da08c1 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -52,6 +52,11 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
bool retry_on_failure) override;
Status Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) override;
+ // Register an Tensor (input or resource variable) with the allocator. If
+ // the operation returns an alias to one of its inputs, then the allocator
+ // needs to be able to handle it.
+ Status RegisterArgument(const Tensor* t);
+
// Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
// interpreted as having data type 'dtype' and shape 'shape'.
Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype,
@@ -103,6 +108,14 @@ xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
return gpu::DeviceMemoryBase(data, size);
}
+Status XlaAllocator::RegisterArgument(const Tensor* t) {
+ void* data =
+ reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data()));
+ TF_RET_CHECK(data != nullptr);
+ tensors_[data] = *t;
+ return Status::OK();
+}
+
Status XlaAllocator::Deallocate(int device_ordinal,
gpu::DeviceMemoryBase* mem) {
if (mem->opaque() != nullptr) {
@@ -284,6 +297,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
shape, client->platform(), client->default_device_ordinal(), dmem)
.ConsumeValueOrDie();
arg_ptrs[i] = arg_buffers[i].get();
+
+ OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t));
}
// Make the final parameter point at local_runtime_context.
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 89145a9038..7dd242425c 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -256,9 +256,9 @@ tensorflow::Status ConvolutionThunk::Convolve(
algorithm_config.algorithm_no_scratch().algo_id());
}
-std::vector<AlgorithmDesc::Index> ConvolutionThunk::GetAlgorithms(
+std::vector<AlgorithmDesc> ConvolutionThunk::GetAlgorithms(
se::StreamExecutor* stream_exec) const {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
// TODO(yangzihao): Currently disable the use of winograd nonfused in XLA
// by default. Should send in conv parameters and enable it when
// ShouldIncludeWinogradNonfusedAlgo() returns true.
@@ -297,32 +297,27 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
se::dnn::ProfileResult best_result;
se::dnn::ProfileResult best_result_without_scratch;
- std::vector<AlgorithmDesc::Index> algorithms =
- GetAlgorithms(stream->parent());
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- AlgorithmDesc algorithm(algo_index, use_tensor_ops);
- ConvolveScratchAllocator scratch_allocator(
- buffer_allocations.device_ordinal(),
- buffer_allocations.memory_allocator());
- se::dnn::ProfileResult profile_result;
- bool launch_ok =
- Convolve(input_descriptor, input_data, filter_descriptor,
- filter_data, output_descriptor, output_data,
- convolution_descriptor,
- se::dnn::AlgorithmConfig(algorithm, algorithm), stream,
- &scratch_allocator, &profile_result)
- .ok();
- if (launch_ok && profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalAllocatedBytes() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_without_scratch.elapsed_time_in_ms()) {
- best_result_without_scratch = profile_result;
- }
+ std::vector<AlgorithmDesc> algorithms = GetAlgorithms(stream->parent());
+ for (auto algorithm : algorithms) {
+ ConvolveScratchAllocator scratch_allocator(
+ buffer_allocations.device_ordinal(),
+ buffer_allocations.memory_allocator());
+ se::dnn::ProfileResult profile_result;
+ bool launch_ok =
+ Convolve(input_descriptor, input_data, filter_descriptor, filter_data,
+ output_descriptor, output_data, convolution_descriptor,
+ se::dnn::AlgorithmConfig(algorithm, algorithm), stream,
+ &scratch_allocator, &profile_result)
+ .ok();
+ if (launch_ok && profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalAllocatedBytes() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_without_scratch.elapsed_time_in_ms()) {
+ best_result_without_scratch = profile_result;
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 509719c1fe..13432301b2 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -115,9 +115,7 @@ class ConvolutionThunk : public Thunk {
perftools::gputools::dnn::ProfileResult* profile_result);
// Returns the convolve algorithms that can be used for this ConvolutionThunk.
- // TODO(nluehr) GetAlgorithms should return AlgorithmDesc including both
- // tensor-op and non-tensor-op variants.
- std::vector<perftools::gputools::dnn::AlgorithmDesc::Index> GetAlgorithms(
+ std::vector<perftools::gputools::dnn::AlgorithmDesc> GetAlgorithms(
perftools::gputools::StreamExecutor* stream_exec) const;
// Fastest cuDNN convolution algorithm for this thunk learned from
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 395dd6c5d2..80e03f2036 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -31,12 +31,13 @@ import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
-import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
+import org.tensorflow.Tensors;
+import org.tensorflow.types.UInt8;
/**
* Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
@@ -328,7 +329,7 @@ public class TensorFlowInferenceInterface {
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, byte[] src, long... dims) {
- addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src)));
+ addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
}
/**
@@ -337,7 +338,7 @@ public class TensorFlowInferenceInterface {
* a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[] src) {
- addFeed(inputName, Tensor.create(src));
+ addFeed(inputName, Tensors.create(src));
}
/**
@@ -346,7 +347,7 @@ public class TensorFlowInferenceInterface {
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[][] src) {
- addFeed(inputName, Tensor.create(src));
+ addFeed(inputName, Tensors.create(src));
}
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
@@ -403,7 +404,7 @@ public class TensorFlowInferenceInterface {
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, ByteBuffer src, long... dims) {
- addFeed(inputName, Tensor.create(DataType.UINT8, dims, src));
+ addFeed(inputName, Tensor.create(UInt8.class, dims, src));
}
/**
@@ -544,7 +545,7 @@ public class TensorFlowInferenceInterface {
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}
- private void addFeed(String inputName, Tensor t) {
+ private void addFeed(String inputName, Tensor<?> t) {
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
TensorId tid = TensorId.parse(inputName);
runner.feed(tid.name, tid.outputIndex, t);
@@ -578,7 +579,7 @@ public class TensorFlowInferenceInterface {
}
}
- private Tensor getTensor(String outputName) {
+ private Tensor<?> getTensor(String outputName) {
int i = 0;
for (String n : fetchNames) {
if (n.equals(outputName)) {
@@ -591,7 +592,7 @@ public class TensorFlowInferenceInterface {
}
private void closeFeeds() {
- for (Tensor t : feedTensors) {
+ for (Tensor<?> t : feedTensors) {
t.close();
}
feedTensors.clear();
@@ -599,7 +600,7 @@ public class TensorFlowInferenceInterface {
}
private void closeFetches() {
- for (Tensor t : fetchTensors) {
+ for (Tensor<?> t : fetchTensors) {
t.close();
}
fetchTensors.clear();
@@ -614,9 +615,9 @@ public class TensorFlowInferenceInterface {
// State reset on every call to run.
private Session.Runner runner;
private List<String> feedNames = new ArrayList<String>();
- private List<Tensor> feedTensors = new ArrayList<Tensor>();
+ private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
private List<String> fetchNames = new ArrayList<String>();
- private List<Tensor> fetchTensors = new ArrayList<Tensor>();
+ private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
// Mutable state.
private RunStats runStats;
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
index dad3b4e10d..c329c6d4f7 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
@@ -36,7 +36,7 @@ class WeightedQuantilesSummary {
struct SummaryEntry {
SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
const WeightType& max) {
- // Explicitely initialize all of memory (including padding from memory
+ // Explicitly initialize all of memory (including padding from memory
// alignment) to allow the struct to be msan-resistant "plain old data".
//
// POD = http://en.cppreference.com/w/cpp/concept/PODType
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 813c64d141..91f100e0f0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -253,6 +253,46 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testDenseToSparseBatchDatasetWithUnknownShape(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ iterator = (dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x, x], x)).dense_to_sparse_batch(
+ 4, [5, -1]).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual(
+ [[i, j, z] for i, c in enumerate(components[start:start+4])
+ for j in range(c) for z in range(c)], results.indices)
+ self.assertAllEqual(
+ [c for c in components[start:start+4]
+ for _ in range(c) for _ in range(c)],
+ results.values)
+ self.assertAllEqual(
+ [min(4, len(components) - start),
+ 5,
+ np.max(components[start:start+4])],
+ results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithInvalidShape(self):
+ input_tensor = array_ops.constant([[1]])
+ iterator = (dataset_ops.Dataset.from_tensors(input_tensor)
+ .dense_to_sparse_batch(4, [-2]).make_initializable_iterator())
+ init_op = iterator.initializer
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Dimension -2 must be >= -1"):
+ sess.run(init_op)
+
def testDenseToSparseBatchDatasetShapeErrors(self):
input_tensor = array_ops.placeholder(dtypes.int32)
iterator = (dataset_ops.Dataset.from_tensors(input_tensor).apply(
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index ff89c47a2e..b74dcd3be2 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -653,7 +653,7 @@ class Dataset(dataset_ops.Dataset):
```python
# Preprocess 4 files concurrently, and interleave blocks of 16 records from
# each file.
- filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ..."]
+ filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
dataset = (Dataset.from_tensor_slices(filenames)
.interleave(lambda x:
TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
diff --git a/tensorflow/contrib/deprecated/__init__.py b/tensorflow/contrib/deprecated/__init__.py
index bfea8445a7..7aff045de3 100644
--- a/tensorflow/contrib/deprecated/__init__.py
+++ b/tensorflow/contrib/deprecated/__init__.py
@@ -91,7 +91,7 @@ from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long
+# pylint: disable=unused-import
from tensorflow.python.ops.logging_ops import audio_summary
from tensorflow.python.ops.logging_ops import histogram_summary
from tensorflow.python.ops.logging_ops import image_summary
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index 888f5c38a2..b417a70b6e 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -208,7 +208,15 @@ string GetTempFilename(const string& extension) {
}
struct stat statbuf;
if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) {
- return io::JoinPath(dir, StrCat("tmp_file_", getpid(), ".", extension));
+ string tmp_filepath =
+ io::JoinPath(dir, StrCat("tmp_file_XXXXXX", ".", extension));
+ int fd = mkstemps(&tmp_filepath[0], extension.length() + 1);
+ if (fd < 0) {
+ LOG(FATAL) << "Failed to create temp file.";
+ } else {
+ close(fd);
+ return tmp_filepath;
+ }
}
}
LOG(FATAL) << "No temp directory found.";
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index e595e4d90b..92a2a4ff2d 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -78,9 +78,9 @@ def reduce_sum_n(tensors, name=None):
return math_ops.add_n(tensors, name=name_scope)
@deprecated(None,
- "Please switch to tf.confusion_matrix.remove_squeezable_dimensions. Note "
- "that order of the inputs and ouputs of labels and predictions have also "
- "been switched.")
+ 'Please switch to tf.confusion_matrix.remove_squeezable_dimensions.'
+ 'Note that order of the inputs and outputs of labels and '
+ 'predictions have also been switched.')
def remove_squeezable_dimensions(predictions, labels, name=None):
"""Squeeze last dim if ranks of `predictions` and `labels` differ by 1.
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 9275d5a22b..256f200868 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -493,42 +493,37 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
dnn::AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
fused_conv_parameters, &algorithm_config)) {
- std::vector<dnn::AlgorithmDesc::Index> algorithms;
+ std::vector<dnn::AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(
fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(),
&algorithms));
dnn::ProfileResult best_result;
dnn::ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- dnn::AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- dnn::ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenFusedConvolveWithAlgorithm(
- conv_input_desc, conv_input_ptr, conv_input_scale,
- filter_desc, filter_ptr, conv_desc, side_input_ptr,
- side_input_scale, bias_desc, bias_ptr,
- dnn::ActivationMode::kRelu, output_desc, &output_ptr,
- &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ dnn::ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenFusedConvolveWithAlgorithm(
+ conv_input_desc, conv_input_ptr, conv_input_scale,
+ filter_desc, filter_ptr, conv_desc, side_input_ptr,
+ side_input_scale, bias_desc, bias_ptr,
+ dnn::ActivationMode::kRelu, output_desc, &output_ptr,
+ &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
+ &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/contrib/memory_stats/__init__.py b/tensorflow/contrib/memory_stats/__init__.py
index a2b2b65692..a32302c854 100644
--- a/tensorflow/contrib/memory_stats/__init__.py
+++ b/tensorflow/contrib/memory_stats/__init__.py
@@ -14,10 +14,12 @@
# ==============================================================================
"""Ops for memory statistics.
+@@BytesInUse
@@BytesLimit
@@MaxBytesInUse
"""
+from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import BytesInUse
from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import BytesLimit
from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import MaxBytesInUse
diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc
index 3b88535dce..7e2e96e160 100644
--- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc
+++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc
@@ -40,6 +40,28 @@ class MemoryStatsOp : public OpKernel {
const AllocatorStats& allocator_stats) const = 0;
};
+// Op that measures current memory in bytes.
+class BytesInUseOp : public MemoryStatsOp {
+ public:
+ explicit BytesInUseOp(OpKernelConstruction* context)
+ : MemoryStatsOp(context) {}
+
+ private:
+ int64 ExtractAllocatorStats(
+ const AllocatorStats& allocator_stats) const override {
+ return allocator_stats.bytes_in_use;
+ }
+};
+
+// Register this op on GPU only, see comment for MaxBytesInUse for reason
+REGISTER_KERNEL_BUILDER(Name("BytesInUse").Device(DEVICE_GPU).HostMemory("out"),
+ BytesInUseOp);
+
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("BytesInUse").Device(DEVICE_SYCL).HostMemory("out"), MaxBytesInUseOp);
+#endif // TENSORFLOW_USE_SYCL
+
// Op that measures the total memory (in bytes) of a device.
class BytesLimitOp : public MemoryStatsOp {
public:
diff --git a/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc
index 08859c8613..42020cf7f6 100644
--- a/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc
+++ b/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc
@@ -17,6 +17,10 @@ limitations under the License.
namespace tensorflow {
+REGISTER_OP("BytesInUse")
+ .Output("out: int64")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("BytesLimit")
.Output("out: int64")
.SetIsStateful()
diff --git a/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py b/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
index ec25c032f0..d1b430b803 100644
--- a/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
+++ b/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.memory_stats.python.ops import memory_stats_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
@@ -64,10 +65,29 @@ class MemoryStatsOpsTest(test_util.TensorFlowTestCase):
d = math_ops.matmul(c, b)
sess.run(d)
- max_bytes_in_use = sess.run(memory_stats_ops.MaxBytesInUse())
+ max_bytes_in_use_op = memory_stats_ops.MaxBytesInUse()
+ max_bytes_in_use = sess.run(max_bytes_in_use_op)
self.assertGreaterEqual(max_bytes_in_use, matrix_size_in_bytes * 3)
self.assertLess(max_bytes_in_use, matrix_size_in_bytes * 4)
+ # run chain with 2 ops, make sure BytesInUse captures intermediate
+ # memory usage
+ a = random_ops.random_uniform(matrix_shape, dtype=dtype)
+ with ops.control_dependencies([a]):
+ bytes_in_use_op = memory_stats_ops.BytesInUse()
+ with ops.control_dependencies([bytes_in_use_op]):
+ b = random_ops.random_uniform(matrix_shape, dtype=dtype)
+
+ _, bytes_in_use, max_bytes_in_use = sess.run([a, bytes_in_use_op,
+ max_bytes_in_use_op])
+
+ # intermediate result allocates 1 matrix, max usage is at least 2
+ self.assertGreaterEqual(bytes_in_use, matrix_size_in_bytes * 1)
+ self.assertLess(bytes_in_use, matrix_size_in_bytes * 2)
+
+ # max usage is still 3 because it reflects maxium from previous .run call
+ self.assertGreaterEqual(max_bytes_in_use, matrix_size_in_bytes * 3)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py b/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py
index d35c6583ed..c0f7788c1c 100644
--- a/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py
+++ b/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py
@@ -26,6 +26,11 @@ _memory_stats_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_memory_stats_ops.so"))
+def BytesInUse():
+ """Generates an op that computes the current memory of a device."""
+ return gen_memory_stats_ops.bytes_in_use()
+
+
def BytesLimit():
"""Generates an op that measures the total memory (in bytes) of a device."""
return gen_memory_stats_ops.bytes_limit()
diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.cc b/tensorflow/contrib/resampler/kernels/resampler_ops.cc
index afc8bcd446..7d9ef14cef 100644
--- a/tensorflow/contrib/resampler/kernels/resampler_ops.cc
+++ b/tensorflow/contrib/resampler/kernels/resampler_ops.cc
@@ -122,7 +122,7 @@ struct Resampler2DFunctor<CPUDevice, T>{
};
// Rough estimate of work for each batch entry.
// From third_party/tensorflow/core/util/work_sharder.cc we gather that an
- // estimate of the cost of each work unit is needed to correclty shard the
+ // estimate of the cost of each work unit is needed to correctly shard the
// workload. Shard assumes each cost unit is 1ns, minimum cost per shard
// being 10us.
const int64 cost = static_cast<int64>(num_sampling_points) *
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 1b0327d62b..6702a89d22 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -525,7 +525,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell):
self._state_tuple_type = collections.namedtuple(
"GridLSTMStateTuple", state_names.strip(","))
self._state_size = self._state_tuple_type(
- *([num_units, num_units] * self._total_blocks))
+ *([num_units, num_units] * self._total_blocks))
else:
self._state_tuple_type = None
self._state_size = num_units * self._total_blocks * 2
@@ -2082,9 +2082,11 @@ def _conv(args,
shape_length = len(shapes[0])
for shape in shapes:
if len(shape) not in [3,4,5]:
- raise ValueError("Conv Linear expects 3D, 4D or 5D arguments: %s" % str(shapes))
+ raise ValueError("Conv Linear expects 3D, 4D "
+ "or 5D arguments: %s" % str(shapes))
if len(shape) != len(shapes[0]):
- raise ValueError("Conv Linear expects all args to be of same Dimensiton: %s" % str(shapes))
+ raise ValueError("Conv Linear expects all args "
+ "to be of same Dimension: %s" % str(shapes))
else:
total_arg_size_depth += shape[-1]
dtype = [a.dtype for a in args][0]
@@ -2102,7 +2104,7 @@ def _conv(args,
# Now the computation.
kernel = vs.get_variable(
- "kernel",
+ "kernel",
filter_size + [total_arg_size_depth, num_features],
dtype=dtype)
if len(args) == 1:
diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py
index 64e00c21c7..b55d90cbab 100644
--- a/tensorflow/contrib/seq2seq/python/ops/helper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/helper.py
@@ -309,7 +309,7 @@ class ScheduledEmbeddingTrainingHelper(TrainingHelper):
gen_array_ops.fill([self.batch_size], -1))
def next_inputs(self, time, outputs, state, sample_ids, name=None):
- with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
+ with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperNextInputs",
[time, outputs, state, sample_ids]):
(finished, base_next_inputs, state) = (
super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index 43f24474ed..2204b684ac 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -5,6 +5,7 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+load("//tensorflow:tensorflow.bzl", "py_test") # @unused
py_library(
name = "signal_py",
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index f9449095be..094568389c 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -135,7 +135,10 @@ class BoundingBox(ItemHandler):
"""
sides = []
for key in self._full_keys:
- side = array_ops.expand_dims(keys_to_tensors[key].values, 0)
+ side = keys_to_tensors[key]
+ if isinstance(side, sparse_tensor.SparseTensor):
+ side = side.values
+ side = array_ops.expand_dims(side, 0)
sides.append(side)
bounding_box = array_ops.concat(sides, 0)
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index 96606b9c0e..60d1eba07f 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -692,7 +692,7 @@ class TFExampleDecoderTest(test.TestCase):
else:
self.assertAllClose(image, decoded_image, atol=0)
- def testDecodeExampleWithBoundingBox(self):
+ def testDecodeExampleWithBoundingBoxSparse(self):
num_bboxes = 10
np_ymin = np.random.rand(num_bboxes, 1)
np_xmin = np.random.rand(num_bboxes, 1)
@@ -731,6 +731,49 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllClose(np_bboxes, bboxes)
+ def testDecodeExampleWithBoundingBoxDense(self):
+ num_bboxes = 10
+ np_ymin = np.random.rand(num_bboxes, 1)
+ np_xmin = np.random.rand(num_bboxes, 1)
+ np_ymax = np.random.rand(num_bboxes, 1)
+ np_xmax = np.random.rand(num_bboxes, 1)
+ np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
+
+ example = example_pb2.Example(features=feature_pb2.Features(feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
+ serialized_example = example.SerializeToString()
+
+ with self.test_session():
+ serialized_example = array_ops.reshape(serialized_example, shape=[])
+
+ keys_to_features = {
+ 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ }
+
+ items_to_handlers = {
+ 'object/bbox':
+ tfexample_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
+ 'image/object/bbox/'),
+ }
+
+ decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
+ items_to_handlers)
+ [tf_bboxes] = decoder.decode(serialized_example, ['object/bbox'])
+ bboxes = tf_bboxes.eval()
+
+ self.assertAllClose(np_bboxes, bboxes)
+
def testDecodeExampleWithRepeatedImages(self):
image_shape = (2, 3, 3)
image_format = 'png'
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index 2c4bed5db1..da583a2ba0 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -42,6 +42,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":feature_keys",
+ ":head",
":input_pipeline",
":model_utils",
"//tensorflow/python:util",
@@ -78,8 +79,8 @@ py_library(
deps = [
":ar_model",
":feature_keys",
+ ":head",
":math_utils",
- ":model_utils",
":state_management",
"//tensorflow/contrib/timeseries/python/timeseries/state_space_models:filtering_postprocessor",
"//tensorflow/contrib/timeseries/python/timeseries/state_space_models:state_space_model",
@@ -123,9 +124,9 @@ py_test(
)
py_library(
- name = "model_utils",
+ name = "head",
srcs = [
- "model_utils.py",
+ "head.py",
],
srcs_version = "PY2AND3",
deps = [
@@ -149,9 +150,9 @@ py_library(
)
py_test(
- name = "model_utils_test",
+ name = "head_test",
srcs = [
- "model_utils_test.py",
+ "head_test.py",
],
srcs_version = "PY2AND3",
tags = [
@@ -159,8 +160,8 @@ py_test(
],
deps = [
":feature_keys",
+ ":head",
":model",
- ":model_utils",
":state_management",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -175,6 +176,41 @@ py_test(
)
py_library(
+ name = "model_utils",
+ srcs = [
+ "model_utils.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":feature_keys",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:variable_scope",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "model_utils_test",
+ srcs = [
+ "model_utils_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip_gpu", # b/63391119
+ ],
+ deps = [
+ ":model_utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_library(
name = "state_management",
srcs = [
"state_management.py",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 267a5f88da..ff140efd48 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -374,7 +374,7 @@ class ARModel(model.TimeSeriesModel):
original_values = values
# Extra shape checking for the window size (above that in
- # model_utils.make_model_fn).
+ # `head.create_estimator_spec`).
expected_times_shape = [None, self.window_size]
if not times.get_shape().is_compatible_with(expected_times_shape):
raise ValueError(
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 4025a8f014..3738dfa154 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -20,8 +20,8 @@ from __future__ import print_function
from tensorflow.contrib.timeseries.python.timeseries import ar_model
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
+from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
from tensorflow.contrib.timeseries.python.timeseries import math_utils
-from tensorflow.contrib.timeseries.python.timeseries import model_utils
from tensorflow.contrib.timeseries.python.timeseries import state_management
from tensorflow.contrib.timeseries.python.timeseries.state_space_models import state_space_model
from tensorflow.contrib.timeseries.python.timeseries.state_space_models import structural_ensemble
@@ -59,9 +59,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
if optimizer is None:
optimizer = train.AdamOptimizer(0.02)
self._model = model
- model_fn = model_utils.make_model_fn(
+ ts_regression_head = ts_head_lib.time_series_regression_head(
model, state_manager, optimizer,
input_statistics_generator=input_statistics_generator)
+ model_fn = ts_regression_head.create_estimator_spec
super(TimeSeriesRegressor, self).__init__(
model_fn=model_fn,
model_dir=model_dir,
@@ -132,7 +133,7 @@ class TimeSeriesRegressor(estimator_lib.Estimator):
with ops.Graph().as_default():
self._model.initialize_graph()
model_start_state = self._model.get_start_state()
- for prefixed_state_name, state_tensor in model_utils.state_to_dictionary(
+ for prefixed_state_name, state_tensor in ts_head_lib.state_to_dictionary(
model_start_state).items():
state_shape_with_batch = tensor_shape.TensorShape(
(default_batch_size,)).concatenate(state_tensor.get_shape())
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
new file mode 100644
index 0000000000..5896fc2a20
--- /dev/null
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -0,0 +1,375 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Timeseries head."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.contrib.framework.python.ops import variables
+from tensorflow.contrib.layers.python.layers import optimizers
+
+from tensorflow.contrib.timeseries.python.timeseries import feature_keys
+
+from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.export import export_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import nest
+
+
+def time_series_regression_head(model,
+ state_manager,
+ optimizer,
+ input_statistics_generator=None):
+ """Creates a `_Head` for time series regression.
+
+ Args:
+ model: A model for time series regression.
+ state_manager: A state manager.
+ optimizer: An optimizer.
+ input_statistics_generator: A input statistics generator.
+
+ Returns:
+ An instance of `_Head` for time series regression.
+ """
+ return _TimeSeriesRegressionHead(model, state_manager, optimizer,
+ input_statistics_generator)
+
+
+class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access
+ """See `time_series_regression_head`."""
+
+ def __init__(self,
+ model,
+ state_manager,
+ optimizer,
+ input_statistics_generator=None,
+ name=None):
+ self.model = model
+ self.state_manager = state_manager
+ self.optimizer = optimizer
+ self.input_statistics_generator = input_statistics_generator
+ self._name = name
+
+ def _train_ops(self, features):
+ """Add training ops to the graph."""
+ with variable_scope.variable_scope("model"):
+ model_outputs = self.state_manager.define_loss(
+ self.model, features, estimator_lib.ModeKeys.TRAIN)
+
+ train_op = optimizers.optimize_loss(
+ model_outputs.loss,
+ global_step=variables.get_global_step(),
+ optimizer=self.optimizer,
+ # Learning rate is set in the Optimizer object
+ learning_rate=None)
+ return estimator_lib.EstimatorSpec(
+ loss=model_outputs.loss,
+ mode=estimator_lib.ModeKeys.TRAIN,
+ train_op=train_op)
+
+ # TODO(terrytangyuan): suffix summary and metrics keys by `"/" + name`
+ @property
+ def name(self):
+ return self._name
+
+ # TODO(terrytangyuan): unused for now. Need to decouple
+ # `state_manager.define_loss` to satisfy the extendable return signature of
+ # `_Head.create_loss`.
+ def create_loss(self, features, mode, logits, labels):
+ """See `_Head`."""
+ return None
+
+ # TODO(terrytangyuan): check label dimension
+ @property
+ def logits_dimension(self):
+ return None
+
+ def _evaluate_ops(self, features):
+ """Add ops for evaluation (aka filtering) to the graph."""
+ with variable_scope.variable_scope("model"):
+ model_outputs = self.state_manager.define_loss(
+ self.model, features, estimator_lib.ModeKeys.EVAL)
+ metrics = {}
+ # Just output in-sample predictions for the last chunk seen
+ for prediction_key, prediction_value in model_outputs.predictions.items():
+ metrics[prediction_key] = _identity_metric_single(prediction_key,
+ prediction_value)
+ metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single(
+ feature_keys.FilteringResults.TIMES, model_outputs.prediction_times)
+ metrics[feature_keys.FilteringResults.STATE_TUPLE] = (
+ _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,
+ model_outputs.end_state))
+ return estimator_lib.EstimatorSpec(
+ loss=model_outputs.loss,
+ mode=estimator_lib.ModeKeys.EVAL,
+ eval_metric_ops=metrics,
+ predictions={})
+
+ def _predict_ops(self, features):
+ """Add ops for prediction to the graph."""
+ with variable_scope.variable_scope("model"):
+ prediction = self.model.predict(features=features)
+ prediction[feature_keys.PredictionResults.TIMES] = features[
+ feature_keys.PredictionFeatures.TIMES]
+ return estimator_lib.EstimatorSpec(
+ predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT)
+
+ def _serving_ops(self, features):
+ """Add ops for serving to the graph."""
+ with variable_scope.variable_scope("model"):
+ prediction_outputs = self.model.predict(features=features)
+ with variable_scope.variable_scope("model", reuse=True):
+ filtering_outputs = self.state_manager.define_loss(
+ self.model, features, estimator_lib.ModeKeys.EVAL)
+
+ return estimator_lib.EstimatorSpec(
+ mode=estimator_lib.ModeKeys.PREDICT,
+ export_outputs={
+ feature_keys.SavedModelLabels.PREDICT:
+ export_lib.PredictOutput(prediction_outputs),
+ feature_keys.SavedModelLabels.FILTER:
+ export_lib.PredictOutput(
+ state_to_dictionary(filtering_outputs.end_state))
+ },
+ # Likely unused, but it is necessary to return `predictions` to satisfy
+ # the Estimator's error checking.
+ predictions={})
+
+ def _convert_feature_to_tensor(self, name, value):
+ """Casts features to the correct dtype based on their name."""
+ if name in [
+ feature_keys.TrainEvalFeatures.TIMES,
+ feature_keys.PredictionFeatures.TIMES
+ ]:
+ return math_ops.cast(value, dtypes.int64)
+ if name == feature_keys.TrainEvalFeatures.VALUES:
+ return math_ops.cast(value, self.model.dtype)
+ if name == feature_keys.PredictionFeatures.STATE_TUPLE:
+ return value # Correct dtypes are model-dependent
+ return ops.convert_to_tensor(value)
+
+ def _gather_state(self, features):
+ """Returns `features` with state packed, indicates if packing was done."""
+ prefixed_state_re = re.compile(r"^" + feature_keys.State.STATE_PREFIX +
+ r"_(\d+)$")
+ numbered_state = []
+ for key, tensor in features.items():
+ search_result = prefixed_state_re.search(key)
+ if search_result:
+ numbered_state.append((int(search_result.group(1)), key, tensor))
+ if not numbered_state:
+ return features, False
+ features = features.copy()
+ for _, key, _ in numbered_state:
+ del features[key]
+ numbered_state.sort(key=lambda number, *_: number)
+ features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as(
+ structure=self.model.get_start_state(),
+ flat_sequence=[tensor for _, _, tensor in numbered_state])
+ return features, True
+
+ def create_estimator_spec(self, features, mode, labels=None):
+ """Performs basic error checking and returns an EstimatorSpec."""
+ with ops.name_scope("head"):
+ if labels:
+ raise ValueError(
+ "The model received a `labels` dictionary, which is "
+ "not supported. Pass '{}' and '{}' as "
+ "features.".format(feature_keys.TrainEvalFeatures.TIMES,
+ feature_keys.TrainEvalFeatures.VALUES))
+ del labels
+ features = {
+ name: self._convert_feature_to_tensor(name=name, value=value)
+ for name, value in features.items()
+ }
+ if self.input_statistics_generator is not None:
+ input_statistics = self.input_statistics_generator.initialize_graph(
+ features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN))
+ else:
+ input_statistics = None
+ self.model.initialize_graph(input_statistics=input_statistics)
+
+ # _gather_state requires the model to have its graph initialized (so it
+ # has access to the structure of the model's state)
+ features, passed_flat_state = self._gather_state(features)
+ if (mode == estimator_lib.ModeKeys.TRAIN or
+ mode == estimator_lib.ModeKeys.EVAL):
+ _check_train_eval_features(features, self.model)
+ elif mode == estimator_lib.ModeKeys.PREDICT:
+ _check_predict_features(features)
+ else:
+ raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode))
+
+ self.state_manager.initialize_graph(
+ model=self.model, input_statistics=input_statistics)
+
+ if mode == estimator_lib.ModeKeys.TRAIN:
+ return self._train_ops(features)
+ elif mode == estimator_lib.ModeKeys.EVAL:
+ return self._evaluate_ops(features)
+ elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state:
+ return self._predict_ops(features)
+ elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state:
+ # The mode is PREDICT, but we're actually in export_savedmodel for
+ # serving. We want to return two graphs: one for filtering (state + data
+ # -> state) and one for predicting (state -> prediction).
+ return self._serving_ops(features)
+
+
+def _check_feature_shapes_compatible_with(features,
+ compatible_with_name,
+ compatible_with_value,
+ ignore=None):
+ """Checks all features are compatible with the given time-like feature."""
+ if ignore is None:
+ ignore = set()
+ for name, value in features.items():
+ if name in ignore:
+ continue
+ feature_shape = value.get_shape()
+ if feature_shape.ndims is None:
+ continue
+ if feature_shape.ndims < 2:
+ raise ValueError(
+ ("Features must have shape (batch dimension, window size, ...) "
+ "(got rank {} for feature '{}')").format(feature_shape.ndims, name))
+ if not feature_shape[:2].is_compatible_with(
+ compatible_with_value.get_shape()):
+ raise ValueError(
+ ("Features must have shape (batch dimension, window size, ...) "
+ "where batch dimension and window size match the "
+ "'{times_feature}' feature (got shape {feature_shape} for "
+ "feature '{feature_name}' but shape {times_shape} for feature "
+ "'{times_feature}')").format(
+ times_feature=compatible_with_name,
+ feature_shape=feature_shape,
+ feature_name=name,
+ times_shape=compatible_with_value.get_shape()))
+
+
+def _check_predict_features(features):
+ """Raises errors if features are not suitable for prediction."""
+ if feature_keys.PredictionFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.TIMES))
+ if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
+ raise ValueError("Expected a '{}' feature for prediction.".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE))
+ times_feature = features[feature_keys.PredictionFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
+ times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.PredictionFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes
+ ]))
+
+
+def _check_train_eval_features(features, model):
+ """Raise errors if features are not suitable for training/evaluation."""
+ if feature_keys.TrainEvalFeatures.TIMES not in features:
+ raise ValueError("Expected a '{}' feature for training/evaluation.".format(
+ feature_keys.TrainEvalFeatures.TIMES))
+ if feature_keys.TrainEvalFeatures.VALUES not in features:
+ raise ValueError("Expected a '{}' feature for training/evaluation.".format(
+ feature_keys.TrainEvalFeatures.VALUES))
+ times_feature = features[feature_keys.TrainEvalFeatures.TIMES]
+ if not times_feature.get_shape().is_compatible_with([None, None]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size) for feature '{}' "
+ "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES,
+ times_feature.get_shape()))
+ values_feature = features[feature_keys.TrainEvalFeatures.VALUES]
+ if not values_feature.get_shape().is_compatible_with(
+ [None, None, model.num_features]):
+ raise ValueError(
+ ("Expected shape (batch dimension, window size, {num_features}) "
+ "for feature '{feature_name}', since the model was configured "
+ "with num_features={num_features} (got shape {got_shape})").format(
+ num_features=model.num_features,
+ feature_name=feature_keys.TrainEvalFeatures.VALUES,
+ got_shape=times_feature.get_shape()))
+ _check_feature_shapes_compatible_with(
+ features=features,
+ compatible_with_name=feature_keys.TrainEvalFeatures.TIMES,
+ compatible_with_value=times_feature,
+ ignore=set([
+ feature_keys.State.STATE_TUPLE # Model-dependent shapes
+ ]))
+
+
+def _identity_metric_single(name, input_tensor):
+ """A metric which takes on its last updated value.
+
+ This keeps evaluation metrics in sync with one another, since update ops are
+ run separately from their result Tensors. Simply returning (input_tensor,
+ no_op) as a metric with a value but no update means that a metric will come
+ from a different batch of data than metrics which cache values in a Variable
+ (e.g. the default loss metric).
+
+ Args:
+ name: A name for the metric.
+ input_tensor: Any Tensor.
+ Returns:
+ A tuple of (value, update_op).
+ """
+ metric_variable = variable_scope.variable(
+ name="{}_identity_metric".format(name),
+ initial_value=array_ops.zeros([], dtype=input_tensor.dtype),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ validate_shape=False)
+ update_op = state_ops.assign(
+ metric_variable, input_tensor, validate_shape=False)
+ # This shape will be correct once the first update runs (but may be
+ # incomplete, so is not helpful for initializing the variable).
+ metric_variable.set_shape(input_tensor.get_shape())
+ return (metric_variable.value(), update_op)
+
+
+def _identity_metric_nested(name, input_tensors):
+ """Create identity metrics for a nested tuple of Tensors."""
+ update_ops = []
+ value_tensors = []
+ for tensor_number, tensor in enumerate(nest.flatten(input_tensors)):
+ value_tensor, update_op = _identity_metric_single(
+ name="{}_{}".format(name, tensor_number), input_tensor=tensor)
+ update_ops.append(update_op)
+ value_tensors.append(value_tensor)
+ return (nest.pack_sequence_as(input_tensors, value_tensors),
+ control_flow_ops.group(*update_ops))
+
+
+def state_to_dictionary(state_tuple):
+ """Flatten model state into a dictionary with string keys."""
+ flattened = {}
+ for state_number, state_value in enumerate(nest.flatten(state_tuple)):
+ prefixed_state_name = "{}_{:02d}".format(feature_keys.State.STATE_PREFIX,
+ state_number)
+ flattened[prefixed_state_name] = state_value
+ return flattened
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
new file mode 100644
index 0000000000..3415061cfd
--- /dev/null
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -0,0 +1,267 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for head."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.timeseries.python.timeseries import feature_keys
+from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib
+from tensorflow.contrib.timeseries.python.timeseries import model
+from tensorflow.contrib.timeseries.python.timeseries import state_management
+
+from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import coordinator as coordinator_lib
+from tensorflow.python.training import queue_runner_impl
+from tensorflow.python.training import training as train
+
+
+class HeadTest(test.TestCase):
+
+ def test_labels_provided_error(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,
+ estimator_lib.ModeKeys.PREDICT]:
+ with self.assertRaisesRegexp(ValueError, "labels"):
+ model_fn(features={}, labels={"a": "b"}, mode=mode)
+
+ def test_unknown_mode(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"):
+ model_fn(features={}, labels={}, mode="Not a mode")
+
+
+class _TickerModel(object):
+ num_features = 1
+ dtype = dtypes.float32
+
+ def initialize_graph(self, input_statistics):
+ pass
+
+ def define_loss(self, features, mode):
+ del mode # unused
+ return model.ModelOutputs(
+ loss=features["ticker"],
+ end_state=(features["ticker"], features["ticker"]),
+ prediction_times=array_ops.zeros(()),
+ predictions={"ticker": features["ticker"]})
+
+
+class EvaluationMetricsTests(test.TestCase):
+
+ def test_metrics_consistent(self):
+ # Tests that the identity metrics used to report in-sample predictions match
+ # the behavior of standard metrics.
+ g = ops.Graph()
+ with g.as_default():
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES:
+ array_ops.zeros((1, 1)),
+ feature_keys.TrainEvalFeatures.VALUES:
+ array_ops.zeros((1, 1, 1)),
+ "ticker":
+ array_ops.reshape(
+ math_ops.cast(
+ variables.Variable(
+ name="ticker",
+ initial_value=0,
+ dtype=dtypes.int64,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ .count_up_to(10),
+ dtype=dtypes.float32), (1, 1, 1))
+ }
+ model_fn = ts_head_lib.time_series_regression_head(
+ model=_TickerModel(),
+ state_manager=state_management.PassthroughStateManager(),
+ optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec
+ outputs = model_fn(
+ features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL)
+ metric_update_ops = [
+ metric[1] for metric in outputs.eval_metric_ops.values()]
+ loss_mean, loss_update = metrics.mean(outputs.loss)
+ metric_update_ops.append(loss_update)
+ with self.test_session() as sess:
+ coordinator = coordinator_lib.Coordinator()
+ queue_runner_impl.start_queue_runners(sess, coord=coordinator)
+ variables.local_variables_initializer().run()
+ sess.run(metric_update_ops)
+ loss_evaled, metric_evaled, nested_metric_evaled = sess.run(
+ (loss_mean, outputs.eval_metric_ops["ticker"][0],
+ outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][
+ 0][0]))
+ # The custom model_utils metrics for in-sample predictions should be in
+ # sync with the Estimator's mean metric for model loss.
+ self.assertAllClose(0., loss_evaled)
+ self.assertAllClose((((0.,),),), metric_evaled)
+ self.assertAllClose((((0.,),),), nested_metric_evaled)
+ coordinator.request_stop()
+ coordinator.join()
+
+
+class _StubModel(object):
+ num_features = 3
+ dtype = dtypes.float64
+
+ def initialize_graph(self, input_statistics):
+ del input_statistics # unused
+
+
+def _stub_model_fn():
+ return ts_head_lib.time_series_regression_head(
+ model=_StubModel(),
+ state_manager=state_management.PassthroughStateManager(),
+ optimizer=train.AdamOptimizer(0.001)).create_estimator_spec
+
+
+class TrainEvalFeatureCheckingTests(test.TestCase):
+
+ def test_no_time_feature(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
+ feature_keys.TrainEvalFeatures.TIMES)):
+ model_fn(
+ features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]},
+ labels=None,
+ mode=mode)
+
+ def test_no_value_feature(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
+ feature_keys.TrainEvalFeatures.VALUES)):
+ model_fn(
+ features={feature_keys.TrainEvalFeatures.TIMES: [[1]]},
+ labels=None,
+ mode=mode)
+
+ def test_bad_time_rank(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(ValueError,
+ "Expected shape.*for feature '{}'".format(
+ feature_keys.TrainEvalFeatures.TIMES)):
+ model_fn(
+ features={
+ feature_keys.TrainEvalFeatures.TIMES: [[[1]]],
+ feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
+ },
+ labels=None,
+ mode=mode)
+
+ def test_bad_value_rank(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(ValueError,
+ "Expected shape.*for feature '{}'".format(
+ feature_keys.TrainEvalFeatures.VALUES)):
+ model_fn(
+ features={
+ feature_keys.TrainEvalFeatures.TIMES: [[1]],
+ feature_keys.TrainEvalFeatures.VALUES: [[1.]]
+ },
+ labels=None,
+ mode=mode)
+
+ def test_bad_value_num_features(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(
+ ValueError, "Expected shape.*, 3.*for feature '{}'".format(
+ feature_keys.TrainEvalFeatures.VALUES)):
+ model_fn(
+ features={
+ feature_keys.TrainEvalFeatures.TIMES: [[1]],
+ feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
+ },
+ labels=None,
+ mode=mode)
+
+ def test_bad_exogenous_shape(self):
+ model_fn = _stub_model_fn()
+ for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Features must have shape.*for feature 'exogenous'"):
+ model_fn(
+ features={
+ feature_keys.TrainEvalFeatures.TIMES: [[1]],
+ feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]],
+ "exogenous": [[1], [2]]
+ },
+ labels=None,
+ mode=mode)
+
+
+class PredictFeatureCheckingTests(test.TestCase):
+
+ def test_no_time_feature(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
+ feature_keys.PredictionFeatures.TIMES)):
+ model_fn(
+ features={
+ feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.)
+ },
+ labels=None,
+ mode=estimator_lib.ModeKeys.PREDICT)
+
+ def test_no_start_state_feature(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
+ feature_keys.PredictionFeatures.STATE_TUPLE)):
+ model_fn(
+ features={feature_keys.PredictionFeatures.TIMES: [[1]]},
+ labels=None,
+ mode=estimator_lib.ModeKeys.PREDICT)
+
+ def test_bad_time_rank(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(ValueError,
+ "Expected shape.*for feature '{}'".format(
+ feature_keys.PredictionFeatures.TIMES)):
+ model_fn(
+ features={
+ feature_keys.PredictionFeatures.TIMES: 1,
+ feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.))
+ },
+ labels=None,
+ mode=estimator_lib.ModeKeys.PREDICT)
+
+ def test_bad_exogenous_shape(self):
+ model_fn = _stub_model_fn()
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Features must have shape.*for feature 'exogenous'"):
+ model_fn(
+ features={
+ feature_keys.PredictionFeatures.TIMES: [[1]],
+ feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)),
+ "exogenous": 1.
+ },
+ labels=None,
+ mode=estimator_lib.ModeKeys.PREDICT)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils.py
index addcdb0575..b5d7cb376b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils.py
@@ -18,334 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import re
-
import numpy
-from tensorflow.contrib.framework.python.ops import variables
-from tensorflow.contrib.layers.python.layers import optimizers
-
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
-from tensorflow.python.estimator import estimator_lib
-from tensorflow.python.estimator.export import export_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-
-
-def _check_feature_shapes_compatible_with(
- features, compatible_with_name, compatible_with_value, ignore=None):
- """Checks all features are compatible with the given time-like feature."""
- if ignore is None:
- ignore = set()
- for name, value in features.items():
- if name in ignore:
- continue
- feature_shape = value.get_shape()
- if feature_shape.ndims is None:
- continue
- if feature_shape.ndims < 2:
- raise ValueError(
- ("Features must have shape (batch dimension, window size, ...) "
- "(got rank {} for feature '{}')").format(
- feature_shape.ndims, name))
- if not feature_shape[:2].is_compatible_with(
- compatible_with_value.get_shape()):
- raise ValueError(
- ("Features must have shape (batch dimension, window size, ...) "
- "where batch dimension and window size match the "
- "'{times_feature}' feature (got shape {feature_shape} for "
- "feature '{feature_name}' but shape {times_shape} for feature "
- "'{times_feature}')").format(
- times_feature=compatible_with_name,
- feature_shape=feature_shape,
- feature_name=name,
- times_shape=compatible_with_value.get_shape()))
-
-
-def _check_predict_features(features):
- """Raises errors if features are not suitable for prediction."""
- if feature_keys.PredictionFeatures.TIMES not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.TIMES))
- if feature_keys.PredictionFeatures.STATE_TUPLE not in features:
- raise ValueError("Expected a '{}' feature for prediction.".format(
- feature_keys.PredictionFeatures.STATE_TUPLE))
- times_feature = features[feature_keys.PredictionFeatures.TIMES]
- if not times_feature.get_shape().is_compatible_with([None, None]):
- raise ValueError(
- ("Expected shape (batch dimension, window size) for feature '{}' "
- "(got shape {})").format(feature_keys.PredictionFeatures.TIMES,
- times_feature.get_shape()))
- _check_feature_shapes_compatible_with(
- features=features,
- compatible_with_name=feature_keys.PredictionFeatures.TIMES,
- compatible_with_value=times_feature,
- ignore=set([
- feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes
- ]))
-
-
-def _check_train_eval_features(features, model):
- """Raise errors if features are not suitable for training/evaluation."""
- if feature_keys.TrainEvalFeatures.TIMES not in features:
- raise ValueError("Expected a '{}' feature for training/evaluation.".format(
- feature_keys.TrainEvalFeatures.TIMES))
- if feature_keys.TrainEvalFeatures.VALUES not in features:
- raise ValueError("Expected a '{}' feature for training/evaluation.".format(
- feature_keys.TrainEvalFeatures.VALUES))
- times_feature = features[feature_keys.TrainEvalFeatures.TIMES]
- if not times_feature.get_shape().is_compatible_with([None, None]):
- raise ValueError(
- ("Expected shape (batch dimension, window size) for feature '{}' "
- "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES,
- times_feature.get_shape()))
- values_feature = features[feature_keys.TrainEvalFeatures.VALUES]
- if not values_feature.get_shape().is_compatible_with(
- [None, None, model.num_features]):
- raise ValueError(
- ("Expected shape (batch dimension, window size, {num_features}) "
- "for feature '{feature_name}', since the model was configured "
- "with num_features={num_features} (got shape {got_shape})").format(
- num_features=model.num_features,
- feature_name=feature_keys.TrainEvalFeatures.VALUES,
- got_shape=times_feature.get_shape()))
- _check_feature_shapes_compatible_with(
- features=features,
- compatible_with_name=feature_keys.TrainEvalFeatures.TIMES,
- compatible_with_value=times_feature,
- ignore=set([
- feature_keys.State.STATE_TUPLE # Model-dependent shapes
- ]))
-
-
-def _identity_metric_single(name, input_tensor):
- """A metric which takes on its last updated value.
-
- This keeps evaluation metrics in sync with one another, since update ops are
- run separately from their result Tensors. Simply returning (input_tensor,
- no_op) as a metric with a value but no update means that a metric will come
- from a different batch of data than metrics which cache values in a Variable
- (e.g. the default loss metric).
-
- Args:
- name: A name for the metric.
- input_tensor: Any Tensor.
- Returns:
- A tuple of (value, update_op).
- """
- metric_variable = variable_scope.variable(
- name="{}_identity_metric".format(name),
- initial_value=array_ops.zeros([], dtype=input_tensor.dtype),
- collections=[ops.GraphKeys.LOCAL_VARIABLES],
- validate_shape=False)
- update_op = state_ops.assign(metric_variable, input_tensor,
- validate_shape=False)
- # This shape will be correct once the first update runs (but may be
- # incomplete, so is not helpful for initializing the variable).
- metric_variable.set_shape(input_tensor.get_shape())
- return (metric_variable.value(), update_op)
-
-
-def _identity_metric_nested(name, input_tensors):
- """Create identity metrics for a nested tuple of Tensors."""
- update_ops = []
- value_tensors = []
- for tensor_number, tensor in enumerate(nest.flatten(input_tensors)):
- value_tensor, update_op = _identity_metric_single(
- name="{}_{}".format(name, tensor_number),
- input_tensor=tensor)
- update_ops.append(update_op)
- value_tensors.append(value_tensor)
- return (nest.pack_sequence_as(input_tensors, value_tensors),
- control_flow_ops.group(*update_ops))
-
-
-def state_to_dictionary(state_tuple):
- """Flatten model state into a dictionary with string keys."""
- flattened = {}
- for state_number, state_value in enumerate(nest.flatten(state_tuple)):
- prefixed_state_name = "{}_{:02d}".format(feature_keys.State.STATE_PREFIX,
- state_number)
- flattened[prefixed_state_name] = state_value
- return flattened
-
-
-def make_model_fn(
- model, state_manager, optimizer, input_statistics_generator=None):
- """Returns a model function suitable for use with a tf.estimator.
-
- Args:
- model: The object (inheriting from Model) to create a function for.
- state_manager: A state manager to wrap the model with (or
- PassthroughStateManager if no state needs to be managed).
- optimizer: An instance of `tf.train.Optimizer` to use for training.
- input_statistics_generator: An InputStatisticsFromMiniBatch object from
- math_utils.py, used for collecting statistics about input data during
- training.
- Returns:
- The model function, suitable for passing to a tf.estimator.Estimator.
- """
-
- def _convert_feature_to_tensor(name, value):
- """Casts features to the correct dtype based on their name."""
- if name in [
- feature_keys.TrainEvalFeatures.TIMES,
- feature_keys.PredictionFeatures.TIMES
- ]:
- return math_ops.cast(value, dtypes.int64)
- if name == feature_keys.TrainEvalFeatures.VALUES:
- return math_ops.cast(value, model.dtype)
- if name == feature_keys.PredictionFeatures.STATE_TUPLE:
- return value # Correct dtypes are model-dependent
- return ops.convert_to_tensor(value)
-
- def _gather_state(features):
- """Returns `features` with state packed, indicates if packing was done."""
- prefixed_state_re = re.compile(r"^" + feature_keys.State.STATE_PREFIX +
- r"_(\d+)$")
- numbered_state = []
- for key, tensor in features.items():
- search_result = prefixed_state_re.search(key)
- if search_result:
- numbered_state.append((int(search_result.group(1)), key, tensor))
- if not numbered_state:
- return features, False
- features = features.copy()
- for _, key, _ in numbered_state:
- del features[key]
- numbered_state.sort(key=lambda number, *_: number)
- features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as(
- structure=model.get_start_state(),
- flat_sequence=[tensor for _, _, tensor in numbered_state])
- return features, True
-
- def _train(features):
- """Add training ops to the graph."""
- with variable_scope.variable_scope("model"):
- model_outputs = state_manager.define_loss(model, features,
- estimator_lib.ModeKeys.TRAIN)
- train_op = optimizers.optimize_loss(
- model_outputs.loss,
- global_step=variables.get_global_step(),
- optimizer=optimizer,
- # Learning rate is set in the Optimizer object
- learning_rate=None)
- return estimator_lib.EstimatorSpec(
- loss=model_outputs.loss,
- mode=estimator_lib.ModeKeys.TRAIN,
- train_op=train_op)
-
- def _evaluate(features):
- """Add ops for evaluation (aka filtering) to the graph."""
- with variable_scope.variable_scope("model"):
- model_outputs = state_manager.define_loss(model, features,
- estimator_lib.ModeKeys.EVAL)
- metrics = {}
- # Just output in-sample predictions for the last chunk seen
- for prediction_key, prediction_value in model_outputs.predictions.items():
- metrics[prediction_key] = _identity_metric_single(prediction_key,
- prediction_value)
- metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single(
- feature_keys.FilteringResults.TIMES, model_outputs.prediction_times)
- metrics[feature_keys.FilteringResults.STATE_TUPLE] = (
- _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,
- model_outputs.end_state))
- return estimator_lib.EstimatorSpec(
- loss=model_outputs.loss,
- mode=estimator_lib.ModeKeys.EVAL,
- eval_metric_ops=metrics,
- predictions={})
-
- def _predict(features):
- """Add ops for prediction to the graph."""
- with variable_scope.variable_scope("model"):
- prediction = model.predict(features=features)
- prediction[feature_keys.PredictionResults.TIMES] = features[
- feature_keys.PredictionFeatures.TIMES]
- return estimator_lib.EstimatorSpec(
- predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT)
-
- def _serving(features):
- with variable_scope.variable_scope("model"):
- prediction_outputs = model.predict(features=features)
- with variable_scope.variable_scope("model", reuse=True):
- filtering_outputs = state_manager.define_loss(model, features,
- estimator_lib.ModeKeys.EVAL)
- return estimator_lib.EstimatorSpec(
- mode=estimator_lib.ModeKeys.PREDICT,
- export_outputs={
- feature_keys.SavedModelLabels.PREDICT:
- export_lib.PredictOutput(prediction_outputs),
- feature_keys.SavedModelLabels.FILTER:
- export_lib.PredictOutput(
- state_to_dictionary(filtering_outputs.end_state))
- },
- # Likely unused, but it is necessary to return `predictions` to satisfy
- # the Estimator's error checking.
- predictions={})
-
- def _model_fn(features, labels, mode):
- """Given a time series in `features`, define a loss for `mode`.
-
- Args:
- features: A dictionary, the output of a chunker (typically with keys
- feature_keys.TrainEvalFeatures.TIMES and
- feature_keys.TrainEvalFeatures.VALUES).
- labels: Not used; included for compatibility with tf.learn.
- mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).
- Returns:
- A tuple of predictions, a loss Tensor, and a train op.
- Raises:
- ValueError: If the model makes predictions which do not have static shape
- information.
- """
- if labels:
- raise ValueError("The model received a `labels` dictionary, which is not"
- " supported. Pass '{}' and '{}' as features.".format(
- feature_keys.TrainEvalFeatures.TIMES,
- feature_keys.TrainEvalFeatures.VALUES))
- del labels
- features = {name: _convert_feature_to_tensor(name=name, value=value)
- for name, value in features.items()}
- if input_statistics_generator is not None:
- input_statistics = input_statistics_generator.initialize_graph(
- features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN))
- else:
- input_statistics = None
- model.initialize_graph(input_statistics=input_statistics)
- # _gather_state requires the model to have its graph initialized (so it has
- # access to the structure of the model's state)
- features, passed_flat_state = _gather_state(features)
- if (mode == estimator_lib.ModeKeys.TRAIN
- or mode == estimator_lib.ModeKeys.EVAL):
- _check_train_eval_features(features, model)
- elif mode == estimator_lib.ModeKeys.PREDICT:
- _check_predict_features(features)
- else:
- raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode))
- state_manager.initialize_graph(
- model=model, input_statistics=input_statistics)
- if mode == estimator_lib.ModeKeys.TRAIN:
- return _train(features)
- elif mode == estimator_lib.ModeKeys.EVAL:
- return _evaluate(features)
- elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state:
- return _predict(features)
- elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state:
- # The mode is PREDICT, but we're actually in export_savedmodel for
- # serving. We want to return two graphs: one for filtering (state + data
- # -> state) and one for predicting (state -> prediction).
- return _serving(features)
- return _model_fn
# TODO(agarwal): Remove and replace with functionality from tf.slim
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
index 2998689554..cfd31cc70d 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
@@ -18,22 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.timeseries.python.timeseries import feature_keys
-from tensorflow.contrib.timeseries.python.timeseries import model
from tensorflow.contrib.timeseries.python.timeseries import model_utils
-from tensorflow.contrib.timeseries.python.timeseries import state_management
-from tensorflow.python.estimator import estimator_lib
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import metrics
-from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import coordinator as coordinator_lib
-from tensorflow.python.training import queue_runner_impl
-from tensorflow.python.training import training as train
class ModelUtilsTest(test.TestCase):
@@ -46,230 +34,6 @@ class ModelUtilsTest(test.TestCase):
self.assertEqual(5, getter(parameter))
self.assertEqual(4, getter(overridden_parameter))
- def test_labels_provided_error(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,
- estimator_lib.ModeKeys.PREDICT]:
- with self.assertRaisesRegexp(ValueError, "labels"):
- model_fn(features={}, labels={"a": "b"}, mode=mode)
-
- def test_unknown_mode(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"):
- model_fn(features={}, labels={}, mode="Not a mode")
-
-
-class _TickerModel(object):
- num_features = 1
- dtype = dtypes.float32
-
- def initialize_graph(self, input_statistics):
- pass
-
- def define_loss(self, features, mode):
- del mode # unused
- return model.ModelOutputs(
- loss=features["ticker"],
- end_state=(features["ticker"], features["ticker"]),
- prediction_times=array_ops.zeros(()),
- predictions={"ticker": features["ticker"]})
-
-
-class EvaluationMetricsTests(test.TestCase):
-
- def test_metrics_consistent(self):
- # Tests that the identity metrics used to report in-sample predictions match
- # the behavior of standard metrics.
- g = ops.Graph()
- with g.as_default():
- features = {
- feature_keys.TrainEvalFeatures.TIMES:
- array_ops.zeros((1, 1)),
- feature_keys.TrainEvalFeatures.VALUES:
- array_ops.zeros((1, 1, 1)),
- "ticker":
- array_ops.reshape(
- math_ops.cast(
- variables.Variable(
- name="ticker",
- initial_value=0,
- dtype=dtypes.int64,
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
- .count_up_to(10),
- dtype=dtypes.float32), (1, 1, 1))
- }
- model_fn = model_utils.make_model_fn(
- model=_TickerModel(),
- state_manager=state_management.PassthroughStateManager(),
- optimizer=train.GradientDescentOptimizer(0.001))
- outputs = model_fn(
- features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL)
- metric_update_ops = [
- metric[1] for metric in outputs.eval_metric_ops.values()]
- loss_mean, loss_update = metrics.mean(outputs.loss)
- metric_update_ops.append(loss_update)
- with self.test_session() as sess:
- coordinator = coordinator_lib.Coordinator()
- queue_runner_impl.start_queue_runners(sess, coord=coordinator)
- variables.local_variables_initializer().run()
- sess.run(metric_update_ops)
- loss_evaled, metric_evaled, nested_metric_evaled = sess.run(
- (loss_mean, outputs.eval_metric_ops["ticker"][0],
- outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][
- 0][0]))
- # The custom model_utils metrics for in-sample predictions should be in
- # sync with the Estimator's mean metric for model loss.
- self.assertAllClose(0., loss_evaled)
- self.assertAllClose((((0.,),),), metric_evaled)
- self.assertAllClose((((0.,),),), nested_metric_evaled)
- coordinator.request_stop()
- coordinator.join()
-
-
-class _StubModel(object):
- num_features = 3
- dtype = dtypes.float64
-
- def initialize_graph(self, input_statistics):
- del input_statistics # unused
-
-
-def _stub_model_fn():
- return model_utils.make_model_fn(
- model=_StubModel(),
- state_manager=state_management.PassthroughStateManager(),
- optimizer=train.AdamOptimizer(0.001))
-
-
-class TrainEvalFeatureCheckingTests(test.TestCase):
-
- def test_no_time_feature(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
- feature_keys.TrainEvalFeatures.TIMES)):
- model_fn(
- features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]},
- labels=None,
- mode=mode)
-
- def test_no_value_feature(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
- feature_keys.TrainEvalFeatures.VALUES)):
- model_fn(
- features={feature_keys.TrainEvalFeatures.TIMES: [[1]]},
- labels=None,
- mode=mode)
-
- def test_bad_time_rank(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(ValueError,
- "Expected shape.*for feature '{}'".format(
- feature_keys.TrainEvalFeatures.TIMES)):
- model_fn(
- features={
- feature_keys.TrainEvalFeatures.TIMES: [[[1]]],
- feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
- },
- labels=None,
- mode=mode)
-
- def test_bad_value_rank(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(ValueError,
- "Expected shape.*for feature '{}'".format(
- feature_keys.TrainEvalFeatures.VALUES)):
- model_fn(
- features={
- feature_keys.TrainEvalFeatures.TIMES: [[1]],
- feature_keys.TrainEvalFeatures.VALUES: [[1.]]
- },
- labels=None,
- mode=mode)
-
- def test_bad_value_num_features(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(
- ValueError, "Expected shape.*, 3.*for feature '{}'".format(
- feature_keys.TrainEvalFeatures.VALUES)):
- model_fn(
- features={
- feature_keys.TrainEvalFeatures.TIMES: [[1]],
- feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]
- },
- labels=None,
- mode=mode)
-
- def test_bad_exogenous_shape(self):
- model_fn = _stub_model_fn()
- for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:
- with self.assertRaisesRegexp(
- ValueError,
- "Features must have shape.*for feature 'exogenous'"):
- model_fn(
- features={
- feature_keys.TrainEvalFeatures.TIMES: [[1]],
- feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]],
- "exogenous": [[1], [2]]
- },
- labels=None,
- mode=mode)
-
-
-class PredictFeatureCheckingTests(test.TestCase):
-
- def test_no_time_feature(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
- feature_keys.PredictionFeatures.TIMES)):
- model_fn(
- features={
- feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.)
- },
- labels=None,
- mode=estimator_lib.ModeKeys.PREDICT)
-
- def test_no_start_state_feature(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format(
- feature_keys.PredictionFeatures.STATE_TUPLE)):
- model_fn(
- features={feature_keys.PredictionFeatures.TIMES: [[1]]},
- labels=None,
- mode=estimator_lib.ModeKeys.PREDICT)
-
- def test_bad_time_rank(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(ValueError,
- "Expected shape.*for feature '{}'".format(
- feature_keys.PredictionFeatures.TIMES)):
- model_fn(
- features={
- feature_keys.PredictionFeatures.TIMES: 1,
- feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.))
- },
- labels=None,
- mode=estimator_lib.ModeKeys.PREDICT)
-
- def test_bad_exogenous_shape(self):
- model_fn = _stub_model_fn()
- with self.assertRaisesRegexp(
- ValueError,
- "Features must have shape.*for feature 'exogenous'"):
- model_fn(
- features={
- feature_keys.PredictionFeatures.TIMES: [[1]],
- feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)),
- "exogenous": 1.
- },
- labels=None,
- mode=estimator_lib.ModeKeys.PREDICT)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py
index 16e29f5e68..97f6d36a87 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py
@@ -23,6 +23,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.timeseries.python.timeseries import feature_keys as _feature_keys
+from tensorflow.contrib.timeseries.python.timeseries import head as _head
from tensorflow.contrib.timeseries.python.timeseries import input_pipeline as _input_pipeline
from tensorflow.contrib.timeseries.python.timeseries import model_utils as _model_utils
@@ -34,7 +35,7 @@ def _colate_features_to_feeds_and_fetches(continue_from, signature, features,
"""Uses a saved model signature to construct feed and fetch dictionaries."""
if _feature_keys.FilteringResults.STATE_TUPLE in continue_from:
# We're continuing from an evaluation, so we need to unpack/flatten state.
- state_values = _model_utils.state_to_dictionary(
+ state_values = _head.state_to_dictionary(
continue_from[_feature_keys.FilteringResults.STATE_TUPLE])
else:
state_values = continue_from
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index eb66d8e329..f3e43dd552 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1773,6 +1773,7 @@ tf_cuda_library(
) + if_mkl(
[
"//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn//:mkl_dnn",
],
),
alwayslink = 1,
@@ -1933,7 +1934,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/visitable_allocator.h",
"graph/gradients.h",
"graph/quantize_training.h",
-]
+] + if_mkl(["graph/mkl_graph_util.h"])
tf_cuda_library(
name = "core_cpu_impl",
@@ -2034,7 +2035,10 @@ tf_cuda_library(
"//third_party/eigen3",
"//tensorflow/core/kernels:required",
] + if_mkl(
- ["//third_party/mkl:intel_binary_blob"],
+ [
+ "//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn//:mkl_dnn",
+ ],
) + tf_additional_core_deps() + if_static([":core_cpu_impl"]),
alwayslink = 1,
)
@@ -2670,7 +2674,7 @@ tf_cc_test_mkl(
"graph/mkl_layout_pass_test.cc",
"graph/mkl_tfconversion_pass_test.cc",
],
- linkstatic = tf_kernel_tests_linkstatic(),
+ linkstatic = 1,
deps = [
":core",
":core_cpu",
@@ -2688,18 +2692,6 @@ tf_cc_test_mkl(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/core/kernels:mkl_aggregate_ops",
- "//tensorflow/core/kernels:mkl_concat_op",
- "//tensorflow/core/kernels:mkl_conv_op",
- "//tensorflow/core/kernels:mkl_cwise_ops_common",
- "//tensorflow/core/kernels:mkl_fused_batch_norm_op",
- "//tensorflow/core/kernels:mkl_identity_op",
- "//tensorflow/core/kernels:mkl_input_conversion_op",
- "//tensorflow/core/kernels:mkl_lrn_op",
- "//tensorflow/core/kernels:mkl_pooling_ops",
- "//tensorflow/core/kernels:mkl_relu_op",
- "//tensorflow/core/kernels:mkl_reshape_op",
- "//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:ops_util",
"//third_party/eigen3",
],
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
new file mode 100644
index 0000000000..cb32d64334
--- /dev/null
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
+#define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
+#ifdef INTEL_MKL
+
+#include <string>
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+// Since our ops are going to produce and also consume N addition tensors
+// (Mkl) for N Tensorflow tensors, we can have following different
+// orderings among these 2N tensors.
+//
+// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
+// consume A_m, B_m, and C_m additionally.
+//
+// INTERLEAVED: in this case 2N tensors are interleaved. So for above
+// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
+//
+// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
+// by N Mkl tensors. So for above example, the ordering looks
+// like: A, B, C, A_m, B_m, C_m
+//
+// Following APIs map index of original Tensorflow tensors to their
+// appropriate position based on selected ordering. For contiguous ordering,
+// we need to know the total number of tensors (parameter total).
+//
+typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
+// NOTE: Currently, we use contiguous ordering. If you change this, then you
+// would need to change Mkl op definitions in nn_ops.cc.
+static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
+
+// Get index of MetaData tensor from index 'n' of Data tensor.
+inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ // For interleaved ordering, Mkl tensor follows immediately after
+ // Tensorflow tensor.
+ return n + 1;
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
+ return n + total_tensors / 2;
+ }
+}
+
+int inline GetTensorDataIndex(int n, int total_tensors) {
+ if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+ return 2 * n; // index corresponding to nth input/output tensor
+ } else {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ return n;
+ }
+}
+
+int inline GetTensorMetaDataIndex(int n, int total_tensors) {
+ // Get index for TensorData first and then use mapping function
+ // to get TensorMetaData index from TensorData index.
+ int tidx = GetTensorDataIndex(n, total_tensors);
+ return DataIndexToMetaDataIndex(tidx, total_tensors);
+}
+
+namespace mkl_op_registry {
+static const char* kMklOpLabel = "MklOp";
+static const char* kMklOpLabelPattern = "label='MklOp'";
+
+// Get the name of Mkl op from original TensorFlow op
+// We prefix 'Mkl' to the original op to get Mkl op.
+inline string GetMklOpName(const string& name) {
+ // Prefix that we add to Tensorflow op name to construct Mkl op name.
+ const char* const kMklOpPrefix = "_Mkl";
+ return string(kMklOpPrefix) + name;
+}
+
+// Check whether opname with type T is registered as MKL-compliant.
+//
+// @input: name of the op
+// @input: T datatype to be used for checking op
+// @return: true if opname is registered as Mkl op; false otherwise
+static inline bool IsMklOp(const std::string& op_name, DataType T) {
+ string kernel = KernelsRegisteredForOp(op_name);
+ bool result =
+ kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
+ if (result) {
+ VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
+ }
+ return result;
+}
+
+// Check whether opname with type T is registered as MKL-compliant and
+// is element-wise.
+//
+// @input: name of the op
+// @input: T datatype to be used for checking op
+// @return: true if opname is registered as element-wise Mkl op;
+// false otherwise
+static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
+ if (!IsMklOp(op_name, T)) {
+ return false;
+ }
+
+ bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
+ 0 == op_name.compare(GetMklOpName("Sub")) ||
+ 0 == op_name.compare(GetMklOpName("Mul")) ||
+ 0 == op_name.compare(GetMklOpName("Maximum")) ||
+ 0 == op_name.compare(GetMklOpName("SquaredDifference")));
+
+ VLOG(1) << "mkl_op_registry::" << op_name
+ << " is elementwise MKL op: " << result;
+ return result;
+}
+} // namespace mkl_op_registry
+} // namespace tensorflow
+#endif // INTEL_MKL
+#endif // TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 90377e54c7..f87a94a76a 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -37,8 +37,8 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
-#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 6a41e3965a..a2b2f6530d 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include "tensorflow/core/graph/mkl_layout_pass.h"
-#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm>
#include <string>
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 3f8b0e86d0..fe4588389e 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -33,8 +33,8 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
-#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index b01818f746..bbdbe78bbd 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
-#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm>
#include <string>
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 36fbf6b023..bdc6faefbc 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -820,6 +820,7 @@ tf_kernel_library(
hdrs = ["transpose_op.h"],
deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn//:mkl_dnn",
]),
)
@@ -2596,6 +2597,7 @@ tf_kernel_library(
"//conditions:default": [],
}) + if_mkl([
"//third_party/mkl:intel_binary_blob",
+ "@mkl_dnn//:mkl_dnn",
]) + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]),
@@ -5501,8 +5503,10 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
+ ] + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
@@ -5516,8 +5520,10 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
+ ] + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
@@ -5566,16 +5572,19 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
+ ] + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
name = "mkl_fused_batch_norm_op",
srcs = ["mkl_fused_batch_norm_op.cc"],
- deps = NN_DEPS + [
+ deps = NN_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
@@ -5589,9 +5598,10 @@ tf_mkl_kernel_library(
tf_mkl_kernel_library(
name = "mkl_concat_op",
prefix = "mkl_concat_op",
- deps = ARRAY_DEPS + [
+ deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
@@ -5605,17 +5615,19 @@ tf_mkl_kernel_library(
tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
- deps = ARRAY_DEPS + [
+ deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
name = "mkl_lrn_op",
prefix = "mkl_lrn_op",
- deps = NN_DEPS + [
+ deps = NN_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
- ],
+ "@mkl_dnn//:mkl_dnn",
+ ]),
)
tf_mkl_kernel_library(
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 1bdfafb89b..368993c827 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -39,6 +39,48 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
+namespace {
+
+void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
+ int32* batch, int32* height, int32* width,
+ int32* channel) {
+ *batch = 1;
+ *width = 1;
+ *height = 1;
+ *channel = 1;
+ if (data_format == FORMAT_NHWC) {
+ int32 channel_dim = value_tensor.dims() - 1;
+ *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
+ for (int32 i = 0; i < channel_dim; i++) {
+ *batch *= static_cast<int32>(value_tensor.dim_size(i));
+ }
+ } else if (data_format == FORMAT_NCHW) {
+ int32 channel_dim = value_tensor.dims() - 3;
+ int32 height_dim = value_tensor.dims() - 2;
+ int32 width_dim = value_tensor.dims() - 1;
+ *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
+ *height = static_cast<int32>(value_tensor.dim_size(height_dim));
+ *width = static_cast<int32>(value_tensor.dim_size(width_dim));
+ for (int32 i = 0; i < channel_dim; i++) {
+ *batch *= static_cast<int32>(value_tensor.dim_size(i));
+ }
+ }
+}
+
+template <class T>
+struct AccumulatorType {
+ typedef T type;
+};
+
+// float is faster on the CPU than half, and also more precise,
+// so use float for the temporary accumulators.
+template <>
+struct AccumulatorType<Eigen::half> {
+ typedef float type;
+};
+
+} // namespace
+
template <typename Device, typename T>
class BiasOp : public BinaryOp<T> {
public:
@@ -50,9 +92,6 @@ class BiasOp : public BinaryOp<T> {
} else {
data_format_ = FORMAT_NHWC;
}
- OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
- errors::InvalidArgument(context->device()->name() +
- " BiasOp only supports NHWC."));
}
void Compute(OpKernelContext* context) override {
@@ -65,9 +104,21 @@ class BiasOp : public BinaryOp<T> {
OP_REQUIRES(context, TensorShapeUtils::IsVector(bias.shape()),
errors::InvalidArgument("Biases must be 1D: ",
bias.shape().DebugString()));
- const auto last_dim = input.shape().dims() - 1;
+
+ // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
+ size_t channel_dim;
+ if (data_format_ == FORMAT_NCHW) {
+ OP_REQUIRES(context, input.dims() == 4,
+ errors::InvalidArgument(
+ "NCHW format supports only 4D input tensor."));
+ channel_dim = 1;
+ } else {
+ channel_dim = input.shape().dims() - 1; // End of code by intel_tf.
+ }
+
OP_REQUIRES(
- context, bias.shape().dim_size(0) == input.shape().dim_size(last_dim),
+ context,
+ bias.shape().dim_size(0) == input.shape().dim_size(channel_dim),
errors::InvalidArgument(
"Must provide as many biases as the last dimension "
"of the input tensor: ",
@@ -78,6 +129,19 @@ class BiasOp : public BinaryOp<T> {
{0}, 0, input.shape(), &output));
if (input.NumElements() == 0) return;
+ // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
+ if (data_format_ == FORMAT_NCHW) {
+ int32 batch, height, width, channel;
+ GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
+ Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
+ Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
+ const Device& d = context->eigen_device<Device>();
+ output->tensor<T, 4>().device(d) =
+ input.tensor<T, 4>() +
+ bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
+ return;
+ } // End of code by intel_tf.
+
switch (input.shape().dims()) {
case 2:
Compute<2>(context, input, bias, output);
@@ -137,48 +201,6 @@ REGISTER_KERNEL(double);
#undef REGISTER_KERNEL
#endif // TENSORFLOW_USE_SYCL
-namespace {
-
-void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
- int32* batch, int32* height, int32* width,
- int32* channel) {
- *batch = 1;
- *width = 1;
- *height = 1;
- *channel = 1;
- if (data_format == FORMAT_NHWC) {
- int32 channel_dim = value_tensor.dims() - 1;
- *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
- for (int32 i = 0; i < channel_dim; i++) {
- *batch *= static_cast<int32>(value_tensor.dim_size(i));
- }
- } else if (data_format == FORMAT_NCHW) {
- int32 channel_dim = value_tensor.dims() - 3;
- int32 height_dim = value_tensor.dims() - 2;
- int32 width_dim = value_tensor.dims() - 1;
- *channel = static_cast<int32>(value_tensor.dim_size(channel_dim));
- *height = static_cast<int32>(value_tensor.dim_size(height_dim));
- *width = static_cast<int32>(value_tensor.dim_size(width_dim));
- for (int32 i = 0; i < channel_dim; i++) {
- *batch *= static_cast<int32>(value_tensor.dim_size(i));
- }
- }
-}
-
-template <class T>
-struct AccumulatorType {
- typedef T type;
-};
-
-// float is faster on the CPU than half, and also more precise,
-// so use float for the temporary accumulators.
-template <>
-struct AccumulatorType<Eigen::half> {
- typedef float type;
-};
-
-} // namespace
-
template <typename Device, typename T>
class BiasGradOp : public OpKernel {
public:
@@ -190,9 +212,6 @@ class BiasGradOp : public OpKernel {
} else {
data_format_ = FORMAT_NHWC;
}
- OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
- errors::InvalidArgument(context->device()->name() +
- " BiasGradOp only supports NHWC."));
}
void Compute(OpKernelContext* context) override {
@@ -222,18 +241,40 @@ class BiasGradOp : public OpKernel {
// Eigen often crashes by design on empty tensors, but setZero is safe
output->template flat<T>().setZero();
} else {
- Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
+ // Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
+ if (data_format_ == FORMAT_NCHW) {
+ OP_REQUIRES(context, output_backprop.dims() == 4,
+ errors::InvalidArgument(
+ "NCHW format supports only 4D input/output tensor."));
+ Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
+#ifdef EIGEN_HAS_INDEX_LIST
+ using idx0 = Eigen::type2index<0>;
+ using idx2 = Eigen::type2index<2>;
+ using idx3 = Eigen::type2index<3>;
+ Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
+#else
+ Eigen::array<int, 3> reduction_axes = {0, 2, 3};
+#endif
+ output->template flat<T>().device(context->eigen_device<Device>()) =
+ output_backprop.flat<T>()
+ .template cast<typename AccumulatorType<T>::type>()
+ .reshape(four_dims)
+ .sum(reduction_axes)
+ .template cast<T>(); // End of code by intel_tf.
+ } else {
+ Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
#ifdef EIGEN_HAS_INDEX_LIST
- Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
+ Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
#else
- Eigen::array<int, 1> reduction_axis = {0};
+ Eigen::array<int, 1> reduction_axis = {0};
#endif
- output->template flat<T>().device(context->eigen_device<Device>()) =
- output_backprop.flat<T>()
- .template cast<typename AccumulatorType<T>::type>()
- .reshape(two_dims)
- .sum(reduction_axis)
- .template cast<T>();
+ output->template flat<T>().device(context->eigen_device<Device>()) =
+ output_backprop.flat<T>()
+ .template cast<typename AccumulatorType<T>::type>()
+ .reshape(two_dims)
+ .sum(reduction_axis)
+ .template cast<T>();
+ }
}
}
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 641077ca65..5e09963d2d 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -816,40 +816,35 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(
- ConvolveBackwardFilterScratchSize, ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardFilterWithAlgorithm(
- input_desc, input_ptr, output_desc, out_backprop_ptr,
- conv_desc, filter_desc, &filter_backprop_ptr,
- &scratch_allocator, AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
+ ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardFilterWithAlgorithm(
+ input_desc, input_ptr, output_desc, out_backprop_ptr,
+ conv_desc, filter_desc, &filter_backprop_ptr,
+ &scratch_allocator, AlgorithmConfig(profile_algorithm),
+ &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 0732bf4046..0b2d01afa9 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -870,39 +870,34 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
- ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardDataWithAlgorithm(
- filter_desc, filter_ptr, output_desc, out_backprop_ptr,
- conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
+ ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr,
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 8ad56053a8..21f5cb1716 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -654,40 +654,34 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(
- ConvolveBackwardDataScratchSize, context);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardDataWithAlgorithm(
- filter_desc, filter_ptr, output_desc, out_backprop_ptr,
- conv_desc, input_desc, &in_backprop_ptr,
- &scratch_allocator, AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
+ context);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_desc, filter_ptr, output_desc, out_backprop_ptr,
+ conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
@@ -1026,40 +1020,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(
- ConvolveBackwardFilterScratchSize, context);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveBackwardFilterWithAlgorithm(
- input_desc, input_ptr, output_desc, out_backprop_ptr,
- conv_desc, filter_desc, &filter_backprop_ptr,
- &scratch_allocator, AlgorithmConfig(profile_algorithm),
- &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(
+ ConvolveBackwardFilterScratchSize, context);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveBackwardFilterWithAlgorithm(
+ input_desc, input_ptr, output_desc, out_backprop_ptr,
+ conv_desc, filter_desc, &filter_backprop_ptr,
+ &scratch_allocator, AlgorithmConfig(profile_algorithm),
+ &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index dc03eeb658..bb67113fb0 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -662,38 +662,33 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveWithAlgorithm(
- input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
- output_desc, &output_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ output_desc, &output_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 72758f707a..8a89d564de 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -390,38 +390,33 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmDesc::Index> algorithms;
+ std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(), &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
- // TODO(benbarsdell): Ideally this should not attempt using tensor op math
- // if it's not enabled.
- for (bool use_tensor_ops : {false, true}) {
- for (auto algo_index : algorithms) {
- AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops);
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveWithAlgorithm(
- input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
- output_desc, &output_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ output_desc, &output_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 42ea23553b..5e48ae9766 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -36,8 +36,8 @@ class DecodeCSVOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_quote_delim", &use_quote_delim_));
OP_REQUIRES(ctx, delim.size() == 1,
errors::InvalidArgument("field_delim should be only 1 char"));
-
delim_ = delim[0];
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("na_value", &na_value_));
}
void Compute(OpKernelContext* ctx) override {
@@ -79,9 +79,9 @@ class DecodeCSVOp : public OpKernel {
const DataType& dtype = out_type_[f];
switch (dtype) {
case DT_INT32: {
- // If this field is empty, check if default is given:
+ // If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
- if (fields[f].empty()) {
+ if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
@@ -99,9 +99,9 @@ class DecodeCSVOp : public OpKernel {
break;
}
case DT_INT64: {
- // If this field is empty, check if default is given:
+ // If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
- if (fields[f].empty()) {
+ if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
@@ -119,9 +119,9 @@ class DecodeCSVOp : public OpKernel {
break;
}
case DT_FLOAT: {
- // If this field is empty, check if default is given:
+ // If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
- if (fields[f].empty()) {
+ if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
@@ -138,9 +138,9 @@ class DecodeCSVOp : public OpKernel {
break;
}
case DT_STRING: {
- // If this field is empty, check if default is given:
+ // If this field is empty or NA value, check if default is given:
// If yes, use default value; Otherwise report error.
- if (fields[f].empty()) {
+ if (fields[f].empty() || fields[f] == na_value_) {
OP_REQUIRES(ctx, record_defaults[f].NumElements() == 1,
errors::InvalidArgument(
"Field ", f,
@@ -165,6 +165,7 @@ class DecodeCSVOp : public OpKernel {
std::vector<DataType> out_type_;
char delim_;
bool use_quote_delim_;
+ string na_value_;
void ExtractFields(OpKernelContext* ctx, StringPiece input,
std::vector<string>* result) {
diff --git a/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
index 25a6813d59..0174c8dfc8 100644
--- a/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/dense_to_sparse_batch_dataset_op.cc
@@ -49,10 +49,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->input("row_shape", &row_shape_t));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(row_shape_t->shape()),
errors::InvalidArgument("row_shape must be a vector"));
- TensorShape row_shape;
- for (size_t i = 0; i < row_shape_t->dim_size(0); ++i) {
- row_shape.AddDim(row_shape_t->vec<int64>()(i));
- }
+ PartialTensorShape row_shape;
+ OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
+ row_shape_t->vec<int64>().data(),
+ row_shape_t->NumElements(), &row_shape));
*output = nullptr;
@@ -78,7 +78,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
template <class T>
class Dataset : public DatasetBase {
public:
- Dataset(int64 batch_size, const TensorShape& row_shape,
+ Dataset(int64 batch_size, const PartialTensorShape& row_shape,
const DatasetBase* input)
: batch_size_(batch_size), row_shape_(row_shape), input_(input) {
input_->Ref();
@@ -129,9 +129,22 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
int64 total_elements = 0;
batch_elements.reserve(
DatasetIterator<Dataset<T>>::dataset()->batch_size_);
- const TensorShape& row_shape =
+ const PartialTensorShape& row_shape =
DatasetIterator<Dataset<T>>::dataset()->row_shape_;
const int row_ndims = row_shape.dims();
+
+ // Determine the size of the output tensors:
+ // * dense_shape will be [`row_shape + 1`].
+ Tensor dense_shape(cpu_allocator(), DT_INT64, {row_ndims + 1});
+ auto dense_shape_vec = dense_shape.vec<int64>();
+ for (size_t i = 0; i < row_ndims; ++i) {
+ if (row_shape.dim_size(i) == -1) {
+ dense_shape_vec(i + 1) = 0;
+ } else {
+ dense_shape_vec(i + 1) = row_shape.dim_size(i);
+ }
+ }
+
{
mutex_lock l(mu_);
*end_of_sequence = false;
@@ -156,9 +169,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
") that is incompatible with the row shape (",
row_shape.DebugString(), ").");
}
- for (int i = 0; i < row_ndims; ++i) {
- if (batch_element_tuple[0].shape().dim_size(i) >
- row_shape.dim_size(i)) {
+ for (int j = 0; j < row_ndims; ++j) {
+ // Take the maximum in the dimension if -1 is given.
+ if (row_shape.dim_size(j) == -1) {
+ dense_shape_vec(j + 1) =
+ std::max(batch_element_tuple[0].dim_size(j),
+ dense_shape_vec(j + 1));
+ } else if (batch_element_tuple[0].dim_size(j) >
+ row_shape.dim_size(j)) {
return errors::DataLoss(
"Input element had shape (",
batch_element_tuple[0].shape().DebugString(),
@@ -175,20 +193,16 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- // Determine the size of the output tensors:
// * indices will be [`total_elements`, `row_shape + 1`].
// * values will be [`total_elements`].
- // * dense_shape will be [`row_shape + 1`].
Tensor indices(cpu_allocator(), DT_INT64,
{total_elements, row_ndims + 1});
Tensor values(
cpu_allocator(),
DatasetIterator<Dataset<T>>::dataset()->output_dtypes()[1],
{total_elements});
- Tensor dense_shape(cpu_allocator(), DT_INT64, {row_ndims + 1});
auto indices_matrix = indices.matrix<int64>();
auto values_flat = values.flat<T>();
- auto dense_shape_vec = dense_shape.vec<int64>();
int64 current_position_in_values = 0;
for (int64 i = 0; i < batch_elements.size(); ++i) {
@@ -220,9 +234,6 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
}
dense_shape_vec(0) = batch_elements.size();
- for (size_t i = 0; i < row_ndims; ++i) {
- dense_shape_vec(i + 1) = row_shape.dim_size(i);
- }
out_tensors->push_back(std::move(indices));
out_tensors->push_back(std::move(values));
@@ -239,7 +250,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
};
const int64 batch_size_;
- const TensorShape row_shape_;
+ const PartialTensorShape row_shape_;
const DatasetBase* const input_;
std::vector<PartialTensorShape> output_shapes_;
};
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index f81a448e51..9080bf7be8 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -41,10 +42,24 @@ limitations under the License.
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+
+using mkldnn::prop_kind;
+using mkldnn::stream;
+
+using mkldnn::convolution_backward_weights;
+using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
+
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifndef INTEL_MKL_DNN
+
template <typename Device, class T>
class MklConv2DCustomBackpropFilterOp : public OpKernel {
public:
@@ -411,6 +426,172 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
TensorFormat data_format_;
};
+#else
+
+template <typename Device, class T>
+class MklConv2DCustomBackpropFilterOp : public OpKernel {
+ public:
+ explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ int stride_n = GetTensorDim(strides_, data_format_, 'N');
+ int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ context, (stride_n == 1 && stride_c == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ auto cpu_engine = engine(engine::cpu, 0);
+
+ MklDnnData<T> input(&cpu_engine);
+ MklDnnData<T> outbackprop(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Input tensors
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ const Tensor& filter_tensor = MklGetInput(context, 1);
+ const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
+
+ // Generate input shapes.
+ TensorShape filter_shape;
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(filter_tensor.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
+ filter_tensor.dims()));
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_tensor.vec<int32>(), &filter_shape));
+ TensorShape input_shape = input_tensor.shape();
+ TensorShape obp_shape = obp_tensor.shape();
+
+ // By default, all dims are in MKL order. Only dims in TF order
+ // are those with prefix tf_order.
+ memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
+ memory::dims padding_l, padding_r, strides, fwd_output_dims;
+ memory::dims fwd_output_dims_tf_order;
+
+ // Get forward convolution parameters.
+ MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims,
+ &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
+ &padding_r);
+ if (!context->status().ok()) return;
+
+ // Create Convolution forward descriptor since Convolution backward
+ // API needs it. For that, we first need to create input, filter
+ // and output memory descriptors.
+ auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
+ auto fwd_src_md =
+ memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_filter_md =
+ memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio);
+ auto fwd_out_md =
+ memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md,
+ fwd_out_md, strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+ auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
+
+ // Allocate output tensor and shape
+ // TODO(nhasabni): Update this when support for MKL layout is added.
+ // Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
+ TensorShape tf_output_shape(filter_shape);
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ Tensor* output_tensor = nullptr;
+ AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
+ mkl_output_mkl_shape);
+
+ // Create memory for user data.
+ // Describe how the inputs and outputs of Convolution look like. Also
+ // specify buffers containing actual input and output data.
+ // Although input shape required is in MKL-DNN order, the layout is
+ // Tensorflow's layout (NHWC or NCHW depending on data format).
+ input.SetUsrMem(fwd_input_dims, mkl_data_format, &input_tensor);
+ // Outbackprop shape is NHWC or NCHW depending on data format. Since
+ // GetInputSizeInMklOrder function returns size in that order we just use
+ // use that function directly.
+ conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
+ if (!context->status().ok()) return;
+ outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
+ // Although output shape required is in MKL-DNN order,
+ // layout is Tensorflow's filter layout (HWIO)
+ // Shape of output of Conv2DBackpropInput is same as shape of filter.
+ memory::dims bwd_output_dims = fwd_filter_dims;
+ output.SetUsrMem(bwd_output_dims, memory::format::hwio, output_tensor);
+
+ // Create memory descriptors for convolution data w/ no specified format.
+ input.SetOpMemDesc(fwd_input_dims, memory::format::any);
+ outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
+ output.SetOpMemDesc(bwd_output_dims, memory::format::any);
+
+ // Create convolution backward weights primitive.
+ auto bwd_desc = convolution_backward_weights::desc(
+ convolution_direct, input.GetOpMemDesc(), output.GetOpMemDesc(),
+ outbackprop.GetOpMemDesc(), strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+
+ auto bwd_pd = convolution_backward_weights::primitive_desc(
+ bwd_desc, cpu_engine, fwd_pd);
+
+ PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output);
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ // Prepare and execute net - checks for input and output reorders.
+ void PrepareAndExecutePrimitive(
+ const convolution_backward_weights::primitive_desc& conv_pd,
+ MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output) {
+ // Create reorders between user layout and MKL layout if it is needed and
+ // add it to the net before convolution.
+ std::vector<primitive> net;
+ input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net);
+ obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
+
+ // Memory for output of convolution. Since we may need reorder on the
+ // output side, we will prepare reorder primitive in case output
+ // reorder to user memory is required.
+ bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
+ conv_pd.diff_weights_primitive_desc());
+
+ net.push_back(convolution_backward_weights(
+ conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem()));
+
+ // Insert reorder primitive in the net for output reorder if reorder is
+ // required.
+ if (output_reorder_required) {
+ output->InsertReorderToUserMem(&net);
+ }
+
+ // Handle output reorder
+ stream(stream::kind::eager).submit(net).wait();
+ }
+};
+#endif
+
#define REGISTER_MKL_FILTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 00884d0981..4b6bf92e42 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -23,6 +23,8 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <vector>
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -30,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -40,13 +43,24 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
+
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+
+using mkldnn::prop_kind;
+using mkldnn::stream;
+
+using mkldnn::convolution_backward_data;
+using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
+#endif
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifndef INTEL_MKL_DNN
+
template <typename Device, class T>
class MklConv2DCustomBackpropInputOp : public OpKernel {
public:
@@ -345,6 +359,178 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format;
};
+#else
+
+template <typename Device, class T>
+class MklConv2DCustomBackpropInputOp : public OpKernel {
+ public:
+ ~MklConv2DCustomBackpropInputOp() {}
+ explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string data_format_str;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
+ OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ int stride_n = GetTensorDim(strides_, data_format_, 'N');
+ int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ context, (stride_n == 1 && stride_c == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ auto cpu_engine = engine(engine::cpu, 0);
+
+ MklDnnData<T> filter(&cpu_engine);
+ MklDnnData<T> outbackprop(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Input tensors
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ const Tensor& filter_tensor = MklGetInput(context, 1);
+ const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
+
+ // Generate input shape.
+ TensorShape input_shape;
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsVector(input_tensor.shape()),
+ errors::InvalidArgument(
+ "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
+ input_tensor.dims()));
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ input_tensor.vec<int32>(), &input_shape));
+ TensorShape filter_shape = filter_tensor.shape();
+ TensorShape obp_shape = obp_tensor.shape();
+
+ // By default, all dims are in MKL order. Only dims in TF order
+ // are those with prefix tf_order.
+ memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
+ memory::dims padding_l, padding_r, strides, fwd_output_dims;
+ memory::dims fwd_output_dims_tf_order;
+
+ // Get forward convolution parameters.
+ MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ input_shape, filter_shape, &fwd_input_dims, &fwd_filter_dims,
+ &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
+ &padding_r);
+ if (!context->status().ok()) return;
+
+ // Create Convolution forward descriptor since Convolution backward
+ // API needs it. For that, we first need to create input, filter
+ // and output memory descriptors.
+ auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
+ auto fwd_src_md =
+ memory::desc(fwd_input_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_filter_md =
+ memory::desc(fwd_filter_dims, MklDnnType<T>(), memory::format::hwio);
+ auto fwd_out_md =
+ memory::desc(fwd_output_dims, MklDnnType<T>(), mkl_data_format);
+ auto fwd_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, fwd_src_md, fwd_filter_md,
+ fwd_out_md, strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+ auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
+
+ // Allocate output tensor and shape
+ // TODO(nhasabni): Update this when support for MKL layout is added.
+ // Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
+ TensorShape tf_output_shape(input_shape);
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ Tensor* output_tensor = nullptr;
+ AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
+ mkl_output_mkl_shape);
+
+ // Create memory for user data.
+ // Describe how the inputs and outputs of Convolution look like. Also
+ // specify buffers containing actual input and output data.
+ // Although input shape required is in MKL-DNN order, the layout is
+ // Tensorflow's layout (NHWC or NCHW depending on data format).
+ // Although filter shape (filter_dims) required is in MKL-DNN order,
+ // the layout is Tensorflow's layout (HWIO).
+ // Shape of Conv2DBackpropInput's filter is same as that of Conv2D filter.
+ filter.SetUsrMem(fwd_filter_dims, memory::format::hwio, &filter_tensor);
+ // Outbackprop shape is NHWC or NCHW depending on data format. Since
+ // GetInputSizeInMklOrder function returns size in that order we just use
+ // use that function directly.
+ conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
+ if (!context->status().ok()) return;
+ outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
+ // Although output shape required is in MKL-DNN order,
+ // layout is Tensorflow's layout (NHWC or NCHW depending on data format).
+ // Shape of output of Conv2DBackpropInput is same as shape of 'input'
+ // of Conv2D.
+ memory::dims bwd_output_dims = fwd_input_dims;
+ output.SetUsrMem(bwd_output_dims, mkl_data_format, output_tensor);
+
+ // Create memory descriptors for convolution data w/ no specified format.
+ filter.SetOpMemDesc(fwd_filter_dims, memory::format::any);
+ outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
+ output.SetOpMemDesc(bwd_output_dims, memory::format::any);
+
+ // Create convolution backward data primitive.
+ auto bwd_desc = convolution_backward_data::desc(
+ convolution_direct, output.GetOpMemDesc(), filter.GetOpMemDesc(),
+ outbackprop.GetOpMemDesc(), strides, padding_l, padding_r,
+ TFPaddingToMklDnnPadding(padding_));
+
+ auto bwd_pd = convolution_backward_data::primitive_desc(
+ bwd_desc, cpu_engine, fwd_pd);
+
+ PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output);
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ // Prepare and execute net - checks for input and output reorders.
+ void PrepareAndExecutePrimitive(
+ const convolution_backward_data::primitive_desc& conv_pd,
+ MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) {
+ // Create reorders between user layout and MKL layout if it is needed and
+ // add it to the net before convolution.
+ std::vector<primitive> net;
+ filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net);
+ obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
+
+ // Memory for output of convolution. Since we may need reorder on the
+ // output side, we will prepare reorder primitive in case output
+ // reorder to user memory is required.
+ bool output_reorder_required =
+ output->PrepareReorderToUserMemIfReq(conv_pd.diff_src_primitive_desc());
+
+ net.push_back(convolution_backward_data(
+ conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem()));
+
+ // Insert reorder primitive in the net for output reorder if reorder is
+ // required.
+ if (output_reorder_required) {
+ output->InsertReorderToUserMem(&net);
+ }
+
+ // Handle output reorder
+ stream(stream::kind::eager).submit(net).wait();
+ }
+};
+
+#endif // INTEL_MKL_DNN
+
#define REGISTER_MKL_CPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 7f1555d325..57661e8b10 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include <string.h>
#include <map>
+#include <string>
#include <vector>
+
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -40,10 +43,23 @@ limitations under the License.
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+
+using mkldnn::prop_kind;
+using mkldnn::stream;
+
+using mkldnn::convolution_direct;
+using mkldnn::convolution_forward;
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+// For now, MKL-ML is default. So making MKL-DNN not a default choice.
+#ifndef INTEL_MKL_DNN
+
template <typename Device, typename T, bool biasEnabled>
class MklConv2DOp : public OpKernel {
public:
@@ -461,6 +477,203 @@ class MklConv2DOp : public OpKernel {
TensorFormat data_format_;
};
+#else
+
+template <typename Device, typename T, bool biasEnabled>
+class MklConv2DOp : public OpKernel {
+ public:
+ ~MklConv2DOp() {}
+
+ explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+
+ const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
+ const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ context, stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ auto cpu_engine = engine(engine::cpu, 0);
+
+ // Input tensors
+ size_t src_idx = 0, filter_idx = 1;
+ const Tensor& src_tensor = MklGetInput(context, src_idx);
+ const Tensor& filter_tensor = MklGetInput(context, filter_idx);
+
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> filter(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ memory::dims src_dims, filter_dims, padding_l, padding_r, strides;
+ memory::dims output_dims_tf_order, output_dims_mkl_order;
+
+ // Get shapes of input tensors in MKL-DNN order
+ MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ src_tensor.shape(), filter_tensor.shape(), &src_dims, &filter_dims,
+ &strides, &output_dims_tf_order, &output_dims_mkl_order, &padding_l,
+ &padding_r);
+ if (!context->status().ok()) return;
+
+ // Check for corner case - if there is nothing to compute, return.
+ TensorShape tf_output_shape(
+ {output_dims_tf_order[0], output_dims_tf_order[1],
+ output_dims_tf_order[2], output_dims_tf_order[3]});
+ Tensor* output_tensor = nullptr;
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
+ mkl_output_mkl_shape);
+
+ // Forward filter in TF format from input at index 1 to output at index 1.
+ ForwardTfTensorInToOut(context, 1, 1);
+
+ if (tf_output_shape.num_elements() == 0) {
+ // TODO(jbobba): Verify correctness here
+ // Need semantics for Null MKL tensor
+ return;
+ }
+
+ // Corner case to handle 0 batch size.
+ if (output_dims_tf_order[0] == 0) {
+ // Nothing to do, allocate output tensor and return
+ // TODO(nhasabni): remove this code later once serialization
+ // in MKL-DNN is supported.
+ AllocateOutputSetMklShape(context, 0, &output_tensor,
+ src_tensor.shape(), mkl_output_mkl_shape);
+ return;
+ } else {
+ // Otherwise regular output tensor allocation
+ // Allocate output tensor.
+ }
+ CHECK_NOTNULL(output_tensor);
+
+ // Create memory for user data.
+ // Describe how the inputs and outputs of Convolution look like. Also
+ // specify buffers containing actual input and output data.
+ // Although input shape (src_dims) required is in MKL-DNN order,
+ // the layout is Tensorflow's layout (NHWC or NCHW depending on data
+ // format).
+ src.SetUsrMem(src_dims, TFDataFormatToMklDnnDataFormat(data_format_),
+ const_cast<void*>(
+ static_cast<const void*>(src_tensor.flat<T>().data())));
+ // Although filter shape (filter_dims) required is in MKL-DNN order,
+ // the layout is Tensorflow's layout (HWIO).
+ filter.SetUsrMem(filter_dims, memory::format::hwio,
+ const_cast<void*>(static_cast<const void*>(
+ filter_tensor.flat<T>().data())));
+ // Although output shape (output_dims) required is in MKL-DNN order,
+ // layout is Tensorflow's layout (NHWC or NCHW depending on data format).
+ output.SetUsrMem(output_dims_mkl_order,
+ TFDataFormatToMklDnnDataFormat(data_format_),
+ output_tensor->flat<T>().data());
+
+ // Create memory descriptors for convolution data w/ no specified format.
+ src.SetOpMemDesc(src_dims, memory::format::any);
+ filter.SetOpMemDesc(filter_dims, memory::format::any);
+ output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
+
+ // If bias is enabled, then do the same steps as above for bias.
+ if (biasEnabled) {
+ MklDnnData<T> bias(&cpu_engine);
+ memory::dims bias_size;
+ conv_utl.GetBiasSizeInMklOrder(2 /* bias idx */, &bias_size);
+ const Tensor& bias_tensor = MklGetInput(context, 2);
+ bias.SetUsrMem(bias_size, memory::format::x,
+ const_cast<void*>(static_cast<const void*>(
+ bias_tensor.flat<T>().data())));
+ bias.SetOpMemDesc(bias_size, memory::format::any);
+
+ // Create convolution primitive with Bias.
+ auto conv_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
+ filter.GetOpMemDesc(), bias.GetOpMemDesc(), output.GetOpMemDesc(),
+ strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
+
+ auto conv_prim_desc =
+ convolution_forward::primitive_desc(conv_desc, cpu_engine);
+ PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output);
+ } else {
+ // Create convolution primitive without Bias.
+ auto conv_desc = convolution_forward::desc(
+ prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
+ filter.GetOpMemDesc(), output.GetOpMemDesc(), strides, padding_l,
+ padding_r, TFPaddingToMklDnnPadding(padding_));
+
+ auto conv_prim_desc =
+ convolution_forward::primitive_desc(conv_desc, cpu_engine);
+ PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output);
+ }
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + std::string(e.message) + ", in file " +
+ std::string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ // Prepare and execute net - checks for input and output reorders.
+ void PrepareAndExecuteNet(
+ const convolution_forward::primitive_desc& conv_prim_desc,
+ MklDnnData<T>* src, MklDnnData<T>* filter, MklDnnData<T>* bias,
+ MklDnnData<T>* output) {
+ // Create reorders between user layout and MKL layout if it is needed and
+ // add it to the net before convolution.
+ std::vector<primitive> net;
+ src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net);
+ filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), &net);
+
+ // Memory for output of convolution. Since we may need reorder on the
+ // output side, we will prepare reorder primitive in case output
+ // reorder to user memory is required.
+ bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
+ conv_prim_desc.dst_primitive_desc());
+
+ // Create convolution primitive and add it to net.
+ if (bias) {
+ CHECK_EQ(biasEnabled, true);
+ net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
+ filter->GetOpMem(), bias->GetOpMem(),
+ output->GetOpMem()));
+ } else {
+ CHECK_EQ(biasEnabled, false);
+ net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
+ filter->GetOpMem(),
+ output->GetOpMem()));
+ }
+
+ // Insert reorder primitive in the net for output reorder if reorder is
+ // required.
+ if (output_reorder_required) {
+ output->InsertReorderToUserMem(&net);
+ }
+
+ // Handle output reorder
+ stream(stream::kind::eager).submit(net).wait();
+ }
+};
+
+#endif
+
#define REGISTER_MKL_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
new file mode 100644
index 0000000000..e29af19ca9
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -0,0 +1,308 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#include "tensorflow/core/util/mkl_util.h"
+
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+#endif
+
+namespace tensorflow {
+
+#ifdef INTEL_MKL_DNN
+
+class MklDnnConvUtil {
+ protected:
+ OpKernelContext *context_; // We don't own this.
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ public:
+ MklDnnConvUtil(OpKernelContext *context, const std::vector<int32> &strides,
+ Padding pad, TensorFormat fm)
+ : context_(context), strides_(strides), padding_(pad), data_format_(fm) {}
+
+ virtual ~MklDnnConvUtil() { context_ = nullptr; }
+
+ // Calculate Convolution strides
+ virtual inline void GetStridesInMklOrder(memory::dims *strides) {
+ // For now we take the stride from the second and third dimensions only
+ // (we do not support striding on the batch or depth dimension).
+ CHECK_NOTNULL(strides);
+ int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ *strides = {stride_rows, stride_cols};
+ }
+
+ // Calculate Convolution input size in MKL-DNN order. MKL-DNN
+ // requires input in NCHW format. Function does not return anything.
+ // But errors arising from sanity checks are returned in context's
+ // status.
+ virtual inline void GetInputSizeInMklOrder(const TensorShape &input_shape,
+ memory::dims *input_dims) {
+#define CHECK_BOUNDS(val, err_msg) \
+ do { \
+ OP_REQUIRES(context_, \
+ FastBoundsCheck(val, std::numeric_limits<int>::max()), \
+ errors::InvalidArgument(err_msg)); \
+ } while (0)
+
+ CHECK_NOTNULL(input_dims);
+
+ // Input channel
+ int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
+ int input_depth = static_cast<int>(input_depth_raw);
+
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // Input batch
+ int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
+ CHECK_BOUNDS(input_batch_raw, "Input batch too large");
+ int input_batch = static_cast<int>(input_batch_raw);
+
+#undef CHECK_BOUNDS
+
+ // MKL-DNN always requires input in NCHW format.
+ *input_dims = {input_batch, input_depth, input_rows, input_cols};
+ }
+
+ // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
+ // requires filter in OIHW format. Function does not return anything.
+ // But errors arising from sanity checks are returned in context's
+ // status.
+ //
+ // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
+ // requires filter in OIHW format. Function does not return anything.
+ // But errors arising from sanity checks are returned in context's
+ // status. This function differs from GetConvFilterSizeInMklOrder in
+ // parameter for input - it accepts src_shape since Convolution Backward
+ // Input gets shape of input tensor rather than actual tensor (Convolution
+ // forward gets actual tensor as input).
+ //
+ // TODO(nhasabni): Add similar function for input and filter in MklShape.
+ virtual inline void GetFilterSizeInMklOrder(const TensorShape &input_shape,
+ const TensorShape &filter_shape,
+ memory::dims *filter_dims) {
+ CHECK_NOTNULL(filter_dims);
+
+ OP_REQUIRES(context_, filter_shape.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter_shape.DebugString()));
+
+ for (int i = 0; i < 3; i++) {
+ OP_REQUIRES(context_,
+ FastBoundsCheck(filter_shape.dim_size(i),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ int input_depth = GetTensorDim(input_shape, data_format_, 'C');
+
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ", input_depth,
+ " vs ", filter_shape.dim_size(2)));
+
+ // TF filter is always in (rows, cols, in_depth, out_depth) order.
+ int filter_rows = static_cast<int>(filter_shape.dim_size(0));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(1));
+ int in_depth = static_cast<int>(filter_shape.dim_size(2));
+ int out_depth = static_cast<int>(filter_shape.dim_size(3));
+
+ // MKL-DNN always needs filter in OIHW format.
+ // OIHW = (out_depth, in_depth, rows, cols)
+ *filter_dims = {out_depth, in_depth, filter_rows, filter_cols};
+ }
+
+ // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
+ // requires filter in OIHW format. Function does not return anything.
+ // But errors arising from sanity checks are returned in context's
+ // status.
+ virtual inline void GetFilterSizeInMklOrder(size_t src_index,
+ size_t filter_index,
+ memory::dims *filter_dims) {
+ CHECK_NOTNULL(filter_dims);
+ const Tensor &input = MklGetInput(context_, src_index);
+ const Tensor &filter = MklGetInput(context_, filter_index);
+ GetFilterSizeInMklOrder(input.shape(), filter.shape(), filter_dims);
+ }
+
+ // Calculate Bias size for 2D Convolution. Function does not return
+ // anything, but sets error in context status.
+ virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
+ memory::dims *bias_dims) {
+ const Tensor &bias = MklGetInput(context_, bias_index);
+ OP_REQUIRES(context_, bias.dims() == 1,
+ errors::InvalidArgument("bias must be 1-dimensional: ",
+ bias.shape().DebugString()));
+
+ *bias_dims = {static_cast<int>(bias.dim_size(0))};
+ }
+
+ // Function to calculate output and padding size for 2D convolution.
+ //
+ // Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
+ // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
+ // NHWC or NCHW format depending on data format. Function also calculates
+ // left, right, top and bottom pads. Function does not return any status -
+ // status is returned via context status.
+ //
+ // TODO(nhasabni): Add similar function for input and filter in MklShape.
+ virtual inline void GetOutputAndPadSizeInMklOrder(
+ const TensorShape &input_shape, const TensorShape &filter_shape,
+ const memory::dims &strides, memory::dims *output_dims_tf_order,
+ memory::dims *output_dims_mkl_order, memory::dims *pad_l,
+ memory::dims *pad_r) {
+ CHECK_NOTNULL(output_dims_tf_order);
+ CHECK_NOTNULL(output_dims_mkl_order);
+ CHECK_NOTNULL(pad_l);
+ CHECK_NOTNULL(pad_r);
+
+ int input_rows = GetTensorDim(input_shape, data_format_, 'H');
+ int input_cols = GetTensorDim(input_shape, data_format_, 'W');
+
+ // The first dimension for filter is rows/height.
+ int filter_rows = filter_shape.dim_size(0);
+ // The second dimension for filter is cols/width.
+ int filter_cols = filter_shape.dim_size(1);
+
+ // Stride is vector of 2 elements: {s_r, s_c}
+ int stride_rows = strides[0];
+ int stride_cols = strides[1];
+
+ // Output batch is same as input batch.
+ int out_batch = GetTensorDim(input_shape, data_format_, 'N');
+ // Output depth is same as last dimension for filter.
+ int out_depth = filter_shape.dim_size(3);
+
+ int64 out_rows = 0, out_cols = 0;
+ int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
+
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_rows, filter_rows, stride_rows, padding_,
+ &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_cols, filter_cols, stride_cols, padding_,
+ &out_cols, &pad_left, &pad_right));
+
+ // Tensorflow output is in data_format order. (NHWC or NCHW)
+ TensorShape out_shape =
+ ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
+ *output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
+
+ // MKL-DNN always needs output in NCHW format.
+ *output_dims_mkl_order = {out_batch, out_depth, static_cast<int>(out_rows),
+ static_cast<int>(out_cols)};
+
+ // Now handle padding. MKL-DNN uses asymetric padding.
+ *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ }
+
+ // Calculate output and pad size of forward Convolution operator.
+ // See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
+ //
+ // Function does not return anything, but sets error in context status.
+ inline void GetOutputAndPadSizeInMklOrder(
+ size_t src_index, size_t filter_index, const memory::dims &strides,
+ memory::dims *output_dims_tf_order, memory::dims *output_dims_mkl_order,
+ memory::dims *pad_l, memory::dims *pad_r) {
+ CHECK_NOTNULL(output_dims_tf_order);
+ CHECK_NOTNULL(output_dims_mkl_order);
+ CHECK_NOTNULL(pad_l);
+ CHECK_NOTNULL(pad_r);
+
+ const Tensor &input = MklGetInput(context_, src_index);
+ const Tensor &filter = MklGetInput(context_, filter_index);
+
+ OP_REQUIRES(context_, input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+
+ GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(), strides,
+ output_dims_tf_order, output_dims_mkl_order,
+ pad_l, pad_r);
+ }
+
+ // Wrapper function to calculate input, filter, and output sizes of
+ // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
+ // Function also calculates output shape in Tensorflow order. Additionally, it
+ // also calculates strides and paddings for 2D Convolution.
+ //
+ // Function does not return anything, but sets error in context status.
+ inline void GetConvFwdSizesInMklOrder(
+ const TensorShape &input_shape, const TensorShape &filter_shape,
+ memory::dims *input_dims, memory::dims *filter_dims,
+ memory::dims *strides, memory::dims *output_dims_tf_order,
+ memory::dims *output_dims_mkl_order, memory::dims *pad_l,
+ memory::dims *pad_r) {
+ CHECK_NOTNULL(input_dims);
+ CHECK_NOTNULL(filter_dims);
+ CHECK_NOTNULL(strides);
+ CHECK_NOTNULL(output_dims_tf_order);
+ CHECK_NOTNULL(output_dims_mkl_order);
+ CHECK_NOTNULL(pad_l);
+ CHECK_NOTNULL(pad_r);
+
+ GetInputSizeInMklOrder(input_shape, input_dims);
+ if (!context_->status().ok()) return;
+ GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims);
+ if (!context_->status().ok()) return;
+ GetStridesInMklOrder(strides);
+ GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides,
+ output_dims_tf_order, output_dims_mkl_order,
+ pad_l, pad_r);
+ if (!context_->status().ok()) return;
+ }
+};
+
+#endif // INTEL_MKL_DNN
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
index 7fc633c254..c065724e0d 100644
--- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
@@ -48,7 +48,7 @@ class MklBinaryOp : public BinaryOp<Device, Functor> {
auto out = context->mutable_output(0);
VLOG(1) << "Shapes (output): " << out->shape().DebugString();
- // Pass input shape through to ouput shape
+ // Pass input shape through to output shape
ForwardMklMetaDataInToOut(context, 0, 0);
out = context->mutable_output(0);
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index 3c85737702..302a6967e3 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -340,7 +340,7 @@ char* FloatToBuffer(float value, char* buffer) {
float parsed_value;
if (!safe_strtof(buffer, &parsed_value) || parsed_value != value) {
snprintf_result =
- snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 2, value);
+ snprintf(buffer, kFastToBufferSize, "%.*g", FLT_DIG + 3, value);
// Should never overflow; see above.
DCHECK(snprintf_result > 0 && snprintf_result < kFastToBufferSize);
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index df189af1b8..c0e84c8bb0 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -383,7 +383,8 @@ input_dataset: A handle to an input dataset. Must have a single component.
batch_size: A scalar representing the number of elements to accumulate in a
batch.
row_shape: A vector representing the dense shape of each row in the produced
- SparseTensor.
+ SparseTensor. The shape may be partially specified, using `-1` to indicate
+ that a particular dimension should use the maximum size of all batch elements.
)doc");
REGISTER_OP("RangeDataset")
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 3dc16ac457..b34dc1a008 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -29,22 +29,6 @@ using shape_inference::ShapeHandle;
namespace {
-// A shape function that uses the tensor value at <input_idx> as a shape for
-// output 0. If the tensor value is not available, it uses a shape with <ndims>
-// unknown dims.
-Status InputTensorShapeOrUnknown(InferenceContext* c, int input_idx,
- int ndims) {
- ShapeHandle out;
- const Tensor* input = c->input_tensor(input_idx);
- if (input == nullptr) {
- out = c->UnknownShapeOfRank(ndims);
- } else {
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(input_idx, &out));
- }
- c->set_output(0, out);
- return Status::OK();
-}
-
Status FractionalPoolShapeFn(InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
@@ -119,11 +103,11 @@ REGISTER_OP("AvgPoolGrad")
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {half, float, double}")
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes gradients of the average pooling function.
@@ -583,11 +567,11 @@ REGISTER_OP("Conv2DBackpropInput")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes the gradients of convolution with respect to the input.
@@ -625,11 +609,11 @@ REGISTER_OP("Conv2DBackpropFilter")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes the gradients of convolution with respect to the filter.
@@ -882,11 +866,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes the gradients of depthwise convolution with respect to the input.
@@ -924,11 +908,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- // NOTE(mrry): We could in principle work out the shape from the
- // gradients and the attrs, but if we do not know orig_input_shape
- // statically, then we are unlikely to know the shape of the
- // gradients either.
- return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
Computes the gradients of depthwise convolution with respect to the filter.
@@ -2870,7 +2854,11 @@ REGISTER_OP("_MklConv2DBackpropFilter")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
@@ -2911,7 +2899,11 @@ REGISTER_OP("_MklConv2DBackpropInput")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn([](InferenceContext* c) {
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
@@ -3034,7 +3026,11 @@ REGISTER_OP("_MklAvgPoolGrad")
.Attr(GetConvnetDataFormatAttrString())
.Attr("T: {float, half, double}")
.SetShapeFn([](InferenceContext* c) {
- return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
+ c->set_output(0, s);
+ return Status::OK();
})
.Doc(R"doc(
MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc
index 51e4f8bffe..4628b725f8 100644
--- a/tensorflow/core/ops/nn_ops_test.cc
+++ b/tensorflow/core/ops/nn_ops_test.cc
@@ -81,55 +81,6 @@ TEST(NNOpsTest, TopKV2_ShapeFn) {
op, "[1,2,3,4];[]");
}
-TEST(NNOpsTest, InputTensorShapeOrUnknown2D_ShapeFn) {
- typedef std::pair<const char*, int> NameAndInputIndex;
- for (const auto& p :
- {NameAndInputIndex("AvgPoolGrad", 0),
- NameAndInputIndex("Conv2DBackpropInput", 0),
- NameAndInputIndex("Conv2DBackpropFilter", 1),
- NameAndInputIndex("DepthwiseConv2dNativeBackpropInput", 0),
- NameAndInputIndex("DepthwiseConv2dNativeBackpropFilter", 1)}) {
- ShapeInferenceTestOp op(p.first);
- op.input_tensors.resize(2);
-
- // Conv and Depthwise conv have three inputs.
- string extra_shapes = (op.name == "AvgPoolGrad" ? "" : ";?");
-
- // When the input tensor is not known, the output is 4 unknown dims.
- INFER_OK(op, "?;?" + extra_shapes, "[?,?,?,?]");
- INFER_OK(op, "[4];?" + extra_shapes, "[?,?,?,?]");
-
- // When input tensor is known, its values determine output shape.
- std::vector<int32> shape{1, 2, 3, 4};
- Tensor shape_t = test::AsTensor<int32>(shape);
- op.input_tensors[p.second] = &shape_t;
- INFER_OK(op, "[4];?" + extra_shapes, "[1,2,3,4]");
- }
-}
-
-TEST(NNOpsTest, InputTensorShapeOrUnknown3D_ShapeFn) {
- typedef std::pair<const char*, int> NameAndInputIndex;
- for (const auto& p : {NameAndInputIndex("AvgPool3DGrad", 0),
- NameAndInputIndex("Conv3DBackpropInputV2", 0),
- NameAndInputIndex("Conv3DBackpropFilterV2", 1)}) {
- ShapeInferenceTestOp op(p.first);
- op.input_tensors.resize(2);
-
- // Conv3D has an extra shape.
- string extra_shapes = (op.name == "AvgPool3DGrad" ? "" : ";?");
-
- // When the input tensor is not known, the output is 4 unknown dims.
- INFER_OK(op, "?;?" + extra_shapes, "[?,?,?,?,?]");
- INFER_OK(op, "[5];?" + extra_shapes, "[?,?,?,?,?]");
-
- // When input tensor is known, its values determine output shape.
- std::vector<int32> shape{1, 2, 3, 4, 5};
- Tensor shape_t = test::AsTensor<int32>(shape);
- op.input_tensors[p.second] = &shape_t;
- INFER_OK(op, "[5];?" + extra_shapes, "[1,2,3,4,5]");
- }
-}
-
TEST(NNOpsTest, BatchNormWithGlobalNormalization_ShapeFn) {
ShapeInferenceTestOp op("BatchNormWithGlobalNormalization");
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index f23ff083af..b44ea2e080 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -332,6 +332,7 @@ REGISTER_OP("DecodeCSV")
.Attr("OUT_TYPE: list({float,int32,int64,string})")
.Attr("field_delim: string = ','")
.Attr("use_quote_delim: bool = true")
+ .Attr("na_value: string = ''")
.SetShapeFn([](InferenceContext* c) {
// Validate the record_defaults inputs.
for (int i = 1; i < c->num_inputs(); ++i) {
@@ -362,6 +363,7 @@ field_delim: char delimiter to separate fields in a record.
use_quote_delim: If false, treats double quotation marks as regular
characters inside of the string fields (ignoring RFC 4180, Section 2,
Bullet 5).
+na_value: Additional string to recognize as NA/NaN.
output: Each tensor will have the same shape as records.
)doc");
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index f4bec9524a..1bfa4f83a3 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -26,13 +26,19 @@ limitations under the License.
#include "mkl_trans.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+#endif
// The file contains a number of utility classes and functions used by MKL
// enabled kernels
@@ -219,19 +225,18 @@ class MklShape {
// Location from start of buffer where isMklTensor_ is serialized
#define DIMS_OFFSET \
(IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_
-#define SIZES_OFFSET(dims) \
- (DIMS_OFFSET + \
- sizeof(size_t)) // Location of sizes. Note dim is not used here, left here
- // to make macros consistent.
+// Location of sizes. Note dim is not used here, left here
+// to make macros consistent.
+#define SIZES_OFFSET(dims) (DIMS_OFFSET + sizeof(size_t))
#define STRIDES_OFFSET(dims) \
(SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides
#define MKL_LAYOUT_OFFSET(dims) \
(STRIDES_OFFSET(dims) + dims * sizeof(size_t)) // Location of mklLayout_
#define TF_LAYOUT_OFFSET(dims) \
(MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_
+// Location of tf_to_mkl_dim_map_
#define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
- (TF_LAYOUT_OFFSET(dims) + \
- SIZE_OF_MKL_DNN_BUF) // Location of tf_to_mkl_dim_map_
+ (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
// TODO(agramesh1) make sure to create a const to share with rewrite pass
// for min size of MKL metadata tensor.
@@ -342,58 +347,6 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
return output_tensor;
}
-// Since our ops are going to produce and also consume N addition tensors
-// (Mkl) for N Tensorflow tensors, we can have following different
-// orderings among these 2N tensors.
-//
-// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
-// consume A_m, B_m, and C_m additionally.
-//
-// INTERLEAVED: in this case 2N tensors are interleaved. So for above
-// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
-//
-// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
-// by N Mkl tensors. So for above example, the ordering looks
-// like: A, B, C, A_m, B_m, C_m
-//
-// Following APIs map index of original Tensorflow tensors to their appropriate
-// position based on selected ordering. For contiguous ordering, we need to know
-// the total number of tensors (parameter total).
-//
-typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
-// NOTE: Currently, we use contiguous ordering. If you change this, then you
-// would need to change Mkl op definitions in nn_ops.cc.
-static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
-
-// Get index of MetaData tensor from index 'n' of Data tensor.
-inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
- if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
- // For interleaved ordering, Mkl tensor follows immediately after
- // Tensorflow tensor.
- return n + 1;
- } else {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
- return n + total_tensors / 2;
- }
-}
-
-int inline GetTensorDataIndex(int n, int total_tensors) {
- if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
- return 2 * n; // index corresponding to nth input/output tensor
- } else {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- return n;
- }
-}
-
-int inline GetTensorMetaDataIndex(int n, int total_tensors) {
- // Get index for TensorData first and then use mapping function
- // to get TensorMetaData index from TensorData index.
- int tidx = GetTensorDataIndex(n, total_tensors);
- return DataIndexToMetaDataIndex(tidx, total_tensors);
-}
-
// Get the MKL shape from the second string tensor
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
mklshape->DeSerializeMklShape(
@@ -480,6 +433,13 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
*buf_out = static_cast<void*>(tensor_out->flat<float>().data());
}
+template <typename T>
+inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
+ TensorShape tf_shape) {
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
+ tf_shape, tensor_out));
+}
+
inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
const size_t* sizes) {
// MKL requires strides in NCHW
@@ -743,56 +703,299 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
}
}
-namespace mkl_op_registry {
-static const char* kMklOpLabel = "MklOp";
-static const char* kMklOpLabelPattern = "label='MklOp'";
+// -------------------------------------------------------------------
+
+#ifdef INTEL_MKL_DNN
+
+using mkldnn::engine;
+using mkldnn::memory;
+using mkldnn::padding_kind;
+using mkldnn::primitive;
+using mkldnn::reorder;
+
+/// Return MKL-DNN data type (memory::data_type) for input type T
+///
+/// @input None
+/// @return memory::data_type corresponding to type T
+template <typename T>
+static memory::data_type MklDnnType();
+
+/// Instantiation for float type. Add similar instantiations for other
+/// type if needed.
+template <>
+memory::data_type MklDnnType<float>() {
+ return memory::data_type::f32;
+}
+
+/// Map TensorFlow's data format into MKL-DNN data format
+///
+/// @input: TensorFlow data format
+/// @return: memory::format corresponding to TensorFlow data format;
+/// Fails with an error if invalid data format.
+inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
+ if (format == FORMAT_NHWC)
+ return memory::format::nhwc;
+ else if (format == FORMAT_NCHW)
+ return memory::format::nchw;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
+ // Return to get rid of compiler warning
+ return memory::format::format_undef;
+}
-// Get the name of Mkl op from original TensorFlow op
-// We prefix 'Mkl' to the original op to get Mkl op.
-inline string GetMklOpName(const string& name) {
- // Prefix that we add to Tensorflow op name to construct Mkl op name.
- const char* const kMklOpPrefix = "_Mkl";
- return string(kMklOpPrefix) + name;
+/// Map TensorShape object into memory::dims required by MKL-DNN
+///
+/// This function will simply map input TensorShape into MKL-DNN dims
+/// naively. So it will preserve the order of dimensions. E.g., if
+/// input tensor is in NHWC format, then dims will be in NHWC format
+/// also.
+///
+/// @input TensorShape object in shape
+/// @return memory::dims corresponding to TensorShape
+inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
+ memory::dims dims(shape.dims());
+ for (unsigned int d = 0; d < shape.dims(); ++d) {
+ dims[d] = shape.dim_size(d);
+ }
+ return dims;
}
-// Check whether opname with type T is registered as MKL-compliant.
-//
-// @input: name of the op
-// @input: T datatype to be used for checking op
-// @return: true if opname is registered as Mkl op; false otherwise
-static inline bool IsMklOp(const std::string& op_name, DataType T) {
- string kernel = KernelsRegisteredForOp(op_name);
- bool result =
- kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
- if (result) {
- VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
- }
- return result;
+/// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
+///
+/// This function is a specific one than above function. It will map input
+/// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
+/// order of dimensions. E.g., if input tensor is in NHWC format, then dims
+/// will be in NCHW format, and not in NHWC format.
+///
+/// @input TensorShape object in shape
+/// @return memory::dims in MKL-DNN required NCHW format
+inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
+ TensorFormat format) {
+ // Check validity of format.
+ CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
+ memory::format::format_undef);
+
+ int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
+ int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
+ int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
+ int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
+
+ // MKL-DNN requires dimensions in NCHW format.
+ return memory::dims({n, c, h, w});
}
-// Check whether opname with type T is registered as MKL-compliant and
-// is element-wise.
-//
-// @input: name of the op
-// @input: T datatype to be used for checking op
-// @return: true if opname is registered as element-wise Mkl op; false otherwise
-static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
- if (!IsMklOp(op_name, T)) {
+inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
+ // MKL-DNN only supports zero padding.
+ return padding_kind::zero;
+}
+
+/*
+ * Class to represent all the resources corresponding to a tensor in TensorFlow
+ * that are required to execute an operation (such as Convolution).
+ */
+template <typename T>
+class MklDnnData {
+ private:
+ /// MKL-DNN memory primitive for input user memory
+ memory* user_memory_;
+
+ /// MKL-DNN memory primitive in case input or output reorder is needed.
+ memory* reorder_memory_;
+
+ /// Operations memory descriptor
+ memory::desc* op_md_;
+
+ /// CPU engine on which operation will be executed
+ const engine* cpu_engine_;
+
+ public:
+ explicit MklDnnData(const engine* e)
+ : user_memory_(nullptr),
+ reorder_memory_(nullptr),
+ op_md_(nullptr),
+ cpu_engine_(e) {}
+
+ ~MklDnnData() {
+ cpu_engine_ = nullptr; // We don't own this.
+ delete (user_memory_);
+ delete (reorder_memory_);
+ delete (op_md_);
+ }
+
+ void* GetTensorBuffer(const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ return const_cast<void*>(
+ static_cast<const void*>(tensor->flat<T>().data()));
+ }
+
+ /// Set user memory primitive using specified dimensions, memory format and
+ /// data_buffer. Function automatically uses element data type by using
+ /// input type T used for creating call object.
+ ///
+ /// In a nutshell, function allows user to describe the input tensor to
+ /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
+ /// memory format HWIO, and the buffer that contains actual values is
+ /// pointed by data_buffer.
+ void SetUsrMem(memory::dims dim, memory::format fm, void* data_buffer) {
+ CHECK_NOTNULL(data_buffer);
+ CHECK_NOTNULL(cpu_engine_);
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ user_memory_ =
+ new memory(memory::primitive_desc(
+ memory::desc(dim, MklDnnType<T>(), fm), *cpu_engine_),
+ data_buffer);
+ }
+
+ void SetUsrMem(memory::dims dim, memory::format fm, const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ SetUsrMem(dim, fm, GetTensorBuffer(tensor));
+ }
+
+ /// A version of function to set user memory primitive that accepts memory
+ /// descriptor directly, instead of accepting dimensions and format. This
+ /// function is more generic that the one above, but the function above is
+ /// sufficient in most cases.
+ void SetUsrMem(memory::desc md, void* data_buffer) {
+ CHECK_NOTNULL(data_buffer);
+ CHECK_NOTNULL(cpu_engine_);
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ user_memory_ =
+ new memory(memory::primitive_desc(md, *cpu_engine_), data_buffer);
+ }
+
+ /// A version of SetUsrMem with memory descriptor and tensor
+ void SetUsrMem(memory::desc md, const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ SetUsrMem(md, GetTensorBuffer(tensor));
+ }
+
+ /// A version of function to set user memory primitive that accepts primitive
+ /// descriptor directly, instead of accepting dimensions and format. This
+ /// function is more generic that the one above, but the function above is
+ /// sufficient in most cases.
+ void SetUsrMem(memory::primitive_desc pd, void* data_buffer) {
+ CHECK_NOTNULL(data_buffer);
+ CHECK_NOTNULL(cpu_engine_);
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ user_memory_ = new memory(pd, data_buffer);
+ }
+
+ /// A version of SetUsrMem with primitive descriptor and tensor
+ void SetUsrMem(memory::primitive_desc pd, const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ SetUsrMem(pd, GetTensorBuffer(tensor));
+ }
+
+ /// Get function for user memory primitive.
+ const memory* GetUsrMem() const { return user_memory_; }
+
+ /// Get function for primitive descriptor of user memory primitive.
+ const memory::primitive_desc GetUsrMemPrimDesc() const {
+ CHECK_NOTNULL(user_memory_);
+ return user_memory_->get_primitive_desc();
+ }
+
+ /// Get function for descriptor of user memory.
+ memory::desc GetUsrMemDesc() {
+ // This is ugly. Why MKL-DNN does not provide desc() method of const type??
+ const memory::primitive_desc pd = GetUsrMemPrimDesc();
+ return const_cast<memory::primitive_desc*>(&pd)->desc();
+ }
+
+ /// Get function for data buffer of user memory primitive.
+ void* GetUsrMemDataHandle() const {
+ CHECK_NOTNULL(user_memory_);
+ return user_memory_->get_data_handle();
+ }
+
+ /// Get the memory primitive for input and output of an op. If inputs
+ /// to an op require reorders, then this function returns memory primitive
+ /// for reorder. Otherwise, it will return memory primitive for user memory.
+ ///
+ /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
+ /// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
+ /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
+ /// primitive for F), then we need I_r and F_r to perform Conv2D.
+ const memory& GetOpMem() const {
+ return reorder_memory_ ? *reorder_memory_ : *user_memory_;
+ }
+
+ /// Set memory descriptor of an operation in terms of dimensions and memory
+ /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
+ /// but memory::format would be mkldnn::any because we want MKL-DNN to choose
+ /// best layout/format for given input dimensions.
+ void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
+ }
+
+ /// Get function for memory descriptor for an operation
+ const memory::desc& GetOpMemDesc() const { return *op_md_; }
+
+ /// Function to handle input reordering
+ ///
+ /// Check if we need to reorder this input of an operation.
+ /// Return true and allocate reorder memory primitive if reorder is needed.
+ /// Otherwise, return false and do not allocate reorder memory primitive.
+ ///
+ /// To check if reorder is needed, this function compares memory primitive
+ /// descriptor of an operation (op_pd) for the given input with the
+ /// user-specified memory primitive descriptor.
+ ///
+ /// @input: op_pd - memory primitive descriptor of the given input of an
+ /// operation
+ /// @input: net - net to which to add reorder primitive in case it is needed.
+ /// @return: true in case reorder of input is needed; false, otherwise.
+ bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ std::vector<primitive>* net) {
+ CHECK_NOTNULL(net);
+ CHECK_NOTNULL(user_memory_);
+ if (op_pd != user_memory_->get_primitive_desc()) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ reorder_memory_ = new memory(op_pd);
+ net->push_back(reorder(*user_memory_, *reorder_memory_));
+ return true;
+ }
return false;
}
- bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
- 0 == op_name.compare(GetMklOpName("Sub")) ||
- 0 == op_name.compare(GetMklOpName("Mul")) ||
- 0 == op_name.compare(GetMklOpName("Maximum")) ||
- 0 == op_name.compare(GetMklOpName("SquaredDifference")));
+ /// Function to handle output reorder
+ ///
+ /// This function performs very similar functionality as input reordering
+ /// function above. The only difference is that this function does not add
+ /// reorder primitive to the net. The reason for this is: the reorder
+ /// primitive for output needs to be added to the list only after operation
+ /// has executed. But we need to prepare a temporary buffer in case output
+ /// reorder is needed. And this temporary buffer will hold the output of
+ /// an operation before it is fed to reorder primitive.
+ ///
+ /// @input memory primitive descriptor for the given output of an operation
+ /// @return: true in case reorder of output is needed; false, otherwise.
+ bool PrepareReorderToUserMemIfReq(const memory::primitive_desc& op_pd) {
+ CHECK_NOTNULL(user_memory_);
+ if (op_pd != user_memory_->get_primitive_desc()) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ reorder_memory_ = new memory(op_pd);
+ return true;
+ }
+ return false;
+ }
- VLOG(1) << "mkl_op_registry::" << op_name
- << " is elementwise MKL op: " << result;
- return result;
-}
+ /// Function to actually insert reorder primitive in the net
+ ///
+ /// This function completes remaining part of output reordering. It inserts
+ /// a reordering primitive from the temporary buffer that holds the output
+ /// to the user-specified output buffer.
+ ///
+ /// @input: net - net to which to add reorder primitive
+ void InsertReorderToUserMem(std::vector<primitive>* net) {
+ CHECK_NOTNULL(net);
+ CHECK_NOTNULL(user_memory_);
+ CHECK_NOTNULL(reorder_memory_);
+ net->push_back(reorder(*reorder_memory_, *user_memory_));
+ }
+};
-} // namespace mkl_op_registry
+#endif // INTEL_MKL_DNN
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index d8925d3909..e6a4088656 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -429,3 +429,41 @@ Stack Overflow and specify the `tensorflow` tag.
<pre>ImportError: cannot import name pywrap_tensorflow</pre></td>
</tr>
</table>
+
+## Tested source configurations
+**Linux**
+<table>
+<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.3.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>6</td><td>8</td></tr>
+<tr><td>tensorflow-1.2.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.2.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.5</td><td>5.1</td><td>8</td></tr>
+<tr><td>tensorflow-1.1.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.1.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
+<tr><td>tensorflow-1.0.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.0.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
+</table>
+
+**Mac**
+<table>
+<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>ttensorflow-1.2.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>ttensorflow-1.1.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>ttensorflow_gpu-1.1.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
+<tr><td>ttensorflow-1.0.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>ttensorflow_gpu-1.0.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.2</td><td>5.1</td><td>8</td></tr>
+</table>
+
+**Windows**
+<table>
+<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
+<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.3.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>6</td><td>8</td></tr>
+<tr><td>tensorflow-1.2.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.2.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
+<tr><td>tensorflow-1.1.0</td><td>CPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.1.0</td><td>GPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
+<tr><td>tensorflow-1.0.0</td><td>CPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.0.0</td><td>GPU</td><td>3.5</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>5.1</td><td>8</td></tr>
+</table>
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
index eb4dc69d63..184df1bdb4 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
@@ -37,6 +37,7 @@ import android.content.pm.PackageManager;
import android.media.AudioFormat;
import android.media.AudioRecord;
import android.media.MediaRecorder;
+import android.os.Build;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
@@ -151,12 +152,15 @@ public class SpeechActivity extends Activity {
// Start the recording and recognition threads.
requestMicrophonePermission();
+ startRecording();
startRecognition();
}
private void requestMicrophonePermission() {
- requestPermissions(
- new String[] {android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
+ if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
+ requestPermissions(
+ new String[]{android.Manifest.permission.RECORD_AUDIO}, REQUEST_RECORD_AUDIO);
+ }
}
@Override
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index 6d98c7b85d..1fa2b14869 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -89,7 +89,7 @@ def build_dataset(words, n_words):
# Filling 4 global variables:
# data - list of codes (integers from 0 to vocabulary_size-1).
# This is the original text but words are replaced by their codes
-# count - map of words(strings) to count of occurences
+# count - map of words(strings) to count of occurrences
# dictionary - map of words(strings) to their codes(integers)
# reverse_dictionary - maps codes(integers) to words(strings)
data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
diff --git a/tensorflow/go/example_inception_inference_test.go b/tensorflow/go/example_inception_inference_test.go
index 2162fbe484..f84a588899 100644
--- a/tensorflow/go/example_inception_inference_test.go
+++ b/tensorflow/go/example_inception_inference_test.go
@@ -28,8 +28,8 @@ import (
"os"
"path/filepath"
- "github.com/tensorflow/tensorflow/tensorflow/go/op"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
+ "github.com/tensorflow/tensorflow/tensorflow/go/op"
)
func Example() {
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index a534a0d659..e8fa21a62b 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -92,7 +92,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
raw := tensorData(t.c)
buf := bytes.NewBuffer(raw[:0:len(raw)])
if dataType != String {
- if err := encodeTensor(buf, val); err != nil {
+ if err := encodeTensor(buf, val, shape); err != nil {
return nil, err
}
if uintptr(buf.Len()) != nbytes {
@@ -100,7 +100,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
}
} else {
e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()}
- if err := e.encode(reflect.ValueOf(value)); err != nil {
+ if err := e.encode(reflect.ValueOf(value), shape); err != nil {
return nil, err
}
if int64(buf.Len()) != nflattened*8 {
@@ -236,17 +236,11 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro
typ := val.Type()
for typ.Kind() == reflect.Array || typ.Kind() == reflect.Slice {
shape = append(shape, int64(val.Len()))
- // If slice elements are slices, verify that all of them have the same size.
- // Go's type system makes that guarantee for arrays.
if val.Len() > 0 {
- if val.Type().Elem().Kind() == reflect.Slice {
- expected := val.Index(0).Len()
- for i := 1; i < val.Len(); i++ {
- if val.Index(i).Len() != expected {
- return shape, dt, fmt.Errorf("mismatched slice lengths: %d and %d", val.Index(i).Len(), expected)
- }
- }
- }
+ // In order to check tensor structure properly in general case we need to iterate over all slices of the tensor to check sizes match
+ // Since we already going to iterate over all elements in encodeTensor() let's
+ // 1) do the actual check in encodeTensor() to save some cpu cycles here
+ // 2) assume the shape is represented by lengths of elements with zero index in each dimension
val = val.Index(0)
}
typ = typ.Elem()
@@ -302,7 +296,7 @@ func byteSizeOfEncodedStrings(val interface{}) uintptr {
// encodeTensor writes v to the specified buffer using the format specified in
// c_api.h. Use stringEncoder for String tensors.
-func encodeTensor(w *bytes.Buffer, v reflect.Value) error {
+func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
switch v.Kind() {
case reflect.Bool:
b := byte(0)
@@ -318,19 +312,18 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value) error {
}
case reflect.Array, reflect.Slice:
- // If slice elements are slices, verify that all of them have the same size.
+ // If current dimension is a slice, verify that it has the expected size
// Go's type system makes that guarantee for arrays.
- if v.Len() > 0 && v.Type().Elem().Kind() == reflect.Slice {
- expected := v.Index(0).Len()
- for i := 1; i < v.Len(); i++ {
- if v.Index(i).Len() != expected {
- return fmt.Errorf("mismatched slice lengths: %d and %d", v.Index(i).Len(), expected)
- }
+ if v.Kind() == reflect.Slice {
+ expected := int(shape[0])
+ if v.Len() != expected {
+ return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
}
}
+ subShape := shape[1:]
for i := 0; i < v.Len(); i++ {
- err := encodeTensor(w, v.Index(i))
+ err := encodeTensor(w, v.Index(i), subShape)
if err != nil {
return err
}
@@ -379,7 +372,7 @@ type stringEncoder struct {
status *status
}
-func (e *stringEncoder) encode(v reflect.Value) error {
+func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
if v.Kind() == reflect.String {
if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil {
return err
@@ -395,8 +388,17 @@ func (e *stringEncoder) encode(v reflect.Value) error {
C.free(unsafe.Pointer(src))
return e.status.Err()
}
+
+ if v.Kind() == reflect.Slice {
+ expected := int(shape[0])
+ if v.Len() != expected {
+ return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
+ }
+ }
+
+ subShape := shape[1:]
for i := 0; i < v.Len(); i++ {
- if err := e.encode(v.Index(i)); err != nil {
+ if err := e.encode(v.Index(i), subShape); err != nil {
return err
}
}
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go
index 2fc7553f87..35bd2fd9a5 100644
--- a/tensorflow/go/tensor_test.go
+++ b/tensorflow/go/tensor_test.go
@@ -42,6 +42,10 @@ func TestNewTensor(t *testing.T) {
{[]int64{2}, []bool{true, false}},
{[]int64{1}, []float64{1}},
{[]int64{1}, [1]float64{1}},
+ {[]int64{1, 1}, [1][1]float64{{1}}},
+ {[]int64{1, 1, 1}, [1][1][]float64{{{1}}}},
+ {[]int64{1, 1, 2}, [1][][2]float64{{{1, 2}}}},
+ {[]int64{1, 1, 1, 1}, [1][][1][]float64{{{{1}}}}},
{[]int64{2}, []string{"string", "slice"}},
{[]int64{2}, [2]string{"string", "array"}},
{[]int64{3, 2}, [][]float64{{1, 2}, {3, 4}, {5, 6}}},
@@ -74,6 +78,12 @@ func TestNewTensor(t *testing.T) {
[]uint64{5},
// Mismatched dimensions
[][]float32{{1, 2, 3}, {4}},
+ // Mismatched dimensions. Should return "mismatched slice lengths" error instead of "BUG"
+ [][][]float32{{{1, 2}, {3, 4}}, {{1}, {3}}},
+ // Mismatched dimensions. Should return error instead of valid tensor
+ [][][]float32{{{1, 2}, {3, 4}}, {{1}, {3}}, {{1, 2, 3}, {2, 3, 4}}},
+ // Mismatched dimensions for strings
+ [][]string{{"abc"}, {"abcd", "abcd"}},
}
for _, test := range tests {
diff --git a/tensorflow/java/src/gen/perl/tftypes-runall.pl b/tensorflow/java/src/gen/perl/tftypes-runall.pl
index 258c1ff836..a451ce92aa 100644
--- a/tensorflow/java/src/gen/perl/tftypes-runall.pl
+++ b/tensorflow/java/src/gen/perl/tftypes-runall.pl
@@ -37,4 +37,4 @@ sub locchk {
&locchk("$rsrc/tftypes.csv");
system("perl $dir/tftypes.pl -t $rsrc/tftypes.csv $pkg/types");
-# system("perl $dir/tftypes.pl -c $rsrc/tftypes.csv $rsrc/Tensors.java.tmpl > $pkg/op/Tensors.java");
+system("perl $dir/tftypes.pl -c $rsrc/tftypes.csv $rsrc/Tensors.java.tmpl > $pkg/Tensors.java");
diff --git a/tensorflow/java/src/gen/perl/tftypes.pl b/tensorflow/java/src/gen/perl/tftypes.pl
index 86867335cb..115723ac8a 100644
--- a/tensorflow/java/src/gen/perl/tftypes.pl
+++ b/tensorflow/java/src/gen/perl/tftypes.pl
@@ -75,15 +75,23 @@ open (TYPEDESC, $typedesc);
my @info = ([]);
+sub trim {
+ (my $ret) = @_;
+ $ret =~ s/^\s*//g;
+ $ret =~ s/\s*$//g;
+ return $ret;
+}
+
while (<TYPEDESC>) {
chomp;
my $line = $_;
if ($line =~ m/^TF type/) { next }
$line =~ s/\r$//;
- (my $name, my $jtype, my $creat, my $default, my $desc) =
- split /,/, $line, 5;
- $desc =~ s/^ *//g;
- $desc =~ s/ *$//g;
+ my @items = split /,/, $line, 6;
+ for (my $i = 0; $i <= $#items; $i++) {
+ $items[$i] = trim $items[$i];
+ }
+ my $jtype = $items[2];
$jtypecount{$jtype}++;
if ($jtypecount{$jtype} > 1) {
# currently allowing Java types to stand for more than one TF type, but
@@ -92,63 +100,85 @@ while (<TYPEDESC>) {
# exit 1
}
- push @info, [$name, $jtype, $creat, $default, $desc];
+ push @info, \@items;
+}
+
+sub article {
+ (my $s) = @_;
+ if (substr($s, 0, 1) =~ m/^[aeoiu8]$/i) {
+ return "an $s"
+ } else {
+ return "a $s"
+ }
}
for (my $i = 1; $i <= $#info; $i++) {
- (my $name, my $jtype, my $creat, my $default, my $desc) =
+ (my $name, my $builtin, my $jtype, my $creat, my $default, my $desc) =
@{$info[$i]};
- my $tfname = "TF".$name;
+ my $tfname = $name;
my $ucname = uc $name;
+ print STDERR "$name $desc\n";
+
if ($option eq '-t') {
if ($jtype eq '') { next }
+ if ($builtin eq 'y') { next }
# Generate class declarations
# print STDERR "Creating $dirname/$tfname.java\n";
open (CLASSFILE, ">$dirname/$tfname.java") || die "Can't open $tfname.java";
- print CLASSFILE $copyright;
- print CLASSFILE "// GENERATED FILE. To update, edit tftypes.pl instead.\n\n";
-
- my $fulldesc = $desc;
- if (substr($desc, 0, 1) =~ m/^[aeoiu8]$/i) {
- $fulldesc = "an $desc"
- } else {
- $fulldesc = "a $desc"
- }
- print CLASSFILE "package org.tensorflow.types;\n\n"
- ."import org.tensorflow.DataType;\n\n";
+ print CLASSFILE $copyright, "\n";
+ # print CLASSFILE "// GENERATED FILE. To update, edit tftypes.pl instead.\n\n";
+
+ my $fulldesc = article($desc);
+ print CLASSFILE "package org.tensorflow.types;\n\n";
print CLASSFILE "/** Represents $fulldesc. */\n"
- ."public class $tfname implements TFType {\n"
- ." private $tfname() {}\n"
- ." static {\n"
- ." Types.typeCodes.put($tfname.class, DataType.$ucname);\n"
- ." }\n";
- if ($default ne '') {
- print CLASSFILE
- " static {\n"
- ." Types.scalars.put($tfname.class, $default);\n"
- ." }\n";
- }
- print CLASSFILE "}\n";
+ ."public class $tfname {\n"
+ ." private $tfname() {\n"
+ ." }\n"
+ ."}\n";
close(CLASSFILE);
} elsif ($option eq '-c') {
# Generate creator declarations for Tensors.java
if ($jtype ne '' && $creat eq 'y') {
- for (my $brackets = ''; length $brackets <= 12; $brackets .= '[]') {
+ for (my $brackets = '', my $rank = 0; length $brackets <= 12; $brackets .= '[]', $rank++) {
+ my $datainfo = " * \@param data An array containing the values to put into the new tensor.\n"
+ ." * The dimensions of the new tensor will match those of the array.\n";
+ if ($rank == 0) {
+ $datainfo = " * \@param data The value to put into the new scalar tensor.\n"
+ }
+
+ my $trank = $rank;
+ if ($tfname eq 'String') {
+ $trank = $rank-1;
+ next if $trank < 0;
+
+ $datainfo = " * \@param data An array containing the data to put into the new tensor.\n"
+ ." * String elements are sequences of bytes from the last array dimension.\n";
+ }
+
+
+ my $intro = ($trank > 0)
+ ? "Creates a rank-$trank tensor of {\@code $jtype} elements."
+ : "Creates a scalar tensor containing a single {\@code $jtype} element.";
$typeinfo .=
- " public static Tensor<$tfname> create($jtype$brackets data) {\n"
- ." return Tensor.create(data, $tfname.class);\n"
- ." }\n";
+ " /**\n"
+ ." * $intro\n"
+ ." * \n"
+ .$datainfo
+ ." */\n"
+ ." public static Tensor<$tfname> create($jtype$brackets data) {\n"
+ ." return Tensor.create(data, $tfname.class);\n"
+ ." }\n\n";
}
}
- if ($text =~ m/\b$tfname\b/ || $creat eq 'y') {
+ if ($text =~ m/\b$tfname\b/ && $builtin eq 'n' && $creat eq 'y') {
$imports .= "import org.tensorflow.types.$tfname;\n";
}
}
}
if ($option ne '-t') {
- print "// GENERATED FILE. Edits to this file will be lost -- edit $tmpl instead.\n";
+# print "// GENERATED FILE. Edits to this file will be lost -- edit $tmpl instead.\n";
$text =~ s/\@TYPEINFO\@/$typeinfo/;
$text =~ s/\@IMPORTS\@/$imports/;
diff --git a/tensorflow/java/src/gen/resources/Tensors.java.tmpl b/tensorflow/java/src/gen/resources/Tensors.java.tmpl
new file mode 100644
index 0000000000..98e1588559
--- /dev/null
+++ b/tensorflow/java/src/gen/resources/Tensors.java.tmpl
@@ -0,0 +1,31 @@
+package org.tensorflow;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import org.tensorflow.Tensor;
+@IMPORTS@
+
+/**
+ * Type-safe factory methods for creating {@link Tensor} objects.
+ */
+public final class Tensors {
+ private Tensors() {}
+
+ /** Creates a scalar String tensor using the default, UTF-8 encoding.
+ *
+ * @param data The string to put into the new scalar tensor.
+ */
+ public static Tensor<String> create(String data) {
+ return Tensor.create(data.getBytes(UTF_8), String.class);
+ }
+
+ /** Creates a scalar String tensor using a specified encoding.
+ *
+ * @param charset The encoding from String to bytes.
+ * @param data The string to put into the new scalar tensor.
+ */
+ public static Tensor<String> create(String data, java.nio.charset.Charset charset) {
+ return Tensor.create(data.getBytes(charset), String.class);
+ }
+
+@TYPEINFO@}
+
diff --git a/tensorflow/java/src/gen/resources/tftypes.csv b/tensorflow/java/src/gen/resources/tftypes.csv
index 88acaafd3c..6f26230f27 100644
--- a/tensorflow/java/src/gen/resources/tftypes.csv
+++ b/tensorflow/java/src/gen/resources/tftypes.csv
@@ -1,21 +1,21 @@
-TF type,Java type,Creator?,Zero value,Description
-Float,float,y,0f,32-bit single precision floating point number
-Double,double,y,0.0,64-bit double precision floating point number
-Int32,int,y,0,32-bit signed integer
-UInt8,byte,n,(byte)0,8-bit unsigned integer
-Int16,,n,(short)0,16-bit signed integer
-Int8,,n,(byte)0,8-bit signed integer
-String,byte,n,,arbitrary sequence of bytes
-Complex64,,n,,single-precision complex number
-Int64,long,y,0L,64-bit signed integer
-Bool,boolean,y,false,boolean
-QInt8,,n,,quantized int8
-QUInt8,,n,,quantized uint8
-QInt32,,n,,quantized int32
-BFloat16,,n,,float32 truncated to 16 bits. Only for cast ops.
-QInt16,,n,,quantized int16
-QUInt16,,n,,quantized uint16
-UInt16,,n,,16-bit unsigned integer
-Complex128,,n,,double-precision complex number
-Half,,n,,
-Resource,,n,,
+TF type,Builtin,Java type,Creator?,Zero value,Description
+Float,y,float,y,0f,32-bit single precision floating point number
+Double,y,double,y,0.0,64-bit double precision floating point number
+Integer,y,int,y,0,32-bit signed integer
+UInt8,n,byte,n,(byte)0,8-bit unsigned integer
+Short,y,,n,(short)0,16-bit signed integer
+Byte,y,,n,(byte)0,8-bit signed integer
+String,y,byte,y,,arbitrary sequence of bytes
+Complex64,n,,n,,single-precision complex number
+Long,y,long,y,0L,64-bit signed integer
+Boolean,y,boolean,y,false,boolean
+QInt8,n,,n,,quantized int8
+QUInt8,n,,n,,quantized uint8
+QInt32,n,,n,,quantized int32
+BFloat16,n,,n,,float32 truncated to 16 bits. Only for cast ops.
+QInt16,n,,n,,quantized int16
+QUInt16,n,,n,,quantized uint16
+UInt16,n,,n,,16-bit unsigned integer
+Complex128,n,,n,,double-precision complex number
+Half,n,,n,,
+Resource,n,,n,,
diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
index e67e266ff7..e835101d08 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
@@ -15,7 +15,13 @@ limitations under the License.
package org.tensorflow;
-/** Type of elements in a {@link Tensor}. */
+import java.util.HashMap;
+import java.util.Map;
+import org.tensorflow.types.UInt8;
+
+/**
+ * Represents the type of elements in a {@link Tensor} as an enum.
+ */
public enum DataType {
/** 32-bit single precision floating point. */
FLOAT(1),
@@ -55,14 +61,41 @@ public enum DataType {
}
// Cached to avoid copying it
- final private static DataType[] values = values();
+ private static final DataType[] values = values();
static DataType fromC(int c) {
for (DataType t : values) {
- if (t.value == c)
+ if (t.value == c) {
return t;
+ }
}
throw new IllegalArgumentException(
"DataType " + c + " is not recognized in Java (version " + TensorFlow.version() + ")");
}
+
+ /**
+ * Returns the DataType of a Tensor whose elements have the type specified by class {@code c}.
+ *
+ * @param c The class describing the TensorFlow type of interest.
+ */
+ public static DataType fromClass(Class<?> c) {
+ DataType dtype = typeCodes.get(c);
+ if (dtype == null) {
+ throw new IllegalArgumentException(
+ c.getName() + " objects cannot be used as elements in a TensorFlow Tensor");
+ }
+ return dtype;
+ }
+
+ private static final Map<Class<?>, DataType> typeCodes = new HashMap<>();
+
+ static {
+ typeCodes.put(Float.class, DataType.FLOAT);
+ typeCodes.put(Double.class, DataType.DOUBLE);
+ typeCodes.put(Integer.class, DataType.INT32);
+ typeCodes.put(UInt8.class, DataType.UINT8);
+ typeCodes.put(Long.class, DataType.INT64);
+ typeCodes.put(Boolean.class, DataType.BOOL);
+ typeCodes.put(String.class, DataType.STRING);
+ }
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 58ad3ab193..d4fd3db5f7 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -81,8 +81,8 @@ public final class Graph implements AutoCloseable {
/**
* Iterator over all the {@link Operation}s in the graph.
*
- * The order of iteration is unspecified. Consumers of the iterator will received no notification
- * should the underlying graph change during iteration.
+ * <p>The order of iteration is unspecified. Consumers of the iterator will receive no
+ * notification should the underlying graph change during iteration.
*/
public Iterator<Operation> operations() {
return new OperationIterator(this);
@@ -245,7 +245,8 @@ public final class Graph implements AutoCloseable {
private static native long operation(long handle, String name);
- // This method returns the Operation native handle at index 0 and the new value for pos at index 1 (see TF_GraphNextOperation)
+ // This method returns the Operation native handle at index 0 and the new value for pos at index 1
+ // (see TF_GraphNextOperation)
private static native long[] nextOperation(long handle, int position);
private static native void importGraphDef(long handle, byte[] graphDef, String prefix)
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java
index 8e6685ee0f..13bc463e7d 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Input.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Input.java
@@ -34,7 +34,7 @@ package org.tensorflow;
* ops.array().concat(0, split);
* }</pre>
*/
-public interface Input {
+public interface Input<T> {
/**
* Returns the symbolic handle of a tensor.
@@ -44,5 +44,5 @@ public interface Input {
*
* @see OperationBuilder#addInput(Output)
*/
- Output asOutput();
+ Output<T> asOutput();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
index d2d019babb..2b431eebf5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
@@ -122,8 +122,7 @@ final class NativeLibrary {
}
private static String extractResource(
- InputStream resource, String resourceName, String extractToDirectory)
- throws IOException {
+ InputStream resource, String resourceName, String extractToDirectory) throws IOException {
final File dst = new File(extractToDirectory, System.mapLibraryName(resourceName));
dst.deleteOnExit();
final String dstPath = dst.toString();
@@ -184,8 +183,7 @@ final class NativeLibrary {
// compatibility.
private static File createTemporaryDirectory() {
File baseDirectory = new File(System.getProperty("java.io.tmpdir"));
- String directoryName
- = "tensorflow_native_libraries-" + System.currentTimeMillis() + "-";
+ String directoryName = "tensorflow_native_libraries-" + System.currentTimeMillis() + "-";
for (int attempt = 0; attempt < 1000; attempt++) {
File temporaryDirectory = new File(baseDirectory, directoryName + attempt);
if (temporaryDirectory.mkdir()) {
@@ -194,7 +192,8 @@ final class NativeLibrary {
}
throw new IllegalStateException(
"Could not create a temporary directory (tried to make "
- + directoryName + "*) to extract TensorFlow native libraries.");
+ + directoryName
+ + "*) to extract TensorFlow native libraries.");
}
private NativeLibrary() {}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operand.java b/tensorflow/java/src/main/java/org/tensorflow/Operand.java
index 695c4c1060..61082e83d5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Operand.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Operand.java
@@ -22,19 +22,19 @@ package org.tensorflow;
*
* <pre>{@code
* // The "decodeJpeg" operation can be used as an operand to the "cast" operation
- * Operand decodeJpeg = ops.image().decodeJpeg(...);
+ * Operand<UInt8> decodeJpeg = ops.image().decodeJpeg(...);
* ops.math().cast(decodeJpeg, DataType.FLOAT);
*
* // The output "y" of the "unique" operation can be used as an operand to the "cast" operation
- * Output y = ops.array().unique(...).y();
- * ops.math().cast(y, DataType.FLOAT);
+ * Output<Integer> y = ops.array().unique(...).y();
+ * ops.math().cast(y, Float.class);
*
* // The "split" operation can be used as operand list to the "concat" operation
- * Iterable<? extends Operand> split = ops.array().split(...);
+ * Iterable<? extends Operand<Float>> split = ops.array().split(...);
* ops.array().concat(0, split);
* }</pre>
*/
-public interface Operand {
+public interface Operand<T> {
/**
* Returns the symbolic handle of a tensor.
@@ -44,5 +44,5 @@ public interface Operand {
*
* @see OperationBuilder#addInput(Output)
*/
- Output asOutput();
+ Output<T> asOutput();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Operation.java b/tensorflow/java/src/main/java/org/tensorflow/Operation.java
index ec26309fba..6b82e5780b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Operation.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Operation.java
@@ -98,16 +98,26 @@ public final class Operation {
* @param length number of tensors in the list
* @return array of {@code Output}
*/
- public Output[] outputList(int idx, int length) {
- Output[] outputs = new Output[length];
+ public Output<?>[] outputList(int idx, int length) {
+ Output<?>[] outputs = new Output<?>[length];
for (int i = 0; i < length; ++i) {
outputs[i] = output(idx + i);
}
return outputs;
}
- /** Returns a symbolic handle to one of the tensors produced by this operation. */
- public Output output(int idx) {
+ /**
+ * Returns a symbolic handle to one of the tensors produced by this operation.
+ *
+ * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
+ * operation.<Integer>output(0)}
+ *
+ * @param <T> The expected element type of the tensors produced by this output.
+ * @param idx The index of the output among the outputs produced by this operation.
+ */
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public <T> Output<T> output(int idx) {
return new Output(this, idx);
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
index 15077ce439..9a1b7592b3 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
@@ -63,7 +63,6 @@ public final class OperationBuilder {
}
}
-
/**
* Returns the builder to create an operation.
*
@@ -73,7 +72,7 @@ public final class OperationBuilder {
* @param input {@link Output} supposed to be the input of the OperationBuilder.
* @return the OperationBuilder instance for chaining.
*/
- public OperationBuilder addInput(Output input) {
+ public OperationBuilder addInput(Output<?> input) {
Graph.Reference r = graph.ref();
try {
addInput(unsafeNativeHandle, input.op().getUnsafeNativeHandle(), input.index());
@@ -106,7 +105,7 @@ public final class OperationBuilder {
return this;
}
- public OperationBuilder addInputList(Output[] inputs) {
+ public OperationBuilder addInputList(Output<?>[] inputs) {
Graph.Reference r = graph.ref();
try {
long[] opHandles = new long[inputs.length];
@@ -231,7 +230,7 @@ public final class OperationBuilder {
return this;
}
- public OperationBuilder setAttr(String name, Tensor value) {
+ public OperationBuilder setAttr(String name, Tensor<?> value) {
Graph.Reference r = graph.ref();
try {
setAttrTensor(unsafeNativeHandle, name, value.getNativeHandle());
@@ -241,10 +240,10 @@ public final class OperationBuilder {
return this;
}
- public OperationBuilder setAttr(String name, Tensor[] value) {
+ public OperationBuilder setAttr(String name, Tensor<?>[] value) {
long[] handles = new long[value.length];
int idx = 0;
- for (Tensor t : value) {
+ for (Tensor<?> t : value) {
handles[idx++] = t.getNativeHandle();
}
Graph.Reference r = graph.ref();
@@ -266,7 +265,7 @@ public final class OperationBuilder {
return this;
}
- public OperationBuilder setAttr(String name, String[] value) {
+ public OperationBuilder setAttr(String name, String[] value) {
Charset utf8 = Charset.forName("UTF-8");
Object[] objects = new Object[value.length];
for (int i = 0; i < value.length; ++i) {
@@ -326,5 +325,4 @@ public final class OperationBuilder {
private static native void setAttrShape(long handle, String name, long[] shape, int numDims);
private static native void setAttrStringList(long handle, String name, Object[] value);
-
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Output.java b/tensorflow/java/src/main/java/org/tensorflow/Output.java
index 8dff50fafb..0e17a722ff 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Output.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Output.java
@@ -20,13 +20,13 @@ import java.util.Objects;
/**
* A symbolic handle to a tensor produced by an {@link Operation}.
*
- * <p>An Output is a symbolic handle to a tensor. The value of the Tensor is computed by executing
- * the {@link Operation} in a {@link Session}.
+ * <p>An Output<T> is a symbolic handle to a Tensor<T>. The value of the tensor is computed by
+ * executing the {@link Operation} in a {@link Session}.
*
* <p>By implementing the {@link Operand} interface, instances of this class also act as operands to
* {@link org.tensorflow.op.Op Op} instances.
*/
-public final class Output implements Operand {
+public final class Output<T> implements Operand<T> {
/** Handle to the idx-th output of the Operation {@code op}. */
public Output(Operation op, int idx) {
@@ -55,7 +55,7 @@ public final class Output implements Operand {
}
@Override
- public Output asOutput() {
+ public Output<T> asOutput() {
return this;
}
@@ -69,8 +69,8 @@ public final class Output implements Operand {
if (o == this) {
return true;
}
- if (o instanceof Output) {
- Output that = (Output) o;
+ if (o instanceof Output<?>) {
+ Output<?> that = (Output<?>) o;
return index == that.index && operation.equals(that.operation);
}
return false;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
index b4591dd869..c8b9126f03 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -27,8 +27,9 @@ package org.tensorflow;
public class SavedModelBundle implements AutoCloseable {
/**
- * Load a saved model from an export directory. The model that is being loaded should be created using
- * the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model API</a>.
+ * Load a saved model from an export directory. The model that is being loaded should be created
+ * using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
+ * API</a>.
*
* @param exportDir the directory path containing a saved model.
* @param tags the tags identifying the specific metagraphdef to load.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java
index 83a300a560..73324f23e6 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java
@@ -127,7 +127,7 @@ public final class Session implements AutoCloseable {
* {@code SignatureDef} protocol buffer messages that are included in {@link
* SavedModelBundle#metaGraphDef()}.
*/
- public Runner feed(String operation, Tensor t) {
+ public Runner feed(String operation, Tensor<?> t) {
return feed(parseOutput(operation), t);
}
@@ -138,7 +138,7 @@ public final class Session implements AutoCloseable {
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
* one {@code t} is being provided for.
*/
- public Runner feed(String operation, int index, Tensor t) {
+ public Runner feed(String operation, int index, Tensor<?> t) {
Operation op = operationByName(operation);
if (op != null) {
inputs.add(op.output(index));
@@ -151,7 +151,7 @@ public final class Session implements AutoCloseable {
* Use {@code t} instead of the Tensor referred to by executing the operation referred to by
* {@code output}.
*/
- public Runner feed(Output o, Tensor t) {
+ public Runner feed(Output<?> o, Tensor<?> t) {
inputs.add(o);
inputTensors.add(t);
return this;
@@ -186,7 +186,7 @@ public final class Session implements AutoCloseable {
}
/** Makes {@link #run()} return the Tensor referred to by {@code output}. */
- public Runner fetch(Output output) {
+ public Runner fetch(Output<?> output) {
outputs.add(output);
return this;
}
@@ -240,8 +240,11 @@ public final class Session implements AutoCloseable {
* easier for the caller to cleanup (perhaps returning something like AutoCloseableList in
* SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a
* {@code Map<Output, Tensor>}?
+ *
+ * <p>TODO(andrewmyers): It would also be good if whatever is returned here made it easier to
+ * extract output tensors in a type-safe way.
*/
- public List<Tensor> run() {
+ public List<Tensor<?>> run() {
return runHelper(false).outputs;
}
@@ -269,17 +272,17 @@ public final class Session implements AutoCloseable {
// It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
// validity of the Graph and graphRef ensures that.
int idx = 0;
- for (Tensor t : inputTensors) {
+ for (Tensor<?> t : inputTensors) {
inputTensorHandles[idx++] = t.getNativeHandle();
}
idx = 0;
- for (Output o : inputs) {
+ for (Output<?> o : inputs) {
inputOpHandles[idx] = o.op().getUnsafeNativeHandle();
inputOpIndices[idx] = o.index();
idx++;
}
idx = 0;
- for (Output o : outputs) {
+ for (Output<?> o : outputs) {
outputOpHandles[idx] = o.op().getUnsafeNativeHandle();
outputOpIndices[idx] = o.index();
idx++;
@@ -306,12 +309,12 @@ public final class Session implements AutoCloseable {
} finally {
runRef.close();
}
- List<Tensor> outputs = new ArrayList<Tensor>();
+ List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
for (long h : outputTensorHandles) {
try {
outputs.add(Tensor.fromHandle(h));
} catch (Exception e) {
- for (Tensor t : outputs) {
+ for (Tensor<?> t : outputs) {
t.close();
}
outputs.clear();
@@ -355,7 +358,8 @@ public final class Session implements AutoCloseable {
return op;
}
- private Output parseOutput(String opName) {
+ @SuppressWarnings("rawtypes")
+ private Output<?> parseOutput(String opName) {
int colon = opName.lastIndexOf(':');
if (colon == -1 || colon == opName.length() - 1) {
return new Output(operationByName(opName), 0);
@@ -369,9 +373,9 @@ public final class Session implements AutoCloseable {
}
}
- private ArrayList<Output> inputs = new ArrayList<Output>();
- private ArrayList<Tensor> inputTensors = new ArrayList<Tensor>();
- private ArrayList<Output> outputs = new ArrayList<Output>();
+ private ArrayList<Output<?>> inputs = new ArrayList<Output<?>>();
+ private ArrayList<Tensor<?>> inputTensors = new ArrayList<Tensor<?>>();
+ private ArrayList<Output<?>> outputs = new ArrayList<Output<?>>();
private ArrayList<Operation> targets = new ArrayList<Operation>();
private byte[] runOptions = null;
}
@@ -388,7 +392,7 @@ public final class Session implements AutoCloseable {
*/
public static final class Run {
/** Tensors from requested fetches. */
- public List<Tensor> outputs;
+ public List<Tensor<?>> outputs;
/**
* (Experimental): Metadata about the run.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index c5ad1ee51c..d4b753628b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -28,89 +28,117 @@ import java.util.Arrays;
import java.util.HashMap;
/**
- * A typed multi-dimensional array.
+ * A statically typed multi-dimensional array whose elements are of a type described by T.
*
* <p>Instances of a Tensor are <b>not</b> thread-safe.
*
* <p><b>WARNING:</b> Resources consumed by the Tensor object <b>must</b> be explicitly freed by
* invoking the {@link #close()} method when the object is no longer needed. For example, using a
- * try-with-resources block like:
+ * try-with-resources block:
*
* <pre>{@code
- * try(Tensor t = Tensor.create(...)) {
+ * try (Tensor t = Tensor.create(...)) {
* doSomethingWith(t);
* }
* }</pre>
*/
-public final class Tensor implements AutoCloseable {
+public final class Tensor<T> implements AutoCloseable {
/**
- * Create a Tensor from a Java object.
+ * Creates a Tensor from a Java object.
*
- * <p>A Tensor is a multi-dimensional array of elements of a limited set of types ({@link
- * DataType}). Thus, not all Java objects can be converted to a Tensor. In particular, {@code obj}
- * must be either a primitive (float, double, int, long, boolean) or a multi-dimensional array of
- * one of those primitives. For example:
+ * <p>A {@code Tensor} is a multi-dimensional array of elements of a limited set of types ({@link
+ * types}), so not all Java objects can be converted to a {@code Tensor}. In particular, the
+ * argument {@code obj} must be either a primitive (float, double, int, long, boolean, byte) or a
+ * multi-dimensional array of one of those primitives. The argument {@code type} specifies how to
+ * interpret the first argument as a TensorFlow type. For example:
*
* <pre>{@code
* // Valid: A 64-bit integer scalar.
- * Tensor s = Tensor.create(42L);
+ * Tensor<Long> s = Tensor.create(42L, Long.class);
*
* // Valid: A 3x2 matrix of floats.
* float[][] matrix = new float[3][2];
- * Tensor m = Tensor.create(matrix);
+ * Tensor<Float> m = Tensor.create(matrix, Float.class);
*
* // Invalid: Will throw an IllegalArgumentException as an arbitrary Object
* // does not fit into the TensorFlow type system.
- * Tensor o = Tensor.create(new Object());
+ * Tensor<?> o = Tensor.create(new Object())
*
* // Invalid: Will throw an IllegalArgumentException since there are
* // a differing number of elements in each row of this 2-D array.
* int[][] twoD = new int[2][];
* twoD[0] = new int[1];
* twoD[1] = new int[2];
- * Tensor x = Tensor.create(twoD);
+ * Tensor<Integer> x = Tensor.create(twoD, Integer.class);
* }</pre>
*
- * {@link DataType#STRING} typed Tensors are multi-dimensionary arrays of arbitrary byte sequences
- * and thus have {@code byte[]} and not {@code String}-valued elements. For example:
+ * {@link String}-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can
+ * be initialized from arrays of {@code byte[]} elements. For example:
*
* <pre>{@code
- * // Valid: A DataType.STRING tensor.
- * Tensor s = Tensor.create(new byte[]{1, 2, 3});
+ * // Valid: A String tensor.
+ * Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class);
*
* // Java Strings will need to be encoded into a byte-sequence.
* String mystring = "foo";
- * Tensor s = Tensor.create(mystring.getBytes("UTF-8"));
+ * Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class);
*
- * // Valid: Matrix of DataType.STRING tensors.
+ * // Valid: Matrix of String tensors.
* // Each element might have a different length.
* byte[][][] matrix = new byte[2][2][];
* matrix[0][0] = "this".getBytes("UTF-8");
* matrix[0][1] = "is".getBytes("UTF-8");
* matrix[1][0] = "a".getBytes("UTF-8");
* matrix[1][1] = "matrix".getBytes("UTF-8");
- * Tensor m = Tensor.create(matrix);
+ * Tensor<String> m = Tensor.create(matrix, String.class);
* }</pre>
*
+ * @param obj The object to convert to a Tensor<T>. Note that whether it is compatible with the
+ * type T is not checked by the type system. For type-safe creation of tensors, use {@link
+ * Tensors}.
+ * @param type The class object representing the type T.
* @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
- * system, or if obj does not disambiguate between multiple DataTypes. In that case, consider
- * using {@link #create(DataType, long[], ByteBuffer)} instead.
+ * system.
*/
- public static Tensor create(Object obj) {
+ @SuppressWarnings("unchecked")
+ public static <T> Tensor<T> create(Object obj, Class<T> type) {
+ DataType dtype = DataType.fromClass(type);
+ if (!objectCompatWithType(obj, dtype)) {
+ throw new IllegalArgumentException(
+ "DataType of object does not match T (expected "
+ + dtype
+ + ", got "
+ + dataTypeOf(obj)
+ + ")");
+ }
+ return (Tensor<T>) create(obj, dtype);
+ }
+
+ /**
+ * Creates a tensor from an object whose class is inspected to figure out what the underlying data
+ * type should be.
+ *
+ * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
+ * system.
+ */
+ public static Tensor<?> create(Object obj) {
return create(obj, dataTypeOf(obj));
}
/**
- * Create a Tensor of data type {@code dtype} from a Java object.
+ * Create a Tensor of data type {@code dtype} from a Java object. Requires the parameter {@code T}
+ * to match {@code type}, but this condition is not checked.
*
- * @param dtype the intended tensor data type. It must match the the run-time type of the object.
+ * @param obj the object supplying the tensor data.
+ * @param dtype the data type of the tensor to create. It must be compatible with the run-time
+ * type of the object.
+ * @return the new tensor
*/
- static Tensor create(Object obj, DataType dtype) {
- Tensor t = new Tensor();
- t.dtype = dtype;
+ private static Tensor<?> create(Object obj, DataType dtype) {
+ @SuppressWarnings("rawtypes")
+ Tensor<?> t = new Tensor(dtype);
t.shapeCopy = new long[numDimensions(obj, dtype)];
- assert objectCompatWithType(obj, dtype);
fillShape(obj, 0, t.shapeCopy);
if (t.dtype != DataType.STRING) {
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
@@ -125,7 +153,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Create an {@link DataType#INT32} Tensor with data from the given buffer.
+ * Create a {@link Integer} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
@@ -136,14 +164,14 @@ public final class Tensor implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor create(long[] shape, IntBuffer data) {
- Tensor t = allocateForBuffer(DataType.INT32, shape, data.remaining());
+ public static Tensor<Integer> create(long[] shape, IntBuffer data) {
+ Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape, data.remaining());
t.buffer().asIntBuffer().put(data);
return t;
}
/**
- * Create a {@link DataType#FLOAT} Tensor with data from the given buffer.
+ * Create a {@link Float} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
@@ -154,14 +182,14 @@ public final class Tensor implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor create(long[] shape, FloatBuffer data) {
- Tensor t = allocateForBuffer(DataType.FLOAT, shape, data.remaining());
+ public static Tensor<Float> create(long[] shape, FloatBuffer data) {
+ Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape, data.remaining());
t.buffer().asFloatBuffer().put(data);
return t;
}
/**
- * Create a {@link DataType#DOUBLE} Tensor with data from the given buffer.
+ * Create a {@link Double} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
@@ -172,14 +200,14 @@ public final class Tensor implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor create(long[] shape, DoubleBuffer data) {
- Tensor t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining());
+ public static Tensor<Double> create(long[] shape, DoubleBuffer data) {
+ Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining());
t.buffer().asDoubleBuffer().put(data);
return t;
}
/**
- * Create an {@link DataType#INT64} Tensor with data from the given buffer.
+ * Create an {@link Long} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
@@ -190,47 +218,87 @@ public final class Tensor implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor create(long[] shape, LongBuffer data) {
- Tensor t = allocateForBuffer(DataType.INT64, shape, data.remaining());
+ public static Tensor<Long> create(long[] shape, LongBuffer data) {
+ Tensor<Long> t = allocateForBuffer(DataType.INT64, shape, data.remaining());
t.buffer().asLongBuffer().put(data);
return t;
}
/**
- * Create a Tensor with data from the given buffer.
+ * Create a Tensor of any type with data from the given buffer.
+ *
+ * <p>Creates a Tensor with the provided shape of any type where the tensor's data has been
+ * encoded into {@code data} as per the specification of the TensorFlow <a
+ * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
+ *
+ * @param <T> the tensor element type
+ * @param type the tensor element type, represented as a class object.
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
+ * buffer
+ */
+ public static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data) {
+ @SuppressWarnings("unchecked")
+ Tensor<T> ret = (Tensor<T>) create(DataType.fromClass(type), shape, data);
+ return ret;
+ }
+
+ /**
+ * Creates a Tensor of any type with data from the given buffer.
*
* <p>Creates a Tensor with the provided shape of any type where the tensor's data has been
* encoded into {@code data} as per the specification of the TensorFlow <a
* href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
*
- * @param dataType the tensor datatype.
+ * @param <T> The tensor element type
+ * @param type the tensor element type, specified as a DataType. This must agree with T.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- public static Tensor create(DataType dataType, long[] shape, ByteBuffer data) {
+ private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) {
int nremaining = 0;
- if (dataType != DataType.STRING) {
- int elemBytes = elemByteSize(dataType);
+ if (dtype != DataType.STRING) {
+ int elemBytes = elemByteSize(dtype);
if (data.remaining() % elemBytes != 0) {
throw new IllegalArgumentException(
String.format(
"ByteBuffer with %d bytes is not compatible with a %s Tensor (%d bytes/element)",
- data.remaining(), dataType.toString(), elemBytes));
+ data.remaining(), dtype.toString(), elemBytes));
}
nremaining = data.remaining() / elemBytes;
} else {
nremaining = data.remaining();
}
- Tensor t = allocateForBuffer(dataType, shape, nremaining);
+ Tensor<?> t = allocateForBuffer(dtype, shape, nremaining);
t.buffer().put(data);
return t;
}
+ /**
+ * Returns this Tensor object with the type {@code Tensor<U>}. This method is useful when given a
+ * value of type {@code Tensor<?>}.
+ *
+ * @param type any (non-null) array of the correct type.
+ * @throws IllegalArgumentException if the actual data type of this object does not match the type
+ * {@code U}.
+ */
+ @SuppressWarnings("unchecked")
+ public <U> Tensor<U> expect(Class<U> type) {
+ DataType dt = DataType.fromClass(type);
+ if (!dt.equals(dtype)) {
+ throw new IllegalArgumentException(
+ "Cannot cast from tensor of " + dtype + " to tensor of " + dt);
+ }
+ return ((Tensor<U>) this);
+ }
+
// Helper function to allocate a Tensor for the create() methods that create a Tensor from
// a java.nio.Buffer.
- private static Tensor allocateForBuffer(DataType dataType, long[] shape, int nBuffered) {
+ // Requires: dataType matches T
+ private static <T> Tensor<T> allocateForBuffer(DataType dataType, long[] shape, int nBuffered) {
final int nflattened = numElements(shape);
int nbytes = 0;
if (dataType != DataType.STRING) {
@@ -242,8 +310,7 @@ public final class Tensor implements AutoCloseable {
// DT_STRING tensor encoded in a ByteBuffer.
nbytes = nBuffered;
}
- Tensor t = new Tensor();
- t.dtype = dataType;
+ Tensor<T> t = new Tensor<T>(dataType);
t.shapeCopy = Arrays.copyOf(shape, shape.length);
t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
return t;
@@ -300,7 +367,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#FLOAT} tensor.
+ * Returns the value in a scalar {@link Float} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
*/
@@ -309,7 +376,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#DOUBLE} tensor.
+ * Returns the value in a scalar {@link Double} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
*/
@@ -318,7 +385,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#INT32} tensor.
+ * Returns the value in a scalar {@link Integer} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
*/
@@ -327,7 +394,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#INT64} tensor.
+ * Returns the value in a scalar {@link Long} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
*/
@@ -336,7 +403,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#BOOL} tensor.
+ * Returns the value in a scalar {@link Boolean} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
@@ -345,7 +412,7 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the value in a scalar {@link DataType#STRING} tensor.
+ * Returns the value in a scalar {@link String} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
@@ -377,21 +444,21 @@ public final class Tensor implements AutoCloseable {
* @throws IllegalArgumentException if the tensor is a scalar or if {@code dst} is not compatible
* with the tensor (for example, mismatched data types or shapes).
*/
- public <T> T copyTo(T dst) {
+ public <U> U copyTo(U dst) {
throwExceptionIfTypeIsIncompatible(dst);
readNDArray(nativeHandle, dst);
return dst;
}
/**
- * Write the data of a {@link DataType#INT32} tensor into the given buffer.
+ * Write the data of a {@link Integer} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
- * @throws IllegalArgumentException If the tensor datatype is not {@link DataType#INT32}
+ * @throws IllegalArgumentException If the tensor data type is not {@link Integer}
*/
public void writeTo(IntBuffer dst) {
if (dtype != DataType.INT32) {
@@ -402,14 +469,14 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Write the data of a {@link DataType#FLOAT} tensor into the given buffer.
+ * Write the data of a {@link Float} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
- * @throws IllegalArgumentException If the tensor datatype is not {@link DataType#FLOAT}
+ * @throws IllegalArgumentException If the tensor datatype is not {@link Float}
*/
public void writeTo(FloatBuffer dst) {
if (dtype != DataType.FLOAT) {
@@ -420,14 +487,14 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Write the data of a {@link DataType#DOUBLE} tensor into the given buffer.
+ * Write the data of a {@link Double} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
- * @throws IllegalArgumentException If the tensor datatype is not {@link DataType#DOUBLE}
+ * @throws IllegalArgumentException If the tensor datatype is not {@link Double}
*/
public void writeTo(DoubleBuffer dst) {
if (dtype != DataType.DOUBLE) {
@@ -438,14 +505,14 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Write the data of a {@link DataType#INT64} tensor into the given buffer.
+ * Write the data of a {@link Long} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
- * @throws IllegalArgumentException If the tensor datatype is not {@link DataType#INT64}
+ * @throws IllegalArgumentException If the tensor datatype is not {@link Long}
*/
public void writeTo(LongBuffer dst) {
if (dtype != DataType.INT64) {
@@ -480,9 +547,9 @@ public final class Tensor implements AutoCloseable {
*
* <p>Takes ownership of the handle.
*/
- static Tensor fromHandle(long handle) {
- Tensor t = new Tensor();
- t.dtype = DataType.fromC(dtype(handle));
+ static Tensor<?> fromHandle(long handle) {
+ @SuppressWarnings("rawtypes")
+ Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
t.shapeCopy = shape(handle);
t.nativeHandle = handle;
return t;
@@ -496,7 +563,9 @@ public final class Tensor implements AutoCloseable {
private DataType dtype;
private long[] shapeCopy = null;
- private Tensor() {}
+ private Tensor(DataType t) {
+ dtype = t;
+ }
private ByteBuffer buffer() {
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
@@ -564,11 +633,26 @@ public final class Tensor implements AutoCloseable {
classDataTypes.put(Boolean.class, DataType.BOOL);
}
- private static DataType dataTypeOf(Object o) {
+ /** The class for the data type to which Java object o corresponds. */
+ private static Class<?> baseObjType(Object o) {
Class<?> c = o.getClass();
while (c.isArray()) {
c = c.getComponentType();
}
+ return c;
+ }
+
+ /**
+ * The default TensorFlow data type to which Java object o corresponds. Some Java objects
+ * represent more than one TensorFlow data type; for example, 'byte' can represent both {@code
+ * uint8} and {@code string}, with the latter being the default interpretation.
+ */
+ private static DataType dataTypeOf(Object o) {
+ Class<?> c = baseObjType(o);
+ return dataTypeFromClass(c);
+ }
+
+ private static DataType dataTypeFromClass(Class<?> c) {
DataType ret = classDataTypes.get(c);
if (ret != null) {
return ret;
@@ -577,7 +661,12 @@ public final class Tensor implements AutoCloseable {
}
/**
- * Returns the number of dimensions of a tensor of type dtype when represented by the object o.
+ * Return the number of dimensions of the tensor that object {@code o} represents as a tensor
+ * whose datatype is {@code dtype}. Normally this is the same as the number of dimensions of o
+ * itself, but is one smaller for tensors of strings.
+ *
+ * @param o The object to inspect. It must be a valid representation of the given data type.
+ * @param dtype The expected data type of the tensor.
*/
private static int numDimensions(Object o, DataType dtype) {
int ret = numArrayDimensions(o);
@@ -624,7 +713,13 @@ public final class Tensor implements AutoCloseable {
/** Returns whether the object {@code obj} can represent a tensor with data type {@code dtype}. */
private static boolean objectCompatWithType(Object obj, DataType dtype) {
- DataType dto = dataTypeOf(obj);
+ Class<?> c = baseObjType(obj);
+ DataType dto = dataTypeFromClass(c);
+ int nd = numDimensions(obj, dto);
+ if (!c.isPrimitive() && c != String.class && nd != 0) {
+ throw new IllegalArgumentException(
+ "cannot create non-scalar Tensors from arrays of boxed values");
+ }
if (dto.equals(dtype)) {
return true;
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
new file mode 100644
index 0000000000..c828d23efc
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
@@ -0,0 +1,447 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+/** Type-safe factory methods for creating {@link org.tensorflow.Tensor} objects. */
+public final class Tensors {
+ private Tensors() {}
+
+ /**
+ * Creates a scalar String tensor using the default, UTF-8 encoding.
+ *
+ * @param data The string to put into the new scalar tensor.
+ */
+ public static Tensor<String> create(String data) {
+ return Tensor.create(data.getBytes(UTF_8), String.class);
+ }
+
+ /**
+ * Creates a scalar String tensor using a specified encoding.
+ *
+ * @param charset The encoding from String to bytes.
+ * @param data The string to put into the new scalar tensor.
+ */
+ public static Tensor<String> create(String data, java.nio.charset.Charset charset) {
+ return Tensor.create(data.getBytes(charset), String.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code float} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Float> create(float data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][][][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][][][][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code float} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Float> create(float[][][][][][] data) {
+ return Tensor.create(data, Float.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code double} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Double> create(double data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][][][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][][][][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code double} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Double> create(double[][][][][][] data) {
+ return Tensor.create(data, Double.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code int} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Integer> create(int data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][][][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][][][][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code int} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Integer> create(int[][][][][][] data) {
+ return Tensor.create(data, Integer.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code byte} element.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][][][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][][][][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code byte} elements.
+ *
+ * @param data An array containing the data to put into the new tensor. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Tensor<String> create(byte[][][][][][] data) {
+ return Tensor.create(data, String.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code long} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Long> create(long data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][][][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][][][][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code long} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Long> create(long[][][][][][] data) {
+ return Tensor.create(data, Long.class);
+ }
+
+ /**
+ * Creates a scalar tensor containing a single {@code boolean} element.
+ *
+ * @param data The value to put into the new scalar tensor.
+ */
+ public static Tensor<Boolean> create(boolean data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-1 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-2 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-3 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-4 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][][][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-5 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][][][][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-6 tensor of {@code boolean} elements.
+ *
+ * @param data An array containing the values to put into the new tensor. The dimensions of the
+ * new tensor will match those of the array.
+ */
+ public static Tensor<Boolean> create(boolean[][][][][][] data) {
+ return Tensor.create(data, Boolean.class);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
index 19929188a5..489e95c310 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@@ -29,6 +29,7 @@ import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
+import org.tensorflow.types.UInt8;
/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
public class LabelImage {
@@ -61,17 +62,17 @@ public class LabelImage {
readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
- try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
+ try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
float[] labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(
- String.format(
- "BEST MATCH: %s (%.2f%% likely)",
- labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
+ String.format("BEST MATCH: %s (%.2f%% likely)",
+ labels.get(bestLabelIdx),
+ labelProbabilities[bestLabelIdx] * 100f));
}
}
- private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
+ private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
try (Graph g = new Graph()) {
GraphBuilder b = new GraphBuilder(g);
// Some constants specific to the pre-trained model at:
@@ -88,28 +89,29 @@ public class LabelImage {
// Since the graph is being constructed once per execution here, we can use a constant for the
// input image. If the graph were to be re-used for multiple input images, a placeholder would
// have been more appropriate.
- final Output input = b.constant("input", imageBytes);
- final Output output =
+ final Output<String> input = b.constant("input", imageBytes);
+ final Output<Float> output =
b.div(
b.sub(
b.resizeBilinear(
b.expandDims(
- b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
+ b.cast(b.decodeJpeg(input, 3), Float.class),
b.constant("make_batch", 0)),
b.constant("size", new int[] {H, W})),
b.constant("mean", mean)),
b.constant("scale", scale));
try (Session s = new Session(g)) {
- return s.runner().fetch(output.op().name()).run().get(0);
+ return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}
}
}
- private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
+ private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
- Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
+ Tensor<Float> result =
+ s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
@@ -161,48 +163,71 @@ public class LabelImage {
this.g = g;
}
- Output div(Output x, Output y) {
+ Output<Float> div(Output<Float> x, Output<Float> y) {
return binaryOp("Div", x, y);
}
- Output sub(Output x, Output y) {
+ <T> Output<T> sub(Output<T> x, Output<T> y) {
return binaryOp("Sub", x, y);
}
- Output resizeBilinear(Output images, Output size) {
- return binaryOp("ResizeBilinear", images, size);
+ <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
+ return binaryOp3("ResizeBilinear", images, size);
}
- Output expandDims(Output input, Output dim) {
- return binaryOp("ExpandDims", input, dim);
+ <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
+ return binaryOp3("ExpandDims", input, dim);
}
- Output cast(Output value, DataType dtype) {
- return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
+ <T, U> Output<U> cast(Output<T> value, Class<U> type) {
+ DataType dtype = DataType.fromClass(type);
+ return g.opBuilder("Cast", "Cast")
+ .addInput(value)
+ .setAttr("DstT", dtype)
+ .build()
+ .<U>output(0);
}
- Output decodeJpeg(Output contents, long channels) {
+ Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
return g.opBuilder("DecodeJpeg", "DecodeJpeg")
.addInput(contents)
.setAttr("channels", channels)
.build()
- .output(0);
+ .<UInt8>output(0);
}
- Output constant(String name, Object value) {
- try (Tensor t = Tensor.create(value)) {
+ <T> Output<T> constant(String name, Object value, Class<T> type) {
+ try (Tensor<T> t = Tensor.<T>create(value, type)) {
return g.opBuilder("Const", name)
- .setAttr("dtype", t.dataType())
+ .setAttr("dtype", DataType.fromClass(type))
.setAttr("value", t)
.build()
- .output(0);
+ .<T>output(0);
}
}
+ Output<String> constant(String name, byte[] value) {
+ return this.constant(name, value, String.class);
+ }
- private Output binaryOp(String type, Output in1, Output in2) {
- return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
+ Output<Integer> constant(String name, int value) {
+ return this.constant(name, value, Integer.class);
}
+ Output<Integer> constant(String name, int[] value) {
+ return this.constant(name, value, Integer.class);
+ }
+
+ Output<Float> constant(String name, float value) {
+ return this.constant(name, value, Float.class);
+ }
+
+ private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
+ return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
+ }
+
+ private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
+ return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
+ }
private Graph g;
}
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java b/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java
index 5971103d6d..ac48da8032 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Operands.java
@@ -33,12 +33,12 @@ public final class Operands {
* @param inputs an iteration of input operands
* @return an array of outputs
*/
- public static Output[] asOutputs(Iterable<? extends Operand> inputs) {
- List<Output> outputList = new ArrayList<>();
- for (Operand input : inputs) {
+ public static Output<?>[] asOutputs(Iterable<? extends Operand<?>> inputs) {
+ List<Output<?>> outputList = new ArrayList<>();
+ for (Operand<?> input : inputs) {
outputList.add(input.asOutput());
}
- return outputList.toArray(new Output[outputList.size()]);
+ return outputList.toArray(new Output<?>[outputList.size()]);
}
// Disabled constructor
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index cd7931d3bb..725c81765a 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -31,7 +31,7 @@ import org.tensorflow.op.annotation.Operator;
/** An operator producing a constant value. */
@Operator
-public final class Constant extends PrimitiveOp implements Operand {
+public final class Constant<T> extends PrimitiveOp implements Operand<T> {
/**
* Create a constant from a Java object.
*
@@ -47,8 +47,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param object a Java object representing the constant.
* @see org.tensorflow.Tensor#create(Object) Tensor.create
*/
- public static Constant create(Scope scope, Object object) {
- try (Tensor value = Tensor.create(object)) {
+ public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
+ try (Tensor<T> value = Tensor.create(object, type)) {
return createWithTensor(scope, value);
}
}
@@ -66,8 +66,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant create(Scope scope, long[] shape, IntBuffer data) {
- try (Tensor value = Tensor.create(shape, data)) {
+ public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) {
+ try (Tensor<Integer> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -85,8 +85,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant create(Scope scope, long[] shape, FloatBuffer data) {
- try (Tensor value = Tensor.create(shape, data)) {
+ public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) {
+ try (Tensor<Float> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -104,8 +104,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant create(Scope scope, long[] shape, DoubleBuffer data) {
- try (Tensor value = Tensor.create(shape, data)) {
+ public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) {
+ try (Tensor<Double> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -123,8 +123,8 @@ public final class Constant extends PrimitiveOp implements Operand {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant create(Scope scope, long[] shape, LongBuffer data) {
- try (Tensor value = Tensor.create(shape, data)) {
+ public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) {
+ try (Tensor<Long> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -143,14 +143,14 @@ public final class Constant extends PrimitiveOp implements Operand {
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) {
- try (Tensor value = Tensor.create(dataType, shape, data)) {
+ public static <T> Constant<T> create(Scope scope, Class<T> type, long[] shape, ByteBuffer data) {
+ try (Tensor<T> value = Tensor.create(type, shape, data)) {
return createWithTensor(scope, value);
}
}
- private static Constant createWithTensor(Scope scope, Tensor value) {
- return new Constant(
+ private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) {
+ return new Constant<T>(
scope
.graph()
.opBuilder("Const", scope.makeOpName("Const"))
@@ -160,7 +160,7 @@ public final class Constant extends PrimitiveOp implements Operand {
}
@Override
- public Output asOutput() {
+ public Output<T> asOutput() {
return output;
}
@@ -169,5 +169,5 @@ public final class Constant extends PrimitiveOp implements Operand {
output = operation.output(0);
}
- private final Output output;
+ private final Output<T> output;
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java
new file mode 100644
index 0000000000..0c751aed9f
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/UInt8.java
@@ -0,0 +1,21 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.types;
+
+/** Represents an 8-bit unsigned integer. */
+public class UInt8 {
+ private UInt8() {}
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
index f1410a760e..96018c5366 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/package-info.java
@@ -15,13 +15,15 @@ limitations under the License.
/**
* Defines classes that represent TensorFlow data types. For each possible data type
- * that can be used in a tensor, there is a corresponding class in this package that
+ * that can be used in a tensor, there is a corresponding class that
* is used to represent it. For example, the TensorFlow int32 type is represented by
- * the type TFInt32 and by the class object TFInt32.class. The former is used to
- * support compile-time checking of tensor data types and the latter is used for
- * run-time checking of data types. All such classes implement the TFType interface.
- * TensorFlow data types are also separately represented by the DataType enum, with
- * one enum value per data type. The enum representation should rarely be needed, but
- * the Types class can be used to obtain it from the class object representation.
+ * the type {@link Integer} and by the class object {@code Integer.class}. The former is used to
+ * support compile-time checking of tensor element types and the latter is used for
+ * run-time checking of element types. Classes appearing in this package, such as
+ * UInt8, represent TensorFlow data types for which there is no existing Java equivalent.
+ *
+ * <p>TensorFlow element types are also separately represented by the {@link DataType} enum, with
+ * one enum value per element type. The enum representation is not usually needed, but
+ * can be obtained using {@link DataType.fromClass}.
*/
package org.tensorflow.types;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index 4adc861bf1..c540299bdc 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
index b3bc3aaef9..6dc233987b 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
@@ -34,8 +34,8 @@ public class OperationBuilderTest {
public void failWhenMixingOperationsOnDifferentGraphs() {
try (Graph g1 = new Graph();
Graph g2 = new Graph()) {
- Output c1 = TestUtil.constant(g1, "C1", 3);
- Output c2 = TestUtil.constant(g2, "C2", 3);
+ Output<Integer> c1 = TestUtil.constant(g1, "C1", 3);
+ Output<Integer> c2 = TestUtil.constant(g2, "C2", 3);
TestUtil.addN(g1, c1, c1);
try {
TestUtil.addN(g2, c1, c2);
@@ -48,7 +48,7 @@ public class OperationBuilderTest {
@Test
public void failOnUseAfterBuild() {
try (Graph g = new Graph();
- Tensor t = Tensor.create(1)) {
+ Tensor<Integer> t = Tensors.create(1)) {
OperationBuilder b =
g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
b.build();
@@ -64,7 +64,7 @@ public class OperationBuilderTest {
public void failOnUseAfterGraphClose() {
OperationBuilder b = null;
try (Graph g = new Graph();
- Tensor t = Tensor.create(1)) {
+ Tensor<Integer> t = Tensors.create(1)) {
b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
}
try {
@@ -85,7 +85,7 @@ public class OperationBuilderTest {
// types that aren't inferred from the input arguments.
try (Graph g = new Graph()) {
// dtype, tensor attributes.
- try (Tensor t = Tensor.create(1)) {
+ try (Tensor<Integer> t = Tensors.create(1)) {
g.opBuilder("Const", "DataTypeAndTensor")
.setAttr("dtype", DataType.INT32)
.setAttr("value", t)
@@ -101,7 +101,7 @@ public class OperationBuilderTest {
assertTrue(hasNode(g, "StringAndBool"));
// int (TF "int" attributes are 64-bit signed, so a Java long).
g.opBuilder("RandomUniform", "Int")
- .addInput(TestUtil.constant(g, "RandomUniformShape", new int[]{1}))
+ .addInput(TestUtil.constant(g, "RandomUniformShape", new int[] {1}))
.setAttr("seed", 10)
.setAttr("dtype", DataType.FLOAT)
.build();
@@ -127,7 +127,7 @@ public class OperationBuilderTest {
@Test
public void setAttrShape() {
try (Graph g = new Graph()) {
- Output n =
+ Output<?> n =
g.opBuilder("Placeholder", "unknown")
.setAttr("dtype", DataType.FLOAT)
.setAttr("shape", Shape.unknown())
@@ -136,8 +136,7 @@ public class OperationBuilderTest {
assertEquals(-1, n.shape().numDimensions());
assertEquals(DataType.FLOAT, n.dataType());
- n =
- g.opBuilder("Placeholder", "batch_of_vectors")
+ n = g.opBuilder("Placeholder", "batch_of_vectors")
.setAttr("dtype", DataType.FLOAT)
.setAttr("shape", Shape.make(-1, 784))
.build()
@@ -153,13 +152,13 @@ public class OperationBuilderTest {
public void addControlInput() {
try (Graph g = new Graph();
Session s = new Session(g);
- Tensor yes = Tensor.create(true);
- Tensor no = Tensor.create(false)) {
- Output placeholder = TestUtil.placeholder(g, "boolean", DataType.BOOL);
+ Tensor<Boolean> yes = Tensors.create(true);
+ Tensor<Boolean> no = Tensors.create(false)) {
+ Output<Boolean> placeholder = TestUtil.placeholder(g, "boolean", Boolean.class);
Operation check =
g.opBuilder("Assert", "assert")
.addInput(placeholder)
- .addInputList(new Output[] {placeholder})
+ .addInputList(new Output<?>[] {placeholder})
.build();
Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build();
diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
index aade375db8..6fe3b3c327 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/OperationTest.java
@@ -24,7 +24,6 @@ import static org.junit.Assert.fail;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -104,9 +103,9 @@ public class OperationTest {
@Test
public void outputEquality() {
try (Graph g = new Graph()) {
- Output output = TestUtil.constant(g, "c", 1);
- Output output1 = output.op().output(0);
- Output output2 = g.operation("c").output(0);
+ Output<Integer> output = TestUtil.constant(g, "c", 1);
+ Output<Integer> output1 = output.op().<Integer>output(0);
+ Output<Integer> output2 = g.operation("c").<Integer>output(0);
assertEquals(output, output1);
assertEquals(output.hashCode(), output1.hashCode());
assertEquals(output, output2);
@@ -117,10 +116,10 @@ public class OperationTest {
@Test
public void outputCollection() {
try (Graph g = new Graph()) {
- Output output = TestUtil.constant(g, "c", 1);
- Output output1 = output.op().output(0);
- Output output2 = g.operation("c").output(0);
- Set<Output> ops = new HashSet<>();
+ Output<Integer> output = TestUtil.constant(g, "c", 1);
+ Output<Integer> output1 = output.op().<Integer>output(0);
+ Output<Integer> output2 = g.operation("c").<Integer>output(0);
+ Set<Output<Integer>> ops = new HashSet<>();
ops.addAll(Arrays.asList(output, output1, output2));
assertEquals(1, ops.size());
assertTrue(ops.contains(output));
@@ -132,7 +131,7 @@ public class OperationTest {
@Test
public void outputToString() {
try (Graph g = new Graph()) {
- Output output = TestUtil.constant(g, "c", new int[] {1});
+ Output<Integer> output = TestUtil.constant(g, "c", new int[] {1});
assertNotNull(output.toString());
}
}
@@ -158,7 +157,7 @@ public class OperationTest {
public void outputList() {
try (Graph g = new Graph()) {
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
- Output[] outputs = split.outputList(1, 2);
+ Output<?>[] outputs = split.outputList(1, 2);
assertNotNull(outputs);
assertEquals(2, outputs.length);
for (int i = 0; i < outputs.length; ++i) {
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
index 50bdf351e3..a86b4dd117 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
@@ -35,9 +35,9 @@ public class SessionTest {
try (Graph g = new Graph();
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
- try (Tensor x = Tensor.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor> outputs =
- new AutoCloseableList<Tensor>(s.runner().feed("X", x).fetch("Y").run())) {
+ try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
+ AutoCloseableList<Tensor<?>> outputs =
+ new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -50,11 +50,11 @@ public class SessionTest {
try (Graph g = new Graph();
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
- Output feed = g.operation("X").output(0);
- Output fetch = g.operation("Y").output(0);
- try (Tensor x = Tensor.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor> outputs =
- new AutoCloseableList<Tensor>(s.runner().feed(feed, x).fetch(fetch).run())) {
+ Output<Integer> feed = g.operation("X").output(0);
+ Output<Integer> fetch = g.operation("Y").output(0);
+ try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
+ AutoCloseableList<Tensor<?>> outputs =
+ new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -78,14 +78,21 @@ public class SessionTest {
.build()
.output(0);
// Fetch using colon separated names.
- try (Tensor fetched = s.runner().fetch("Split:1").run().get(0)) {
+ try (Tensor<Integer> fetched =
+ s.runner().fetch("Split:1").run().get(0).expect(Integer.class)) {
final int[] expected = {3, 4};
assertArrayEquals(expected, fetched.copyTo(new int[2]));
}
// Feed using colon separated names.
- try (Tensor fed = Tensor.create(new int[] {4, 3, 2, 1});
- Tensor fetched =
- s.runner().feed("Split:0", fed).feed("Split:1", fed).fetch("Add").run().get(0)) {
+ try (Tensor<Integer> fed = Tensors.create(new int[] {4, 3, 2, 1});
+ Tensor<Integer> fetched =
+ s.runner()
+ .feed("Split:0", fed)
+ .feed("Split:1", fed)
+ .fetch("Add")
+ .run()
+ .get(0)
+ .expect(Integer.class)) {
final int[] expected = {8, 6, 4, 2};
assertArrayEquals(expected, fetched.copyTo(new int[4]));
}
@@ -97,7 +104,7 @@ public class SessionTest {
try (Graph g = new Graph();
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
- try (Tensor x = Tensor.create(new int[][] {{5}, {7}})) {
+ try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}})) {
Session.Run result =
s.runner()
.feed("X", x)
@@ -105,7 +112,7 @@ public class SessionTest {
.setOptions(fullTraceRunOptions())
.runAndFetchMetadata();
// Sanity check on outputs.
- AutoCloseableList<Tensor> outputs = new AutoCloseableList<Tensor>(result.outputs);
+ AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs);
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -117,6 +124,7 @@ public class SessionTest {
assertTrue(md.toString(), md.hasStepStats());
*/
assertTrue(result.metadata.length > 0);
+ outputs.close();
}
}
}
@@ -127,11 +135,12 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.constant(g, "c1", 2718);
TestUtil.constant(g, "c2", 31415);
- AutoCloseableList<Tensor> outputs =
- new AutoCloseableList<Tensor>(s.runner().fetch("c2").fetch("c1").run());
+ AutoCloseableList<Tensor<?>> outputs =
+ new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
assertEquals(2, outputs.size());
assertEquals(31415, outputs.get(0).intValue());
assertEquals(2718, outputs.get(1).intValue());
+ outputs.close();
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
index fe46c0184c..3b027700c5 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
@@ -61,7 +61,7 @@ public class ShapeTest {
@Test
public void nodesInAGraph() {
try (Graph g = new Graph()) {
- Output n = TestUtil.placeholder(g, "feed", DataType.FLOAT);
+ Output<Float> n = TestUtil.placeholder(g, "feed", Float.class);
assertEquals(-1, n.shape().numDimensions());
n = TestUtil.constant(g, "scalar", 3);
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
index 036db04503..6538359d11 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
@@ -30,6 +30,7 @@ import java.nio.LongBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.tensorflow.types.UInt8;
/** Unit tests for {@link org.tensorflow.Tensor}. */
@RunWith(JUnit4.class)
@@ -47,7 +48,7 @@ public class TensorTest {
byte[] strings = "test".getBytes(UTF_8);
long[] strings_shape = {};
byte[] strings_; // raw TF_STRING
- try (Tensor t = Tensor.create(strings)) {
+ try (Tensor<String> t = Tensors.create(strings)) {
ByteBuffer to = ByteBuffer.allocate(t.numBytes());
t.writeTo(to);
strings_ = to.array();
@@ -55,7 +56,7 @@ public class TensorTest {
// validate creating a tensor using a byte buffer
{
- try (Tensor t = Tensor.create(DataType.BOOL, bools_shape, ByteBuffer.wrap(bools_))) {
+ try (Tensor<Boolean> t = Tensor.create(Boolean.class, bools_shape, ByteBuffer.wrap(bools_))) {
boolean[] actual = t.copyTo(new boolean[bools_.length]);
for (int i = 0; i < bools.length; ++i) {
assertEquals("" + i, bools[i], actual[i]);
@@ -63,7 +64,8 @@ public class TensorTest {
}
// note: the buffer is expected to contain raw TF_STRING (as per C API)
- try (Tensor t = Tensor.create(DataType.STRING, strings_shape, ByteBuffer.wrap(strings_))) {
+ try (Tensor<String> t =
+ Tensor.create(String.class, strings_shape, ByteBuffer.wrap(strings_))) {
assertArrayEquals(strings, t.bytesValue());
}
}
@@ -72,15 +74,15 @@ public class TensorTest {
{
ByteBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder());
buf.asDoubleBuffer().put(doubles);
- try (Tensor t = Tensor.create(DataType.DOUBLE, doubles_shape, buf)) {
+ try (Tensor<Double> t = Tensor.create(Double.class, doubles_shape, buf)) {
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
}
}
// validate shape checking
- try (Tensor t =
- Tensor.create(DataType.BOOL, new long[bools_.length * 2], ByteBuffer.wrap(bools_))) {
+ try (Tensor<Boolean> t =
+ Tensor.create(Boolean.class, new long[bools_.length * 2], ByteBuffer.wrap(bools_))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
@@ -99,7 +101,7 @@ public class TensorTest {
.asDoubleBuffer()
.put(doubles);
buf.flip();
- try (Tensor t = Tensor.create(new long[] {doubles.length}, buf)) {
+ try (Tensor<Double> t = Tensor.create(new long[] {doubles.length}, buf)) {
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
}
@@ -115,19 +117,19 @@ public class TensorTest {
// validate creating a tensor using a typed buffer
{
- try (Tensor t = Tensor.create(shape, DoubleBuffer.wrap(doubles))) {
+ try (Tensor<Double> t = Tensor.create(shape, DoubleBuffer.wrap(doubles))) {
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
}
- try (Tensor t = Tensor.create(shape, FloatBuffer.wrap(floats))) {
+ try (Tensor<Float> t = Tensor.create(shape, FloatBuffer.wrap(floats))) {
float[] actual = new float[floats.length];
assertArrayEquals(floats, t.copyTo(actual), EPSILON_F);
}
- try (Tensor t = Tensor.create(shape, IntBuffer.wrap(ints))) {
+ try (Tensor<Integer> t = Tensor.create(shape, IntBuffer.wrap(ints))) {
int[] actual = new int[ints.length];
assertArrayEquals(ints, t.copyTo(actual));
}
- try (Tensor t = Tensor.create(shape, LongBuffer.wrap(longs))) {
+ try (Tensor<Long> t = Tensor.create(shape, LongBuffer.wrap(longs))) {
long[] actual = new long[longs.length];
assertArrayEquals(longs, t.copyTo(actual));
}
@@ -135,22 +137,23 @@ public class TensorTest {
// validate shape-checking
{
- try (Tensor t = Tensor.create(new long[doubles.length + 1], DoubleBuffer.wrap(doubles))) {
+ try (Tensor<Double> t =
+ Tensor.create(new long[doubles.length + 1], DoubleBuffer.wrap(doubles))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
}
- try (Tensor t = Tensor.create(new long[floats.length + 1], FloatBuffer.wrap(floats))) {
+ try (Tensor<Float> t = Tensor.create(new long[floats.length + 1], FloatBuffer.wrap(floats))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
}
- try (Tensor t = Tensor.create(new long[ints.length + 1], IntBuffer.wrap(ints))) {
+ try (Tensor<Integer> t = Tensor.create(new long[ints.length + 1], IntBuffer.wrap(ints))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
}
- try (Tensor t = Tensor.create(new long[longs.length + 1], LongBuffer.wrap(longs))) {
+ try (Tensor<Long> t = Tensor.create(new long[longs.length + 1], LongBuffer.wrap(longs))) {
fail("should have failed on incompatible buffer");
} catch (IllegalArgumentException e) {
// expected
@@ -166,11 +169,11 @@ public class TensorTest {
long[] longs = {1L, 2L, 3L};
boolean[] bools = {true, false, true};
- try (Tensor tints = Tensor.create(ints);
- Tensor tfloats = Tensor.create(floats);
- Tensor tdoubles = Tensor.create(doubles);
- Tensor tlongs = Tensor.create(longs);
- Tensor tbools = Tensor.create(bools)) {
+ try (Tensor<Integer> tints = Tensors.create(ints);
+ Tensor<Float> tfloats = Tensors.create(floats);
+ Tensor<Double> tdoubles = Tensors.create(doubles);
+ Tensor<Long> tlongs = Tensors.create(longs);
+ Tensor<Boolean> tbools = Tensors.create(bools)) {
// validate that any datatype is readable with ByteBuffer (content, position)
{
@@ -293,35 +296,35 @@ public class TensorTest {
@Test
public void scalars() {
- try (Tensor t = Tensor.create(2.718f)) {
+ try (Tensor<Float> t = Tensors.create(2.718f)) {
assertEquals(DataType.FLOAT, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertEquals(2.718f, t.floatValue(), EPSILON_F);
}
- try (Tensor t = Tensor.create(3.1415)) {
+ try (Tensor<Double> t = Tensors.create(3.1415)) {
assertEquals(DataType.DOUBLE, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertEquals(3.1415, t.doubleValue(), EPSILON);
}
- try (Tensor t = Tensor.create(-33)) {
+ try (Tensor<Integer> t = Tensors.create(-33)) {
assertEquals(DataType.INT32, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertEquals(-33, t.intValue());
}
- try (Tensor t = Tensor.create(8589934592L)) {
+ try (Tensor<Long> t = Tensors.create(8589934592L)) {
assertEquals(DataType.INT64, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
assertEquals(8589934592L, t.longValue());
}
- try (Tensor t = Tensor.create(true)) {
+ try (Tensor<Boolean> t = Tensors.create(true)) {
assertEquals(DataType.BOOL, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
@@ -329,7 +332,7 @@ public class TensorTest {
}
final byte[] bytes = {1, 2, 3, 4};
- try (Tensor t = Tensor.create(bytes)) {
+ try (Tensor<String> t = Tensors.create(bytes)) {
assertEquals(DataType.STRING, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
@@ -340,7 +343,7 @@ public class TensorTest {
@Test
public void nDimensional() {
double[] vector = {1.414, 2.718, 3.1415};
- try (Tensor t = Tensor.create(vector)) {
+ try (Tensor<Double> t = Tensors.create(vector)) {
assertEquals(DataType.DOUBLE, t.dataType());
assertEquals(1, t.numDimensions());
assertArrayEquals(new long[] {3}, t.shape());
@@ -350,7 +353,7 @@ public class TensorTest {
}
int[][] matrix = {{1, 2, 3}, {4, 5, 6}};
- try (Tensor t = Tensor.create(matrix)) {
+ try (Tensor<Integer> t = Tensors.create(matrix)) {
assertEquals(DataType.INT32, t.dataType());
assertEquals(2, t.numDimensions());
assertArrayEquals(new long[] {2, 3}, t.shape());
@@ -362,7 +365,7 @@ public class TensorTest {
long[][][] threeD = {
{{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}},
};
- try (Tensor t = Tensor.create(threeD)) {
+ try (Tensor<Long> t = Tensors.create(threeD)) {
assertEquals(DataType.INT64, t.dataType());
assertEquals(3, t.numDimensions());
assertArrayEquals(new long[] {2, 5, 1}, t.shape());
@@ -376,7 +379,7 @@ public class TensorTest {
{{{false, false, true, true}, {false, true, false, false}}},
{{{false, true, false, true}, {false, true, true, false}}},
};
- try (Tensor t = Tensor.create(fourD)) {
+ try (Tensor<Boolean> t = Tensors.create(fourD)) {
assertEquals(DataType.BOOL, t.dataType());
assertEquals(4, t.numDimensions());
assertArrayEquals(new long[] {3, 1, 2, 4}, t.shape());
@@ -394,7 +397,7 @@ public class TensorTest {
matrix[i][j] = String.format("(%d, %d) = %d", i, j, i << j).getBytes(UTF_8);
}
}
- try (Tensor t = Tensor.create(matrix)) {
+ try (Tensor<String> t = Tensors.create(matrix)) {
assertEquals(DataType.STRING, t.dataType());
assertEquals(2, t.numDimensions());
assertArrayEquals(new long[] {4, 3}, t.shape());
@@ -412,14 +415,24 @@ public class TensorTest {
@Test
public void testUInt8Tensor() {
- byte[] vector = new byte[] { 1, 2, 3, 4 };
- try (Tensor t = Tensor.create(vector, DataType.UINT8)) {
+ byte[] vector = new byte[] {1, 2, 3, 4};
+ try (Tensor<UInt8> t = Tensor.create(vector, UInt8.class)) {
assertEquals(DataType.UINT8, t.dataType());
assertEquals(1, t.numDimensions());
assertArrayEquals(new long[] {4}, t.shape());
byte[] got = t.copyTo(new byte[4]);
- assertArrayEquals(got, vector);
+ assertArrayEquals(vector, got);
+ }
+ }
+
+ @Test
+ public void testCreateFromArrayOfBoxed() {
+ Integer[] vector = new Integer[] {1, 2, 3, 4};
+ try (Tensor<Integer> t = Tensor.create(vector, Integer.class)) {
+ fail("Tensor.create() should fail because it was given an array of boxed values");
+ } catch (IllegalArgumentException e) {
+ // The expected exception
}
}
@@ -431,7 +444,7 @@ public class TensorTest {
invalid[x][y] = new int[x + y + 1];
}
}
- try (Tensor t = Tensor.create(invalid)) {
+ try (Tensor<?> t = Tensor.create(invalid)) {
fail("Tensor.create() should fail because of differing sizes in the 3rd dimension");
} catch (IllegalArgumentException e) {
// The expected exception.
@@ -440,7 +453,7 @@ public class TensorTest {
@Test
public void failCopyToOnIncompatibleDestination() {
- try (final Tensor matrix = Tensor.create(new int[][] {{1, 2}, {3, 4}})) {
+ try (final Tensor<Integer> matrix = Tensors.create(new int[][] {{1, 2}, {3, 4}})) {
try {
matrix.copyTo(new int[2]);
fail("should have failed on dimension mismatch");
@@ -466,7 +479,7 @@ public class TensorTest {
@Test
public void failCopyToOnScalar() {
- try (final Tensor scalar = Tensor.create(3)) {
+ try (final Tensor<Integer> scalar = Tensors.create(3)) {
try {
scalar.copyTo(3);
fail("copyTo should fail on scalar tensors, suggesting use of primitive accessors instead");
@@ -478,8 +491,8 @@ public class TensorTest {
@Test
public void failOnArbitraryObject() {
- try (Tensor t = Tensor.create(new Object())) {
- fail("should fail on creating a Tensor with a Java object that has not equivalent DataType");
+ try (Tensor<?> t = Tensor.create(new Object())) {
+ fail("should fail on creating a Tensor with a Java object that has no equivalent DataType");
} catch (IllegalArgumentException e) {
// The expected exception.
}
@@ -487,7 +500,7 @@ public class TensorTest {
@Test
public void failOnZeroDimension() {
- try (Tensor t = Tensor.create(new int[3][0][1])) {
+ try (Tensor<Integer> t = Tensors.create(new int[3][0][1])) {
fail("should fail on creating a Tensor where one of the dimensions is 0");
} catch (IllegalArgumentException e) {
// The expected exception.
@@ -497,7 +510,7 @@ public class TensorTest {
@Test
public void useAfterClose() {
int n = 4;
- Tensor t = Tensor.create(n);
+ Tensor<?> t = Tensor.create(n);
t.close();
try {
t.intValue();
@@ -515,8 +528,8 @@ public class TensorTest {
// An exception is made for this test, where the pitfalls of this is avoided by not calling
// close() on both Tensors.
final float[][] matrix = {{1, 2, 3}, {4, 5, 6}};
- try (Tensor src = Tensor.create(matrix)) {
- Tensor cpy = Tensor.fromHandle(src.getNativeHandle());
+ try (Tensor<Float> src = Tensors.create(matrix)) {
+ Tensor<Float> cpy = Tensor.fromHandle(src.getNativeHandle()).expect(Float.class);
assertEquals(src.dataType(), cpy.dataType());
assertEquals(src.numDimensions(), cpy.numDimensions());
assertArrayEquals(src.shape(), cpy.shape());
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index e3415a696d..c973b5a3d8 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -19,33 +19,36 @@ import java.lang.reflect.Array;
/** Static utility functions. */
public class TestUtil {
- public static Output constant(Graph g, String name, Object value) {
- try (Tensor t = Tensor.create(value)) {
+ public static <T> Output<T> constant(Graph g, String name, Object value) {
+ try (Tensor<?> t = Tensor.create(value)) {
return g.opBuilder("Const", name)
.setAttr("dtype", t.dataType())
.setAttr("value", t)
.build()
- .output(0);
+ .<T>output(0);
}
}
- public static Output placeholder(Graph g, String name, DataType dtype) {
- return g.opBuilder("Placeholder", name).setAttr("dtype", dtype).build().output(0);
+ public static <T> Output<T> placeholder(Graph g, String name, Class<T> type) {
+ return g.opBuilder("Placeholder", name)
+ .setAttr("dtype", DataType.fromClass(type))
+ .build()
+ .<T>output(0);
}
- public static Output addN(Graph g, Output... inputs) {
+ public static Output<?> addN(Graph g, Output<?>... inputs) {
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
}
- public static Output matmul(
- Graph g, String name, Output a, Output b, boolean transposeA, boolean transposeB) {
+ public static <T> Output<T> matmul(
+ Graph g, String name, Output<T> a, Output<T> b, boolean transposeA, boolean transposeB) {
return g.opBuilder("MatMul", name)
.addInput(a)
.addInput(b)
.setAttr("transpose_a", transposeA)
.setAttr("transpose_b", transposeB)
.build()
- .output(0);
+ .<T>output(0);
}
public static Operation split(Graph g, String name, int[] values, int numSplit) {
@@ -57,7 +60,8 @@ public class TestUtil {
}
public static void transpose_A_times_X(Graph g, int[][] a) {
- matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false);
+ Output<Integer> aa = constant(g, "A", a);
+ matmul(g, "Y", aa, placeholder(g, "X", Integer.class), true, false);
}
/**
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java
index 4fdd150acc..79bfcc8354 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/OperandsTest.java
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
@@ -36,8 +36,9 @@ public class OperandsTest {
public void createOutputArrayFromOperandList() {
try (Graph g = new Graph()) {
Operation split = TestUtil.split(g, "split", new int[] {0, 1, 2}, 3);
- List<Output> list = Arrays.asList(split.output(0), split.output(2));
- Output[] array = Operands.asOutputs(list);
+ List<Output<Integer>> list =
+ Arrays.asList(split.<Integer>output(0), split.<Integer>output(2));
+ Output<?>[] array = Operands.asOutputs(list);
assertEquals(list.size(), array.length);
assertSame(array[0], list.get(0));
assertSame(array[1], list.get(1));
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java
index b24bf5a476..e02c38ed22 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/PrimitiveOpTest.java
@@ -36,7 +36,7 @@ public class PrimitiveOpTest {
@Test
public void equalsHashcode() {
try (Graph g = new Graph()) {
- Output array = TestUtil.constant(g, "array", new int[2]);
+ Output<Integer> array = TestUtil.constant(g, "array", new int[2]);
PrimitiveOp test1 =
new PrimitiveOp(g.opBuilder("Shape", "shape1").addInput(array).build()) {};
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
index 9256cb281d..125de73554 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
@@ -19,6 +19,8 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -26,6 +28,8 @@ import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
+import org.tensorflow.Tensors;
+import org.tensorflow.types.UInt8;
/** Unit tests for {@link org.tensorflow.Scope}. */
@RunWith(JUnit4.class)
@@ -122,13 +126,13 @@ public class ScopeTest {
public void basic() {
try (Graph g = new Graph()) {
Scope s = new Scope(g);
- Const c1 = Const.create(s, 42);
+ Const<Integer> c1 = Const.create(s, 42);
assertEquals("Const", c1.output().op().name());
- Const c2 = Const.create(s, 7);
+ Const<Integer> c2 = Const.create(s, 7);
assertEquals("Const_1", c2.output().op().name());
- Const c3 = Const.create(s.withName("four"), 4);
+ Const<Integer> c3 = Const.create(s.withName("four"), 4);
assertEquals("four", c3.output().op().name());
- Const c4 = Const.create(s.withName("four"), 4);
+ Const<Integer> c4 = Const.create(s.withName("four"), 4);
assertEquals("four_1", c4.output().op().name());
}
}
@@ -148,122 +152,164 @@ public class ScopeTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope s = new Scope(g);
- Output data = Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output();
+ Output<Integer> data =
+ Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output();
// Create a composite op with a customized name
- Variance var1 = Variance.create(s.withName("example"), data);
+ Variance<Integer> var1 = Variance.create(s.withName("example"), data, Integer.class);
assertEquals("example/variance", var1.output().op().name());
// Confirm internally added ops have the right names.
assertNotNull(g.operation("example/squared_deviation"));
assertNotNull(g.operation("example/Mean"));
- assertNotNull(g.operation("example/zero"));
+ // assertNotNull(g.operation("example/zero"));
// Same composite op with a default name
- Variance var2 = Variance.create(s, data);
+ Variance<Integer> var2 = Variance.create(s, data, Integer.class);
assertEquals("variance/variance", var2.output().op().name());
// Confirm internally added ops have the right names.
assertNotNull(g.operation("variance/squared_deviation"));
assertNotNull(g.operation("variance/Mean"));
- assertNotNull(g.operation("variance/zero"));
+ // assertNotNull(g.operation("variance/zero"));
// Verify correct results as well.
- Tensor result = sess.runner().fetch(var1.output()).run().get(0);
+ Tensor<Integer> result =
+ sess.runner().fetch(var1.output()).run().get(0).expect(Integer.class);
assertEquals(21704, result.intValue());
- result = sess.runner().fetch(var2.output()).run().get(0);
+ result = sess.runner().fetch(var2.output()).run().get(0).expect(Integer.class);
assertEquals(21704, result.intValue());
}
}
// "handwritten" sample operator classes
- private static final class Const {
- private final Output output;
+ private static final class Const<T> {
+ private final Output<T> output;
- static Const create(Scope s, Object v) {
- try (Tensor value = Tensor.create(v)) {
- return new Const(
+ static Const<Integer> create(Scope s, int v) {
+ return create(s, Tensors.create(v));
+ }
+
+ static Const<Integer> create(Scope s, int[] v) {
+ return create(s, Tensors.create(v));
+ }
+
+ static <T> Const<T> create(Scope s, Tensor<T> value) {
+ return new Const<T>(
+ s.graph()
+ .opBuilder("Const", s.makeOpName("Const"))
+ .setAttr("dtype", value.dataType())
+ .setAttr("value", value)
+ .build()
+ .<T>output(0));
+ }
+
+ static <T> Const<T> create(Scope s, Object v, Class<T> type) {
+ try (Tensor<T> value = Tensor.create(v, type)) {
+ return new Const<T>(
s.graph()
.opBuilder("Const", s.makeOpName("Const"))
.setAttr("dtype", value.dataType())
.setAttr("value", value)
.build()
- .output(0));
+ .<T>output(0));
}
}
- Const(Output o) {
+ Const(Output<T> o) {
output = o;
}
- Output output() {
+ Output<T> output() {
return output;
}
}
- private static final class Mean {
- private final Output output;
+ private static final class Mean<T> {
+ private final Output<T> output;
- static Mean create(Scope s, Output input, Output reductionIndices) {
- return new Mean(
+ static <T> Mean<T> create(Scope s, Output<T> input, Output<T> reductionIndices) {
+ return new Mean<T>(
s.graph()
.opBuilder("Mean", s.makeOpName("Mean"))
.addInput(input)
.addInput(reductionIndices)
.build()
- .output(0));
+ .<T>output(0));
}
- Mean(Output o) {
+ Mean(Output<T> o) {
output = o;
}
- Output output() {
+ Output<T> output() {
return output;
}
}
- private static final class SquaredDifference {
- private final Output output;
+ private static final class SquaredDifference<T> {
+ private final Output<T> output;
- static SquaredDifference create(Scope s, Output x, Output y) {
- return new SquaredDifference(
+ static <T> SquaredDifference<T> create(Scope s, Output<T> x, Output<T> y) {
+ return new SquaredDifference<T>(
s.graph()
.opBuilder("SquaredDifference", s.makeOpName("SquaredDifference"))
.addInput(x)
.addInput(y)
.build()
- .output(0));
+ .<T>output(0));
}
- SquaredDifference(Output o) {
+ SquaredDifference(Output<T> o) {
output = o;
}
- Output output() {
+ Output<T> output() {
return output;
}
}
- private static final class Variance {
- private final Output output;
+ /**
+ * Returns the zero value of type described by {@code c}, or null if the type (e.g., string) is
+ * not numeric and therefore has no zero value.
+ *
+ * @param c The class describing the TensorFlow type of interest.
+ */
+ public static Object zeroValue(Class<?> c) {
+ return zeros.get(c);
+ }
+
+ private static final Map<Class<?>, Object> zeros = new HashMap<>();
+
+ static {
+ zeros.put(Float.class, 0.0f);
+ zeros.put(Double.class, 0.0);
+ zeros.put(Integer.class, 0);
+ zeros.put(UInt8.class, (byte) 0);
+ zeros.put(Long.class, 0L);
+ zeros.put(Boolean.class, false);
+ zeros.put(String.class, null); // no zero value
+ }
+
+ private static final class Variance<T> {
+ private final Output<T> output;
- static Variance create(Scope base, Output x) {
+ static <T> Variance<T> create(Scope base, Output<T> x, Class<T> type) {
Scope s = base.withSubScope("variance");
- Output zero = Const.create(s.withName("zero"), new int[] {0}).output();
- Output sqdiff =
+ Output<T> zero = Const.create(base, zeroValue(type), type).output();
+ Output<T> sqdiff =
SquaredDifference.create(
s.withName("squared_deviation"), x, Mean.create(s, x, zero).output())
.output();
- return new Variance(Mean.create(s.withName("variance"), sqdiff, zero).output());
+ return new Variance<T>(Mean.create(s.withName("variance"), sqdiff, zero).output());
}
- Variance(Output o) {
+ Variance(Output<T> o) {
output = o;
}
- Output output() {
+ Output<T> output() {
return output;
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
index ec23792485..ca54214e06 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -29,7 +29,6 @@ import java.nio.LongBuffer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
@@ -47,8 +46,9 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
+ Tensor<Integer> result = sess.runner().fetch(op.asOutput())
+ .run().get(0).expect(Integer.class);
int[] actual = new int[ints.length];
assertArrayEquals(ints, result.copyTo(actual));
}
@@ -62,8 +62,8 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
+ Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
float[] actual = new float[floats.length];
assertArrayEquals(floats, result.copyTo(actual), EPSILON);
}
@@ -77,8 +77,8 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
+ Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
double[] actual = new double[doubles.length];
assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
}
@@ -92,8 +92,8 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs));
+ Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
long[] actual = new long[longs.length];
assertArrayEquals(longs, result.copyTo(actual));
}
@@ -123,8 +123,8 @@ public class ConstantTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content));
- Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content));
+ Tensor<String> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class);
assertArrayEquals(data, result.bytesValue());
}
}
diff --git a/tensorflow/python/debug/lib/debug_graphs.py b/tensorflow/python/debug/lib/debug_graphs.py
index 486e659158..87033d53a4 100644
--- a/tensorflow/python/debug/lib/debug_graphs.py
+++ b/tensorflow/python/debug/lib/debug_graphs.py
@@ -231,8 +231,8 @@ def _infer_device_name(graph_def):
break
if device_name is None:
logging.warn(
- "Failed to infer device name from partiton GraphDef: none of the nodes "
- "of the GraphDef has a non-empty device name.")
+ "Failed to infer device name from partition GraphDef: none of the "
+ "nodes of the GraphDef has a non-empty device name.")
return device_name
diff --git a/tensorflow/python/estimator/inputs/queues/feeding_functions.py b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
index d7fe4bbfa1..c0a287e922 100644
--- a/tensorflow/python/estimator/inputs/queues/feeding_functions.py
+++ b/tensorflow/python/estimator/inputs/queues/feeding_functions.py
@@ -49,7 +49,7 @@ except ImportError:
def _fill_array(arr, seq, fillvalue=0):
"""
Recursively fills padded arr with elements from seq.
- If lenght of seq is less then arr padded length, fillvalue used.
+ If length of seq is less than arr padded length, fillvalue used.
Args:
arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
index 97bef2965c..32e692ba7c 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
@@ -200,7 +200,7 @@ class TopologyConstructionTest(test.TestCase):
with self.assertRaises(ValueError):
_ = keras.layers.Input(shape=(32,), batch_shape=(10, 32))
with self.assertRaises(ValueError):
- _ = keras.layers.Input(shape=(32,), unknwon_kwarg=None)
+ _ = keras.layers.Input(shape=(32,), unknown_kwarg=None)
self.assertListEqual(a.get_shape().as_list(), [None, 32])
a_layer, a_node_index, a_tensor_index = a._keras_history
diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
index 18184a0ee0..7d0bc54b69 100644
--- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
@@ -24,8 +24,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import device_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@@ -289,6 +293,16 @@ class Conv2DTransposeTest(test.TestCase):
self.assertAllClose(cache_values, value)
+ def testConv2DTransposeShapeInference(self):
+ # Test case for 8972
+ initializer = random_ops.truncated_normal(
+ [3, 3, 5, 1], mean=0.0, stddev=0.01, dtype=dtypes.float32)
+ x = variables.Variable(random_ops.random_normal([3, 10, 5, 1]))
+ f = variable_scope.get_variable("f", initializer=initializer)
+ f_shape = array_ops.stack([array_ops.shape(x)[0], 10, 5, 5])
+ output = nn_ops.conv2d_transpose(
+ x, f, f_shape, strides=[1, 1, 1, 1], padding="SAME")
+ self.assertEqual(output.get_shape().as_list(), [None, 10, 5, 5])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
index 3853379328..7d9e57c8e5 100644
--- a/tensorflow/python/kernel_tests/decode_csv_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -116,6 +116,17 @@ class DecodeCSVOpTest(test.TestCase):
self._test(args, expected_out)
+ def testNA(self):
+ args = {
+ "records": ["2.0,NA,aa", "NA,5,bb", "3,6,NA"],
+ "record_defaults": [[0.0], [0], [""]],
+ "na_value": "NA"
+ }
+
+ expected_out = [[2.0, 0.0, 3], [0, 5, 6], [b"aa", b"bb", b""]]
+
+ self._test(args, expected_out)
+
def testWithDefaults(self):
args = {
"records": [",1,", "0.2,3,bcd", "3.0,,"],
diff --git a/tensorflow/python/kernel_tests/summary_tensor_op_test.py b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
index 3584637865..d534aadb79 100644
--- a/tensorflow/python/kernel_tests/summary_tensor_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
@@ -154,7 +154,7 @@ class SummaryOpsTest(test.TestCase):
self.assertEqual(descr.display_name, "my name")
self.assertEqual(descr.summary_description, "my description")
- # If both SummmaryMetadata and explicit args are provided, the args win
+ # If both SummaryMetadata and explicit args are provided, the args win
overwrite = summary_ops.tensor_summary(
"simple",
const,
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 6e7122db5e..d27e867583 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -207,6 +207,7 @@ TextLineReaderV2
TFRecordReaderV2
WholeFileReaderV2
LMDBReader
+DecodeCSV
# linalg_ops
BatchCholesky
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index c5fd15bae4..ea7132791c 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -1166,3 +1166,42 @@ def _parse_single_sequence_example_raw(serialized,
feature_list_sparse_tensors + feature_list_dense_values))
return (context_output, feature_list_output)
+
+
+# Swap `name` and `na_value` for backward compatibility.
+def decode_csv(records, record_defaults, field_delim=",",
+ use_quote_delim=True, name=None, na_value=""):
+ # pylint: disable=protected-access
+ """Convert CSV records to tensors. Each column maps to one tensor.
+
+ RFC 4180 format is expected for the CSV records.
+ (https://tools.ietf.org/html/rfc4180)
+ Note that we allow leading and trailing spaces with int or float field.
+
+ Args:
+ records: A `Tensor` of type `string`.
+ Each string is a record/row in the csv and all records should have
+ the same format.
+ record_defaults: A list of `Tensor` objects with specific types.
+ Acceptable types are `float32`, `int32`, `int64`, `string`.
+ One tensor per column of the input record, with either a
+ scalar default value for that column or empty if the column is required.
+ field_delim: An optional `string`. Defaults to `","`.
+ char delimiter to separate fields in a record.
+ use_quote_delim: An optional `bool`. Defaults to `True`.
+ If false, treats double quotation marks as regular
+ characters inside of the string fields (ignoring RFC 4180, Section 2,
+ Bullet 5).
+ name: A name for the operation (optional).
+ na_value: Additional string to recognize as NA/NaN.
+
+ Returns:
+ A list of `Tensor` objects. Has the same type as `record_defaults`.
+ Each tensor will have the same shape as records.
+ """
+ # TODO(martinwicke), remove the wrapper when new Python API generator is done.
+ return gen_parsing_ops._decode_csv(
+ records=records, record_defaults=record_defaults,
+ field_delim=field_delim, use_quote_delim=use_quote_delim,
+ na_value=na_value, name=name)
+ # pylint: enable=protected-access
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index bf8380ebbd..0a1a748c40 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -562,7 +562,7 @@ static bool TensorOpMathEnabled() {
bool ret;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DISABLE_TENSOR_OP_MATH",
/*default=*/false, &ret));
- return ret;
+ return !ret;
}();
return is_enabled;
}
@@ -2474,58 +2474,73 @@ struct WinogradNonfused {
};
bool CudnnSupport::GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
- out_algorithms->assign({
- // clang-format off
- CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
- CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
- CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
- CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
- CUDNN_CONVOLUTION_FWD_ALGO_FFT,
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index> algo_types = {
+ // clang-format off
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
+ CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
+ CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
+ CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
+ CUDNN_CONVOLUTION_FWD_ALGO_FFT,
#if CUDNN_VERSION >= 5000
- CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
#endif
- // clang-format on
- });
+ // clang-format on
+ };
if (CudnnEnvVar<FftTilingForward>::IsEnabled()) {
- out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
+ algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING);
}
#if CUDNN_VERSION >= 5100
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
- out_algorithms->push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
+ algo_types.push_back(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED);
}
#endif
+
+ out_algorithms->clear();
+ for (auto i : algo_types) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/false});
+ if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/true});
+ }
+ }
return true;
}
bool CudnnSupport::GetConvolveBackwardDataAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
- out_algorithms->assign({
- // clang-format off
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index> algo_types = {
+ // clang-format off
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
#if CUDNN_VERSION >= 5000
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
#endif
- // clang-format on
- });
+ // clang-format on
+ };
#if CUDNN_VERSION >= 5100
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
- out_algorithms->push_back(
- CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
+ algo_types.push_back(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED);
}
#endif
+
+ out_algorithms->clear();
+ for (auto i : algo_types) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/false});
+ if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/true});
+ }
+ }
return true;
}
bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) {
- out_algorithms->assign({
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) {
+ std::vector<dnn::AlgorithmDesc::Index> algo_types = {
// clang-format off
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
@@ -2534,13 +2549,20 @@ bool CudnnSupport::GetConvolveBackwardFilterAlgorithms(
// Based on cudnn.h, the following is not implemented.
// CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD,
// clang-format on
- });
+ };
#if CUDNN_VERSION >= 5110
if (CudnnEnvVar<WinogradNonfused>::IsEnabled() && with_winograd_nonfused) {
- out_algorithms->push_back(
- CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
+ algo_types.push_back(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED);
}
#endif
+
+ out_algorithms->clear();
+ for (auto i : algo_types) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/false});
+ if (cc_major >= 7 && CUDNN_VERSION >= 7000 && TensorOpMathEnabled()) {
+ out_algorithms->push_back({i, /*use_tensor_ops=*/true});
+ }
+ }
return true;
}
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index beb2f7d050..8d7069a902 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -145,16 +145,16 @@ class CudnnSupport : public dnn::DnnSupport {
ScratchAllocator* workspace_allocator) override;
bool GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) override;
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
bool GetConvolveBackwardDataAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) override;
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
bool GetConvolveBackwardFilterAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index>* out_algorithms) override;
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
bool DoBatchNormalizationForward(
Stream* stream, const DeviceMemory<float>& x,
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 2c40e18f5c..07fe8a85f4 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -23,20 +23,20 @@ namespace gputools {
namespace dnn {
bool DnnSupport::GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms) {
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms) {
return false;
}
bool DnnSupport::GetConvolveBackwardDataAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms) {
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms) {
return false;
}
bool DnnSupport::GetConvolveBackwardFilterAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms) {
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms) {
return false;
}
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 5fe523602a..624357b82f 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -1183,8 +1183,8 @@ class DnnSupport {
// Return a list of algorithms supported by the forward convolution pass.
virtual bool GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms);
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms);
// Version of DoConvolve that uses pre-quantized 8 bit coefficients.
// coefficient_scales specifies the scaling of each column of coefficients:
@@ -1263,8 +1263,8 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for
// data.
virtual bool GetConvolveBackwardDataAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms);
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms);
virtual bool DoConvolveBackwardData(
Stream* stream, const FilterDescriptor& filter_descriptor,
@@ -1312,8 +1312,8 @@ class DnnSupport {
// Return a list of algorithms supported by the backward convolution pass for
// filters.
virtual bool GetConvolveBackwardFilterAlgorithms(
- bool with_winograd_nonfused,
- std::vector<AlgorithmDesc::Index>* out_algorithms);
+ bool with_winograd_nonfused, int cc_major, int cc_minor,
+ std::vector<AlgorithmDesc>* out_algorithms);
virtual bool DoConvolveBackwardFilter(
Stream* stream, const BatchDescriptor& input_descriptor,
diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h
index ed12982e30..f0a0e60e02 100644
--- a/tensorflow/stream_executor/platform.h
+++ b/tensorflow/stream_executor/platform.h
@@ -96,7 +96,7 @@ class Platform {
// each platform is required to expose an ID to ensure unique registration and
// as a target against which plugins can register.
//
- // The macro below is provided to help generate a [process-unique] identifer.
+ // The macro below is provided to help generate a [process-unique] identifier.
using Id = void*;
// Helper macro to define a plugin ID. To be used only inside plugin
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index a72ee804c1..21172d5a16 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -70,7 +70,7 @@ class BatchDescriptor;
class FilterDescriptor;
class ConvolutionDescriptor;
class ProfileResult;
-struct AlgorithmDesc;
+class AlgorithmDesc;
} // namespace dnn
class StreamExecutor;
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 199a908914..9bbfe7f04a 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -286,35 +286,41 @@ bool StreamExecutor::SupportsDnn() const {
bool StreamExecutor::GetConvolveAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms) {
+ std::vector<dnn::AlgorithmDesc> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused,
- out_algorithms);
+ int cc_major, cc_minor;
+ GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
+ return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major,
+ cc_minor, out_algorithms);
}
bool StreamExecutor::GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms) {
+ std::vector<dnn::AlgorithmDesc> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
- return dnn_support->GetConvolveBackwardDataAlgorithms(with_winograd_nonfused,
- out_algorithms);
+ int cc_major, cc_minor;
+ GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
+ return dnn_support->GetConvolveBackwardDataAlgorithms(
+ with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
}
bool StreamExecutor::GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms) {
+ std::vector<dnn::AlgorithmDesc> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
+ int cc_major, cc_minor;
+ GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor);
return dnn_support->GetConvolveBackwardFilterAlgorithms(
- with_winograd_nonfused, out_algorithms);
+ with_winograd_nonfused, cc_major, cc_minor, out_algorithms);
}
bool StreamExecutor::GetBlasGemmAlgorithms(
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 98136a92a0..f354317a6e 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -343,20 +343,19 @@ class StreamExecutor {
bool SupportsDnn() const;
// Get the list of supported algorithms for the forward convolution opeartion.
- bool GetConvolveAlgorithms(
- bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms);
+ bool GetConvolveAlgorithms(bool with_winograd_nonfused,
+ std::vector<dnn::AlgorithmDesc> *out_algorithms);
// Get the list of supported algorithms for the backward convolution on data.
bool GetConvolveBackwardDataAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms);
+ std::vector<dnn::AlgorithmDesc> *out_algorithms);
// Get the list of supported algorithms for the backward convolution on the
// filter.
bool GetConvolveBackwardFilterAlgorithms(
bool with_winograd_nonfused,
- std::vector<dnn::AlgorithmDesc::Index> *out_algorithms);
+ std::vector<dnn::AlgorithmDesc> *out_algorithms);
// Get the list of supported algorithms for BLAS gemm.
bool GetBlasGemmAlgorithms(std::vector<blas::AlgorithmType> *out_algorithms);
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index a308688790..0f074151db 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -526,6 +526,7 @@ def tf_cc_test(name,
extra_copts=[],
suffix="",
linkopts=[],
+ nocopts=None,
**kwargs):
native.cc_test(
name="%s%s" % (name, suffix),
@@ -547,6 +548,7 @@ def tf_cc_test(name,
clean_dep("//tensorflow:darwin"): 1,
"//conditions:default": 0,
}),
+ nocopts=nocopts,
**kwargs)
@@ -649,7 +651,8 @@ def tf_cc_tests(srcs,
tags=[],
size="medium",
args=None,
- linkopts=[]):
+ linkopts=[],
+ nocopts=None):
for src in srcs:
tf_cc_test(
name=src_to_test_name(src),
@@ -659,7 +662,8 @@ def tf_cc_tests(srcs,
tags=tags,
size=size,
args=args,
- linkopts=linkopts)
+ linkopts=linkopts,
+ nocopts=nocopts)
def tf_cc_test_mkl(srcs,
@@ -669,7 +673,7 @@ def tf_cc_test_mkl(srcs,
tags=[],
size="medium",
args=None):
- if_mkl(tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args))
+ if_mkl(tf_cc_tests(srcs, deps, name, linkstatic=linkstatic, tags=tags, size=size, args=args, nocopts="-fno-exceptions"))
def tf_cc_tests_gpu(srcs,
@@ -867,18 +871,33 @@ def tf_mkl_kernel_library(name,
deps=None,
alwayslink=1,
copts=tf_copts(),
+ nocopts="-fno-exceptions",
**kwargs):
+ """A rule to build MKL-based TensorFlow kernel libraries."""
+ gpu_srcs = gpu_srcs # unused argument
+ kwargs = kwargs # unused argument
+
+ if not bool(srcs):
+ srcs = []
+ if not bool(hdrs):
+ hdrs = []
+
+ if prefix:
+ srcs = srcs + native.glob(
+ [prefix + "*.cc"])
+ hdrs = hdrs + native.glob(
+ [prefix + "*.h"])
+
if_mkl(
- tf_kernel_library(
- name,
- prefix=prefix,
+ native.cc_library(
+ name=name,
srcs=srcs,
- gpu_srcs=gpu_srcs,
hdrs=hdrs,
deps=deps,
alwayslink=alwayslink,
copts=copts,
- **kwargs))
+ nocopts=nocopts
+ ))
# Bazel rules for building swig files.
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 32a86e420a..6e03f9e8fb 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -874,7 +874,7 @@ tf_module {
}
member_method {
name: "decode_csv"
- argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\'], "
+ argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\', \'na_value\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\', \'\'], "
}
member_method {
name: "decode_json_example"
diff --git a/tensorflow/tools/ci_build/install/install_golang.sh b/tensorflow/tools/ci_build/install/install_golang.sh
index 88bc2960e3..596265b069 100755
--- a/tensorflow/tools/ci_build/install/install_golang.sh
+++ b/tensorflow/tools/ci_build/install/install_golang.sh
@@ -16,7 +16,7 @@
set -ex
-GOLANG_URL="https://storage.googleapis.com/golang/go1.8.3.linux-amd64.tar.gz"
+GOLANG_URL="https://storage.googleapis.com/golang/go1.9.linux-amd64.tar.gz"
sudo mkdir -p /usr/local
wget -q -O - "${GOLANG_URL}" | sudo tar -C /usr/local -xz
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index f5364d803a..04773376e9 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -78,10 +78,12 @@ WORKDIR /tensorflow
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
-ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
+ENV LD_LIBRARY_PATH /usr/local/cuda/lib64/stubs:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
ENV TF_NEED_CUDA 1
ENV TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
+
RUN tensorflow/tools/ci_build/builds/configured GPU \
bazel build -c opt --config=cuda --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
tensorflow/tools/pip_package:build_pip_package && \
diff --git a/tensorflow/tools/docker/jupyter_notebook_config.py b/tensorflow/tools/docker/jupyter_notebook_config.py
index 747beb8251..0acbf6fcee 100644
--- a/tensorflow/tools/docker/jupyter_notebook_config.py
+++ b/tensorflow/tools/docker/jupyter_notebook_config.py
@@ -18,7 +18,6 @@ from IPython.lib import passwd
c.NotebookApp.ip = '*'
c.NotebookApp.port = int(os.getenv('PORT', 8888))
c.NotebookApp.open_browser = False
-c.MultiKernelManager.default_kernel_name = 'python2'
# sets a password if PASSWORD is set in the environment
if 'PASSWORD' in os.environ:
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index ca3b778c29..1015103077 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -923,7 +923,7 @@ class _ClassPageInfo(object):
"""Sets the `aliases` list.
Args:
- aliases: A list of strings. Containing all the obejct's full names.
+ aliases: A list of strings. Containing all the object's full names.
"""
assert self.aliases is None
self._aliases = aliases
@@ -1438,7 +1438,7 @@ class _PythonBuiltin(object):
class _PythonFile(object):
"""This class indicates that the object is defined in a regular python file.
- This can be used for the `defined_in` slot of the `PageInfo` obejcts.
+ This can be used for the `defined_in` slot of the `PageInfo` objects.
"""
def __init__(self, path, parser_config):
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
index 81f85e0009..6f0b4f47de 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
@@ -93,13 +93,15 @@ TEST(CreateProtoDebugStringLibTest, ValidSimpleTypes) {
proto.set_optional_int64(std::numeric_limits<protobuf_int64>::max());
proto.set_optional_uint32(std::numeric_limits<uint32>::max());
proto.set_optional_uint64(std::numeric_limits<uint64>::max());
- proto.set_optional_float(std::numeric_limits<float>::max());
+ // TODO(b/67475677): Re-enable after resolving float precision issue
+ // proto.set_optional_float(std::numeric_limits<float>::max());
proto.set_optional_double(std::numeric_limits<double>::max());
EXPECT_TEXT_TRANSFORMS_MATCH();
// Least positive numeric values.
proto.Clear();
- proto.set_optional_float(std::numeric_limits<float>::min());
+ // TODO(b/67475677): Re-enable after resolving float precision issue
+ // proto.set_optional_float(std::numeric_limits<float>::min());
proto.set_optional_double(std::numeric_limits<double>::min());
EXPECT_TEXT_TRANSFORMS_MATCH();
@@ -107,7 +109,8 @@ TEST(CreateProtoDebugStringLibTest, ValidSimpleTypes) {
proto.Clear();
proto.set_optional_int32(std::numeric_limits<int32>::lowest());
proto.set_optional_int64(std::numeric_limits<protobuf_int64>::lowest());
- proto.set_optional_float(std::numeric_limits<float>::lowest());
+ // TODO(b/67475677): Re-enable after resolving float precision issue
+ // proto.set_optional_float(std::numeric_limits<float>::lowest());
proto.set_optional_double(std::numeric_limits<double>::lowest());
EXPECT_TEXT_TRANSFORMS_MATCH();
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index b226184261..de0084613b 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -171,6 +171,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
"and will be removed in the future.")
native.new_http_archive(
+ name = "mkl_dnn",
+ urls = [
+ "https://github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz",
+ "http://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz",
+ ],
+ sha256 = "0d529ad4c49dc799e6df07c2b88b115d0668735da15fb3b3862d28d33fa68165",
+ strip_prefix = "mkl-dnn-b01e3a55a07be62172e713bcd2644c5176360212",
+ build_file = str(Label("//third_party/mkl_dnn:mkldnn.BUILD")),
+ )
+
+ native.new_http_archive(
name = "eigen_archive",
urls = [
"https://bitbucket.org/eigen/eigen/get/429aa5254200.tar.gz",
@@ -373,10 +384,10 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
patched_http_archive(
name = "protobuf_archive",
urls = [
- "http://mirror.bazel.build/github.com/google/protobuf/archive/0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66.tar.gz",
+ "http://mirror.bazel.build/github.com/google/protobuf/archive/b04e5cba356212e4e8c66c61bbe0c3a20537c5b9.tar.gz",
],
- sha256 = "6d43b9d223ce09e5d4ce8b0060cb8a7513577a35a64c7e3dad10f0703bf3ad93",
- strip_prefix = "protobuf-0b059a3d8a8f8aa40dde7bea55edca4ec5dfea66",
+ sha256 = "e178a25c52efcb6b05988bdbeace4c0d3f2d2fe5b46696d1d9898875c3803d6a",
+ strip_prefix = "protobuf-b04e5cba356212e4e8c66c61bbe0c3a20537c5b9",
# TODO: remove patching when tensorflow stops linking same protos into
# multiple shared libraries loaded in runtime by python.
# This patch fixes a runtime crash when tensorflow is compiled
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index baa6e01bca..31a4bfabf6 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -117,7 +117,7 @@ def get_cxx_inc_directories(repository_ctx, cc):
includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
- includes_cpp_set = set(includes_cpp)
+ includes_cpp_set = depset(includes_cpp)
return includes_cpp + [inc for inc in includes_c
if inc not in includes_cpp_set]
diff --git a/third_party/mkl_dnn/BUILD b/third_party/mkl_dnn/BUILD
new file mode 100644
index 0000000000..5b01f6e3e4
--- /dev/null
+++ b/third_party/mkl_dnn/BUILD
@@ -0,0 +1 @@
+licenses(["notice"])
diff --git a/third_party/mkl_dnn/mkldnn.BUILD b/third_party/mkl_dnn/mkldnn.BUILD
new file mode 100644
index 0000000000..58bb7a6a5d
--- /dev/null
+++ b/third_party/mkl_dnn/mkldnn.BUILD
@@ -0,0 +1,25 @@
+exports_files(["LICENSE"])
+
+cc_library(
+ name = "mkl_dnn",
+ srcs = glob([
+ "src/common/*.cpp",
+ "src/cpu/*.cpp",
+ ]),
+ hdrs = glob(["include/*"]),
+ copts = ["-fexceptions"] + select({
+ "@org_tensorflow//tensorflow:linux_x86_64": [
+ "-fopenmp",
+ ],
+ "//conditions:default": [],
+ }),
+ includes = [
+ "include",
+ "src",
+ "src/common",
+ "src/cpu",
+ "src/cpu/xbyak",
+ ],
+ nocopts = "-fno-exceptions",
+ visibility = ["//visibility:public"],
+)