aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-01 18:13:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-01 18:33:19 -0800
commitd1ba01f81d8fa1d0171ba9ce871599063d5c7eb9 (patch)
treecd28fd2d32712c59f8452ede903cd592e0dc95bd
parentffc667757c6c328e48d80c14f97e32cf6a9d0f53 (diff)
Merge changes from github.
Change: 146316196
-rw-r--r--LICENSE4
-rw-r--r--tensorflow/BUILD16
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt3
-rw-r--r--tensorflow/contrib/cmake/README.md2
-rw-r--r--tensorflow/contrib/cmake/tf_core_cpu.cmake2
-rw-r--r--tensorflow/contrib/cmake/tf_core_distributed_runtime.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake8
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake2
-rw-r--r--tensorflow/contrib/cmake/tf_tools.cmake8
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py31
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py15
-rw-r--r--tensorflow/contrib/learn/python/learn/datasets/mnist.py38
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py2
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py32
-rw-r--r--tensorflow/contrib/slim/python/slim/learning.py17
-rw-r--r--tensorflow/contrib/slim/python/slim/learning_test.py26
-rw-r--r--tensorflow/core/kernels/decode_raw_op.cc6
-rw-r--r--tensorflow/core/ops/array_ops.cc8
-rw-r--r--tensorflow/core/ops/ctc_ops.cc2
-rw-r--r--tensorflow/core/platform/denormal.cc8
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc10
-rw-r--r--tensorflow/examples/tutorials/input_fn/boston_predict.csv2
-rw-r--r--tensorflow/examples/tutorials/input_fn/boston_test.csv2
-rw-r--r--tensorflow/examples/tutorials/input_fn/boston_train.csv2
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_with_summaries.py2
-rw-r--r--tensorflow/examples/udacity/5_word2vec.ipynb4
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md2
-rw-r--r--tensorflow/g3doc/how_tos/distributed/index.md2
-rw-r--r--tensorflow/g3doc/how_tos/language_bindings/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/estimators/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/beginners/index.md6
-rw-r--r--tensorflow/java/BUILD2
-rw-r--r--tensorflow/java/README.md7
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java376
-rw-r--r--tensorflow/java/src/main/native/tensor_jni.cc26
-rw-r--r--tensorflow/java/src/main/native/tensor_jni.h12
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TensorTest.java256
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java72
-rw-r--r--tensorflow/python/kernel_tests/decode_raw_op_test.py9
-rw-r--r--tensorflow/python/kernel_tests/denormal_test.py4
-rw-r--r--tensorflow/python/ops/control_flow_ops.py4
-rw-r--r--tensorflow/python/ops/ctc_ops.py4
-rw-r--r--tensorflow/python/ops/metrics_impl.py6
-rw-r--r--tensorflow/python/training/input.py2
-rw-r--r--tensorflow/python/util/deprecation.py5
-rw-r--r--tensorflow/stream_executor/dso_loader.cc12
-rw-r--r--tensorflow/tensorboard/plugins/projector/plugin.py2
-rw-r--r--tensorflow/tensorflow.bzl1
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh31
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh1
-rw-r--r--tensorflow/tools/graph_transforms/BUILD13
-rw-r--r--third_party/gpus/cuda_configure.bzl17
52 files changed, 984 insertions, 145 deletions
diff --git a/LICENSE b/LICENSE
index d3da228420..15ae421404 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,4 +1,4 @@
-Copyright 2015 The TensorFlow Authors. All rights reserved.
+Copyright 2017 The TensorFlow Authors. All rights reserved.
Apache License
Version 2.0, January 2004
@@ -188,7 +188,7 @@ Copyright 2015 The TensorFlow Authors. All rights reserved.
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright 2015, The TensorFlow Authors.
+ Copyright 2017, The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 9ef718ae05..9e556b6e4e 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -70,6 +70,22 @@ config_setting(
visibility = ["//visibility:public"],
)
+config_setting(
+ name = "debug",
+ values = {
+ "compilation_mode": "dbg",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "optimized",
+ values = {
+ "compilation_mode": "opt",
+ },
+ visibility = ["//visibility:public"],
+)
+
package_group(
name = "internal",
packages = ["//tensorflow/..."],
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 5ac9ec5681..64262fdce5 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -21,6 +21,7 @@ option(tensorflow_VERBOSE "Enable for verbose output" OFF)
option(tensorflow_ENABLE_GPU "Enable GPU support" OFF)
option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF)
option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON)
+option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF)
option(tensorflow_BUILD_CC_EXAMPLE "Build the C++ tutorial example" ON)
option(tensorflow_BUILD_PYTHON_BINDINGS "Build the Python bindings" ON)
option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
@@ -58,6 +59,7 @@ if(WIN32)
add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC -D__VERSION__=\"MSVC\")
add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN64 -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS)
add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH -D_ITERATOR_DEBUG_LEVEL=0)
+ add_definitions(-DEIGEN_VECTORIZE_SSE3) # Needed to suppress denormals without __SSE3__ in MSVC
add_definitions(-DNDEBUG /O2) # Equivalent of -c opt in Bazel.
add_definitions(/bigobj /nologo /EHsc /GF /FC /MP /Gm-)
# Suppress warnings to reduce build log size.
@@ -161,6 +163,7 @@ if (tensorflow_ENABLE_GPU)
# CUDA_NVCC_FLAGS and cuda_config.h below
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_30,code=\"sm_30,compute_30\";-gencode arch=compute_35,code=\"sm_35,compute_35\";-gencode arch=compute_52,code=\"sm_52,compute_52\")
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr)
+ set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero
set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include)
include_directories(${CUDA_INCLUDE})
add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.0,3.5,5.2)
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 1383e0d983..8e7f43b511 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -18,7 +18,7 @@ for instructions on how to install a pre-built TensorFlow package on Windows.
### Current known limitations
* It is not possible to load a custom Op library.
-* GCS and HDFS file systems are not supported.
+* GCS file system is not supported.
* The following Ops are not currently implemented:
- Dequantize
- QuantizeAndDequantize
diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake
index 320dfbf68c..970d87748e 100644
--- a/tensorflow/contrib/cmake/tf_core_cpu.cmake
+++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake
@@ -6,6 +6,8 @@ file(GLOB_RECURSE tf_core_cpu_srcs
"${tensorflow_source_dir}/tensorflow/cc/saved_model/*.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/*.h"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/*.cc"
+ "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/server_lib.h"
+ "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/server_lib.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/*.h"
"${tensorflow_source_dir}/tensorflow/core/graph/*.cc"
"${tensorflow_source_dir}/tensorflow/core/public/*.h"
diff --git a/tensorflow/contrib/cmake/tf_core_distributed_runtime.cmake b/tensorflow/contrib/cmake/tf_core_distributed_runtime.cmake
index b3c06d2c6a..ffa5710534 100644
--- a/tensorflow/contrib/cmake/tf_core_distributed_runtime.cmake
+++ b/tensorflow/contrib/cmake/tf_core_distributed_runtime.cmake
@@ -7,6 +7,7 @@ file(GLOB_RECURSE tf_core_distributed_runtime_srcs
)
file(GLOB_RECURSE tf_core_distributed_runtime_exclude_srcs
+ "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/server_lib.cc" # Build in tf_core_cpu instead.
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/*test*.h"
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/*test*.cc"
"${tensorflow_source_dir}/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc"
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index d5e02056c4..af0fadd0ce 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -158,6 +158,14 @@ if(tensorflow_ENABLE_SSL_SUPPORT)
list(APPEND tf_core_lib_srcs ${tf_core_platform_cloud_srcs})
endif()
+if (tensorflow_ENABLE_HDFS_SUPPORT)
+ list(APPEND tf_core_platform_hdfs_srcs
+ "${tensorflow_source_dir}/tensorflow/core/platform/hadoop/hadoop_file_system.cc"
+ "${tensorflow_source_dir}/tensorflow/core/platform/hadoop/hadoop_file_system.h"
+ )
+ list(APPEND tf_core_lib_srcs ${tf_core_platform_hdfs_srcs})
+endif()
+
file(GLOB_RECURSE tf_core_lib_test_srcs
"${tensorflow_source_dir}/tensorflow/core/lib/*test*.h"
"${tensorflow_source_dir}/tensorflow/core/lib/*test*.cc"
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 1705452909..68620dd7f9 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -169,6 +169,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# tensor_forest tests (also note that we exclude the hybrid tests for now)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order.
+ "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py" # Bad placement.
+ "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/topn_test.py" # Results inaccurate
)
endif()
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
diff --git a/tensorflow/contrib/cmake/tf_tools.cmake b/tensorflow/contrib/cmake/tf_tools.cmake
index c291e10170..2aaa9ed53e 100644
--- a/tensorflow/contrib/cmake/tf_tools.cmake
+++ b/tensorflow/contrib/cmake/tf_tools.cmake
@@ -18,10 +18,10 @@ target_link_libraries(${proto_text} PUBLIC
tf_protos_cc
)
-add_dependencies(${proto_text}
- tf_core_lib
- grpc
-)
+add_dependencies(${proto_text} tf_core_lib)
+if(tensorflow_ENABLE_GRPC_SUPPORT)
+ add_dependencies(${proto_text} grpc)
+endif(tensorflow_ENABLE_GRPC_SUPPORT)
file(GLOB_RECURSE tf_tools_transform_graph_lib_srcs
"${tensorflow_source_dir}/tensorflow/tools/graph_transforms/*.h"
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index eb908e7672..e73b20d187 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -1195,19 +1195,34 @@ def flatten(inputs,
Returns:
A flattened tensor with shape [batch_size, k].
Raises:
- ValueError: If inputs.dense_shape is wrong.
+ ValueError: If inputs rank is unknown or less than 2.
"""
with ops.name_scope(scope, 'Flatten', [inputs]) as sc:
inputs = ops.convert_to_tensor(inputs)
- inputs_shape = inputs.get_shape()
- inputs_rank = inputs_shape.ndims
+ inputs_rank = inputs.get_shape().ndims
if (inputs_rank is None) or (inputs_rank < 2):
raise ValueError('Inputs must have a least 2 dimensions.')
- dims = inputs_shape[1:]
- if not dims.is_fully_defined():
- raise ValueError('Inputs 2nd dimension must be defined.')
- k = dims.num_elements()
- outputs = array_ops.reshape(inputs, [-1, k])
+
+ inputs_shape = array_ops.shape(inputs)
+
+ batch_dim = array_ops.slice(inputs_shape, [0], [1])
+ spatial_dims = array_ops.slice(inputs_shape, [1], [inputs_rank - 1])
+
+ flat_spatial_dim = math_ops.reduce_prod(spatial_dims)
+ flat_spatial_dim = array_ops.expand_dims(flat_spatial_dim, 0)
+ flat_shape = array_ops.concat([batch_dim, flat_spatial_dim], 0)
+
+ outputs = array_ops.reshape(inputs, flat_shape)
+
+ # Attempt to propagate shape information, if it is defined.
+ input_shape = inputs.get_shape().as_list()
+ batch_dim, spatial_dims = input_shape[0], input_shape[1:]
+ if all(spatial_dims):
+ outputs.set_shape([batch_dim,
+ functools.reduce(lambda x, y: x * y, spatial_dims)])
+ else:
+ outputs.set_shape([batch_dim, None])
+
return utils.collect_named_outputs(outputs_collections, sc, outputs)
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 6043d4dc0e..5561ccd5f5 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1342,8 +1342,8 @@ class FlattenTest(test.TestCase):
with ops.Graph().as_default() as g, self.test_session(g):
inputs = array_ops.placeholder(dtype=dtypes.float32)
inputs.set_shape(tensor_shape.TensorShape((5, None)))
- with self.assertRaisesRegexp(ValueError, '2nd dimension must be defined'):
- _layers.flatten(inputs)
+ output = _layers.flatten(inputs)
+ self.assertEqual(output.get_shape().as_list(), [5, None])
def testCollectOutputs(self):
height, width = 3, 3
@@ -1386,6 +1386,17 @@ class FlattenTest(test.TestCase):
self.assertEqual(output.size, images.get_shape().num_elements())
self.assertEqual(output.shape[0], images.get_shape()[0])
+ def testUnknownDims(self):
+ height = width = depth = 3
+ with self.test_session() as sess:
+ images = random_ops.random_uniform(
+ (5, height, width, depth), seed=1, name='images')
+ inputs = array_ops.placeholder(dtypes.int32, (None, None, None, None))
+ output = _layers.flatten(inputs)
+ output = sess.run(output, {inputs: images.eval()})
+ self.assertEqual(output.size, images.get_shape().num_elements())
+ self.assertEqual(output.shape[0], images.get_shape()[0])
+
def _sparsify(array, threshold=0.5):
array[array < threshold] = 0
diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py
index f11e40e045..59bdea7293 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py
@@ -157,7 +157,7 @@ class DataSet(object):
def epochs_completed(self):
return self._epochs_completed
- def next_batch(self, batch_size, fake_data=False):
+ def next_batch(self, batch_size, fake_data=False, shuffle=True):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
fake_image = [1] * 784
@@ -169,21 +169,37 @@ class DataSet(object):
fake_label for _ in xrange(batch_size)
]
start = self._index_in_epoch
- self._index_in_epoch += batch_size
- if self._index_in_epoch > self._num_examples:
+ # Shuffle for the first epoch
+ if self._epochs_completed == 0 and start == 0 and shuffle:
+ perm0 = numpy.arange(self._num_examples)
+ numpy.random.shuffle(perm0)
+ self._images = self.images[perm0]
+ self._labels = self.labels[perm0]
+ # Go to the next epoch
+ if start + batch_size > self._num_examples:
# Finished epoch
self._epochs_completed += 1
+ # Get the rest examples in this epoch
+ rest_num_examples = self._num_examples - start
+ images_rest_part = self._images[start:self._num_examples]
+ labels_rest_part = self._labels[start:self._num_examples]
# Shuffle the data
- perm = numpy.arange(self._num_examples)
- numpy.random.shuffle(perm)
- self._images = self._images[perm]
- self._labels = self._labels[perm]
+ if shuffle:
+ perm = numpy.arange(self._num_examples)
+ numpy.random.shuffle(perm)
+ self._images = self.images[perm]
+ self._labels = self.labels[perm]
# Start next epoch
start = 0
- self._index_in_epoch = batch_size
- assert batch_size <= self._num_examples
- end = self._index_in_epoch
- return self._images[start:end], self._labels[start:end]
+ self._index_in_epoch = batch_size - rest_num_examples
+ end = self._index_in_epoch
+ images_new_part = self.images[start:end]
+ labels_new_part = self.labels[start:end]
+ return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
+ else:
+ self._index_in_epoch += batch_size
+ end = self._index_in_epoch
+ return self._images[start:end], self._labels[start:end]
def read_data_sets(train_dir,
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 3ac413ec08..cf20c5da99 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2029,7 +2029,7 @@ def streaming_concat(values,
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that appends the values of a tensor and returns the
- `value` of the concatenated tensors.
+ length of the concatenated axis.
This op allows for evaluating metrics that cannot be updated incrementally
using the same framework as other streaming metrics.
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index 3355f29894..18c97d75e5 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -39,6 +39,7 @@ from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary_iterator
+from tensorflow.python.training import input
from tensorflow.python.training import saver as saver_lib
FLAGS = flags.FLAGS
@@ -87,13 +88,13 @@ class EvaluationTest(test.TestCase):
self._labels)
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
- # Create Checkpoint and log directories
+ # Create checkpoint and log directories:
chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/')
gfile.MakeDirs(chkpt_dir)
logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
gfile.MakeDirs(logdir)
- # Save initialized variables to checkpoint directory
+ # Save initialized variables to a checkpoint directory:
saver = saver_lib.Saver()
with self.test_session() as sess:
init_op.run()
@@ -157,6 +158,33 @@ class EvaluationTest(test.TestCase):
'/non-existent-dir', timeout=0))
self.assertEqual(ret, [])
+ def testWithEpochLimit(self):
+ predictions_limited = input.limit_epochs(self._predictions, num_epochs=1)
+ labels_limited = input.limit_epochs(self._labels, num_epochs=1)
+
+ value_op, update_op = metric_ops.streaming_accuracy(
+ predictions_limited, labels_limited)
+
+ init_op = control_flow_ops.group(variables.global_variables_initializer(),
+ variables.local_variables_initializer())
+ # Create checkpoint and log directories:
+ chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/')
+ gfile.MakeDirs(chkpt_dir)
+ logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
+ gfile.MakeDirs(logdir)
+
+ # Save initialized variables to a checkpoint directory:
+ saver = saver_lib.Saver()
+ with self.test_session() as sess:
+ init_op.run()
+ saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))
+
+ # Now, run the evaluation loop:
+ accuracy_value = evaluation.evaluation_loop(
+ '', chkpt_dir, logdir, eval_op=update_op, final_op=value_op,
+ max_number_of_evaluations=1, num_evals=10000)
+ self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
+
class SingleEvaluationTest(test.TestCase):
diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py
index 2a3ded6d80..b6ba71a22f 100644
--- a/tensorflow/contrib/slim/python/slim/learning.py
+++ b/tensorflow/contrib/slim/python/slim/learning.py
@@ -788,12 +788,17 @@ def train(train_op,
sv.start_queue_runners(sess, [chief_queue_runner])
sess.run(init_tokens_op)
try:
- while not sv.should_stop():
- total_loss, should_stop = train_step_fn(sess, train_op, global_step,
- train_step_kwargs)
- if should_stop:
- logging.info('Stopping Training.')
- break
+ try:
+ while not sv.should_stop():
+ total_loss, should_stop = train_step_fn(sess, train_op, global_step,
+ train_step_kwargs)
+ if should_stop:
+ logging.info('Stopping Training.')
+ break
+ except errors.OutOfRangeError:
+ # OutOfRangeError is thrown when epoch limit per
+ # tf.train.limit_epochs is reached.
+ logging.info('Caught OutOfRangeError. Stopping Training.')
if logdir and sv.is_chief:
logging.info('Finished training! Saving model to disk.')
sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py
index 305cb9a3c4..eb57102c63 100644
--- a/tensorflow/contrib/slim/python/slim/learning_test.py
+++ b/tensorflow/contrib/slim/python/slim/learning_test.py
@@ -40,9 +40,9 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import input as input_lib
from tensorflow.python.training import saver as saver_lib
-
class ClipGradientNormsTest(test.TestCase):
def clip_values(self, arr):
@@ -888,6 +888,30 @@ class TrainTest(test.TestCase):
# be smaller.
self.assertGreater(losses[0], losses[1])
+ def testTrainWithEpochLimit(self):
+ logdir = os.path.join(tempfile.mkdtemp(prefix=self.get_temp_dir()),
+ 'tmp_logs')
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(0)
+ tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
+ tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
+ tf_inputs_limited = input_lib.limit_epochs(tf_inputs, num_epochs=300)
+ tf_labels_limited = input_lib.limit_epochs(tf_labels, num_epochs=300)
+
+ tf_predictions = LogisticClassifier(tf_inputs_limited)
+ loss_ops.log_loss(tf_predictions, tf_labels_limited)
+ total_loss = loss_ops.get_total_loss()
+
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+
+ train_op = learning.create_train_op(total_loss, optimizer)
+
+ loss = learning.train(train_op, logdir, log_every_n_steps=10)
+ self.assertIsNotNone(loss)
+ self.assertLess(loss, .015)
+ self.assertTrue(os.path.isfile('{}/model.ckpt-300.index'.format(logdir)))
+ self.assertTrue(os.path.isfile('{}/model.ckpt-300.data-00000-of-00001'.format(logdir)))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc
index 280c2dc71c..4247abcd71 100644
--- a/tensorflow/core/kernels/decode_raw_op.cc
+++ b/tensorflow/core/kernels/decode_raw_op.cc
@@ -69,12 +69,6 @@ class DecodeRawOp : public OpKernel {
context, context->allocate_output("output", out_shape, &output_tensor));
auto out = output_tensor->flat_inner_dims<T>();
DCHECK_EQ(flat_in.size(), out.dimensions()[0]);
- OP_REQUIRES(
- context,
- little_endian_ == ::tensorflow::port::kLittleEndian || sizeof(T) == 1,
- errors::Unimplemented("Unimplemented support for little_endian=",
- little_endian_ ? "true" : "false"));
- // Endianness matches, so just copy each string byte-for-byte.
T* out_data = out.data();
for (int64 i = 0; i < flat_in.size(); ++i) {
const T* in_data = reinterpret_cast<const T*>(flat_in(i).data());
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index d61e7b32de..f54c7b0cfd 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -3322,7 +3322,7 @@ x = [[[[1], [2], [3], [4]],
The output tensor has shape `[4, 2, 2, 1]` and value:
```prettyprint
-x = [[[[1], [3]], [[5], [7]]],
+x = [[[[1], [3]], [[9], [11]]],
[[[2], [4]], [[10], [12]]],
[[[5], [7]], [[13], [15]]],
[[[6], [8]], [[14], [16]]]]
@@ -3449,7 +3449,7 @@ x = [[[[1], [2], [3], [4]],
The output tensor has shape `[4, 2, 2, 1]` and value:
```prettyprint
-x = [[[[1], [3]], [[5], [7]]],
+x = [[[[1], [3]], [[9], [11]]],
[[[2], [4]], [[10], [12]]],
[[[5], [7]], [[13], [15]]],
[[[6], [8]], [[14], [16]]]]
@@ -3580,7 +3580,7 @@ x = [[[[1, 2, 3], [4, 5, 6]],
`crops = [[0, 0], [0, 0]]`:
```prettyprint
-x = [[[[1], [3]], [[5], [7]]],
+x = [[[[1], [3]], [[9], [11]]],
[[[2], [4]], [[10], [12]]],
[[[5], [7]], [[13], [15]]],
[[[6], [8]], [[14], [16]]]]
@@ -3698,7 +3698,7 @@ x = [[[[1, 2, 3], [4, 5, 6]],
(3) For the following input of shape `[4, 2, 2, 1]` and block_size of 2:
```prettyprint
-x = [[[[1], [3]], [[5], [7]]],
+x = [[[[1], [3]], [[9], [11]]],
[[[2], [4]], [[10], [12]]],
[[[5], [7]], [[13], [15]]],
[[[6], [8]], [[14], [16]]]]
diff --git a/tensorflow/core/ops/ctc_ops.cc b/tensorflow/core/ops/ctc_ops.cc
index 0b58a8d817..c94ce577c0 100644
--- a/tensorflow/core/ops/ctc_ops.cc
+++ b/tensorflow/core/ops/ctc_ops.cc
@@ -113,7 +113,7 @@ Performs greedy decoding on the logits given in inputs.
A note about the attribute merge_repeated: if enabled, when
consecutive logits' maximum indices are the same, only the first of
these is emitted. Labeling the blank '*', the sequence "A B B * B B"
-becomes "A B" if merge_repeated = True and "A B B B B" if
+becomes "A B B" if merge_repeated = True and "A B B B B" if
merge_repeated = False.
Regardless of the value of merge_repeated, if the maximum index of a given
diff --git a/tensorflow/core/platform/denormal.cc b/tensorflow/core/platform/denormal.cc
index 04079de05b..08e060bbbd 100644
--- a/tensorflow/core/platform/denormal.cc
+++ b/tensorflow/core/platform/denormal.cc
@@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/denormal.h"
-#ifdef __SSE3__
+#include "third_party/eigen3/Eigen/Core"
+// Check EIGEN_VECTORIZE_SSE3 since Windows doesn't define __SSE3__ properly
+#ifdef EIGEN_VECTORIZE_SSE3
#include <pmmintrin.h>
#endif
@@ -24,7 +26,7 @@ namespace port {
ScopedFlushDenormal::ScopedFlushDenormal() {
// For now, we flush denormals only on SSE 3. Other architectures such as ARM
// can be added as needed.
-#ifdef __SSE3__
+#ifdef EIGEN_VECTORIZE_SSE3
// Save existing flags
flush_zero_mode_ = _MM_GET_FLUSH_ZERO_MODE() == _MM_FLUSH_ZERO_ON;
denormals_zero_mode_ = _MM_GET_DENORMALS_ZERO_MODE() == _MM_DENORMALS_ZERO_ON;
@@ -38,7 +40,7 @@ ScopedFlushDenormal::ScopedFlushDenormal() {
}
ScopedFlushDenormal::~ScopedFlushDenormal() {
-#ifdef __SSE3__
+#ifdef EIGEN_VECTORIZE_SSE3
// Restore flags
_MM_SET_FLUSH_ZERO_MODE(flush_zero_mode_ ? _MM_FLUSH_ZERO_ON
: _MM_FLUSH_ZERO_OFF);
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index b0f0cbe3f1..16e401a54e 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/platform/posix/error.h"
#include "third_party/hadoop/hdfs.h"
+
namespace tensorflow {
template <typename R, typename... Args>
@@ -104,18 +105,23 @@ class LibHDFS {
// libhdfs.so won't be in the standard locations. Use the path as specified
// in the libhdfs documentation.
+#if defined(PLATFORM_WINDOWS)
+ const char *kLibHdfsDso = "hdfs.dll";
+#else
+ const char *kLibHdfsDso = "libhdfs.so";
+#endif
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
if (hdfs_home == nullptr) {
status_ = errors::FailedPrecondition(
"Environment variable HADOOP_HDFS_HOME not set");
return;
}
- string path = io::JoinPath(hdfs_home, "lib", "native", "libhdfs.so");
+ string path = io::JoinPath(hdfs_home, "lib", "native", kLibHdfsDso);
status_ = TryLoadAndBind(path.c_str(), &handle_);
if (!status_.ok()) {
// try load libhdfs.so using dynamic loader's search path in case libhdfs.so
// is installed in non-standard location
- status_ = TryLoadAndBind("libhdfs.so", &handle_);
+ status_ = TryLoadAndBind(kLibHdfsDso, &handle_);
}
return;
}
diff --git a/tensorflow/examples/tutorials/input_fn/boston_predict.csv b/tensorflow/examples/tutorials/input_fn/boston_predict.csv
index 27b017155c..cc757a4a7d 100644
--- a/tensorflow/examples/tutorials/input_fn/boston_predict.csv
+++ b/tensorflow/examples/tutorials/input_fn/boston_predict.csv
@@ -1,4 +1,4 @@
-6,9,CRIM,ZN,INDUS,NOX,RM,AGE,DIS,TAX,PTRATIO
+CRIM,ZN,INDUS,NOX,RM,AGE,DIS,TAX,PTRATIO
0.03359,75.0,2.95,0.428,7.024,15.8,5.4011,252,18.3
5.09017,0.0,18.1,0.713,6.297,91.8,2.3682,666,20.2
0.1265,25.0,5.13,0.453,6.762,43.4,7.9809,284,19.7
diff --git a/tensorflow/examples/tutorials/input_fn/boston_test.csv b/tensorflow/examples/tutorials/input_fn/boston_test.csv
index 00cd0c6bb3..769aee040c 100644
--- a/tensorflow/examples/tutorials/input_fn/boston_test.csv
+++ b/tensorflow/examples/tutorials/input_fn/boston_test.csv
@@ -1,4 +1,4 @@
-100,9,CRIM,ZN,INDUS,NOX,RM,AGE,DIS,TAX,PTRATIO,MEDV
+CRIM,ZN,INDUS,NOX,RM,AGE,DIS,TAX,PTRATIO,MEDV
0.13587,0.0,10.59,0.489,6.064,59.1,4.2392,277,18.6,24.4
0.08664,45.0,3.44,0.437,7.178,26.3,6.4798,398,15.2,36.4
0.26938,0.0,9.9,0.544,6.266,82.8,3.2628,304,18.4,21.6
diff --git a/tensorflow/examples/tutorials/input_fn/boston_train.csv b/tensorflow/examples/tutorials/input_fn/boston_train.csv
index 5d30ebd3a6..e675a26817 100644
--- a/tensorflow/examples/tutorials/input_fn/boston_train.csv
+++ b/tensorflow/examples/tutorials/input_fn/boston_train.csv
@@ -1,4 +1,4 @@
-400,9,CRIM,ZN,INDUS,NOX,RM,AGE,DIS,TAX,PTRATIO,MEDV
+CRIM,ZN,INDUS,NOX,RM,AGE,DIS,TAX,PTRATIO,MEDV
2.3004,0.0,19.58,0.605,6.319,96.1,2.1,403,14.7,23.8
13.3598,0.0,18.1,0.693,5.887,94.7,1.7821,666,20.2,12.7
0.12744,0.0,6.91,0.448,6.77,2.9,5.7209,233,17.9,26.6
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index ff78f151c3..75ea0b9c67 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -135,7 +135,7 @@ def train():
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', accuracy)
- # Merge all the summaries and write them out to /tmp/mnist_logs (by default)
+ # Merge all the summaries and write them out to /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
diff --git a/tensorflow/examples/udacity/5_word2vec.ipynb b/tensorflow/examples/udacity/5_word2vec.ipynb
index ec6413a0a3..9d4243d7ae 100644
--- a/tensorflow/examples/udacity/5_word2vec.ipynb
+++ b/tensorflow/examples/udacity/5_word2vec.ipynb
@@ -442,8 +442,8 @@
" embed = tf.nn.embedding_lookup(embeddings, train_dataset)\n",
" # Compute the softmax loss, using a sample of the negative labels each time.\n",
" loss = tf.reduce_mean(\n",
- " tf.nn.sampled_softmax_loss(softmax_weights, softmax_biases, embed,\n",
- " train_labels, num_sampled, vocabulary_size))\n",
+ " tf.nn.sampled_softmax_loss(weights=softmax_weights, biases=softmax_biases, inputs=embed,\n",
+ " labels=train_labels, num_sampled=num_sampled, num_classes=vocabulary_size))\n",
"\n",
" # Optimizer.\n",
" # Note: The optimizer will optimize the softmax_weights AND the embeddings.\n",
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index c4c9bdfcaf..16f2686114 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -923,6 +923,8 @@ Note also that `bazel test` will not always properly resolve dependencies
through these symlinks, so test results may be unreliable. A workaround is to
remove the `_python_build` directory before running `bazel test`.
+One more thing is that `python setup.py development` does not install the packages listed in `REQUIRED_PACKAGES` part of `setup.py`. You might need to install them separately.
+
## Train your first TensorFlow neural net model
Start by cloning the [TensorFlow models repo](https://github.com/tensorflow/models) from GitHub. Run the following commands:
diff --git a/tensorflow/g3doc/how_tos/distributed/index.md b/tensorflow/g3doc/how_tos/distributed/index.md
index bce6af6f80..880976ca8d 100644
--- a/tensorflow/g3doc/how_tos/distributed/index.md
+++ b/tensorflow/g3doc/how_tos/distributed/index.md
@@ -144,7 +144,7 @@ applying gradients).
A common training configuration, called "data parallelism," involves multiple
tasks in a `worker` job training the same model on different mini-batches of
-data, updating shared parameters hosted in a one or more tasks in a `ps`
+data, updating shared parameters hosted in one or more tasks in a `ps`
job. All tasks typically run on different machines. There are many ways to
specify this structure in TensorFlow, and we are building libraries that will
simplify the work of specifying a replicated model. Possible approaches include:
diff --git a/tensorflow/g3doc/how_tos/language_bindings/index.md b/tensorflow/g3doc/how_tos/language_bindings/index.md
index b32c91354c..89d7d25162 100644
--- a/tensorflow/g3doc/how_tos/language_bindings/index.md
+++ b/tensorflow/g3doc/how_tos/language_bindings/index.md
@@ -172,7 +172,7 @@ are added to the graph and used as input to the op being instantiated.
If the language allows for optional parameters to a function (like keyword
arguments with defaults in Python), use them for optional attributes, operation
-names, devices, control inputs etc. In some langauges, these optional parameters
+names, devices, control inputs etc. In some languages, these optional parameters
can be set using dynamic scopes (like "with" blocks in Python). Without these
features, the library may resort to the "builder pattern", as is done in the C++
version of the TensorFlow API.
diff --git a/tensorflow/g3doc/tutorials/estimators/index.md b/tensorflow/g3doc/tutorials/estimators/index.md
index ed605f0915..01a2ed803c 100644
--- a/tensorflow/g3doc/tutorials/estimators/index.md
+++ b/tensorflow/g3doc/tutorials/estimators/index.md
@@ -380,7 +380,7 @@ with the `input_from_feature_columns()` function in
[tf.contrib.layers](../../api_docs/python/contrib.layers.md#layers-contrib).
```python
-input layer = tf.contrib.layers.input_from_feature_columns(
+input_layer = tf.contrib.layers.input_from_feature_columns(
columns_to_tensors=features, feature_columns=[age, height, weight])
```
diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
index 9b6caf8358..1d94d6f5b1 100644
--- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
@@ -54,8 +54,8 @@ What we will accomplish in this tutorial:
- Create a function that is a model for recognizing digits, based on looking at
every pixel in the image
-- Use Tensorflow to train the model to recognize digits by having it "look" at
- thousands of examples (and run our first Tensorflow session to do so)
+- Use TensorFlow to train the model to recognize digits by having it "look" at
+ thousands of examples (and run our first TensorFlow session to do so)
- Check the model's accuracy with our test data
@@ -223,7 +223,7 @@ More compactly, we can just write:
$$y = \text{softmax}(Wx + b)$$
-Now let's turn that into something that Tensorflow can use.
+Now let's turn that into something that TensorFlow can use.
## Implementing the Regression
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 134c173b62..80fb500f30 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -96,6 +96,7 @@ java_test(
test_class = "org.tensorflow.TensorTest",
deps = [
":tensorflow",
+ ":testutil",
"//external:junit",
],
)
@@ -118,6 +119,7 @@ cc_binary(
# symbols from the library. This reduces the size of the library
# considerably (~50% as of January 2017).
linkopts = select({
+ "//tensorflow:debug": [], # Disable all custom linker options in debug mode
"//tensorflow:darwin": [
"-Wl,-exported_symbols_list", # This line must be directly followed by LINKER_EXPORTED_SYMBOLS
LINKER_EXPORTED_SYMBOLS,
diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md
index 7edd109cfd..31c77903ad 100644
--- a/tensorflow/java/README.md
+++ b/tensorflow/java/README.md
@@ -32,9 +32,14 @@ Java bindings for TensorFlow.
## Installation
-Build the Java Archive (JAR) and native library:
+Configure and build the Java Archive (JAR) and native library:
```sh
+# Configure the build (e.g. GPU support etc.), as per
+# https://www.tensorflow.org/get_started/os_setup#configure_the_installation
+./configure
+
+# Build the JAR and native library
bazel build -c opt \
//tensorflow/java:tensorflow \
//tensorflow/java:libtensorflow_jni
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 5478bb85e9..ebb930a869 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -16,6 +16,14 @@ limitations under the License.
package org.tensorflow;
import java.lang.reflect.Array;
+import java.nio.BufferOverflowException;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
import java.util.Arrays;
/**
@@ -34,6 +42,145 @@ import java.util.Arrays;
* }</pre>
*/
public final class Tensor implements AutoCloseable {
+
+ /**
+ * Gets the size (in bytes) of the tensor data.
+ */
+ public int numBytes() {
+ return buffer().remaining();
+ }
+
+ /**
+ * Gets the total number of elements in the tensor.
+ *
+ * @return the number of elements in a flattened (1-D) view of the tensor.
+ */
+ public int numElements() {
+ switch(dataType()) {
+ case FLOAT:
+ case DOUBLE:
+ case INT32:
+ case INT64:
+ case BOOL:
+ return numBytes() / elemByteSize(dataType());
+ case STRING:
+ default:
+ throw new UnsupportedOperationException("unsupported operation on a tensor of type "
+ + dataType());
+ }
+ }
+
+ /**
+ * Write the tensor data into the given buffer.
+ *
+ * <p>This method copies {@code numElements()} elements to the buffer.
+ *
+ * <p>This method may be used to read tensor data of type {@link DataType#INT32}.
+ *
+ * @param dst the destination buffer
+ *
+ * @throws BufferOverflowException
+ * If there is insufficient space in the given buffer
+ * for the data in this tensor
+ * @throws IllegalArgumentException
+ * If the tensor datatype is not {@link DataType#INT32}
+ */
+ public void writeTo(IntBuffer dst) {
+ if(dtype != DataType.INT32) {
+ throw incompatibleBuffer(dst, dtype);
+ }
+ ByteBuffer src = buffer();
+ dst.put(src.asIntBuffer());
+ }
+
+ /**
+ * Write the tensor data into the given buffer.
+ *
+ * <p>This method copies {@code numElements()} elements to the buffer.
+ *
+ * <p>This method may be used to read tensor data of type {@link DataType#FLOAT}.
+ *
+ * @param dst the destination buffer
+ *
+ * @throws BufferOverflowException
+ * If there is insufficient space in the given buffer
+ * for the data in this tensor
+ * @throws IllegalArgumentException
+ * If the tensor datatype is not {@link DataType#FLOAT}
+ */
+ public void writeTo(FloatBuffer dst) {
+ if(dtype != DataType.FLOAT) {
+ throw incompatibleBuffer(dst, dtype);
+ }
+ ByteBuffer src = buffer();
+ dst.put(src.asFloatBuffer());
+ }
+
+ /**
+ * Write the tensor data into the given buffer.
+ *
+ * <p>This method copies {@code numElements()} elements to the buffer.
+ *
+ * <p>This method may be used to read tensor data of type {@link DataType#DOUBLE}.
+ *
+ * @param dst the destination buffer
+ *
+ * @throws BufferOverflowException
+ * If there is insufficient space in the given buffer
+ * for the data in this tensor
+ * @throws IllegalArgumentException
+ * If the tensor datatype is not {@link DataType#DOUBLE}
+ */
+ public void writeTo(DoubleBuffer dst) {
+ if(dtype != DataType.DOUBLE) {
+ throw incompatibleBuffer(dst, dtype);
+ }
+ ByteBuffer src = buffer();
+ dst.put(src.asDoubleBuffer());
+ }
+
+ /**
+ * Write the tensor data into the given buffer.
+ *
+ * <p>This method copies {@code numElements()} elements to the buffer.
+ *
+ * <p>This method may be used to read tensor data of type {@link DataType#INT64}.
+ *
+ * @param dst the destination buffer
+ *
+ * @throws BufferOverflowException
+ * If there is insufficient space in the given buffer
+ * for the data in this tensor
+ * @throws IllegalArgumentException
+ * If the tensor datatype is not {@link DataType#INT64}
+ */
+ public void writeTo(LongBuffer dst) {
+ if(dtype != DataType.INT64) {
+ throw incompatibleBuffer(dst, dtype);
+ }
+ ByteBuffer src = buffer();
+ dst.put(src.asLongBuffer());
+ }
+
+ /**
+ * Write the tensor data into the given buffer.
+ *
+ * <p>This method copies {@code byteSize()} bytes to the buffer.
+ *
+ * <p>This method may be used to read tensor data of any type. Note that
+ * primitive data is in native byte order.
+ *
+ * @param dst the destination buffer
+ *
+ * @throws BufferOverflowException
+ * If there is insufficient space in the given buffer
+ * for the data in this tensor
+ */
+ public void writeTo(ByteBuffer dst) {
+ ByteBuffer src = buffer();
+ dst.put(src);
+ }
+
/**
* Create a Tensor from a Java object.
*
@@ -71,7 +218,8 @@ public final class Tensor implements AutoCloseable {
t.shapeCopy = new long[numDimensions(obj)];
fillShape(obj, 0, t.shapeCopy);
if (t.dtype != DataType.STRING) {
- t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy);
+ int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
+ t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
setValue(t.nativeHandle, obj);
} else if (t.shapeCopy.length != 0) {
throw new UnsupportedOperationException(
@@ -85,6 +233,188 @@ public final class Tensor implements AutoCloseable {
}
/**
+ * Create an {@link DataType#INT32} Tensor with data from the given buffer.
+ *
+ * <p>Creates a Tensor with the given shape by copying elements
+ * from the buffer (starting from its current position) into the tensor.
+ * For example, if {@code shape = {2,3} } (which represents a 2x3 matrix)
+ * then the buffer must have 6 elements remaining, which will be
+ * consumed by this method.
+ *
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ *
+ * @throws IllegalArgumentException
+ * If the tensor shape is not compatible with the buffer
+ */
+ public static Tensor create(long[] shape, IntBuffer data) {
+ if(data.remaining() != numElements(shape)) {
+ throw incompatibleBuffer(data.remaining(), shape);
+ }
+ int elemSize = elemByteSize(DataType.INT32);
+ Tensor t = createHelper(DataType.INT32, shape, elemSize * data.remaining());
+ try {
+ ByteBuffer dst = t.buffer();
+ dst.asIntBuffer().put(data);
+ return t;
+ } catch(RuntimeException e) {
+ delete(t.nativeHandle);
+ throw e;
+ }
+ }
+
+ /**
+ * Create a {@link DataType#FLOAT} Tensor with data from the given buffer.
+ *
+ * <p>Creates a Tensor with the given shape by copying elements
+ * from the buffer (starting from its current position) into the tensor.
+ * For example, if {@code shape = {2,3} } (which represents a 2x3 matrix)
+ * then the buffer must have 6 elements remaining, which will be
+ * consumed by this method.
+ *
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ *
+ * @throws IllegalArgumentException
+ * If the tensor shape is not compatible with the buffer
+ */
+ public static Tensor create(long[] shape, FloatBuffer data) {
+ if(data.remaining() != numElements(shape)) {
+ throw incompatibleBuffer(data.remaining(), shape);
+ }
+ int elemSize = elemByteSize(DataType.FLOAT);
+ Tensor t = createHelper(DataType.FLOAT, shape, elemSize * data.remaining());
+ try {
+ ByteBuffer dst = t.buffer();
+ dst.asFloatBuffer().put(data);
+ return t;
+ } catch(RuntimeException e) {
+ delete(t.nativeHandle);
+ throw e;
+ }
+ }
+
+ /**
+ * Create a {@link DataType#DOUBLE} Tensor with data from the given buffer.
+ *
+ * <p>Creates a Tensor with the given shape by copying elements
+ * from the buffer (starting from its current position) into the tensor.
+ * For example, if {@code shape = {2,3} } (which represents a 2x3 matrix)
+ * then the buffer must have 6 elements remaining, which will be
+ * consumed by this method.
+ *
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ *
+ * @throws IllegalArgumentException
+ * If the tensor shape is not compatible with the buffer
+ */
+ public static Tensor create(long[] shape, DoubleBuffer data) {
+ if(data.remaining() != numElements(shape)) {
+ throw incompatibleBuffer(data.remaining(), shape);
+ }
+ int elemSize = elemByteSize(DataType.DOUBLE);
+ Tensor t = createHelper(DataType.DOUBLE, shape, elemSize * data.remaining());
+ try {
+ ByteBuffer dst = t.buffer();
+ dst.asDoubleBuffer().put(data);
+ return t;
+ } catch(RuntimeException e) {
+ delete(t.nativeHandle);
+ throw e;
+ }
+ }
+
+ /**
+ * Create an {@link DataType#INT64} Tensor with data from the given buffer.
+ *
+ * <p>Creates a Tensor with the given shape by copying elements
+ * from the buffer (starting from its current position) into the tensor.
+ * For example, if {@code shape = {2,3} } (which represents a 2x3 matrix)
+ * then the buffer must have 6 elements remaining, which will be
+ * consumed by this method.
+ *
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ *
+ * @throws IllegalArgumentException
+ * If the tensor shape is not compatible with the buffer
+ */
+ public static Tensor create(long[] shape, LongBuffer data) {
+ if(data.remaining() != numElements(shape)) {
+ throw incompatibleBuffer(data.remaining(), shape);
+ }
+ int elemSize = elemByteSize(DataType.INT64);
+ Tensor t = createHelper(DataType.INT64, shape, elemSize * data.remaining());
+ try {
+ ByteBuffer dst = t.buffer();
+ dst.asLongBuffer().put(data);
+ return t;
+ } catch(RuntimeException e) {
+ delete(t.nativeHandle);
+ throw e;
+ }
+ }
+
+ /**
+ * Create a Tensor with data from the given buffer.
+ *
+ * <p>Supports all datatypes. Note that primitive data must be in native byte order.
+ * Tensors of type {@link DataType#STRING} must be encoded as per the C API.
+ *
+ * <p>Creates a Tensor with the given shape by copying elements
+ * from the buffer (starting from its current position) into the tensor.
+ * For example, if {@code shape = {2,3} } (which represents a 2x3 matrix)
+ * then the buffer must have (6 elements x the byte size per element) bytes remaining,
+ * which will be consumed by this method.
+ *
+ * @param dataType the tensor datatype.
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ *
+ * @throws IllegalArgumentException
+ * If the tensor datatype or shape is not compatible with the buffer
+ */
+ public static Tensor create(DataType dataType, long[] shape, ByteBuffer data) {
+ switch(dataType) {
+ case FLOAT:
+ case DOUBLE:
+ case INT32:
+ case INT64:
+ case BOOL:
+ int elemSize = elemByteSize(dataType);
+ if(elemSize * numElements(shape) != data.remaining()) {
+ throw new IllegalArgumentException(String.format(
+ "byte buffer with %d bytes is not compatible with a Tensor of type %s with shape %s",
+ data.remaining(), dataType, Arrays.toString(shape)));
+ }
+ break;
+ case STRING:
+ default:
+ // not all types are checked
+ break;
+ }
+
+ Tensor t = createHelper(dataType, shape, data.remaining());
+ try {
+ ByteBuffer dst = t.buffer();
+ dst.put(data);
+ return t;
+ } catch(RuntimeException e) {
+ delete(t.nativeHandle);
+ throw e;
+ }
+ }
+
+ private static Tensor createHelper(DataType dataType, long[] shape, long byteSize) {
+ Tensor t = new Tensor();
+ t.dtype = dataType;
+ t.shapeCopy = Arrays.copyOf(shape, shape.length);
+ t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
+ return t;
+ }
+
+ /**
* Release resources associated with the Tensor.
*
* <p><b>WARNING:</b>If not invoked, memory will be leaked.
@@ -237,6 +567,46 @@ public final class Tensor implements AutoCloseable {
private Tensor() {}
+ private ByteBuffer buffer() {
+ return buffer(nativeHandle).order(ByteOrder.nativeOrder());
+ }
+
+ private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
+ return new IllegalArgumentException(String.format(
+ "cannot use %s with Tensor of type %s",
+ buf.getClass().getName(), dataType));
+ }
+
+ private static IllegalArgumentException incompatibleBuffer(int numElements, long[] shape) {
+ return new IllegalArgumentException(String.format(
+ "buffer with %d elements is not compatible with a Tensor with shape %s",
+ numElements, Arrays.toString(shape)));
+ }
+
+ private static int numElements(long[] shape) {
+ // assumes a fully-known shape
+ int n = 1;
+ for(int i = 0; i < shape.length; i++) {
+ n *= shape[i];
+ }
+ return n;
+ }
+
+ private static int elemByteSize(DataType dataType) {
+ switch(dataType) {
+ case FLOAT:
+ case INT32:
+ return 4;
+ case DOUBLE:
+ case INT64:
+ return 8;
+ case BOOL:
+ return 1;
+ default:
+ throw new IllegalArgumentException("unsupported DataType " + dataType);
+ }
+ }
+
private static DataType dataTypeOf(Object o) {
if (o.getClass().isArray()) {
if (Array.getLength(o) == 0) {
@@ -317,12 +687,14 @@ public final class Tensor implements AutoCloseable {
}
}
- private static native long allocate(int dtype, long[] shape);
+ private static native long allocate(int dtype, long[] shape, long byteSize);
private static native long allocateScalarBytes(byte[] value);
private static native void delete(long handle);
+ private static native ByteBuffer buffer(long handle);
+
private static native int dtype(long handle);
private static native long[] shape(long handle);
diff --git a/tensorflow/java/src/main/native/tensor_jni.cc b/tensorflow/java/src/main/native/tensor_jni.cc
index c98d6807ac..27897d2d12 100644
--- a/tensorflow/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/java/src/main/native/tensor_jni.cc
@@ -218,23 +218,14 @@ size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
jclass clazz,
jint dtype,
- jlongArray shape) {
- size_t elem_size = elemByteSize(static_cast<TF_DataType>(dtype));
- if (elem_size == 0) {
- throwException(env, kIllegalArgumentException,
- "cannot allocate Tensor with DataType %d", dtype);
- return 0;
- }
+ jlongArray shape,
+ jlong sizeInBytes) {
int num_dims = static_cast<int>(env->GetArrayLength(shape));
jlong* dims = nullptr;
if (num_dims > 0) {
jboolean is_copy;
dims = env->GetLongArrayElements(shape, &is_copy);
}
- size_t num_elems = 1;
- for (int i = 0; i < num_dims; ++i) {
- num_elems *= dims[i];
- }
static_assert(sizeof(jlong) == sizeof(int64_t),
"Java long is not compatible with the TensorFlow C API");
// On some platforms "jlong" is a "long" while "int64_t" is a "long long".
@@ -250,7 +241,7 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
dims_copy[i] = static_cast<int64_t>(dims[i]);
}
TF_Tensor* t = TF_AllocateTensor(static_cast<TF_DataType>(dtype), dims_copy,
- num_dims, elem_size * num_elems);
+ num_dims, static_cast<size_t>(sizeInBytes));
delete[] dims_copy;
if (dims != nullptr) {
env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
@@ -303,6 +294,17 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
}
+JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ TF_Tensor* t = requireHandle(env, handle);
+ if (t == nullptr) return nullptr;
+ void* data = TF_TensorData(t);
+ const size_t sz = TF_TensorByteSize(t);
+
+ return env->NewDirectByteBuffer(data, static_cast<jlong>(sz));
+}
+
JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
diff --git a/tensorflow/java/src/main/native/tensor_jni.h b/tensorflow/java/src/main/native/tensor_jni.h
index ea0dfc819e..70850d250b 100644
--- a/tensorflow/java/src/main/native/tensor_jni.h
+++ b/tensorflow/java/src/main/native/tensor_jni.h
@@ -25,10 +25,10 @@ extern "C" {
/*
* Class: org_tensorflow_Tensor
* Method: allocate
- * Signature: (I[J)J
+ * Signature: (I[JJ)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv *, jclass,
- jint, jlongArray);
+ jint, jlongArray, jlong);
/*
* Class: org_tensorflow_Tensor
@@ -48,6 +48,14 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv *, jclass,
/*
* Class: org_tensorflow_Tensor
+ * Method: buffer
+ * Signature: (J)Ljava/nio/ByteBuffer;
+ */
+JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv *, jclass,
+ jlong);
+
+/*
+ * Class: org_tensorflow_Tensor
* Method: dtype
* Signature: (J)I
*/
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
index ec1c8551a7..1fd7774345 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
@@ -24,23 +24,273 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import java.nio.*;
+
/** Unit tests for {@link org.tensorflow.Tensor}. */
@RunWith(JUnit4.class)
public class TensorTest {
+ private static final double EPSILON = 1e-7;
+ private static final float EPSILON_F = 1e-7f;
+
+ @Test
+ public void createWithByteBuffer() {
+ double[] doubles = {1d, 2d, 3d, 4d};
+ long[] doubles_shape = {4};
+ boolean[] bools = {true, false, true, false};
+ long[] bools_shape = {4};
+ byte[] bools_ = TestUtil.bool2byte(bools);
+ byte[] strings = "test".getBytes();
+ long[] strings_shape = {};
+ byte[] strings_; // raw TF_STRING
+ try(Tensor t = Tensor.create(strings)) {
+ ByteBuffer to = ByteBuffer.allocate(t.numBytes());
+ t.writeTo(to);
+ strings_ = to.array();
+ }
+
+ // validate creating a tensor using a byte buffer
+ {
+ try(Tensor t = Tensor.create(DataType.BOOL, bools_shape, ByteBuffer.wrap(bools_))) {
+ boolean[] actual = new boolean[bools_.length];
+ assertEquals(bools[0], t.copyTo(actual)[0]);
+ }
+
+ // note: the buffer is expected to contain raw TF_STRING (as per C API)
+ try(Tensor t = Tensor.create(DataType.STRING, strings_shape, ByteBuffer.wrap(strings_))) {
+ assertArrayEquals(strings, t.bytesValue());
+ }
+ }
+
+ // validate creating a tensor using a direct byte buffer (in host order)
+ {
+ ByteBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder());
+ buf.asDoubleBuffer().put(doubles);
+ try(Tensor t = Tensor.create(DataType.DOUBLE, doubles_shape, buf)) {
+ double[] actual = new double[doubles.length];
+ assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
+ }
+ }
+
+ // validate shape checking
+ try(Tensor t = Tensor.create(DataType.BOOL, new long[bools_.length * 2], ByteBuffer.wrap(bools_))) {
+ fail("should have failed on incompatible buffer");
+ }
+ catch(IllegalArgumentException e) {
+ // expected
+ }
+ }
+
+ @Test
+ public void createWithTypedBuffer() {
+ int[] ints = {1, 2, 3, 4};
+ float[] floats = {1f, 2f, 3f, 4f};
+ double[] doubles = {1d, 2d, 3d, 4d};
+ long[] longs = {1L, 2L, 3L, 4L};
+ long[] shape = {4};
+
+ // validate byte order conversion
+ {
+ DoubleBuffer buf = ByteBuffer.allocate(8 * doubles.length)
+ .order(ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN)
+ .asDoubleBuffer()
+ .put(doubles);
+ buf.flip();
+ try(Tensor t = Tensor.create(shape, buf)) {
+ double[] actual = new double[doubles.length];
+ assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
+ }
+ }
+
+ // validate creating a tensor using a typed buffer
+ {
+ try(Tensor t = Tensor.create(shape, DoubleBuffer.wrap(doubles))) {
+ double[] actual = new double[doubles.length];
+ assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
+ }
+ try(Tensor t = Tensor.create(shape, FloatBuffer.wrap(floats))) {
+ float[] actual = new float[floats.length];
+ assertArrayEquals(floats, t.copyTo(actual), EPSILON_F);
+ }
+ try(Tensor t = Tensor.create(shape, IntBuffer.wrap(ints))) {
+ int[] actual = new int[ints.length];
+ assertArrayEquals(ints, t.copyTo(actual));
+ }
+ try(Tensor t = Tensor.create(shape, LongBuffer.wrap(longs))) {
+ long[] actual = new long[longs.length];
+ assertArrayEquals(longs, t.copyTo(actual));
+ }
+ }
+
+ // validate shape-checking
+ {
+ try(Tensor t = Tensor.create(new long[doubles.length + 1], DoubleBuffer.wrap(doubles))) {
+ fail("should have failed on incompatible buffer");
+ }
+ catch(IllegalArgumentException e) {
+ // expected
+ }
+ try(Tensor t = Tensor.create(new long[floats.length + 1], FloatBuffer.wrap(floats))) {
+ fail("should have failed on incompatible buffer");
+ }
+ catch(IllegalArgumentException e) {
+ // expected
+ }
+ try(Tensor t = Tensor.create(new long[ints.length + 1], IntBuffer.wrap(ints))) {
+ fail("should have failed on incompatible buffer");
+ }
+ catch(IllegalArgumentException e) {
+ // expected
+ }
+ try(Tensor t = Tensor.create(new long[longs.length + 1], LongBuffer.wrap(longs))) {
+ fail("should have failed on incompatible buffer");
+ }
+ catch(IllegalArgumentException e) {
+ // expected
+ }
+ }
+ }
+
+ @Test
+ public void writeTo() {
+ int[] ints = {1, 2, 3};
+ float[] floats = {1f, 2f, 3f};
+ double[] doubles = {1d, 2d, 3d};
+ long[] longs = {1L, 2L, 3L};
+ boolean[] bools = {true,false,true};
+
+ try(Tensor tints = Tensor.create(ints);
+ Tensor tfloats = Tensor.create(floats);
+ Tensor tdoubles = Tensor.create(doubles);
+ Tensor tlongs = Tensor.create(longs);
+ Tensor tbools = Tensor.create(bools)) {
+
+ // validate that any datatype is readable with ByteBuffer (content, position)
+ {
+ ByteBuffer bbuf = ByteBuffer.allocate(1024).order(ByteOrder.nativeOrder());
+
+ bbuf.clear(); // FLOAT
+ tfloats.writeTo(bbuf);
+ assertEquals(tfloats.numBytes(), bbuf.position());
+ bbuf.flip();
+ assertEquals(floats[0], bbuf.asFloatBuffer().get(0), EPSILON);
+ bbuf.clear(); // DOUBLE
+ tdoubles.writeTo(bbuf);
+ assertEquals(tdoubles.numBytes(), bbuf.position());
+ bbuf.flip();
+ assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON);
+ bbuf.clear(); // INT32
+ tints.writeTo(bbuf);
+ assertEquals(tints.numBytes(), bbuf.position());
+ bbuf.flip();
+ assertEquals(ints[0], bbuf.asIntBuffer().get(0));
+ bbuf.clear(); // INT64
+ tlongs.writeTo(bbuf);
+ assertEquals(tlongs.numBytes(), bbuf.position());
+ bbuf.flip();
+ assertEquals(longs[0], bbuf.asLongBuffer().get(0));
+ bbuf.clear(); // BOOL
+ tbools.writeTo(bbuf);
+ assertEquals(tbools.numBytes(), bbuf.position());
+ bbuf.flip();
+ assertEquals(bools[0], bbuf.get(0) != 0);
+ }
+
+ // validate the use of direct buffers
+ {
+ DoubleBuffer buf = ByteBuffer.allocateDirect(tdoubles.numBytes())
+ .order(ByteOrder.nativeOrder()).asDoubleBuffer();
+ tdoubles.writeTo(buf);
+ assertTrue(buf.isDirect());
+ assertEquals(tdoubles.numElements(), buf.position());
+ assertEquals(doubles[0], buf.get(0), EPSILON);
+ }
+
+ // validate typed buffers (content, position)
+ {
+ FloatBuffer buf = FloatBuffer.allocate(tfloats.numElements());
+ tfloats.writeTo(buf);
+ assertEquals(tfloats.numElements(), buf.position());
+ assertEquals(floats[0], buf.get(0), EPSILON);
+ }
+ {
+ DoubleBuffer buf = DoubleBuffer.allocate(tdoubles.numElements());
+ tdoubles.writeTo(buf);
+ assertEquals(tdoubles.numElements(), buf.position());
+ assertEquals(doubles[0], buf.get(0), EPSILON);
+ }
+ {
+ IntBuffer buf = IntBuffer.allocate(tints.numElements());
+ tints.writeTo(buf);
+ assertEquals(tints.numElements(), buf.position());
+ assertEquals(ints[0], buf.get(0));
+ }
+ {
+ LongBuffer buf = LongBuffer.allocate(tlongs.numElements());
+ tlongs.writeTo(buf);
+ assertEquals(tlongs.numElements(), buf.position());
+ assertEquals(longs[0], buf.get(0));
+ }
+
+ // validate byte order conversion
+ {
+ DoubleBuffer foreignBuf = ByteBuffer.allocate(tdoubles.numBytes())
+ .order(ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN)
+ .asDoubleBuffer();
+ tdoubles.writeTo(foreignBuf);
+ foreignBuf.flip();
+ double[] actual = new double[foreignBuf.remaining()];
+ foreignBuf.get(actual);
+ assertArrayEquals(doubles, actual, EPSILON);
+ }
+
+ // validate that incompatible buffers are rejected
+ {
+ IntBuffer badbuf1 = IntBuffer.allocate(128);
+ try {
+ tbools.writeTo(badbuf1);
+ fail("should have failed on incompatible buffer");
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+ FloatBuffer badbuf2 = FloatBuffer.allocate(128);
+ try {
+ tbools.writeTo(badbuf2);
+ fail("should have failed on incompatible buffer");
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+ DoubleBuffer badbuf3 = DoubleBuffer.allocate(128);
+ try {
+ tbools.writeTo(badbuf3);
+ fail("should have failed on incompatible buffer");
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+ LongBuffer badbuf4 = LongBuffer.allocate(128);
+ try {
+ tbools.writeTo(badbuf4);
+ fail("should have failed on incompatible buffer");
+ } catch (IllegalArgumentException e) {
+ // expected
+ }
+ }
+ }
+ }
+
@Test
public void scalars() {
try (Tensor t = Tensor.create(2.718f)) {
assertEquals(DataType.FLOAT, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
- assertEquals(2.718f, t.floatValue(), 0);
+ assertEquals(2.718f, t.floatValue(), EPSILON_F);
}
try (Tensor t = Tensor.create(3.1415)) {
assertEquals(DataType.DOUBLE, t.dataType());
assertEquals(0, t.numDimensions());
assertEquals(0, t.shape().length);
- assertEquals(3.1415, t.doubleValue(), 0);
+ assertEquals(3.1415, t.doubleValue(), EPSILON);
}
try (Tensor t = Tensor.create(-33)) {
@@ -82,7 +332,7 @@ public class TensorTest {
assertArrayEquals(new long[] {3}, t.shape());
double[] got = new double[3];
- assertArrayEquals(vector, t.copyTo(got), 0);
+ assertArrayEquals(vector, t.copyTo(got), EPSILON);
}
int[][] matrix = {{1, 2, 3}, {4, 5, 6}};
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index 67d456202f..265e21203b 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -15,6 +15,8 @@ limitations under the License.
package org.tensorflow;
+import java.lang.reflect.Array;
+
/** Static utility functions. */
public class TestUtil {
public static Output constant(Graph g, String name, Object value) {
@@ -49,4 +51,74 @@ public class TestUtil {
public static void transpose_A_times_X(Graph g, int[][] a) {
matmul(g, "Y", constant(g, "A", a), placeholder(g, "X", DataType.INT32), true, false);
}
+
+ /**
+ * Counts the total number of elements in an ND array.
+ * @param array the array to count the elements of
+ * @return the number of elements
+ */
+ public static int flattenedNumElements(Object array) {
+ int count = 0;
+ for (int i = 0; i < Array.getLength(array); i++) {
+ Object e = Array.get(array, i);
+ if(!e.getClass().isArray()) {
+ count += 1;
+ }
+ else {
+ count += flattenedNumElements(e);
+ }
+ }
+ return count;
+ }
+
+ /**
+ * Flattens an ND-array into a 1D-array with the same elements.
+ * @param array the array to flatten
+ * @param elementType the element class (e.g. {@code Integer.TYPE} for an {@code int[]})
+ * @return a flattened array
+ */
+ public static Object flatten(Object array, Class<?> elementType) {
+ Object out = Array.newInstance(elementType, flattenedNumElements(array));
+ flatten(array, out, 0);
+ return out;
+ }
+
+ private static int flatten(Object array, Object out, int next) {
+ for (int i = 0; i < Array.getLength(array); i++) {
+ Object e = Array.get(array, i);
+ if(!e.getClass().isArray()) {
+ Array.set(out, next++, e);
+ }
+ else {
+ next = flatten(e, out, next);
+ }
+ }
+ return next;
+ }
+
+ /**
+ * Converts a {@code boolean[]} to a {@code byte[]}.
+ *
+ * <p>Suitable for creating tensors of type {@link DataType#BOOL} using {@link java.nio.ByteBuffer}.
+ */
+ public static byte[] bool2byte(boolean[] array) {
+ byte[] out = new byte[array.length];
+ for(int i = 0; i< array.length; i++) {
+ out[i] = array[i] ? (byte) 1 : (byte) 0;
+ }
+ return out;
+ }
+
+ /**
+ * Converts a {@code byte[]} to a {@code boolean[]}.
+ *
+ * <p>Suitable for reading tensors of type {@link DataType#BOOL} using {@link java.nio.ByteBuffer}.
+ */
+ public static boolean[] byte2bool(byte[] array) {
+ boolean[] out = new boolean[array.length];
+ for(int i = 0; i< array.length; i++) {
+ out[i] = array[i] != 0;
+ }
+ return out;
+ }
}
diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py
index 472808c8f9..cd7216c527 100644
--- a/tensorflow/python/kernel_tests/decode_raw_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+import sys
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -53,8 +54,12 @@ class DecodeRawOpTest(test.TestCase):
self.assertEqual([None, None], decode.get_shape().as_list())
result = decode.eval(feed_dict={in_bytes: ["AaBC"]})
- self.assertAllEqual(
- [[ord("A") + ord("a") * 256, ord("B") + ord("C") * 256]], result)
+ if sys.byteorder == "big":
+ self.assertAllEqual(
+ [[ord("A") * 256 + ord("a"), ord("B") * 256 + ord("C")]], result)
+ else:
+ self.assertAllEqual(
+ [[ord("A") + ord("a") * 256, ord("B") + ord("C") * 256]], result)
with self.assertRaisesOpError(
"Input to DecodeRaw has length 3 that is not a multiple of 2, the "
diff --git a/tensorflow/python/kernel_tests/denormal_test.py b/tensorflow/python/kernel_tests/denormal_test.py
index 7047c25555..f3b1a8768f 100644
--- a/tensorflow/python/kernel_tests/denormal_test.py
+++ b/tensorflow/python/kernel_tests/denormal_test.py
@@ -22,7 +22,6 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import control_imports
from tensorflow.python.platform import test
@@ -35,9 +34,6 @@ class DenormalTest(test.TestCase):
self.assertEqual(tiny, tiny / 16 * 16)
def _flushDenormalsTest(self, use_gpu, dtypes):
- if control_imports.USE_OSS:
- # TODO(irving): Fix denormal flushing for open source.
- return
with self.test_session(use_gpu=use_gpu):
array_ops.identity(7).eval()
for dtype in dtypes:
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index aa34446f26..bdb65e72a3 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1703,9 +1703,9 @@ def cond(pred, fn1, fn2, name=None):
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
```
- If x < y, the `tf.add` operation will be executed and tf.square
+ If x < y, the `tf.add` operation will be executed and `tf.square`
operation will not be executed. Since z is needed for at least one
- branch of the cond, the tf.mul operation is always executed, unconditionally.
+ branch of the cond, the `tf.multiply` operation is always executed, unconditionally.
Although this behavior is consistent with the dataflow model of TensorFlow,
it has occasionally surprised some users who expected a lazier semantics.
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index bea3f6f448..1ce5597e13 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -183,8 +183,8 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
only the first of these is emitted. The sequence `A B B * B * B` (where '*'
is the blank label) becomes
- * `A B` if `merge_repeated=True`.
- * `A B B B B B` if `merge_repeated=False`.
+ * `A B B B` if `merge_repeated=True`.
+ * `A B B B B` if `merge_repeated=False`.
Args:
inputs: 3-D `float` `Tensor` sized
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 33ce2b8b92..55aa425524 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -271,7 +271,8 @@ def mean(values, weights=None, metrics_collections=None,
num_values = math_ops.reduce_sum(weights)
update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
- update_count_op = state_ops.assign_add(count, num_values)
+ with ops.control_dependencies([values]):
+ update_count_op = state_ops.assign_add(count, num_values)
mean_t = _safe_div(total, count, 'value')
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
@@ -983,7 +984,8 @@ def mean_tensor(values, weights=None, metrics_collections=None,
num_values = math_ops.multiply(num_values, weights)
update_total_op = state_ops.assign_add(total, values)
- update_count_op = state_ops.assign_add(count, num_values)
+ with ops.control_dependencies([values]):
+ update_count_op = state_ops.assign_add(count, num_values)
def compute_mean(total, count, name):
non_zero_count = math_ops.maximum(count,
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index f535c692a6..042e68df76 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -260,7 +260,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
range_tensor = math_ops.range(limit)
return input_producer(
range_tensor, [], num_epochs, shuffle, seed, capacity,
- shared_name, name, "fraction_of_%d_full" % capacity)
+ shared_name, "fraction_of_%d_full" % capacity, name)
def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index c4afed649e..d09476a680 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -63,8 +63,9 @@ def _call_location():
if frame:
# CPython internals are available, use them for performance.
# walk back two frames to get to deprecated function caller.
- frame = frame.f_back
- frame = frame.f_back
+ first_frame = frame.f_back
+ second_frame = first_frame.f_back
+ frame = second_frame if second_frame else first_frame
return '%s:%d' % (frame.f_code.co_filename, frame.f_lineno)
else:
# Slow fallback path
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
index 9aa50e976f..db1f8d9ba9 100644
--- a/tensorflow/stream_executor/dso_loader.cc
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -78,10 +78,20 @@ string GetCudnnVersion() { return TF_CUDNN_VERSION; }
GetCudaDriverLibraryPath()),
dso_handle);
#else
- return GetDsoHandle(
+ port::Status status = GetDsoHandle(
FindDsoPath(port::Env::Default()->FormatLibraryFileName("cuda", "1"),
GetCudaDriverLibraryPath()),
dso_handle);
+#if defined(__APPLE__)
+ // On Mac OS X, CUDA sometimes installs libcuda.dylib instead of
+ // libcuda.1.dylib.
+ return status.ok() ? status : GetDsoHandle(
+ FindDsoPath(port::Env::Default()->FormatLibraryFileName("cuda", ""),
+ GetCudaDriverLibraryPath()),
+ dso_handle);
+#else
+ return status;
+#endif
#endif
}
diff --git a/tensorflow/tensorboard/plugins/projector/plugin.py b/tensorflow/tensorboard/plugins/projector/plugin.py
index 8e701a305c..38b0d2076f 100644
--- a/tensorflow/tensorboard/plugins/projector/plugin.py
+++ b/tensorflow/tensorboard/plugins/projector/plugin.py
@@ -58,7 +58,7 @@ def _read_tensor_file(fpath):
tensor = []
for line in f:
if line:
- tensor.append(map(float, line.rstrip('\n').split('\t')))
+ tensor.append(list(map(float, line.rstrip('\n').split('\t'))))
return np.array(tensor, dtype='float32')
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 30f3501c6b..2fece4731a 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -112,6 +112,7 @@ def tf_copts():
"/DPLATFORM_WINDOWS",
"/DEIGEN_HAS_C99_MATH",
"/DTENSORFLOW_USE_EIGEN_THREADPOOL",
+ "/DEIGEN_VECTORIZE_SSE3", # To flush denormals without __SSE3__ set.
],
"//tensorflow:ios": ["-std=c++11"],
"//conditions:default": ["-pthread"]}))
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index f1c3f38812..0c86db7119 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -49,35 +49,6 @@ set -e
# Install Python 3.5 and dev library
apt-get install -y --no-install-recommends python3.5 libpython3.5-dev
-# Install pip3.4 and numpy for Python 3.4
-# This strange-looking install step is a stopgap measure to make the genrule
-# contrib/session_bundle/example:half_plus_two pass. The genrule calls Python
-# (via bazel) directly, but calls the wrong version of Python (3.4) because
-# bazel does not support specification of Python minor versions yet. So we
-# install numpy for Python3.4 here so that the genrule will at least not
-# complain about missing numpy. Once we upgrade to 16.04 for Python 3.5 builds,
-# this will no longer be necessary.
-set +e
-pip3_version=$(pip3 --version | grep "python 3.4")
-if [[ -z $pip3_version ]]; then
- set -e
- wget -q https://bootstrap.pypa.io/get-pip.py
- python3.4 get-pip.py
- rm -f get-pip.py
-fi
-
-NUMPY_VERSION="1.11.0"
-numpy_ver_flat=$(echo $NUMPY_VERSION | sed 's/\.//g' | sed 's/^0*//g')
-local_numpy_ver=$(python3 -c "import numpy; print(numpy.__version__)")
-local_numpy_ver_flat=$(echo $local_numpy_ver | sed 's/\.//g' | sed 's/^0*//g')
-if [[ -z $local_numpy_ver_flat ]]; then
- local_numpy_ver_flat=0
-fi
-if (( $local_numpy_ver_flat < $numpy_ver_flat )); then
- set -e
- pip3 install --upgrade numpy==${NUMPY_VERSION}
-fi
-
# Install pip3.5
set +e
pip35_version=$(pip3.5 --version | grep "python 3.5")
@@ -125,3 +96,5 @@ pip3.5 install wheel==0.29.0
pip3.5 install --upgrade pandas==0.18.1
pip3.5 install portpicker
+
+pip3.5 install werkzeug
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index 27bbb3d1ab..e775a5791f 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -100,6 +100,7 @@ exclude_gpu_cc_tests="${extra_failing_gpu_cc_tests} + ${exclude_cpu_cc_tests}"
function get_failing_cpu_py_tests() {
echo "
//$1/tensorflow/python:basic_session_run_hooks_test + \
+ //$1/tensorflow/python:bigquery_reader_ops_test + \
//$1/tensorflow/python:contrib_test + \
//$1/tensorflow/python:dequantize_op_test + \
//$1/tensorflow/python:directory_watcher_test + \
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 4b91da5079..509c8c310f 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
+ "if_not_windows",
"tf_copts",
"tf_cc_test",
"tf_py_test",
@@ -65,18 +66,19 @@ cc_library(
"fuse_convolutions.cc",
"insert_logging.cc",
"obsfucate_names.cc",
- "quantize_nodes.cc",
- "quantize_weights.cc",
"remove_attribute.cc",
"remove_device.cc",
"remove_nodes.cc",
"rename_attribute.cc",
"rename_op.cc",
- "round_weights.cc",
"set_device.cc",
"sort_by_execution_order.cc",
"strip_unused_nodes.cc",
- ],
+ ] + if_not_windows([
+ "quantize_nodes.cc",
+ "quantize_weights.cc",
+ "round_weights.cc",
+ ]),
hdrs = [
"fold_constants_lib.h",
],
@@ -91,8 +93,9 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
+ ] + if_not_windows([
"//tensorflow/core/kernels:quantized_ops",
- ],
+ ]),
alwayslink = 1,
)
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 1c008138ae..d58a32fb9d 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -249,7 +249,7 @@ _DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
-def _find_cuda_define(repository_ctx, cudnn_install_basedir, define):
+def _find_cuda_define(repository_ctx, cudnn_header_dir, define):
"""Returns the value of a #define in cudnn.h
Greps through cudnn.h and returns the value of the specified #define. If the
@@ -257,15 +257,14 @@ def _find_cuda_define(repository_ctx, cudnn_install_basedir, define):
Args:
repository_ctx: The repository context.
- cudnn_install_basedir: The install directory for cuDNN on the system.
+ cudnn_header_dir: The directory containing the cuDNN header.
define: The #define to search for.
Returns:
The value of the #define found in cudnn.h.
"""
- # Find cudnn.h and grep for the line defining CUDNN_MAJOR.
- cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
- cudnn_install_basedir)
+ # Confirm location of cudnn.h and grep for the line defining CUDNN_MAJOR.
+ cudnn_h_path = repository_ctx.path("%s/cudnn.h" % cudnn_header_dir)
if not cudnn_h_path.exists:
auto_configure_fail("Cannot find cudnn.h at %s" % str(cudnn_h_path))
result = repository_ctx.execute(["grep", "-E", define, str(cudnn_h_path)])
@@ -292,11 +291,13 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
Returns:
A string containing the version of cuDNN.
"""
- major_version = _find_cuda_define(repository_ctx, cudnn_install_basedir,
+ cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
+ cudnn_install_basedir)
+ major_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
_DEFINE_CUDNN_MAJOR)
- minor_version = _find_cuda_define(repository_ctx, cudnn_install_basedir,
+ minor_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
_DEFINE_CUDNN_MINOR)
- patch_version = _find_cuda_define(repository_ctx, cudnn_install_basedir,
+ patch_version = _find_cuda_define(repository_ctx, cudnn_header_dir,
_DEFINE_CUDNN_PATCHLEVEL)
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)