diff options
22 files changed, 156 insertions, 76 deletions
diff --git a/google/protobuf b/google/protobuf -Subproject bd8a476510d17d3841ff2509fbd67b7f4b543c1 +Subproject 0906f5d18a2548024b511eadcbb4cfc0ca56cd6 diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c5dfeb36b7..257280559b 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -29,7 +29,7 @@ # ":all_kernels" - The cpu-specific kernels, plus ":gpu_kernels" if # built with Cuda # ":tensorflow_opensource" - The complete open-source package, including -# ":kernels", ":core", and a Session implementation. +# ":all_kernels", ":core", and a Session implementation. # ":tensorflow" - "tensorflow_opensource" plus some Google-internal libraries. # ":testlib" - TensorFlow-specific test support, e.g. utilities for testing # kernels. @@ -457,19 +457,6 @@ tf_cuda_library( ], ) -# DEPRECATED: Use either ":all_kernels" and/or ":kernel_lib" instead. -# We need to get rid of this library before we can make the kernels -# directory its own package, due to name conflicts. -tf_cuda_library( - name = "kernels", - hdrs = glob(["kernels/**/*.h"]), - deprecation = "use ':all_kernels' or ':kernel_lib' instead", - visibility = ["//visibility:public"], - deps = [ - ":kernel_lib", - ], -) - tf_cuda_library( name = "tensorflow_opensource", copts = tf_copts(), @@ -1094,8 +1081,9 @@ tf_cc_tests( # TODO(opensource): fix "common_runtime/gpu/*_test.cc", # Run by tests below - "common_runtime/gpu/gpu_region_allocator_test.cc", + "common_runtime/gpu/gpu_allocator_retry_test.cc", "common_runtime/gpu/gpu_bfc_allocator_test.cc", + "common_runtime/gpu/gpu_region_allocator_test.cc", ], ), deps = [ @@ -1129,6 +1117,10 @@ tf_cc_tests( "user_ops/**/*_test.cc", "common_runtime/gpu/*_test.cc", ], + exclude = [ + # Run by tests below + "common_runtime/gpu/gpu_allocator_retry_test.cc", + ], ), deps = [ ":all_kernels", @@ -1149,6 +1141,29 @@ tf_cc_tests( ], ) +tf_cc_tests( + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + tests = ["common_runtime/gpu/gpu_allocator_retry_test.cc"], + deps = [ + ":all_kernels", + ":core_cpu", + ":core_cpu_internal", + ":direct_session", + ":framework", + ":framework_internal", + ":gpu_runtime", + ":kernel_lib", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/cc:cc_ops", + ], +) + # Test data filegroup( name = "image_testdata", diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index f98b7bb32b..d49d0d6005 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -108,6 +108,12 @@ class Allocator { // TracksAlloctionSizes is overridden to return true. virtual bool TracksAllocationSizes() { return false; } + // Returns true if this allocator requires tensors with 0 elements + // to allocate buffers. This is false for most allocators, but may + // be used by special-case allocators that want to track tensor + // usage. + virtual bool ShouldAllocateEmptyTensors() { return false; } + // Returns the user-requested size of the data allocated at // 'ptr'. Note that the actual buffer allocated might be larger // than requested, but this function returns the size requested by diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 8526d26148..8eee2126d9 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -381,7 +381,7 @@ void Tensor::CopyFromInternal(const Tensor& other, const TensorShape& shape) { Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape) : type_(type), shape_(shape), buf_(nullptr) { CHECK_NOTNULL(a); - if (shape_.num_elements() > 0) { + if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) { CASES(type, buf_ = new Buffer<T>(a, shape.num_elements())); } } @@ -390,7 +390,7 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape, const AllocationAttributes& allocation_attr) : type_(type), shape_(shape), buf_(nullptr) { CHECK_NOTNULL(a); - if (shape_.num_elements() > 0) { + if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) { CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr)); } } diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h index 577ed057f6..27ae3561e4 100644 --- a/tensorflow/core/kernels/reshape_op.h +++ b/tensorflow/core/kernels/reshape_op.h @@ -66,19 +66,22 @@ class ReshapeOp : public OpKernel { if (unknown_index != -1) { OP_REQUIRES( context, product > 0, - errors::InvalidArgument("cannot infer the missing input size for " - "an empty tensor unless all specified " + errors::InvalidArgument("Reshape cannot infer the missing input size " + "for an empty tensor unless all specified " "input sizes are non-zero")); const int32 missing = input.NumElements() / product; - OP_REQUIRES(context, product * missing == input.NumElements(), - errors::InvalidArgument("Input has ", input.NumElements(), - " values, which isn't divisible by ", - product)); + OP_REQUIRES( + context, product * missing == input.NumElements(), + errors::InvalidArgument( + "Input to reshape is a tensor with ", input.NumElements(), + " values, but the requested shape requires a multiple of ", + product)); shape.set_dim(unknown_index, missing); } OP_REQUIRES(context, shape.num_elements() == input.NumElements(), - errors::InvalidArgument("Input has ", input.NumElements(), - " values, which isn't the same as ", + errors::InvalidArgument("Input to reshape is a tensor with ", + input.NumElements(), + " values, but the requested shape has ", shape.num_elements())); // Actually produce the reshaped output. diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index 82026c3cb4..7ef04bee94 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -122,7 +122,7 @@ class TensorArray : public ResourceBase { mutex_lock l(mu_); values->clear(); values->resize(tensors_.size()); - for (int32 i = 0; i < tensors_.size(); ++i) { + for (std::size_t i = 0; i < tensors_.size(); ++i) { TF_RETURN_IF_ERROR(LockedRead(i, &(*values)[i])); } return Status::OK(); diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc index 27c00e912c..d5ea5e15a8 100644 --- a/tensorflow/core/kernels/topk_op.cc +++ b/tensorflow/core/kernels/topk_op.cc @@ -33,6 +33,7 @@ class TopK : public OpKernel { explicit TopK(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_)); if (num_inputs() < 2) { // k is an attr (TopK). + OP_DEPRECATED(context, 7, "Use TopKV2 instead"); OP_REQUIRES_OK(context, context->GetAttr("k", &k_)); } else { // k is an input (TopKV2), so we won't know it until Compute. k_ = -1; diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 5d0ea8b750..9ec1aa9b7f 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -60,8 +60,9 @@ limitations under the License. // 111635679, 7jan2016). // 5. Graphs are wholly-validated during Session::Create() (7jan2016). // 6. TensorFlow is scalar strict within Google (27jan2016). +// 7. Remove TopK in favor of TopKV2 (5feb2016). #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 6 +#define TF_GRAPH_DEF_VERSION 7 #endif // TENSORFLOW_CORE_PUBLIC_VERSION_H_ diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py index e736fcb467..7493b092fe 100644 --- a/tensorflow/models/image/cifar10/cifar10.py +++ b/tensorflow/models/image/cifar10/cifar10.py @@ -265,18 +265,10 @@ def loss(logits, labels): Returns: Loss tensor of type float. """ - # Reshape the labels into a dense Tensor of - # shape [batch_size, NUM_CLASSES]. - sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1]) - indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1]) - concated = tf.concat(1, [indices, sparse_labels]) - dense_labels = tf.sparse_to_dense(concated, - [FLAGS.batch_size, NUM_CLASSES], - 1.0, 0.0) - # Calculate the average cross entropy loss across the batch. - cross_entropy = tf.nn.softmax_cross_entropy_with_logits( - logits, dense_labels, name='cross_entropy_per_example') + labels = tf.cast(labels, tf.int64) + cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits, labels, name='cross_entropy_per_example') cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') tf.add_to_collection('losses', cross_entropy_mean) diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py index df0ca22063..edceb2a1ec 100644 --- a/tensorflow/models/image/mnist/convolutional.py +++ b/tensorflow/models/image/mnist/convolutional.py @@ -81,14 +81,13 @@ def extract_data(filename, num_images): def extract_labels(filename, num_images): - """Extract the labels into a 1-hot matrix [image index, label index].""" + """Extract the labels into a vector of int64 label IDs.""" print('Extracting', filename) with gzip.open(filename) as bytestream: bytestream.read(8) buf = bytestream.read(1 * num_images) - labels = numpy.frombuffer(buf, dtype=numpy.uint8) - # Convert to dense 1-hot representation. - return (numpy.arange(NUM_LABELS) == labels[:, None]).astype(numpy.float32) + labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64) + return labels def fake_data(num_images): @@ -96,19 +95,19 @@ def fake_data(num_images): data = numpy.ndarray( shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS), dtype=numpy.float32) - labels = numpy.zeros(shape=(num_images, NUM_LABELS), dtype=numpy.float32) + labels = numpy.zeros(shape=(num_images,), dtype=numpy.int64) for image in xrange(num_images): label = image % 2 data[image, :, :, 0] = label - 0.5 - labels[image, label] = 1.0 + labels[image] = label return data, labels def error_rate(predictions, labels): - """Return the error rate based on dense predictions and 1-hot labels.""" + """Return the error rate based on dense predictions and sparse labels.""" return 100.0 - ( 100.0 * - numpy.sum(numpy.argmax(predictions, 1) == numpy.argmax(labels, 1)) / + numpy.sum(numpy.argmax(predictions, 1) == labels) / predictions.shape[0]) @@ -146,8 +145,7 @@ def main(argv=None): # pylint: disable=unused-argument train_data_node = tf.placeholder( tf.float32, shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) - train_labels_node = tf.placeholder(tf.float32, - shape=(BATCH_SIZE, NUM_LABELS)) + train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE,)) eval_data = tf.placeholder( tf.float32, shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)) @@ -222,7 +220,7 @@ def main(argv=None): # pylint: disable=unused-argument # Training computation: logits + cross-entropy loss. logits = model(train_data_node, True) - loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( + loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( logits, train_labels_node)) # L2 regularization for the fully connected parameters. diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index b65ecf9aeb..2eeef95d99 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -563,30 +563,53 @@ def resize_images(images, _, height, width, depth = _ImageDimensions(images) - if width == new_width and height == new_height: + # Handle tensor-valued sizes as well as Python integers. + try: + new_width = ops.convert_to_tensor(new_width, dtypes.int32, + name='new_width') + new_width.get_shape().assert_has_rank(0) + except (TypeError, ValueError): + raise ValueError('new_width must be a scalar integer') + try: + new_height = ops.convert_to_tensor(new_height, dtypes.int32, + name='new_height') + new_height.get_shape().assert_has_rank(0) + except (TypeError, ValueError): + raise ValueError('new_height must be a scalar integer') + + new_width_const = tensor_util.constant_value(new_width) + new_height_const = tensor_util.constant_value(new_height) + + if width == new_width_const and height == new_height_const: if not is_batch: images = array_ops.squeeze(images, squeeze_dims=[0]) return images + new_size = array_ops.pack([new_height, new_width]) + if method == ResizeMethod.BILINEAR: images = gen_image_ops.resize_bilinear(images, - [new_height, new_width], + new_size, align_corners=align_corners) elif method == ResizeMethod.NEAREST_NEIGHBOR: images = gen_image_ops.resize_nearest_neighbor(images, - [new_height, new_width], + new_size, align_corners=align_corners) elif method == ResizeMethod.BICUBIC: images = gen_image_ops.resize_bicubic(images, - [new_height, new_width], + new_size, align_corners=align_corners) elif method == ResizeMethod.AREA: images = gen_image_ops.resize_area(images, - [new_height, new_width], + new_size, align_corners=align_corners) else: raise ValueError('Resize method is not implemented.') + # NOTE(mrry): The shape functions for the resize ops cannot unpack + # the packed values in `new_size`, so set the shape here. + images.set_shape([None, new_height_const, new_width_const, None]) + if not is_batch: images = array_ops.squeeze(images, squeeze_dims=[0]) return images @@ -779,6 +802,7 @@ ops.RegisterShape('AdjustContrastv2')( def _ResizeShape(op): """Shape function for the resize_bilinear and resize_nearest_neighbor ops.""" input_shape = op.inputs[0].get_shape().with_rank(4) + unused_size_shape = op.inputs[1].get_shape().merge_with([2]) size = tensor_util.constant_value(op.inputs[1]) if size is not None: height = size[0] diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index f44fe344e6..d09004556c 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -654,6 +654,58 @@ class ResizeImagesTest(test_util.TensorFlowTestCase): newshape = yshape.eval() self.assertAllEqual(single_shape, newshape) + def testTensorArguments(self): + img_shape = [1, 6, 4, 1] + single_shape = [6, 4, 1] + # This test is also conducted with int8, so 127 is the maximum + # value that can be used. + data = [127, 127, 64, 64, + 127, 127, 64, 64, + 64, 64, 127, 127, + 64, 64, 127, 127, + 50, 50, 100, 100, + 50, 50, 100, 100] + target_height = array_ops.placeholder(dtypes.int32) + target_width = array_ops.placeholder(dtypes.int32) + + img_np = np.array(data, dtype=np.uint8).reshape(img_shape) + + for opt in self.OPTIONS: + with self.test_session() as sess: + image = constant_op.constant(img_np, shape=img_shape) + y = image_ops.resize_images(image, target_height, target_width, opt) + yshape = array_ops.shape(y) + resized, newshape = sess.run([y, yshape], {target_height: 6, + target_width: 4}) + self.assertAllEqual(img_shape, newshape) + self.assertAllClose(resized, img_np, atol=1e-5) + + # Resizing with a single image must leave the shape unchanged also. + with self.test_session(): + img_single = img_np.reshape(single_shape) + image = constant_op.constant(img_single, shape=single_shape) + y = image_ops.resize_images(image, target_height, target_width, + self.OPTIONS[0]) + yshape = array_ops.shape(y) + newshape = yshape.eval(feed_dict={target_height: 6, target_width: 4}) + self.assertAllEqual(single_shape, newshape) + + # Incorrect shape. + with self.assertRaises(ValueError): + _ = image_ops.resize_images( + image, [12, 32], 4, image_ops.ResizeMethod.BILINEAR) + with self.assertRaises(ValueError): + _ = image_ops.resize_images( + image, 6, [12, 32], image_ops.ResizeMethod.BILINEAR) + + # Incorrect dtypes. + with self.assertRaises(ValueError): + _ = image_ops.resize_images( + image, 6.0, 4, image_ops.ResizeMethod.BILINEAR) + with self.assertRaises(ValueError): + _ = image_ops.resize_images( + image, 6, 4.0, image_ops.ResizeMethod.BILINEAR) + def testResizeDown(self): # This test is also conducted with int8, so 127 is the maximum # value that can be used. diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index b01d6d0cf9..06e625e695 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -614,12 +614,7 @@ def top_k(input, k=1, sorted=True, name=None): values: The `k` largest elements along each last dimensional slice. indices: The indices of `values` within the last dimension of `input`. """ - # TODO(irving): Always use v2 once the GraphDef mechanism is unstuck. - if isinstance(k, ops.Tensor): - op = gen_nn_ops._top_kv2 - else: - op = gen_nn_ops._top_k - return op(input, k=k, sorted=sorted, name=name) + return gen_nn_ops._top_kv2(input, k=k, sorted=sorted, name=name) # pylint: enable=invalid-name diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 2565078bb2..f98eec3b33 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -208,15 +208,6 @@ static string GetBinaryDir(bool strip_exe) { return exe_path; } -// Returns the location of the runfiles directory. -// This is the directory which "bazel run" sets as the current working directory -// before the program starts. -// N.B. This doesn't have to be running under "bazel run" in order to get the -// appropriate runfiles directory. -static string GetRunfilesDir() { - return port::StrCat(GetBinaryDir(false), ".runfiles"); -} - bool CUDAExecutor::GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel) { CUDAKernel *cuda_kernel = AsCUDAKernel(kernel); diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 21ed3a8369..c9bf1ac35a 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -260,7 +260,7 @@ def _py_wrap_cc_impl(ctx): ctx.action(executable=ctx.executable.swig_binary, arguments=args, mnemonic="PythonSwig", - inputs=list(set([src]) + cc_includes + ctx.files.swig_includes + + inputs=sorted(set([src]) + cc_includes + ctx.files.swig_includes + ctx.attr.swig_deps.files), outputs=outputs, progress_message="SWIGing {input}".format(input=src.path)) diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky index afc87201eb..ef31fc971b 100644 --- a/third_party/eigen3/Eigen/Cholesky +++ b/third_party/eigen3/Eigen/Cholesky @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/Cholesky" +#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Cholesky"
\ No newline at end of file diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core index 6527c9f80a..a330b6166f 100644 --- a/third_party/eigen3/Eigen/Core +++ b/third_party/eigen3/Eigen/Core @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/Core" +#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Core" diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues index d8c61044c5..30158ba1ea 100644 --- a/third_party/eigen3/Eigen/Eigenvalues +++ b/third_party/eigen3/Eigen/Eigenvalues @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/Eigenvalues" +#include "eigen-eigen-8cd7c2c6e9e1/Eigen/Eigenvalues" diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU index a290471c8d..5637771a51 100644 --- a/third_party/eigen3/Eigen/LU +++ b/third_party/eigen3/Eigen/LU @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/LU" +#include "eigen-eigen-8cd7c2c6e9e1/Eigen/LU" diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR index c67724defb..360ba8e5e3 100644 --- a/third_party/eigen3/Eigen/QR +++ b/third_party/eigen3/Eigen/QR @@ -1 +1 @@ -#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/Eigen/QR" +#include "eigen-eigen-8cd7c2c6e9e1/Eigen/QR" diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor index 6f97c57e43..eb293afd04 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor +++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor @@ -1 +1,2 @@ -#include "external/eigen_archive/eigen-eigen-8cd7c2c6e9e1/unsupported/Eigen/CXX11/Tensor" + +#include "eigen-eigen-8cd7c2c6e9e1/unsupported/Eigen/CXX11/Tensor" diff --git a/third_party/gpus/crosstool/CROSSTOOL b/third_party/gpus/crosstool/CROSSTOOL index 629dc32f30..dfde7cd216 100644 --- a/third_party/gpus/crosstool/CROSSTOOL +++ b/third_party/gpus/crosstool/CROSSTOOL @@ -144,4 +144,5 @@ toolchain { compiler_flag: "-fdata-sections" linker_flag: "-Wl,--gc-sections" } + linking_mode_flags { mode: DYNAMIC } } |