diff options
28 files changed, 676 insertions, 68 deletions
diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 932206137e..9525df9a4b 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -16,7 +16,7 @@ Installed version of CUDA and cuDNN: If installed from binary pip package, provide: 1. Which pip package you installed. -2. The output from python -c "import tensorflow; print(tensorflow.__version__)". +2. The output from `python -c "import tensorflow; print(tensorflow.__version__)"`. If installed from sources, provide the commit hash: diff --git a/grpc.BUILD b/grpc.BUILD index 8c00c90da9..f79cc8b610 100644 --- a/grpc.BUILD +++ b/grpc.BUILD @@ -145,6 +145,12 @@ cc_library( "include", ".", ], + defines = [ + "GPR_BACKWARDS_COMPATIBILITY_MODE", + ], + copts = [ + "-std=c99", + ], deps = [ ], ) diff --git a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj index 0156188577..b39b5d6b21 100644 --- a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj @@ -285,7 +285,7 @@ GCC_WARN_UNUSED_VARIABLE = YES; HEADER_SEARCH_PATHS = ( "$(SRCROOT)/../../makefile/gen/proto", - "$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f", + "$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30", "$(SRCROOT)/../../makefile/downloads", "$(SRCROOT)/../../makefile/downloads/protobuf/src/", "$(SRCROOT)/../../../..", @@ -300,6 +300,12 @@ OTHER_LDFLAGS = ( "-force_load", "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", + "-Xlinker", + "-S", + "-Xlinker", + "-x", + "-Xlinker", + "-dead_strip", ); PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -344,7 +350,7 @@ GCC_WARN_UNUSED_VARIABLE = YES; HEADER_SEARCH_PATHS = ( "$(SRCROOT)/../../makefile/gen/proto", - "$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f", + "$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30", "$(SRCROOT)/../../makefile/downloads", "$(SRCROOT)/../../makefile/downloads/protobuf/src/", "$(SRCROOT)/../../../..", @@ -359,6 +365,12 @@ OTHER_LDFLAGS = ( "-force_load", "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", + "-Xlinker", + "-S", + "-Xlinker", + "-x", + "-Xlinker", + "-dead_strip", ); PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample; PRODUCT_NAME = "$(TARGET_NAME)"; diff --git a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj index 91866cecac..ed2bd525f0 100644 --- a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj @@ -276,7 +276,7 @@ "$(SRCROOT)/../../../..", "$(SRCROOT)/../../makefile/downloads/protobuf/src/", "$(SRCROOT)/../../makefile/downloads", - "$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f", + "$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30", "$(SRCROOT)/../../makefile/gen/proto", ); INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; @@ -289,6 +289,12 @@ OTHER_LDFLAGS = ( "-force_load", "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", + "-Xlinker", + "-S", + "-Xlinker", + "-x", + "-Xlinker", + "-dead_strip", ); PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -304,7 +310,7 @@ "$(SRCROOT)/../../../..", "$(SRCROOT)/../../makefile/downloads/protobuf/src/", "$(SRCROOT)/../../makefile/downloads", - "$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f", + "$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30", "$(SRCROOT)/../../makefile/gen/proto", ); INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; @@ -314,10 +320,16 @@ "$(SRCROOT)/../../makefile/gen/protobuf_ios/lib", "$(SRCROOT)/../../makefile/gen/lib", ); - ONLY_ACTIVE_ARCH = NO; + ONLY_ACTIVE_ARCH = YES; OTHER_LDFLAGS = ( "-force_load", "$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a", + "-Xlinker", + "-S", + "-Xlinker", + "-x", + "-Xlinker", + "-dead_strip", ); PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test"; PRODUCT_NAME = "$(TARGET_NAME)"; diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py index acfb6b48a9..f23a4a3bea 100644 --- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py @@ -358,7 +358,7 @@ class DataFeeder(object): else: if self.n_classes > 1: if len(self.output_shape) == 2: - out.itemset((i, self.y[sample]), 1.0) + out.itemset((i, int(self.y[sample])), 1.0) else: for idx, value in enumerate(self.y[sample]): out.itemset(tuple([i, idx, value]), 1.0) diff --git a/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py b/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py index ae8c79ebcd..de7a315657 100644 --- a/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py @@ -54,22 +54,31 @@ def batch_normalize(tensor_in, initializer=init_ops.random_normal_initializer(1., 0.02)) beta = vs.get_variable("beta", [shape[-1]], initializer=init_ops.constant_initializer(0.)) - ema = moving_averages.ExponentialMovingAverage(decay=decay) - if convnet: - assign_mean, assign_var = nn.moments(tensor_in, [0, 1, 2]) - else: - assign_mean, assign_var = nn.moments(tensor_in, [0]) - ema_assign_op = ema.apply([assign_mean, assign_var]) - ema_mean, ema_var = ema.average(assign_mean), ema.average(assign_var) + moving_mean = vs.get_variable( + 'moving_mean', + shape=[shape[-1]], + initializer=init_ops.zeros_initializer, + trainable=False) + moving_var = vs.get_variable( + 'moving_var', + shape=[shape[-1]], + initializer=init_ops.ones_initializer, + trainable=False) def _update_mean_var(): """Internal function that updates mean and variance during training.""" - with ops.control_dependencies([ema_assign_op]): - return array_ops_.identity(assign_mean), array_ops_.identity(assign_var) + axis = [0, 1, 2] if convnet else [0] + mean, var = nn.moments(tensor_in, axis) + update_moving_mean = moving_averages.assign_moving_average( + moving_mean, mean, decay) + update_moving_var = moving_averages.assign_moving_average( + moving_var, var, decay) + with ops.control_dependencies([update_moving_mean, update_moving_var]): + return array_ops_.identity(mean), array_ops_.identity(var) is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING")) mean, variance = control_flow_ops.cond(is_training, _update_mean_var, - lambda: (ema_mean, ema_var)) + lambda: (moving_mean, moving_var)) return nn.batch_norm_with_global_normalization( tensor_in, mean, diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 984250ce83..86ad8f41c1 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -35,8 +35,8 @@ HOST_OBJDIR := $(MAKEFILE_DIR)/gen/host_obj/ HOST_BINDIR := $(MAKEFILE_DIR)/gen/host_bin/ HOST_GENDIR := $(MAKEFILE_DIR)/gen/host_obj/ -# Which Eigen version we're using. -EIGEN_HASH := d02e6a705c30 +# Find the current Eigen version name from the Bazel build file +EIGEN_HASH := $(shell cat eigen.BUILD | grep archive_dir | head -1 | cut -f3 -d- | cut -f1 -d\") # Settings for the host compiler. HOST_CXX := gcc @@ -56,9 +56,15 @@ HOST_LIBS := \ # If we're on Linux, also link in the dl library. ifeq ($(HOST_OS),LINUX) - HOST_LIBS += -ldl + HOST_LIBS += -ldl -lpthread endif +# If we're on a Pi, link in pthreads and dl +ifeq ($(HOST_OS),PI) + HOST_LIBS += -ldl -lpthread +endif + + # proto_text is a tool that converts protobufs into a form we can use more # compactly within TensorFlow. It's a bit like protoc, but is designed to # produce a much more minimal result so we can save binary space. @@ -125,13 +131,13 @@ ifeq ($(TARGET),LINUX) endif # If we're on Linux, also link in the dl library. ifeq ($(TARGET),LINUX) - LIBS += -ldl + LIBS += -ldl -lpthread endif # If we're cross-compiling for the Raspberry Pi, use the right gcc. ifeq ($(TARGET),PI) - CXX := arm-linux-gnueabihf-g++ - LDFLAGS := -L$(GENDIR)protobuf_pi/lib/ -Wl,--no-whole-archive - LIBS += -ldl + CXXFLAGS += -D__ANDROID_TYPES_SLIM__ + LDFLAGS := -Wl,--no-whole-archive + LIBS += -ldl -lpthread LIBFLAGS += -Wl,--allow-multiple-definition -Wl,--whole-archive endif @@ -169,12 +175,16 @@ ifeq ($(TARGET),IOS) -Wno-c++11-narrowing \ -mno-thumb \ -DTF_LEAN_BINARY \ + -D__ANDROID_TYPES_SLIM__ \ -DMIN_LOG_LEVEL=0 \ -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} LDFLAGS := -arch armv7 \ -miphoneos-version-min=${MIN_SDK_VERSION} \ + -Xlinker -S \ + -Xlinker -x \ + -Xlinker -dead_strip \ -all_load \ -L$(GENDIR)protobuf_ios/lib \ -lz @@ -186,6 +196,7 @@ ifeq ($(TARGET),IOS) -Wno-c++11-narrowing \ -mno-thumb \ -DTF_LEAN_BINARY \ + -D__ANDROID_TYPES_SLIM__ \ -DMIN_LOG_LEVEL=0 \ -fno-exceptions \ -isysroot \ @@ -205,6 +216,7 @@ ifeq ($(TARGET),IOS) -D__thread= \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ + -D__ANDROID_TYPES_SLIM__ \ -DMIN_LOG_LEVEL=0 \ -fno-exceptions \ -isysroot \ @@ -224,6 +236,7 @@ ifeq ($(TARGET),IOS) -D__thread= \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ + -D__ANDROID_TYPES_SLIM__ \ -DMIN_LOG_LEVEL=0 \ -fno-exceptions \ -isysroot \ @@ -243,6 +256,7 @@ ifeq ($(TARGET),IOS) -D__thread= \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ + -D__ANDROID_TYPES_SLIM__ \ -DMIN_LOG_LEVEL=0 \ -fno-exceptions \ -isysroot \ diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index f82b70187c..f7f46fca5c 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -42,7 +42,7 @@ at `tensorflow/contrib/makefile/gen/bin/benchmark`. To run the executable, use: tensorflow/contrib/makefile/gen/bin/benchmark --graph=tensorflow_inception_graph.pb ``` -You should download the example graph from [http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz). +You should download the example graph from [https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip). ## Supported Systems @@ -132,24 +132,31 @@ static library in a simple app. ## Raspberry Pi -The easiest way to build for the Raspberry Pi is to cross-compile from Linux. -To use this makefile to do that, you first need to install the right version of -the compiler to target the Pi, using a command like this on your Linux machine: +Building on the Raspberry Pi is similar to a normal Linux system, though we +recommend starting by compiling and installing protobuf: ```bash -sudo apt-get install g++-arm-linux-gnueabihf +cd tensorflow/contrib/makefile/downloads/protobuf/ +./autogen.sh +./configure +make +sudo make install +cd ../../../../.. ``` -After that, run `tensorflow/contrib/makefile/compile_pi_protobuf.sh` to build a -version of the protobuf library aimed at the Pi. Then you should be able to run: +Once that's done, you can use make to build the library and example: ```bash -make -f tensorflow/contrib/makefile/Makefile TARGET=PI +make -f tensorflow/contrib/makefile/Makefile HOST_OS=PI TARGET=PI OPTFLAGS="-Os" ``` -This will build the static library, and the example benchmark executable. You -can then copy the `tensorflow/contrib/makefile/gen/bin/benchmark` program over -to your Raspberry Pi, and run it there. +If you're only interested in building for Raspberry Pi's 2 and 3, you can supply +some extra optimization flags to give you code that will run faster: + +```bash +make -f tensorflow/contrib/makefile/Makefile HOST_OS=PI TARGET=PI \ +OPTFLAGS="-Os -mfpu=neon-vfpv4 -funsafe-math-optimizations -ftree-vectorize" +``` ## Dependencies diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 2ff9013804..3aad6270fe 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -18,7 +18,12 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads mkdir ${DOWNLOADS_DIR} -EIGEN_HASH=d02e6a705c30 +EIGEN_HASH=62a2305d5734 +if [ -f eigen.BUILD ]; then + # Grab the current Eigen version name from the Bazel build file + EIGEN_HASH=$(cat eigen.BUILD | grep archive_dir | head -1 | cut -f3 -d- | cut -f1 -d\") +fi + curl "https://bitbucket.org/eigen/eigen/get/${EIGEN_HASH}.tar.gz" \ -o /tmp/eigen-${EIGEN_HASH}.tar.gz tar xzf /tmp/eigen-${EIGEN_HASH}.tar.gz -C ${DOWNLOADS_DIR} diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index b16c9c860a..2f9714a37a 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -1937,4 +1937,141 @@ REGISTER_KERNELS(GPU, double); #undef REGISTER_CPU_KERNELS #undef REGISTER_KERNELS + +// Note, this op works on cpu only. +template <typename T, typename Tindex> +class SparseApplyRMSPropOp : public OpKernel { + public: + explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { + auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2}); + + Tensor var = ctx->mutable_input(0, use_exclusive_lock_); + Tensor ms = ctx->mutable_input(1, use_exclusive_lock_); + Tensor mom = ctx->mutable_input(2, use_exclusive_lock_); + + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(0))); + OP_REQUIRES( + ctx, ms.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(1))); + OP_REQUIRES( + ctx, mom.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", def().input(2))); + + const Tensor& lr = ctx->input(3); + const Tensor& rho = ctx->input(4); + const Tensor& momentum = ctx->input(5); + const Tensor& epsilon = ctx->input(6); + const Tensor& grad = ctx->input(7); + const Tensor& indices = ctx->input(8); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()), + errors::InvalidArgument("rho is not a scalar: ", + rho.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()), + errors::InvalidArgument("momentum is not a scalar: ", + momentum.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon.shape().DebugString())); + + OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()), + errors::InvalidArgument("var and ms do not have the same shape", + var.shape().DebugString(), " ", + ms.shape().DebugString())); + + OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()), + errors::InvalidArgument( + "var and mom do not have the same shape", + var.shape().DebugString(), " ", mom.shape().DebugString())); + + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and grad do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices must be one-dimensional")); + + const Tindex N = indices.dim_size(0); + OP_REQUIRES( + ctx, grad.dim_size(0) == N, + errors::InvalidArgument( + "grad must be the same size as indices in the first dimension.")); + + if (N > 0) { + const Tindex first_dim_size = var.dim_size(0); + // Validate all the indices are in range + auto indices_vec = indices.vec<Tindex>(); + for (Tindex i = 0; i < N; i++) { + const Tindex index = indices_vec(i); + OP_REQUIRES(ctx, index >= 0 && index < first_dim_size, + errors::InvalidArgument( + strings::StrCat("Index ", index, " at offset ", i, + " in indices is out of range"))); + } + + auto var_flat = var.flat_outer_dims<T>(); + auto ms_flat = ms.flat_outer_dims<T>(); + auto mom_flat = mom.flat_outer_dims<T>(); + auto grad_flat = grad.flat_outer_dims<T>(); + const T lr_scalar = lr.scalar<T>()(); + const T rho_scalar = rho.scalar<T>()(); + const T epsilon_scalar = epsilon.scalar<T>()(); + const T momentum_scalar = momentum.scalar<T>()(); + + for (Tindex i = 0; i < N; i++) { + const Tindex index = indices_vec(i); + + auto ms_ = ms_flat.template chip<0>(index); + auto mom_ = mom_flat.template chip<0>(index); + auto grad_ = grad_flat.template chip<0>(i); + + ms_ = ms_ * ms_.constant(rho_scalar) + + grad_.square() * grad_.constant(T(1) - rho_scalar); + mom_ = mom_ * mom_.constant(momentum_scalar) + + (ms_ + ms_.constant(epsilon_scalar)).rsqrt() * + ms_.constant(lr_scalar) * grad_; + + auto v = var_flat.template chip<0>(index); + v -= mom_; + } + } + + ctx->forward_ref_input_to_ref_output(0, 0); + } + + private: + bool use_exclusive_lock_; +}; + +#define REGISTER_KERNELS(T, Tindices) \ + REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .TypeConstraint<Tindices>("Tindices"), \ + SparseApplyRMSPropOp<T, Tindices>); + +REGISTER_KERNELS(Eigen::half, int32); +REGISTER_KERNELS(Eigen::half, int64); +REGISTER_KERNELS(float, int32); +REGISTER_KERNELS(float, int64); +REGISTER_KERNELS(double, int32); +REGISTER_KERNELS(double, int64); + +#undef REGISTER_KERNELS + + } // namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index b11264b4d4..33deb51e9c 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1715,7 +1715,7 @@ REGISTER_OP("ExtractImagePatches") .Attr("T: realnumbertype") .Attr(GetPaddingAttrString()) .Doc(R"doc( -Extract `patches` from `images` and puth them in the "depth" output dimension. +Extract `patches` from `images` and put them in the "depth" output dimension. images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`. patches: 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows * diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index 5eb011684b..eabec80c2e 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -441,6 +441,9 @@ REGISTER_OP("ApplyRMSProp") .Attr("use_locking: bool = false") .Doc(R"doc( Update '*var' according to the RMSProp algorithm. +Note that in dense implement of this algorithm, ms and mom will +update even if the grad is zero, but in this sparse implement, ms +and mom will not update in iterations the grad is zero. mean_square = decay * mean_square + (1-decay) * gradient ** 2 Delta = learning_rate * gradient / sqrt(mean_square + epsilon) @@ -461,5 +464,46 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less contention. )doc"); + +REGISTER_OP("SparseApplyRMSProp") + .Input("var: Ref(T)") + .Input("ms: Ref(T)") + .Input("mom: Ref(T)") + .Input("lr: T") + .Input("rho: T") + .Input("momentum: T") + .Input("epsilon: T") + .Input("grad: T") + .Input("indices: Tindices") + .Output("out: Ref(T)") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("use_locking: bool = false") + .Doc(R"doc( +Update '*var' according to the RMSProp algorithm. +Note that in dense implement of this algorithm, ms and mom will +update even if the grad is zero, but in this sparse implement, ms +and mom will not update in iterations the grad is zero. + +mean_square = decay * mean_square + (1-decay) * gradient ** 2 +Delta = learning_rate * gradient / sqrt(mean_square + epsilon) + +ms <- rho * ms_{t-1} + (1-rho) * grad * grad +mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) +var <- var - mom + +var: Should be from a Variable(). +ms: Should be from a Variable(). +mom: Should be from a Variable(). +lr: Scaling factor. Must be a scalar. +epsilon: Ridge term. Must be a scalar. +rho: Decay rate. Must be a scalar. +grad: The gradient. +indices: A vector of indices into the first dimension of var, ms and mom. +out: Same as "var". +use_locking: If `True`, updating of the var, m, and v tensors will be protected + by a lock; otherwise the behavior is undefined, but may exhibit less + contention. +)doc"); } // namespace tensorflow diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h index 8f71a617b6..02731d9275 100644 --- a/tensorflow/core/platform/platform.h +++ b/tensorflow/core/platform/platform.h @@ -37,6 +37,13 @@ limitations under the License. #define IS_MOBILE_PLATFORM #endif +#elif defined(__arm__) +#define PLATFORM_POSIX + +// Since there's no macro for the Raspberry Pi, assume we're on a mobile +// platform if we're compiling for the ARM CPU. +#define IS_MOBILE_PLATFORM + #else // If no platform specified, use: #define PLATFORM_POSIX diff --git a/tensorflow/examples/skflow/README.md b/tensorflow/examples/skflow/README.md index 756351c0b6..0d3d3cc1cf 100644 --- a/tensorflow/examples/skflow/README.md +++ b/tensorflow/examples/skflow/README.md @@ -33,6 +33,7 @@ Some examples use the `pandas` library for data processing (`sudo pip install pa ## Image classification * [Convolutional Neural Networks on MNIST Data](mnist.py) +* [Recurrent Neural Networks on MNIST Data](mnist_rnn.py) * [Deep Residual Networks on MNIST Data](resnet.py) diff --git a/tensorflow/examples/skflow/mnist_rnn.py b/tensorflow/examples/skflow/mnist_rnn.py new file mode 100644 index 0000000000..a6a594fad5 --- /dev/null +++ b/tensorflow/examples/skflow/mnist_rnn.py @@ -0,0 +1,78 @@ +# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This example builds rnn network for mnist data. +Borrowed structure from here: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3%20-%20Neural%20Networks/recurrent_network.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from sklearn import metrics, preprocessing + +import tensorflow as tf +from tensorflow.contrib import learn + +# Parameters +learning_rate = 0.1 +training_steps = 3000 +batch_size = 128 + +# Network Parameters +n_input = 28 # MNIST data input (img shape: 28*28) +n_steps = 28 # timesteps +n_hidden = 128 # hidden layer num of features +n_classes = 10 # MNIST total classes (0-9 digits) + +### Download and load MNIST data. +mnist = learn.datasets.load_dataset('mnist') + +X_train = mnist.train.images +y_train = mnist.train.labels +X_test = mnist.test.images +y_test = mnist.test.labels + +# It's useful to scale to ensure Stochastic Gradient Descent will do the right thing +scaler = preprocessing.StandardScaler() +X_train = scaler.fit_transform(X_train) +X_test = scaler.fit_transform(X_test) + + +def rnn_model(X, y): + X = tf.reshape(X, [-1, n_steps, n_input]) # (batch_size, n_steps, n_input) + # # permute n_steps and batch_size + X = tf.transpose(X, [1, 0, 2]) + # # Reshape to prepare input to hidden activation + X = tf.reshape(X, [-1, n_input]) # (n_steps*batch_size, n_input) + # # Split data because rnn cell needs a list of inputs for the RNN inner loop + X = tf.split(0, n_steps, X) # n_steps * (batch_size, n_input) + + # Define a GRU cell with tensorflow + lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden) + # Get lstm cell output + _, encoding = tf.nn.rnn(lstm_cell, X, dtype=tf.float32) + + return learn.models.logistic_regression(encoding, y) + + +classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=n_classes, + batch_size=batch_size, + steps=training_steps, + learning_rate=learning_rate) + +classifier.fit(X_train, y_train, logdir="/tmp/mnist_rnn") +score = metrics.accuracy_score(y_test, classifier.predict(X_test)) +print('Accuracy: {0:f}'.format(score)) diff --git a/tensorflow/g3doc/api_docs/python/contrib.metrics.md b/tensorflow/g3doc/api_docs/python/contrib.metrics.md index 863400eaa6..9600b9a523 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.metrics.md +++ b/tensorflow/g3doc/api_docs/python/contrib.metrics.md @@ -106,7 +106,7 @@ idempotent operation that simply divides `total` by `count`. To facilitate the estimation of the accuracy over a stream of data, the function utilizes two operations. First, an `is_correct` operation that computes a tensor whose shape matches `predictions` and whose elements are -set to 1.0 when the corresponding values of `predictions` and `labels match +set to 1.0 when the corresponding values of `predictions` and `labels` match and 0.0 otherwise. Second, an `update_op` operation whose behavior is dependent on the value of `weights`. If `weights` is None, then `update_op` increments `total` with the number of elements of `predictions` that match diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 1eceeaeee5..74cdc4b518 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -62,7 +62,7 @@ Install TensorFlow: # Ubuntu/Linux 64-bit, CPU only, Python 2.7: $ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4. +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4. # For other versions, see "Install from sources" below. $ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl @@ -77,7 +77,7 @@ For python3: # Ubuntu/Linux 64-bit, CPU only, Python 3.4: $ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4. +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4. # For other versions, see "Install from sources" below. $ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl @@ -137,7 +137,7 @@ $ source ~/tensorflow/bin/activate.csh # If using csh # Ubuntu/Linux 64-bit, CPU only, Python 2.7: (tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4. +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4. # For other versions, see "Install from sources" below. (tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl @@ -155,7 +155,7 @@ $ source ~/tensorflow/bin/activate.csh # If using csh # Ubuntu/Linux 64-bit, CPU only, Python 3.4: (tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4. +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4. # For other versions, see "Install from sources" below. (tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl @@ -228,7 +228,7 @@ $ source activate tensorflow # Ubuntu/Linux 64-bit, CPU only, Python 2.7: (tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4. +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4. # For other versions, see "Install from sources" below. (tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl @@ -245,7 +245,7 @@ $ source activate tensorflow # Ubuntu/Linux 64-bit, CPU only, Python 3.4: (tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4. +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4. # For other versions, see "Install from sources" below. (tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl @@ -314,7 +314,7 @@ $ docker run -it -p 8888:8888 gcr.io/tensorflow/tensorflow The option `-p 8888:8888` is used to publish the Docker container᾿s internal port to the host machine, in this case to ensure Jupyter notebook connection. -The format of the port mapping `hostPort:containerPort`. You can speficy any valid port number for the host port but has to be `8888` for the container port portion. +The format of the port mapping is `hostPort:containerPort`. You can specify any valid port number for the host port but have to use `8888` for the container port portion. If you're using a container with GPU support, some additional flags must be passed to expose the GPU device to the container. For the default config, we @@ -526,7 +526,7 @@ empty to use system default]: 7.5 Please specify the location where CUDA 7.5 toolkit is installed. Refer to README.md for more details. [default is: /usr/local/cuda]: /usr/local/cuda -Please specify the Cudnn version you want to use. [Leave empty to use system +Please specify the cuDNN version you want to use. [Leave empty to use system default]: 4.0.4 Please specify the location where the cuDNN 4.0.4 library is installed. Refer to @@ -549,7 +549,7 @@ Configuration finished This creates a canonical set of symbolic links to the Cuda libraries on your system. Every time you change the Cuda library paths you need to run this step again before -you invoke the bazel build command. For the Cudnn libraries, use '6.5' for R2, '7.0' +you invoke the bazel build command. For the cuDNN libraries, use '6.5' for R2, '7.0' for R3, and '4.0.4' for R4-RC. @@ -672,7 +672,7 @@ GPU support will be enabled for TensorFlow Please specify which gcc nvcc should use as the host compiler. [Default is /usr/bin/gcc]: Please specify the Cuda SDK version you want to use, e.g. 7.0. [Leave empty to use system default]: 7.5 Please specify the location where CUDA 7.5 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: -Please specify the Cudnn version you want to use. [Leave empty to use system default]: 5 +Please specify the cuDNN version you want to use. [Leave empty to use system default]: 5 Please specify the location where cuDNN 5 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]: Please specify a list of comma-separated Cuda compute capabilities you want to build with. You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus. diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 9c9c8d4675..c9c8ecb72c 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -173,5 +173,40 @@ class ReverseTest(test_util.TensorFlowTestCase): tf.reverse(data_2d_t, dims_3d_t) +class MeshgridTest(test_util.TensorFlowTestCase): + + def _compare(self, n, np_dtype, use_gpu): + inputs = [] + for i in range(n): + x = np.linspace(-10, 10, 5).astype(np_dtype) + if np_dtype in (np.complex64, np.complex128): + x += 1j + inputs.append(x) + + numpy_out = np.meshgrid(*inputs) + with self.test_session(use_gpu=use_gpu): + tf_out = array_ops.meshgrid(*inputs) + for X, _X in zip(numpy_out, tf_out): + self.assertAllEqual(X, _X.eval()) + + def testCompare(self): + for t in (np.float16, np.float32, np.float64, np.int32, np.int64, + np.complex64, np.complex128): + # Don't test the one-dimensional case, as + # old numpy versions don't support it + self._compare(2, t, False) + self._compare(3, t, False) + self._compare(4, t, False) + self._compare(5, t, False) + + # Test for inputs with rank not equal to 1 + x = [[1, 1], [1, 1]] + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "needs to have rank 1"): + with self.test_session(): + X, _ = array_ops.meshgrid(x, x) + X.eval() + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 13b3180b12..fd442c6eb8 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -38,6 +38,7 @@ of a tensor and change the shape of a tensor. @@reshape @@squeeze @@expand_dims +@@meshgrid ## Slicing and Joining @@ -125,7 +126,7 @@ def shape(input, name=None): else: return gen_array_ops.shape(input, name=name) - + def rank(input, name=None): """Returns the rank of a tensor. @@ -1047,6 +1048,82 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali raise ValueError("Unknown padding mode: %s" % mode) +def meshgrid(*args, **kwargs): + """Broadcasts parameters for evaluation on an N-D grid. + + Given N one-dimensional coordinate arrays `*args`, returns a list `outputs` + of N-D coordinate arrays for evaluating expressions on an N-D grid. + + Notes: + + `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions. + When the `indexing` argument is set to 'xy' (the default), the broadcasting + instructions for the first two dimensions are swapped. + + Examples: + + Calling `X, Y = meshgrid(x, y)` with the tensors + ```prettyprint + x = [1, 2, 3] + y = [4, 5, 6] + ``` + results in + ```prettyprint + X = [[1, 1, 1], + [2, 2, 2], + [3, 3, 3]] + Y = [[4, 5, 6], + [4, 5, 6], + [4, 5, 6]] + ``` + + Args: + *args: `Tensor`s with rank 1 + indexing: Either 'xy' or 'ij' (optional, default: 'xy') + name: A name for the operation (optional). + + Returns: + outputs: A list of N `Tensor`s with rank N + """ + indexing = kwargs.pop("indexing", "xy") + name = kwargs.pop("name", "meshgrid") + if len(kwargs) > 0: + key = list(kwargs.keys())[0] + raise TypeError("'{}' is an invalid keyword argument " + "for this function".format(key)) + + if indexing not in ("xy", "ij"): + raise ValueError("indexing parameter must be either 'xy' or 'ij'") + + with ops.op_scope(args, name, "meshgrid") as name: + num_inputs = len(args) + ones = (1,) * num_inputs + + asserts = [logging_ops.Assert( + gen_math_ops.equal(rank(x), 1), + ["Input %d needs to have rank 1: " % i, rank(x)], + ) for i, x in enumerate(args)] + + # Prepare reshape by inserting dimensions with size 1 where needed + shapes = [ones[:i] + (-1,) + ones[i + 1:] for i in range(num_inputs)] + # Create parameters for broadcasting each tensor to the full size + sizes = [size(x) for x in args] + bcast = [sizes[:i] + [1] + sizes[i + 1:] for i in range(num_inputs)] + + # By default, the numpy version swaps the instructions + # for the first and second dimension + if indexing == "xy" and num_inputs > 1: + shapes[0], shapes[1] = shapes[1], shapes[0] + bcast[0], bcast[1] = bcast[1], bcast[0] + + results = [] + with ops.control_dependencies(asserts): + for a, r, e in zip(args, shapes, bcast): + results.append(tile(reshape(a, r), e)) + + return results + + @ops.RegisterShape("Placeholder") def _PlaceholderShape(op): given_shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape")) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index d0a36edbe7..6957745182 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -184,7 +184,7 @@ class _VariableStore(object): If initializer is `None` (the default), the default initializer passed in the constructor is used. If that one is `None` too, we use a new - `UniformUnitScalingInitializer`. If initializer is a Tensor, we use + `uniform_unit_scaling_initializer`. If initializer is a Tensor, we use it as a value and derive the shape from the initializer. If the initializer is a callable, then it will be called for each @@ -681,7 +681,7 @@ def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None, If initializer is `None` (the default), the default initializer passed in the variable scope will be used. If that one is `None` too, a - `UniformUnitScalingInitializer` will be used. The initializer can also be + `uniform_unit_scaling_initializer` will be used. The initializer can also be a Tensor, in which case the variable is initialized to this value and shape. Similarly, if the regularizer is `None` (the default), the default regularizer @@ -757,7 +757,7 @@ def _get_partitioned_variable( If initializer is `None` (the default), the default initializer passed in the constructor is used. If that one is `None` too, we use a new - `UniformUnitScalingInitializer`. If initializer is a Tensor, we use + `uniform_unit_scaling_initializer`. If initializer is a Tensor, we use it as a value and derive the shape from the initializer. If the initializer is a callable, then it will be called for each diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py index 3bc35facf1..84ecaa5d74 100644 --- a/tensorflow/python/training/adam.py +++ b/tensorflow/python/training/adam.py @@ -64,6 +64,10 @@ class AdamOptimizer(optimizer.Optimizer): general. For example, when training an Inception network on ImageNet a current good choice is 1.0 or 0.1. + Note that in dense implement of this algorithm, m_t, v_t and variable will + update even if g is zero, but in sparse implement, m_t, v_t and variable + will not update in iterations g is zero. + Args: learning_rate: A Tensor or a floating point value. The learning rate. beta1: A float value or a constant float tensor. diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py index ad77c5ecc9..5100cedf88 100644 --- a/tensorflow/python/training/rmsprop.py +++ b/tensorflow/python/training/rmsprop.py @@ -57,6 +57,10 @@ class RMSPropOptimizer(optimizer.Optimizer): name="RMSProp"): """Construct a new RMSProp optimizer. + Note that in dense implement of this algorithm, m_t and v_t will + update even if g is zero, but in sparse implement, m_t and v_t + will not update in iterations g is zero. + Args: learning_rate: A Tensor or a floating point value. The learning rate. decay: Discounting factor for the history/coming gradient @@ -105,4 +109,14 @@ class RMSPropOptimizer(optimizer.Optimizer): grad, use_locking=self._use_locking).op def _apply_sparse(self, grad, var): - raise NotImplementedError() + rms = self.get_slot(var, "rms") + mom = self.get_slot(var, "momentum") + return training_ops.sparse_apply_rms_prop( + var, rms, mom, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + math_ops.cast(self._decay_tensor, var.dtype.base_dtype), + math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), + math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype), + grad.values, + grad.indices, + use_locking=self._use_locking) diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py index 541b3e0942..499e452d90 100644 --- a/tensorflow/python/training/rmsprop_test.py +++ b/tensorflow/python/training/rmsprop_test.py @@ -26,6 +26,131 @@ import tensorflow as tf class RMSPropOptimizerTest(tf.test.TestCase): + def _rmsprop_update_numpy(self, var, g, rms, mom, lr, decay, momentum, + epsilon): + rms_t = rms * decay + (1-decay) * g * g + mom_t = momentum * mom + lr * g / np.sqrt(rms_t + epsilon) + var_t = var - mom_t + return var_t, rms_t, mom_t + + def testSparseWithMomentum(self): + for dtype in [tf.half, tf.float32]: + with self.test_session(): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = tf.IndexedSlices(tf.constant(grads0_np), + tf.constant(grads0_np_indices), + tf.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = tf.IndexedSlices(tf.constant(grads1_np), + tf.constant(grads1_np_indices), + tf.constant([2])) + opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9, + momentum=0.5, epsilon=1e-5) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + tf.initialize_all_variables().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for t in range(1, 5): + update.run() + + var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np, + grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.5, 1e-5) + var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np, + grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.5, 1e-5) + + # Validate updated params + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + + def testSparseWithoutMomentum(self): + for dtype in [tf.half, tf.float32]: + with self.test_session(): + # Initialize variables for numpy implementation. + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = tf.Variable(var0_np) + var1 = tf.Variable(var1_np) + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = tf.IndexedSlices(tf.constant(grads0_np), + tf.constant(grads0_np_indices), + tf.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = tf.IndexedSlices(tf.constant(grads1_np), + tf.constant(grads1_np_indices), + tf.constant([2])) + opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9, + momentum=0.0, epsilon=1.0) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + tf.initialize_all_variables().run() + + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + mom0 = opt.get_slot(var0, "momentum") + self.assertTrue(mom0 is not None) + mom1 = opt.get_slot(var1, "momentum") + self.assertTrue(mom1 is not None) + + rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype) + mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], var0.eval()) + self.assertAllClose([3.0, 4.0], var1.eval()) + + # Run 4 steps of RMSProp + for t in range(1, 5): + update.run() + + var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np, + grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.0, 1.0) + var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np, + grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.0, 1.0) + + # Validate updated params + self.assertAllCloseAccordingToType(rms0_np, rms0.eval()) + self.assertAllCloseAccordingToType(rms1_np, rms1.eval()) + self.assertAllCloseAccordingToType(mom0_np, mom0.eval()) + self.assertAllCloseAccordingToType(mom1_np, mom1.eval()) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) + def testWithoutMomentum(self): for dtype in [tf.half, tf.float32]: with self.test_session(): diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py index 8619752338..1a96f77c1c 100644 --- a/tensorflow/python/training/training_ops.py +++ b/tensorflow/python/training/training_ops.py @@ -170,6 +170,23 @@ def _SparseApplyProximalGradientDescentShape(op): return [var_shape] +@ops.RegisterShape("SparseApplyRMSProp") +def _SparseApplyRMSPropShape(op): + """Shape function for the SparseApplyRMSProp op.""" + var_shape = op.inputs[0].get_shape() + ms_shape = op.inputs[1].get_shape().merge_with(var_shape) + mom_shape = op.inputs[2].get_shape().merge_with(ms_shape) + _AssertInputIsScalar(op, 3) # lr + _AssertInputIsScalar(op, 4) # rho + _AssertInputIsScalar(op, 5) # momentum + _AssertInputIsScalar(op, 6) # epsilon + grad_shape = op.inputs[7].get_shape().merge_with( + tensor_shape.TensorShape([None]).concatenate(mom_shape[1:])) + unused_indices_shape = op.inputs[8].get_shape().merge_with( + tensor_shape.vector(grad_shape[0])) + return [mom_shape] + + @ops.RegisterShape("SparseApplyAdadelta") def _SparseApplyAdadeltaShape(op): """Shape function for the SparseApplyAdadelta op.""" diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 224803fc84..a9dd2953e5 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -25,6 +25,12 @@ limitations under the License. #define EIGEN_HAS_CUDA_FP16 #endif +#if CUDA_VERSION >= 8000 +#define SE_CUDA_DATA_HALF CUDA_R_16F +#else +#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF +#endif + #include "tensorflow/stream_executor/cuda/cuda_blas.h" #include <dlfcn.h> @@ -1680,10 +1686,10 @@ bool CUDABlas::DoBlasGemm( return DoBlasInternal( dynload::cublasSgemmEx, stream, true /* = pointer_mode_host */, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha, - CUDAMemory(a), CUBLAS_DATA_HALF, lda, - CUDAMemory(b), CUBLAS_DATA_HALF, ldb, + CUDAMemory(a), SE_CUDA_DATA_HALF, lda, + CUDAMemory(b), SE_CUDA_DATA_HALF, ldb, &beta, - CUDAMemoryMutable(c), CUBLAS_DATA_HALF, ldc); + CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc); #else LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version " << "(need at least CUDA 7.5)"; diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts index 76f38814f2..299e536393 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts @@ -267,8 +267,11 @@ export class Minimap { downloadContext.drawImage(image, 0, 0, this.downloadCanvas.width, this.downloadCanvas.height); }; - let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'}); - image.src = URL.createObjectURL(blob); + image.onerror = () => { + let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'}); + image.src = URL.createObjectURL(blob); + } + image.src = 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svgXml); } /** diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html index 014061a807..25048790ae 100644 --- a/tensorflow/tensorboard/dist/tf-tensorboard.html +++ b/tensorflow/tensorboard/dist/tf-tensorboard.html @@ -9987,8 +9987,11 @@ var tf; downloadContext.clearRect(0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height); downloadContext.drawImage(image, 0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height); }; - var blob = new Blob([svgXml], { type: 'image/svg+xml;charset=utf-8' }); - image.src = URL.createObjectURL(blob); + image.onerror = function() { + var blob = new Blob([svgXml], {type: "image/svg+xml;charset=utf-8"}); + image.src = URL.createObjectURL(blob); + }; + image.src = "data:image/svg+xml;charset=utf-8," + encodeURIComponent(svgXml); }; /** * Handles changes in zooming/panning. Should be called from the main svg diff --git a/tensorflow/tools/docker/docker_run_gpu.sh b/tensorflow/tools/docker/docker_run_gpu.sh index 2fd885802b..08f391ddf9 100755 --- a/tensorflow/tools/docker/docker_run_gpu.sh +++ b/tensorflow/tools/docker/docker_run_gpu.sh @@ -14,16 +14,8 @@ # limitations under the License. # ============================================================================== - set -e -export CUDA_HOME=${CUDA_HOME:-/usr/local/cuda} - -if [ ! -d ${CUDA_HOME}/lib64 ]; then - echo "Failed to locate CUDA libs at ${CUDA_HOME}/lib64." - exit 1 -fi - export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | \ xargs -I{} echo '-v {}:{}') export DEVICES=$(\ls /dev/nvidia* | \ |