aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ISSUE_TEMPLATE.md2
-rw-r--r--grpc.BUILD6
-rw-r--r--tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj16
-rw-r--r--tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj18
-rw-r--r--tensorflow/contrib/learn/python/learn/io/data_feeder.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py29
-rw-r--r--tensorflow/contrib/makefile/Makefile28
-rw-r--r--tensorflow/contrib/makefile/README.md29
-rwxr-xr-xtensorflow/contrib/makefile/download_dependencies.sh7
-rw-r--r--tensorflow/core/kernels/training_ops.cc137
-rw-r--r--tensorflow/core/ops/array_ops.cc2
-rw-r--r--tensorflow/core/ops/training_ops.cc44
-rw-r--r--tensorflow/core/platform/platform.h7
-rw-r--r--tensorflow/examples/skflow/README.md1
-rw-r--r--tensorflow/examples/skflow/mnist_rnn.py78
-rw-r--r--tensorflow/g3doc/api_docs/python/contrib.metrics.md2
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md20
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py35
-rw-r--r--tensorflow/python/ops/array_ops.py79
-rw-r--r--tensorflow/python/ops/variable_scope.py6
-rw-r--r--tensorflow/python/training/adam.py4
-rw-r--r--tensorflow/python/training/rmsprop.py16
-rw-r--r--tensorflow/python/training/rmsprop_test.py125
-rw-r--r--tensorflow/python/training/training_ops.py17
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc12
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts7
-rw-r--r--tensorflow/tensorboard/dist/tf-tensorboard.html7
-rwxr-xr-xtensorflow/tools/docker/docker_run_gpu.sh8
28 files changed, 676 insertions, 68 deletions
diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md
index 932206137e..9525df9a4b 100644
--- a/ISSUE_TEMPLATE.md
+++ b/ISSUE_TEMPLATE.md
@@ -16,7 +16,7 @@ Installed version of CUDA and cuDNN:
If installed from binary pip package, provide:
1. Which pip package you installed.
-2. The output from python -c "import tensorflow; print(tensorflow.__version__)".
+2. The output from `python -c "import tensorflow; print(tensorflow.__version__)"`.
If installed from sources, provide the commit hash:
diff --git a/grpc.BUILD b/grpc.BUILD
index 8c00c90da9..f79cc8b610 100644
--- a/grpc.BUILD
+++ b/grpc.BUILD
@@ -145,6 +145,12 @@ cc_library(
"include",
".",
],
+ defines = [
+ "GPR_BACKWARDS_COMPATIBILITY_MODE",
+ ],
+ copts = [
+ "-std=c99",
+ ],
deps = [
],
)
diff --git a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj
index 0156188577..b39b5d6b21 100644
--- a/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj
+++ b/tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj/project.pbxproj
@@ -285,7 +285,7 @@
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = (
"$(SRCROOT)/../../makefile/gen/proto",
- "$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
+ "$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
"$(SRCROOT)/../../makefile/downloads",
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
"$(SRCROOT)/../../../..",
@@ -300,6 +300,12 @@
OTHER_LDFLAGS = (
"-force_load",
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
+ "-Xlinker",
+ "-S",
+ "-Xlinker",
+ "-x",
+ "-Xlinker",
+ "-dead_strip",
);
PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample;
PRODUCT_NAME = "$(TARGET_NAME)";
@@ -344,7 +350,7 @@
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = (
"$(SRCROOT)/../../makefile/gen/proto",
- "$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
+ "$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
"$(SRCROOT)/../../makefile/downloads",
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
"$(SRCROOT)/../../../..",
@@ -359,6 +365,12 @@
OTHER_LDFLAGS = (
"-force_load",
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
+ "-Xlinker",
+ "-S",
+ "-Xlinker",
+ "-x",
+ "-Xlinker",
+ "-dead_strip",
);
PRODUCT_BUNDLE_IDENTIFIER = com.google.CameraExample;
PRODUCT_NAME = "$(TARGET_NAME)";
diff --git a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj b/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj
index 91866cecac..ed2bd525f0 100644
--- a/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj
+++ b/tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj/project.pbxproj
@@ -276,7 +276,7 @@
"$(SRCROOT)/../../../..",
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
"$(SRCROOT)/../../makefile/downloads",
- "$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
+ "$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
"$(SRCROOT)/../../makefile/gen/proto",
);
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
@@ -289,6 +289,12 @@
OTHER_LDFLAGS = (
"-force_load",
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
+ "-Xlinker",
+ "-S",
+ "-Xlinker",
+ "-x",
+ "-Xlinker",
+ "-dead_strip",
);
PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
PRODUCT_NAME = "$(TARGET_NAME)";
@@ -304,7 +310,7 @@
"$(SRCROOT)/../../../..",
"$(SRCROOT)/../../makefile/downloads/protobuf/src/",
"$(SRCROOT)/../../makefile/downloads",
- "$(SRCROOT)/../../makefile/downloads/eigen-eigen-f3a13643ac1f",
+ "$(SRCROOT)/../../makefile/downloads/eigen-eigen-d02e6a705c30",
"$(SRCROOT)/../../makefile/gen/proto",
);
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
@@ -314,10 +320,16 @@
"$(SRCROOT)/../../makefile/gen/protobuf_ios/lib",
"$(SRCROOT)/../../makefile/gen/lib",
);
- ONLY_ACTIVE_ARCH = NO;
+ ONLY_ACTIVE_ARCH = YES;
OTHER_LDFLAGS = (
"-force_load",
"$(SRCROOT)/../../makefile/gen/lib/libtensorflow-core.a",
+ "-Xlinker",
+ "-S",
+ "-Xlinker",
+ "-x",
+ "-Xlinker",
+ "-dead_strip",
);
PRODUCT_BUNDLE_IDENTIFIER = "com.google.TF-Test";
PRODUCT_NAME = "$(TARGET_NAME)";
diff --git a/tensorflow/contrib/learn/python/learn/io/data_feeder.py b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
index acfb6b48a9..f23a4a3bea 100644
--- a/tensorflow/contrib/learn/python/learn/io/data_feeder.py
+++ b/tensorflow/contrib/learn/python/learn/io/data_feeder.py
@@ -358,7 +358,7 @@ class DataFeeder(object):
else:
if self.n_classes > 1:
if len(self.output_shape) == 2:
- out.itemset((i, self.y[sample]), 1.0)
+ out.itemset((i, int(self.y[sample])), 1.0)
else:
for idx, value in enumerate(self.y[sample]):
out.itemset(tuple([i, idx, value]), 1.0)
diff --git a/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py b/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py
index ae8c79ebcd..de7a315657 100644
--- a/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/batch_norm_ops.py
@@ -54,22 +54,31 @@ def batch_normalize(tensor_in,
initializer=init_ops.random_normal_initializer(1., 0.02))
beta = vs.get_variable("beta", [shape[-1]],
initializer=init_ops.constant_initializer(0.))
- ema = moving_averages.ExponentialMovingAverage(decay=decay)
- if convnet:
- assign_mean, assign_var = nn.moments(tensor_in, [0, 1, 2])
- else:
- assign_mean, assign_var = nn.moments(tensor_in, [0])
- ema_assign_op = ema.apply([assign_mean, assign_var])
- ema_mean, ema_var = ema.average(assign_mean), ema.average(assign_var)
+ moving_mean = vs.get_variable(
+ 'moving_mean',
+ shape=[shape[-1]],
+ initializer=init_ops.zeros_initializer,
+ trainable=False)
+ moving_var = vs.get_variable(
+ 'moving_var',
+ shape=[shape[-1]],
+ initializer=init_ops.ones_initializer,
+ trainable=False)
def _update_mean_var():
"""Internal function that updates mean and variance during training."""
- with ops.control_dependencies([ema_assign_op]):
- return array_ops_.identity(assign_mean), array_ops_.identity(assign_var)
+ axis = [0, 1, 2] if convnet else [0]
+ mean, var = nn.moments(tensor_in, axis)
+ update_moving_mean = moving_averages.assign_moving_average(
+ moving_mean, mean, decay)
+ update_moving_var = moving_averages.assign_moving_average(
+ moving_var, var, decay)
+ with ops.control_dependencies([update_moving_mean, update_moving_var]):
+ return array_ops_.identity(mean), array_ops_.identity(var)
is_training = array_ops_.squeeze(ops.get_collection("IS_TRAINING"))
mean, variance = control_flow_ops.cond(is_training, _update_mean_var,
- lambda: (ema_mean, ema_var))
+ lambda: (moving_mean, moving_var))
return nn.batch_norm_with_global_normalization(
tensor_in,
mean,
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index 984250ce83..86ad8f41c1 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -35,8 +35,8 @@ HOST_OBJDIR := $(MAKEFILE_DIR)/gen/host_obj/
HOST_BINDIR := $(MAKEFILE_DIR)/gen/host_bin/
HOST_GENDIR := $(MAKEFILE_DIR)/gen/host_obj/
-# Which Eigen version we're using.
-EIGEN_HASH := d02e6a705c30
+# Find the current Eigen version name from the Bazel build file
+EIGEN_HASH := $(shell cat eigen.BUILD | grep archive_dir | head -1 | cut -f3 -d- | cut -f1 -d\")
# Settings for the host compiler.
HOST_CXX := gcc
@@ -56,9 +56,15 @@ HOST_LIBS := \
# If we're on Linux, also link in the dl library.
ifeq ($(HOST_OS),LINUX)
- HOST_LIBS += -ldl
+ HOST_LIBS += -ldl -lpthread
endif
+# If we're on a Pi, link in pthreads and dl
+ifeq ($(HOST_OS),PI)
+ HOST_LIBS += -ldl -lpthread
+endif
+
+
# proto_text is a tool that converts protobufs into a form we can use more
# compactly within TensorFlow. It's a bit like protoc, but is designed to
# produce a much more minimal result so we can save binary space.
@@ -125,13 +131,13 @@ ifeq ($(TARGET),LINUX)
endif
# If we're on Linux, also link in the dl library.
ifeq ($(TARGET),LINUX)
- LIBS += -ldl
+ LIBS += -ldl -lpthread
endif
# If we're cross-compiling for the Raspberry Pi, use the right gcc.
ifeq ($(TARGET),PI)
- CXX := arm-linux-gnueabihf-g++
- LDFLAGS := -L$(GENDIR)protobuf_pi/lib/ -Wl,--no-whole-archive
- LIBS += -ldl
+ CXXFLAGS += -D__ANDROID_TYPES_SLIM__
+ LDFLAGS := -Wl,--no-whole-archive
+ LIBS += -ldl -lpthread
LIBFLAGS += -Wl,--allow-multiple-definition -Wl,--whole-archive
endif
@@ -169,12 +175,16 @@ ifeq ($(TARGET),IOS)
-Wno-c++11-narrowing \
-mno-thumb \
-DTF_LEAN_BINARY \
+ -D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
${IPHONEOS_SYSROOT}
LDFLAGS := -arch armv7 \
-miphoneos-version-min=${MIN_SDK_VERSION} \
+ -Xlinker -S \
+ -Xlinker -x \
+ -Xlinker -dead_strip \
-all_load \
-L$(GENDIR)protobuf_ios/lib \
-lz
@@ -186,6 +196,7 @@ ifeq ($(TARGET),IOS)
-Wno-c++11-narrowing \
-mno-thumb \
-DTF_LEAN_BINARY \
+ -D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
@@ -205,6 +216,7 @@ ifeq ($(TARGET),IOS)
-D__thread= \
-Wno-c++11-narrowing \
-DTF_LEAN_BINARY \
+ -D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
@@ -224,6 +236,7 @@ ifeq ($(TARGET),IOS)
-D__thread= \
-Wno-c++11-narrowing \
-DTF_LEAN_BINARY \
+ -D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
@@ -243,6 +256,7 @@ ifeq ($(TARGET),IOS)
-D__thread= \
-Wno-c++11-narrowing \
-DTF_LEAN_BINARY \
+ -D__ANDROID_TYPES_SLIM__ \
-DMIN_LOG_LEVEL=0 \
-fno-exceptions \
-isysroot \
diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md
index f82b70187c..f7f46fca5c 100644
--- a/tensorflow/contrib/makefile/README.md
+++ b/tensorflow/contrib/makefile/README.md
@@ -42,7 +42,7 @@ at `tensorflow/contrib/makefile/gen/bin/benchmark`. To run the executable, use:
tensorflow/contrib/makefile/gen/bin/benchmark --graph=tensorflow_inception_graph.pb
```
-You should download the example graph from [http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz).
+You should download the example graph from [https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip).
## Supported Systems
@@ -132,24 +132,31 @@ static library in a simple app.
## Raspberry Pi
-The easiest way to build for the Raspberry Pi is to cross-compile from Linux.
-To use this makefile to do that, you first need to install the right version of
-the compiler to target the Pi, using a command like this on your Linux machine:
+Building on the Raspberry Pi is similar to a normal Linux system, though we
+recommend starting by compiling and installing protobuf:
```bash
-sudo apt-get install g++-arm-linux-gnueabihf
+cd tensorflow/contrib/makefile/downloads/protobuf/
+./autogen.sh
+./configure
+make
+sudo make install
+cd ../../../../..
```
-After that, run `tensorflow/contrib/makefile/compile_pi_protobuf.sh` to build a
-version of the protobuf library aimed at the Pi. Then you should be able to run:
+Once that's done, you can use make to build the library and example:
```bash
-make -f tensorflow/contrib/makefile/Makefile TARGET=PI
+make -f tensorflow/contrib/makefile/Makefile HOST_OS=PI TARGET=PI OPTFLAGS="-Os"
```
-This will build the static library, and the example benchmark executable. You
-can then copy the `tensorflow/contrib/makefile/gen/bin/benchmark` program over
-to your Raspberry Pi, and run it there.
+If you're only interested in building for Raspberry Pi's 2 and 3, you can supply
+some extra optimization flags to give you code that will run faster:
+
+```bash
+make -f tensorflow/contrib/makefile/Makefile HOST_OS=PI TARGET=PI \
+OPTFLAGS="-Os -mfpu=neon-vfpv4 -funsafe-math-optimizations -ftree-vectorize"
+```
## Dependencies
diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh
index 2ff9013804..3aad6270fe 100755
--- a/tensorflow/contrib/makefile/download_dependencies.sh
+++ b/tensorflow/contrib/makefile/download_dependencies.sh
@@ -18,7 +18,12 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads
mkdir ${DOWNLOADS_DIR}
-EIGEN_HASH=d02e6a705c30
+EIGEN_HASH=62a2305d5734
+if [ -f eigen.BUILD ]; then
+ # Grab the current Eigen version name from the Bazel build file
+ EIGEN_HASH=$(cat eigen.BUILD | grep archive_dir | head -1 | cut -f3 -d- | cut -f1 -d\")
+fi
+
curl "https://bitbucket.org/eigen/eigen/get/${EIGEN_HASH}.tar.gz" \
-o /tmp/eigen-${EIGEN_HASH}.tar.gz
tar xzf /tmp/eigen-${EIGEN_HASH}.tar.gz -C ${DOWNLOADS_DIR}
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index b16c9c860a..2f9714a37a 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -1937,4 +1937,141 @@ REGISTER_KERNELS(GPU, double);
#undef REGISTER_CPU_KERNELS
#undef REGISTER_KERNELS
+
+// Note, this op works on cpu only.
+template <typename T, typename Tindex>
+class SparseApplyRMSPropOp : public OpKernel {
+ public:
+ explicit SparseApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
+ auto locks = MaybeLockMutexesInOrder(ctx, use_exclusive_lock_, {0, 1, 2});
+
+ Tensor var = ctx->mutable_input(0, use_exclusive_lock_);
+ Tensor ms = ctx->mutable_input(1, use_exclusive_lock_);
+ Tensor mom = ctx->mutable_input(2, use_exclusive_lock_);
+
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(0)));
+ OP_REQUIRES(
+ ctx, ms.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(1)));
+ OP_REQUIRES(
+ ctx, mom.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", def().input(2)));
+
+ const Tensor& lr = ctx->input(3);
+ const Tensor& rho = ctx->input(4);
+ const Tensor& momentum = ctx->input(5);
+ const Tensor& epsilon = ctx->input(6);
+ const Tensor& grad = ctx->input(7);
+ const Tensor& indices = ctx->input(8);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho.shape()),
+ errors::InvalidArgument("rho is not a scalar: ",
+ rho.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum.shape()),
+ errors::InvalidArgument("momentum is not a scalar: ",
+ momentum.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon.shape().DebugString()));
+
+ OP_REQUIRES(ctx, var.shape().IsSameSize(ms.shape()),
+ errors::InvalidArgument("var and ms do not have the same shape",
+ var.shape().DebugString(), " ",
+ ms.shape().DebugString()));
+
+ OP_REQUIRES(ctx, var.shape().IsSameSize(mom.shape()),
+ errors::InvalidArgument(
+ "var and mom do not have the same shape",
+ var.shape().DebugString(), " ", mom.shape().DebugString()));
+
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and grad do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(indices.shape()),
+ errors::InvalidArgument("indices must be one-dimensional"));
+
+ const Tindex N = indices.dim_size(0);
+ OP_REQUIRES(
+ ctx, grad.dim_size(0) == N,
+ errors::InvalidArgument(
+ "grad must be the same size as indices in the first dimension."));
+
+ if (N > 0) {
+ const Tindex first_dim_size = var.dim_size(0);
+ // Validate all the indices are in range
+ auto indices_vec = indices.vec<Tindex>();
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+ OP_REQUIRES(ctx, index >= 0 && index < first_dim_size,
+ errors::InvalidArgument(
+ strings::StrCat("Index ", index, " at offset ", i,
+ " in indices is out of range")));
+ }
+
+ auto var_flat = var.flat_outer_dims<T>();
+ auto ms_flat = ms.flat_outer_dims<T>();
+ auto mom_flat = mom.flat_outer_dims<T>();
+ auto grad_flat = grad.flat_outer_dims<T>();
+ const T lr_scalar = lr.scalar<T>()();
+ const T rho_scalar = rho.scalar<T>()();
+ const T epsilon_scalar = epsilon.scalar<T>()();
+ const T momentum_scalar = momentum.scalar<T>()();
+
+ for (Tindex i = 0; i < N; i++) {
+ const Tindex index = indices_vec(i);
+
+ auto ms_ = ms_flat.template chip<0>(index);
+ auto mom_ = mom_flat.template chip<0>(index);
+ auto grad_ = grad_flat.template chip<0>(i);
+
+ ms_ = ms_ * ms_.constant(rho_scalar) +
+ grad_.square() * grad_.constant(T(1) - rho_scalar);
+ mom_ = mom_ * mom_.constant(momentum_scalar) +
+ (ms_ + ms_.constant(epsilon_scalar)).rsqrt() *
+ ms_.constant(lr_scalar) * grad_;
+
+ auto v = var_flat.template chip<0>(index);
+ v -= mom_;
+ }
+ }
+
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+};
+
+#define REGISTER_KERNELS(T, Tindices) \
+ REGISTER_KERNEL_BUILDER(Name("SparseApplyRMSProp") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .TypeConstraint<Tindices>("Tindices"), \
+ SparseApplyRMSPropOp<T, Tindices>);
+
+REGISTER_KERNELS(Eigen::half, int32);
+REGISTER_KERNELS(Eigen::half, int64);
+REGISTER_KERNELS(float, int32);
+REGISTER_KERNELS(float, int64);
+REGISTER_KERNELS(double, int32);
+REGISTER_KERNELS(double, int64);
+
+#undef REGISTER_KERNELS
+
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index b11264b4d4..33deb51e9c 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1715,7 +1715,7 @@ REGISTER_OP("ExtractImagePatches")
.Attr("T: realnumbertype")
.Attr(GetPaddingAttrString())
.Doc(R"doc(
-Extract `patches` from `images` and puth them in the "depth" output dimension.
+Extract `patches` from `images` and put them in the "depth" output dimension.
images: 4-D Tensor with shape `[batch, in_rows, in_cols, depth]`.
patches: 4-D Tensor with shape `[batch, out_rows, out_cols, ksize_rows *
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index 5eb011684b..eabec80c2e 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -441,6 +441,9 @@ REGISTER_OP("ApplyRMSProp")
.Attr("use_locking: bool = false")
.Doc(R"doc(
Update '*var' according to the RMSProp algorithm.
+Note that in dense implement of this algorithm, ms and mom will
+update even if the grad is zero, but in this sparse implement, ms
+and mom will not update in iterations the grad is zero.
mean_square = decay * mean_square + (1-decay) * gradient ** 2
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
@@ -461,5 +464,46 @@ use_locking: If `True`, updating of the var, m, and v tensors will be protected
by a lock; otherwise the behavior is undefined, but may exhibit less
contention.
)doc");
+
+REGISTER_OP("SparseApplyRMSProp")
+ .Input("var: Ref(T)")
+ .Input("ms: Ref(T)")
+ .Input("mom: Ref(T)")
+ .Input("lr: T")
+ .Input("rho: T")
+ .Input("momentum: T")
+ .Input("epsilon: T")
+ .Input("grad: T")
+ .Input("indices: Tindices")
+ .Output("out: Ref(T)")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = false")
+ .Doc(R"doc(
+Update '*var' according to the RMSProp algorithm.
+Note that in dense implement of this algorithm, ms and mom will
+update even if the grad is zero, but in this sparse implement, ms
+and mom will not update in iterations the grad is zero.
+
+mean_square = decay * mean_square + (1-decay) * gradient ** 2
+Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+
+ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+var <- var - mom
+
+var: Should be from a Variable().
+ms: Should be from a Variable().
+mom: Should be from a Variable().
+lr: Scaling factor. Must be a scalar.
+epsilon: Ridge term. Must be a scalar.
+rho: Decay rate. Must be a scalar.
+grad: The gradient.
+indices: A vector of indices into the first dimension of var, ms and mom.
+out: Same as "var".
+use_locking: If `True`, updating of the var, m, and v tensors will be protected
+ by a lock; otherwise the behavior is undefined, but may exhibit less
+ contention.
+)doc");
} // namespace tensorflow
diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h
index 8f71a617b6..02731d9275 100644
--- a/tensorflow/core/platform/platform.h
+++ b/tensorflow/core/platform/platform.h
@@ -37,6 +37,13 @@ limitations under the License.
#define IS_MOBILE_PLATFORM
#endif
+#elif defined(__arm__)
+#define PLATFORM_POSIX
+
+// Since there's no macro for the Raspberry Pi, assume we're on a mobile
+// platform if we're compiling for the ARM CPU.
+#define IS_MOBILE_PLATFORM
+
#else
// If no platform specified, use:
#define PLATFORM_POSIX
diff --git a/tensorflow/examples/skflow/README.md b/tensorflow/examples/skflow/README.md
index 756351c0b6..0d3d3cc1cf 100644
--- a/tensorflow/examples/skflow/README.md
+++ b/tensorflow/examples/skflow/README.md
@@ -33,6 +33,7 @@ Some examples use the `pandas` library for data processing (`sudo pip install pa
## Image classification
* [Convolutional Neural Networks on MNIST Data](mnist.py)
+* [Recurrent Neural Networks on MNIST Data](mnist_rnn.py)
* [Deep Residual Networks on MNIST Data](resnet.py)
diff --git a/tensorflow/examples/skflow/mnist_rnn.py b/tensorflow/examples/skflow/mnist_rnn.py
new file mode 100644
index 0000000000..a6a594fad5
--- /dev/null
+++ b/tensorflow/examples/skflow/mnist_rnn.py
@@ -0,0 +1,78 @@
+# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This example builds rnn network for mnist data.
+Borrowed structure from here: https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3%20-%20Neural%20Networks/recurrent_network.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from sklearn import metrics, preprocessing
+
+import tensorflow as tf
+from tensorflow.contrib import learn
+
+# Parameters
+learning_rate = 0.1
+training_steps = 3000
+batch_size = 128
+
+# Network Parameters
+n_input = 28 # MNIST data input (img shape: 28*28)
+n_steps = 28 # timesteps
+n_hidden = 128 # hidden layer num of features
+n_classes = 10 # MNIST total classes (0-9 digits)
+
+### Download and load MNIST data.
+mnist = learn.datasets.load_dataset('mnist')
+
+X_train = mnist.train.images
+y_train = mnist.train.labels
+X_test = mnist.test.images
+y_test = mnist.test.labels
+
+# It's useful to scale to ensure Stochastic Gradient Descent will do the right thing
+scaler = preprocessing.StandardScaler()
+X_train = scaler.fit_transform(X_train)
+X_test = scaler.fit_transform(X_test)
+
+
+def rnn_model(X, y):
+ X = tf.reshape(X, [-1, n_steps, n_input]) # (batch_size, n_steps, n_input)
+ # # permute n_steps and batch_size
+ X = tf.transpose(X, [1, 0, 2])
+ # # Reshape to prepare input to hidden activation
+ X = tf.reshape(X, [-1, n_input]) # (n_steps*batch_size, n_input)
+ # # Split data because rnn cell needs a list of inputs for the RNN inner loop
+ X = tf.split(0, n_steps, X) # n_steps * (batch_size, n_input)
+
+ # Define a GRU cell with tensorflow
+ lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
+ # Get lstm cell output
+ _, encoding = tf.nn.rnn(lstm_cell, X, dtype=tf.float32)
+
+ return learn.models.logistic_regression(encoding, y)
+
+
+classifier = learn.TensorFlowEstimator(model_fn=rnn_model, n_classes=n_classes,
+ batch_size=batch_size,
+ steps=training_steps,
+ learning_rate=learning_rate)
+
+classifier.fit(X_train, y_train, logdir="/tmp/mnist_rnn")
+score = metrics.accuracy_score(y_test, classifier.predict(X_test))
+print('Accuracy: {0:f}'.format(score))
diff --git a/tensorflow/g3doc/api_docs/python/contrib.metrics.md b/tensorflow/g3doc/api_docs/python/contrib.metrics.md
index 863400eaa6..9600b9a523 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.metrics.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.metrics.md
@@ -106,7 +106,7 @@ idempotent operation that simply divides `total` by `count`.
To facilitate the estimation of the accuracy over a stream of data, the
function utilizes two operations. First, an `is_correct` operation that
computes a tensor whose shape matches `predictions` and whose elements are
-set to 1.0 when the corresponding values of `predictions` and `labels match
+set to 1.0 when the corresponding values of `predictions` and `labels` match
and 0.0 otherwise. Second, an `update_op` operation whose behavior is
dependent on the value of `weights`. If `weights` is None, then `update_op`
increments `total` with the number of elements of `predictions` that match
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 1eceeaeee5..74cdc4b518 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -62,7 +62,7 @@ Install TensorFlow:
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
@@ -77,7 +77,7 @@ For python3:
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
@@ -137,7 +137,7 @@ $ source ~/tensorflow/bin/activate.csh # If using csh
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
@@ -155,7 +155,7 @@ $ source ~/tensorflow/bin/activate.csh # If using csh
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
@@ -228,7 +228,7 @@ $ source activate tensorflow
# Ubuntu/Linux 64-bit, CPU only, Python 2.7:
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and CuDNN v4.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl
@@ -245,7 +245,7 @@ $ source activate tensorflow
# Ubuntu/Linux 64-bit, CPU only, Python 3.4:
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
-# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and CuDNN v4.
+# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4. Requires CUDA toolkit 7.5 and cuDNN v4.
# For other versions, see "Install from sources" below.
(tensorflow)$ pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
@@ -314,7 +314,7 @@ $ docker run -it -p 8888:8888 gcr.io/tensorflow/tensorflow
The option `-p 8888:8888` is used to publish the Docker container᾿s internal port to the host machine, in this case to ensure Jupyter notebook connection.
-The format of the port mapping `hostPort:containerPort`. You can speficy any valid port number for the host port but has to be `8888` for the container port portion.
+The format of the port mapping is `hostPort:containerPort`. You can specify any valid port number for the host port but have to use `8888` for the container port portion.
If you're using a container with GPU support, some additional flags must be
passed to expose the GPU device to the container. For the default config, we
@@ -526,7 +526,7 @@ empty to use system default]: 7.5
Please specify the location where CUDA 7.5 toolkit is installed. Refer to
README.md for more details. [default is: /usr/local/cuda]: /usr/local/cuda
-Please specify the Cudnn version you want to use. [Leave empty to use system
+Please specify the cuDNN version you want to use. [Leave empty to use system
default]: 4.0.4
Please specify the location where the cuDNN 4.0.4 library is installed. Refer to
@@ -549,7 +549,7 @@ Configuration finished
This creates a canonical set of symbolic links to the Cuda libraries on your system.
Every time you change the Cuda library paths you need to run this step again before
-you invoke the bazel build command. For the Cudnn libraries, use '6.5' for R2, '7.0'
+you invoke the bazel build command. For the cuDNN libraries, use '6.5' for R2, '7.0'
for R3, and '4.0.4' for R4-RC.
@@ -672,7 +672,7 @@ GPU support will be enabled for TensorFlow
Please specify which gcc nvcc should use as the host compiler. [Default is /usr/bin/gcc]:
Please specify the Cuda SDK version you want to use, e.g. 7.0. [Leave empty to use system default]: 7.5
Please specify the location where CUDA 7.5 toolkit is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
-Please specify the Cudnn version you want to use. [Leave empty to use system default]: 5
+Please specify the cuDNN version you want to use. [Leave empty to use system default]: 5
Please specify the location where cuDNN 5 library is installed. Refer to README.md for more details. [Default is /usr/local/cuda]:
Please specify a list of comma-separated Cuda compute capabilities you want to build with.
You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 9c9c8d4675..c9c8ecb72c 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -173,5 +173,40 @@ class ReverseTest(test_util.TensorFlowTestCase):
tf.reverse(data_2d_t, dims_3d_t)
+class MeshgridTest(test_util.TensorFlowTestCase):
+
+ def _compare(self, n, np_dtype, use_gpu):
+ inputs = []
+ for i in range(n):
+ x = np.linspace(-10, 10, 5).astype(np_dtype)
+ if np_dtype in (np.complex64, np.complex128):
+ x += 1j
+ inputs.append(x)
+
+ numpy_out = np.meshgrid(*inputs)
+ with self.test_session(use_gpu=use_gpu):
+ tf_out = array_ops.meshgrid(*inputs)
+ for X, _X in zip(numpy_out, tf_out):
+ self.assertAllEqual(X, _X.eval())
+
+ def testCompare(self):
+ for t in (np.float16, np.float32, np.float64, np.int32, np.int64,
+ np.complex64, np.complex128):
+ # Don't test the one-dimensional case, as
+ # old numpy versions don't support it
+ self._compare(2, t, False)
+ self._compare(3, t, False)
+ self._compare(4, t, False)
+ self._compare(5, t, False)
+
+ # Test for inputs with rank not equal to 1
+ x = [[1, 1], [1, 1]]
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "needs to have rank 1"):
+ with self.test_session():
+ X, _ = array_ops.meshgrid(x, x)
+ X.eval()
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 13b3180b12..fd442c6eb8 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -38,6 +38,7 @@ of a tensor and change the shape of a tensor.
@@reshape
@@squeeze
@@expand_dims
+@@meshgrid
## Slicing and Joining
@@ -125,7 +126,7 @@ def shape(input, name=None):
else:
return gen_array_ops.shape(input, name=name)
-
+
def rank(input, name=None):
"""Returns the rank of a tensor.
@@ -1047,6 +1048,82 @@ def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invali
raise ValueError("Unknown padding mode: %s" % mode)
+def meshgrid(*args, **kwargs):
+ """Broadcasts parameters for evaluation on an N-D grid.
+
+ Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
+ of N-D coordinate arrays for evaluating expressions on an N-D grid.
+
+ Notes:
+
+ `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
+ When the `indexing` argument is set to 'xy' (the default), the broadcasting
+ instructions for the first two dimensions are swapped.
+
+ Examples:
+
+ Calling `X, Y = meshgrid(x, y)` with the tensors
+ ```prettyprint
+ x = [1, 2, 3]
+ y = [4, 5, 6]
+ ```
+ results in
+ ```prettyprint
+ X = [[1, 1, 1],
+ [2, 2, 2],
+ [3, 3, 3]]
+ Y = [[4, 5, 6],
+ [4, 5, 6],
+ [4, 5, 6]]
+ ```
+
+ Args:
+ *args: `Tensor`s with rank 1
+ indexing: Either 'xy' or 'ij' (optional, default: 'xy')
+ name: A name for the operation (optional).
+
+ Returns:
+ outputs: A list of N `Tensor`s with rank N
+ """
+ indexing = kwargs.pop("indexing", "xy")
+ name = kwargs.pop("name", "meshgrid")
+ if len(kwargs) > 0:
+ key = list(kwargs.keys())[0]
+ raise TypeError("'{}' is an invalid keyword argument "
+ "for this function".format(key))
+
+ if indexing not in ("xy", "ij"):
+ raise ValueError("indexing parameter must be either 'xy' or 'ij'")
+
+ with ops.op_scope(args, name, "meshgrid") as name:
+ num_inputs = len(args)
+ ones = (1,) * num_inputs
+
+ asserts = [logging_ops.Assert(
+ gen_math_ops.equal(rank(x), 1),
+ ["Input %d needs to have rank 1: " % i, rank(x)],
+ ) for i, x in enumerate(args)]
+
+ # Prepare reshape by inserting dimensions with size 1 where needed
+ shapes = [ones[:i] + (-1,) + ones[i + 1:] for i in range(num_inputs)]
+ # Create parameters for broadcasting each tensor to the full size
+ sizes = [size(x) for x in args]
+ bcast = [sizes[:i] + [1] + sizes[i + 1:] for i in range(num_inputs)]
+
+ # By default, the numpy version swaps the instructions
+ # for the first and second dimension
+ if indexing == "xy" and num_inputs > 1:
+ shapes[0], shapes[1] = shapes[1], shapes[0]
+ bcast[0], bcast[1] = bcast[1], bcast[0]
+
+ results = []
+ with ops.control_dependencies(asserts):
+ for a, r, e in zip(args, shapes, bcast):
+ results.append(tile(reshape(a, r), e))
+
+ return results
+
+
@ops.RegisterShape("Placeholder")
def _PlaceholderShape(op):
given_shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape"))
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index d0a36edbe7..6957745182 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -184,7 +184,7 @@ class _VariableStore(object):
If initializer is `None` (the default), the default initializer passed in
the constructor is used. If that one is `None` too, we use a new
- `UniformUnitScalingInitializer`. If initializer is a Tensor, we use
+ `uniform_unit_scaling_initializer`. If initializer is a Tensor, we use
it as a value and derive the shape from the initializer.
If the initializer is a callable, then it will be called for each
@@ -681,7 +681,7 @@ def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None,
If initializer is `None` (the default), the default initializer passed in
the variable scope will be used. If that one is `None` too, a
- `UniformUnitScalingInitializer` will be used. The initializer can also be
+ `uniform_unit_scaling_initializer` will be used. The initializer can also be
a Tensor, in which case the variable is initialized to this value and shape.
Similarly, if the regularizer is `None` (the default), the default regularizer
@@ -757,7 +757,7 @@ def _get_partitioned_variable(
If initializer is `None` (the default), the default initializer passed in
the constructor is used. If that one is `None` too, we use a new
- `UniformUnitScalingInitializer`. If initializer is a Tensor, we use
+ `uniform_unit_scaling_initializer`. If initializer is a Tensor, we use
it as a value and derive the shape from the initializer.
If the initializer is a callable, then it will be called for each
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 3bc35facf1..84ecaa5d74 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -64,6 +64,10 @@ class AdamOptimizer(optimizer.Optimizer):
general. For example, when training an Inception network on ImageNet a
current good choice is 1.0 or 0.1.
+ Note that in dense implement of this algorithm, m_t, v_t and variable will
+ update even if g is zero, but in sparse implement, m_t, v_t and variable
+ will not update in iterations g is zero.
+
Args:
learning_rate: A Tensor or a floating point value. The learning rate.
beta1: A float value or a constant float tensor.
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
index ad77c5ecc9..5100cedf88 100644
--- a/tensorflow/python/training/rmsprop.py
+++ b/tensorflow/python/training/rmsprop.py
@@ -57,6 +57,10 @@ class RMSPropOptimizer(optimizer.Optimizer):
name="RMSProp"):
"""Construct a new RMSProp optimizer.
+ Note that in dense implement of this algorithm, m_t and v_t will
+ update even if g is zero, but in sparse implement, m_t and v_t
+ will not update in iterations g is zero.
+
Args:
learning_rate: A Tensor or a floating point value. The learning rate.
decay: Discounting factor for the history/coming gradient
@@ -105,4 +109,14 @@ class RMSPropOptimizer(optimizer.Optimizer):
grad, use_locking=self._use_locking).op
def _apply_sparse(self, grad, var):
- raise NotImplementedError()
+ rms = self.get_slot(var, "rms")
+ mom = self.get_slot(var, "momentum")
+ return training_ops.sparse_apply_rms_prop(
+ var, rms, mom,
+ math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
+ math_ops.cast(self._decay_tensor, var.dtype.base_dtype),
+ math_ops.cast(self._momentum_tensor, var.dtype.base_dtype),
+ math_ops.cast(self._epsilon_tensor, var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index 541b3e0942..499e452d90 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -26,6 +26,131 @@ import tensorflow as tf
class RMSPropOptimizerTest(tf.test.TestCase):
+ def _rmsprop_update_numpy(self, var, g, rms, mom, lr, decay, momentum,
+ epsilon):
+ rms_t = rms * decay + (1-decay) * g * g
+ mom_t = momentum * mom + lr * g / np.sqrt(rms_t + epsilon)
+ var_t = var - mom_t
+ return var_t, rms_t, mom_t
+
+ def testSparseWithMomentum(self):
+ for dtype in [tf.half, tf.float32]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = tf.IndexedSlices(tf.constant(grads0_np),
+ tf.constant(grads0_np_indices),
+ tf.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = tf.IndexedSlices(tf.constant(grads1_np),
+ tf.constant(grads1_np_indices),
+ tf.constant([2]))
+ opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
+ momentum=0.5, epsilon=1e-5)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for t in range(1, 5):
+ update.run()
+
+ var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np,
+ grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.5, 1e-5)
+ var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np,
+ grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.5, 1e-5)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSparseWithoutMomentum(self):
+ for dtype in [tf.half, tf.float32]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = tf.IndexedSlices(tf.constant(grads0_np),
+ tf.constant(grads0_np_indices),
+ tf.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = tf.IndexedSlices(tf.constant(grads1_np),
+ tf.constant(grads1_np_indices),
+ tf.constant([2]))
+ opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
+ momentum=0.0, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ rms0_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([1.0, 1.0], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for t in range(1, 5):
+ update.run()
+
+ var0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(var0_np,
+ grads0_np, rms0_np, mom0_np, 2.0, 0.9, 0.0, 1.0)
+ var1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(var1_np,
+ grads1_np, rms1_np, mom1_np, 2.0, 0.9, 0.0, 1.0)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
def testWithoutMomentum(self):
for dtype in [tf.half, tf.float32]:
with self.test_session():
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py
index 8619752338..1a96f77c1c 100644
--- a/tensorflow/python/training/training_ops.py
+++ b/tensorflow/python/training/training_ops.py
@@ -170,6 +170,23 @@ def _SparseApplyProximalGradientDescentShape(op):
return [var_shape]
+@ops.RegisterShape("SparseApplyRMSProp")
+def _SparseApplyRMSPropShape(op):
+ """Shape function for the SparseApplyRMSProp op."""
+ var_shape = op.inputs[0].get_shape()
+ ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
+ _AssertInputIsScalar(op, 3) # lr
+ _AssertInputIsScalar(op, 4) # rho
+ _AssertInputIsScalar(op, 5) # momentum
+ _AssertInputIsScalar(op, 6) # epsilon
+ grad_shape = op.inputs[7].get_shape().merge_with(
+ tensor_shape.TensorShape([None]).concatenate(mom_shape[1:]))
+ unused_indices_shape = op.inputs[8].get_shape().merge_with(
+ tensor_shape.vector(grad_shape[0]))
+ return [mom_shape]
+
+
@ops.RegisterShape("SparseApplyAdadelta")
def _SparseApplyAdadeltaShape(op):
"""Shape function for the SparseApplyAdadelta op."""
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 224803fc84..a9dd2953e5 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -25,6 +25,12 @@ limitations under the License.
#define EIGEN_HAS_CUDA_FP16
#endif
+#if CUDA_VERSION >= 8000
+#define SE_CUDA_DATA_HALF CUDA_R_16F
+#else
+#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF
+#endif
+
#include "tensorflow/stream_executor/cuda/cuda_blas.h"
#include <dlfcn.h>
@@ -1680,10 +1686,10 @@ bool CUDABlas::DoBlasGemm(
return DoBlasInternal(
dynload::cublasSgemmEx, stream, true /* = pointer_mode_host */,
CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
- CUDAMemory(a), CUBLAS_DATA_HALF, lda,
- CUDAMemory(b), CUBLAS_DATA_HALF, ldb,
+ CUDAMemory(a), SE_CUDA_DATA_HALF, lda,
+ CUDAMemory(b), SE_CUDA_DATA_HALF, ldb,
&beta,
- CUDAMemoryMutable(c), CUBLAS_DATA_HALF, ldc);
+ CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc);
#else
LOG(ERROR) << "fp16 sgemm is not implemented in this cuBLAS version "
<< "(need at least CUDA 7.5)";
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
index 76f38814f2..299e536393 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
@@ -267,8 +267,11 @@ export class Minimap {
downloadContext.drawImage(image, 0, 0,
this.downloadCanvas.width, this.downloadCanvas.height);
};
- let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'});
- image.src = URL.createObjectURL(blob);
+ image.onerror = () => {
+ let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'});
+ image.src = URL.createObjectURL(blob);
+ }
+ image.src = 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svgXml);
}
/**
diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html
index 014061a807..25048790ae 100644
--- a/tensorflow/tensorboard/dist/tf-tensorboard.html
+++ b/tensorflow/tensorboard/dist/tf-tensorboard.html
@@ -9987,8 +9987,11 @@ var tf;
downloadContext.clearRect(0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height);
downloadContext.drawImage(image, 0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height);
};
- var blob = new Blob([svgXml], { type: 'image/svg+xml;charset=utf-8' });
- image.src = URL.createObjectURL(blob);
+ image.onerror = function() {
+ var blob = new Blob([svgXml], {type: "image/svg+xml;charset=utf-8"});
+ image.src = URL.createObjectURL(blob);
+ };
+ image.src = "data:image/svg+xml;charset=utf-8," + encodeURIComponent(svgXml);
};
/**
* Handles changes in zooming/panning. Should be called from the main svg
diff --git a/tensorflow/tools/docker/docker_run_gpu.sh b/tensorflow/tools/docker/docker_run_gpu.sh
index 2fd885802b..08f391ddf9 100755
--- a/tensorflow/tools/docker/docker_run_gpu.sh
+++ b/tensorflow/tools/docker/docker_run_gpu.sh
@@ -14,16 +14,8 @@
# limitations under the License.
# ==============================================================================
-
set -e
-export CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
-
-if [ ! -d ${CUDA_HOME}/lib64 ]; then
- echo "Failed to locate CUDA libs at ${CUDA_HOME}/lib64."
- exit 1
-fi
-
export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda.* | \
xargs -I{} echo '-v {}:{}')
export DEVICES=$(\ls /dev/nvidia* | \