aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-05-26 11:05:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-26 12:18:27 -0700
commit8cc567bf9703e14f0e16595eb9f220859a5eff20 (patch)
tree13694c01d607a2d2d2bad57ae533894214eee7f5
parent9a69f398e9958d975db0c651e9ec95762b2ff8b4 (diff)
Merge changes from github.
Change: 123342870
-rw-r--r--WORKSPACE2
-rw-r--r--tensorflow/contrib/util/BUILD4
-rw-r--r--tensorflow/core/BUILD4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD4
-rw-r--r--tensorflow/core/kernels/BUILD12
-rw-r--r--tensorflow/core/kernels/matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/transpose_functor_gpu.cu.cc4
-rw-r--r--tensorflow/core/ops/compat/ops_history.v0.pbtxt88
-rw-r--r--tensorflow/core/ops/math_ops.cc34
-rw-r--r--tensorflow/examples/image_retraining/retrain.py6
-rw-r--r--tensorflow/g3doc/api_docs/python/constant_op.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md6
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md2
-rwxr-xr-xtensorflow/g3doc/tutorials/pdes/index.md10
-rw-r--r--tensorflow/models/rnn/translate/seq2seq_model.py14
-rw-r--r--tensorflow/python/framework/random_seed.py10
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/random_ops_test.py7
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py11
-rw-r--r--tensorflow/python/ops/seq2seq.py15
-rw-r--r--tensorflow/python/training/adagrad_test.py45
-rw-r--r--tensorflow/python/training/adam_test.py41
-rw-r--r--tensorflow/python/training/gradient_descent_test.py45
-rw-r--r--tensorflow/python/training/learning_rate_decay.py60
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py35
-rw-r--r--tensorflow/python/training/momentum_test.py41
-rw-r--r--tensorflow/python/training/optimizer.py2
-rw-r--r--tensorflow/python/training/optimizer_test.py10
-rw-r--r--tensorflow/stream_executor/dso_loader.cc4
-rw-r--r--tensorflow/tensorflow.bzl21
-rwxr-xr-xtensorflow/tools/docs/gen_docs_test.sh7
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh37
-rwxr-xr-xutil/python/python_config.sh15
33 files changed, 282 insertions, 323 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 1156a45a39..ffebcde554 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,3 +1,5 @@
+workspace(name = "org_tensorflow")
+
# Uncomment and update the paths in these entries to build the Android demo.
#android_sdk_repository(
# name = "androidsdk",
diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD
index 80495c9b8a..7683cda797 100644
--- a/tensorflow/contrib/util/BUILD
+++ b/tensorflow/contrib/util/BUILD
@@ -36,6 +36,10 @@ cc_binary(
cc_test(
name = "convert_graphdef_memmapped_format_test",
srcs = ["convert_graphdef_memmapped_format_test.cc"],
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
deps = [
":convert_graphdef_memmapped_format_lib",
"//tensorflow/cc:cc_ops",
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index aa226886a8..81e2d08536 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1231,6 +1231,10 @@ cc_test(
# higher level tests
tf_cc_tests(
size = "small",
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
linkstatic = tf_kernel_tests_linkstatic(),
tests = glob(
[
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index f9201eeacf..8b7f08bc8a 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -324,6 +324,10 @@ cc_library(
tf_cuda_cc_tests(
size = "small",
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [],
tests = [
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index cb3b0a536d..06b7486997 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -473,6 +473,10 @@ tf_cc_test(
tf_cc_test(
name = "slice_op_test",
size = "small",
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
deps = [
":ops_testutil",
":ops_util",
@@ -768,6 +772,10 @@ tf_cc_tests(
)
tf_cc_tests(
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
tests = [
"adjust_contrast_op_test",
"colorspace_op_test",
@@ -1058,6 +1066,10 @@ tf_cuda_cc_test(
tf_cuda_cc_test(
name = "reduction_ops_test",
size = "small",
+ linkopts = select({
+ "//tensorflow:darwin": ["-headerpad_max_install_names"],
+ "//conditions:default": [],
+ }),
deps = [
":ops_testutil",
":ops_util",
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index f6420aaad8..ce57f61d43 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -210,7 +210,7 @@ REGISTER_CPU(complex64);
REGISTER_CPU(complex128);
#if GOOGLE_CUDA
REGISTER_GPU(float);
-// REGISTER_GPU(double);
+REGISTER_GPU(double);
#if CUDA_VERSION >= 7050
REGISTER_GPU(Eigen::half);
#endif
diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
index 3febba0441..8e6cbc4270 100644
--- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
@@ -130,6 +130,10 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
internal::Transpose<Device, uint64>(d, in, perm, out);
break;
+ case DT_COMPLEX128:
+ internal::Transpose<Device, float4>(d, in, perm, out);
+ break;
+
default:
return errors::Unimplemented("Unsupported dtype on GPU: ", in.dtype());
}
diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
index 0a15d861c2..f44b7ea05e 100644
--- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
@@ -6366,9 +6366,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -6391,9 +6389,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -6416,8 +6412,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@@ -6993,9 +6987,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -7018,9 +7009,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -7043,10 +7031,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
}
}
}
@@ -7762,9 +7746,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -7787,9 +7768,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -7812,10 +7790,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
}
}
}
@@ -7837,9 +7811,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -7862,9 +7833,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -7887,10 +7855,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
}
}
}
@@ -7927,9 +7891,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -7952,9 +7914,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -7977,8 +7937,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@@ -10263,9 +10221,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -10288,9 +10243,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -10313,10 +10265,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_COMPLEX64
- type: DT_COMPLEX128
}
}
}
@@ -10390,9 +10338,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -10415,9 +10361,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -10440,8 +10384,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@@ -16137,9 +16079,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -16162,9 +16102,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -16187,8 +16125,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@@ -17916,9 +17852,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -17941,9 +17875,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -17966,8 +17898,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@@ -18089,9 +18019,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -18114,9 +18042,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -18139,8 +18065,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@@ -20897,9 +20821,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -20922,9 +20844,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -20947,8 +20867,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
@@ -21696,9 +21614,7 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -21721,9 +21637,7 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
type: DT_COMPLEX64
- type: DT_INT64
}
}
}
@@ -21746,8 +21660,6 @@ op {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
}
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index fdb490df9e..15039911d3 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -136,6 +136,14 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229]
Input("x: T").Output("y: T").Attr( \
"T: {half, float, double, int32, int64, complex64, complex128}")
+#define UNARY_REAL() \
+ Input("x: T").Output("y: T").Attr( \
+ "T: {half, float, double}")
+
+#define UNARY_COMPLEX() \
+ Input("x: T").Output("y: T").Attr( \
+ "T: {half, float, double, complex64, complex128}")
+
REGISTER_OP("Neg")
.UNARY()
.Doc(R"doc(
@@ -158,65 +166,65 @@ I.e., \\(y = x * x = x^2\\).
)doc");
REGISTER_OP("Sqrt")
- .UNARY()
+ .UNARY_COMPLEX()
.Doc(R"doc(
Computes square root of x element-wise.
I.e., \\(y = \sqrt{x} = x^{1/2}\\).
)doc");
REGISTER_OP("Rsqrt")
- .UNARY()
+ .UNARY_COMPLEX()
.Doc(R"doc(
Computes reciprocal of square root of x element-wise.
I.e., \\(y = 1 / \sqrt{x}\\).
)doc");
REGISTER_OP("Exp")
- .UNARY()
+ .UNARY_COMPLEX()
.Doc(R"doc(
Computes exponential of x element-wise. \\(y = e^x\\).
)doc");
REGISTER_OP("Log")
- .UNARY()
+ .UNARY_COMPLEX()
.Doc(R"doc(
Computes natural logarithm of x element-wise.
I.e., \\(y = \log_e x\\).
)doc");
REGISTER_OP("Tanh")
- .UNARY()
+ .UNARY_COMPLEX()
.Doc(R"doc(
Computes hyperbolic tangent of `x` element-wise.
)doc");
REGISTER_OP("Lgamma")
- .UNARY()
+ .UNARY_REAL()
.Doc(R"doc(
Computes the log of the absolute value of `Gamma(x)` element-wise.
)doc");
REGISTER_OP("Digamma")
- .UNARY()
+ .UNARY_REAL()
.Doc(R"doc(
Computes Psi, the derivative of Lgamma (the log of the absolute value of
`Gamma(x)`), element-wise.
)doc");
REGISTER_OP("Erf")
- .UNARY()
+ .UNARY_REAL()
.Doc(R"doc(
Computes the Gauss error function of `x` element-wise.
)doc");
REGISTER_OP("Erfc")
- .UNARY()
+ .UNARY_REAL()
.Doc(R"doc(
Computes the complementary error function of `x` element-wise.
)doc");
REGISTER_OP("Sigmoid")
- .UNARY()
+ .UNARY_COMPLEX()
.Doc(R"doc(
Computes sigmoid of `x` element-wise.
@@ -224,18 +232,20 @@ Specifically, `y = 1 / (1 + exp(-x))`.
)doc");
REGISTER_OP("Sin")
- .UNARY()
+ .UNARY_COMPLEX()
.Doc(R"doc(
Computes sin of x element-wise.
)doc");
REGISTER_OP("Cos")
- .UNARY()
+ .UNARY_COMPLEX()
.Doc(R"doc(
Computes cos of x element-wise.
)doc");
#undef UNARY
+#undef UNARY_REAL
+#undef UNARY_COMPLEX
REGISTER_OP("IsNan")
.Input("x: T")
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 9ed8dbf13e..8f5f9c635c 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -521,13 +521,13 @@ def get_random_distorted_bottlenecks(
ground_truths = []
for unused_i in range(how_many):
label_index = random.randrange(class_count)
- label_name = image_lists.keys()[label_index]
+ label_name = list(image_lists.keys())[label_index]
image_index = random.randrange(65536)
image_path = get_image_path(image_lists, label_name, image_index, image_dir,
category)
if not gfile.Exists(image_path):
tf.logging.fatal('File does not exist %s', image_path)
- jpeg_data = gfile.FastGFile(image_path, 'r').read()
+ jpeg_data = gfile.FastGFile(image_path, 'rb').read()
# Note that we materialize the distorted_image_data as a numpy array before
# sending running inference on the image. This involves 2 memory copies and
# might be optimized in other implementations.
@@ -616,7 +616,7 @@ def add_input_distortions(flip_left_right, random_crop, random_scale,
"""
jpeg_data = tf.placeholder(tf.string, name='DistortJPGInput')
- decoded_image = tf.image.decode_jpeg(jpeg_data)
+ decoded_image = tf.image.decode_jpeg(jpeg_data, channels=MODEL_INPUT_DEPTH)
decoded_image_as_float = tf.cast(decoded_image, dtype=tf.float32)
decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0)
margin_scale = 1.0 + (random_crop / 100.0)
diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md
index 008174f9d6..1aaf39bd50 100644
--- a/tensorflow/g3doc/api_docs/python/constant_op.md
+++ b/tensorflow/g3doc/api_docs/python/constant_op.md
@@ -60,7 +60,7 @@ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
* <b>`tensor`</b>: A `Tensor`.
* <b>`dtype`</b>: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64`, or `complex128`.
* <b>`name`</b>: A name for the operation (optional).
@@ -119,7 +119,7 @@ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
* <b>`tensor`</b>: A `Tensor`.
* <b>`dtype`</b>: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `int16`, `int32`, `int64`, `uint8`, or `complex64`.
+ `int8`, `int16`, `int32`, `int64`, `uint8`, `complex64` or `complex128`.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index e92d5b9410..9405ed74d1 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -163,7 +163,7 @@ case where both types are quantized.
* <b>`value`</b>: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
- `int16`, `int8`, or `complex64`.
+ `int16`, `int8`, `complex64` or `complex128`.
* <b>`bias`</b>: A 1-D `Tensor` with size matching the last dimension of `value`.
Must be the same type as `value` unless `value` is a quantized type,
in which case a different quantized type may be used.
@@ -186,7 +186,7 @@ Specifically, `y = 1 / (1 + exp(-x))`.
##### Args:
-* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`,
or `qint32`.
* <b>`name`</b>: A name for the operation (optional).
@@ -205,7 +205,7 @@ Computes hyperbolic tangent of `x` element-wise.
##### Args:
-* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
+* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `complex128`, `int64`,
or `qint32`.
* <b>`name`</b>: A name for the operation (optional).
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index bdb94cda10..0d7b851bac 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -658,7 +658,7 @@ bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_packag
mkdir _python_build
cd _python_build
-ln -s ../bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/* .
+ln -s ../bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow* .
ln -s ../tensorflow/tools/pip_package/* .
python setup.py develop
```
diff --git a/tensorflow/g3doc/tutorials/pdes/index.md b/tensorflow/g3doc/tutorials/pdes/index.md
index ca24034759..9bc4340285 100755
--- a/tensorflow/g3doc/tutorials/pdes/index.md
+++ b/tensorflow/g3doc/tutorials/pdes/index.md
@@ -19,7 +19,7 @@ import numpy as np
#Imports for visualization
import PIL.Image
-from cStringIO import StringIO
+from io import BytesIO
from IPython.display import clear_output, Image, display
```
@@ -30,8 +30,9 @@ def DisplayArray(a, fmt='jpeg', rng=[0,1]):
"""Display an array as a picture."""
a = (a - rng[0])/float(rng[1] - rng[0])*255
a = np.uint8(np.clip(a, 0, 255))
- f = StringIO()
+ f = BytesIO()
PIL.Image.fromarray(a).save(f, fmt)
+ clear_output(wait = True)
display(Image(data=f.getvalue()))
```
@@ -132,10 +133,7 @@ tf.initialize_all_variables().run()
for i in range(1000):
# Step simulation
step.run({eps: 0.03, damping: 0.04})
- # Visualize every 50 steps
- if i % 50 == 0:
- clear_output()
- DisplayArray(U.eval(), rng=[-0.1, 0.1])
+ DisplayArray(U.eval(), rng=[-0.1, 0.1])
```
![jpeg](../../images/pde_output_2.jpg)
diff --git a/tensorflow/models/rnn/translate/seq2seq_model.py b/tensorflow/models/rnn/translate/seq2seq_model.py
index b0d8ff43db..a921f28c06 100644
--- a/tensorflow/models/rnn/translate/seq2seq_model.py
+++ b/tensorflow/models/rnn/translate/seq2seq_model.py
@@ -83,17 +83,15 @@ class Seq2SeqModel(object):
softmax_loss_function = None
# Sampled softmax only makes sense if we sample less than vocabulary size.
if num_samples > 0 and num_samples < self.target_vocab_size:
- with tf.device("/cpu:0"):
- w = tf.get_variable("proj_w", [size, self.target_vocab_size])
- w_t = tf.transpose(w)
- b = tf.get_variable("proj_b", [self.target_vocab_size])
+ w = tf.get_variable("proj_w", [size, self.target_vocab_size])
+ w_t = tf.transpose(w)
+ b = tf.get_variable("proj_b", [self.target_vocab_size])
output_projection = (w, b)
def sampled_loss(inputs, labels):
- with tf.device("/cpu:0"):
- labels = tf.reshape(labels, [-1, 1])
- return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
- self.target_vocab_size)
+ labels = tf.reshape(labels, [-1, 1])
+ return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
+ self.target_vocab_size)
softmax_loss_function = sampled_loss
# Create the internal multi-layer cell for our RNN.
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index b70f626a9e..9f503b8f29 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -24,6 +24,10 @@ from tensorflow.python.framework import ops
_DEFAULT_GRAPH_SEED = 87654321
+_MAXINT32 = 2**31 - 1
+
+def _truncate_seed(seed):
+ return seed % _MAXINT32 # truncate to fit into 32-bit integer
def get_seed(op_seed):
@@ -47,12 +51,12 @@ def get_seed(op_seed):
graph_seed = ops.get_default_graph().seed
if graph_seed is not None:
if op_seed is not None:
- return graph_seed, op_seed
+ return _truncate_seed(graph_seed), _truncate_seed(op_seed)
else:
- return graph_seed, ops.get_default_graph()._last_id
+ return _truncate_seed(graph_seed), _truncate_seed(ops.get_default_graph()._last_id)
else:
if op_seed is not None:
- return _DEFAULT_GRAPH_SEED, op_seed
+ return _truncate_seed(_DEFAULT_GRAPH_SEED), _truncate_seed(op_seed)
else:
return None, None
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index 25553097a6..6c817a5da8 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -95,6 +95,7 @@ class MatMulTest(tf.test.TestCase):
x = np.arange(1., 5.).reshape([4, 1]).astype(np.float64)
y = np.arange(1., 3.).reshape([1, 2]).astype(np.float64)
self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
def testHalfBasic(self):
x = np.arange(1., 5.).reshape([4, 1]).astype(np.float16)
@@ -135,6 +136,7 @@ class MatMulTest(tf.test.TestCase):
x = self._randMatrix(n, k, np.float64)
y = self._randMatrix(k, m, np.float64)
self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
def testHalfRandom(self):
for _ in range(10):
@@ -185,6 +187,7 @@ class MatMulTest(tf.test.TestCase):
x = self._randMatrix(k, n, np.float64)
y = self._randMatrix(m, k, np.float64)
self._testCpuMatmul(x, y, True, True)
+ self._testGpuMatmul(x, y, True, True)
def testHalfRandomTransposeBoth(self):
for _ in range(10):
diff --git a/tensorflow/python/kernel_tests/random_ops_test.py b/tensorflow/python/kernel_tests/random_ops_test.py
index 45b61be0c3..f4ed26b1e2 100644
--- a/tensorflow/python/kernel_tests/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random_ops_test.py
@@ -237,9 +237,10 @@ class RandomUniformTest(tf.test.TestCase):
def testSeed(self):
for use_gpu in False, True:
for dt in tf.float16, tf.float32, tf.float64, tf.int32, tf.int64:
- sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
- sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=345)
- self.assertAllEqual(sx(), sy())
+ for seed in [345, 2**100, -2**100]:
+ sx = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
+ sy = self._Sampler(1000, 0, 17, dtype=dt, use_gpu=use_gpu, seed=seed)
+ self.assertAllEqual(sx(), sy())
def testNoCSE(self):
shape = [2, 3, 4]
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 7881fe152d..e4ce40303e 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -131,7 +131,7 @@ class TransposeTest(tf.test.TestCase):
self._compare_cpu_gpu(
np.arange(0, 16).reshape([1, 2, 1, 2, 1, 2, 1, 2]).astype(np.float64))
- def testSComplex(self):
+ def testComplex64(self):
self._testBoth(np.complex(1, 2) *
np.arange(0, 21).reshape([3, 7]).astype(np.complex64))
self._testBoth(np.complex(1, 2) *
@@ -140,6 +140,15 @@ class TransposeTest(tf.test.TestCase):
np.complex(1, 2) *
np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(np.complex64))
+ def testComplex128(self):
+ self._testBoth(np.complex(1, 2) *
+ np.arange(0, 21).reshape([3, 7]).astype(np.complex128))
+ self._testBoth(np.complex(1, 2) *
+ np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.complex128))
+ self._testBoth(
+ np.complex(1, 2) *
+ np.arange(0, 1260).reshape([2, 3, 5, 7, 2, 3]).astype(np.complex128))
+
def testInt8(self):
self._testBoth(np.arange(0, 21).reshape([3, 7]).astype(np.int8))
self._testBoth(np.arange(0, 210).reshape([2, 3, 5, 7]).astype(np.int8))
diff --git a/tensorflow/python/ops/seq2seq.py b/tensorflow/python/ops/seq2seq.py
index e920c95dec..cb1773c7f1 100644
--- a/tensorflow/python/ops/seq2seq.py
+++ b/tensorflow/python/ops/seq2seq.py
@@ -260,9 +260,8 @@ def embedding_rnn_decoder(decoder_inputs, initial_state, cell, num_symbols,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_rnn_decoder"):
- with ops.device("/cpu:0"):
- embedding = variable_scope.get_variable("embedding",
- [num_symbols, embedding_size])
+ embedding = variable_scope.get_variable("embedding",
+ [num_symbols, embedding_size])
loop_function = _extract_argmax_and_embed(
embedding, output_projection,
update_embedding_for_previous) if feed_previous else None
@@ -398,9 +397,8 @@ def embedding_tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_tied_rnn_seq2seq"):
- with ops.device("/cpu:0"):
- embedding = variable_scope.get_variable("embedding",
- [num_symbols, embedding_size])
+ embedding = variable_scope.get_variable("embedding",
+ [num_symbols, embedding_size])
emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x)
for x in encoder_inputs]
@@ -636,9 +634,8 @@ def embedding_attention_decoder(decoder_inputs, initial_state, attention_states,
proj_biases.get_shape().assert_is_compatible_with([num_symbols])
with variable_scope.variable_scope(scope or "embedding_attention_decoder"):
- with ops.device("/cpu:0"):
- embedding = variable_scope.get_variable("embedding",
- [num_symbols, embedding_size])
+ embedding = variable_scope.get_variable("embedding",
+ [num_symbols, embedding_size])
loop_function = _extract_argmax_and_embed(
embedding, output_projection,
update_embedding_for_previous) if feed_previous else None
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
index c5e118bdc7..4125a7aa3c 100644
--- a/tensorflow/python/training/adagrad_test.py
+++ b/tensorflow/python/training/adagrad_test.py
@@ -25,7 +25,7 @@ import tensorflow as tf
class AdagradOptimizerTest(tf.test.TestCase):
def doTestBasic(self, use_locking=False):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -56,7 +56,7 @@ class AdagradOptimizerTest(tf.test.TestCase):
self.doTestBasic(use_locking=True)
def testTensorLearningRate(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -80,43 +80,8 @@ class AdagradOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
- def testFloat64(self):
- with self.test_session():
- opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
-
- # compute_gradients.
- values = [1.0, 3.0]
- good_vars = [tf.Variable([v]) for v in values]
- bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
- opt.compute_gradients, bad_loss, good_vars)
- bad_vars = [
- tf.Variable(np.array([v], np.float64), name="bad_var")
- for v in values
- ]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
- opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
- bad_vars)
- opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
-
- # apply_gradients.
- bad_grads = [
- tf.constant([0.1], dtype=np.float64, name="bad_grad"),
- tf.constant([0.01])
- ]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
- opt.apply_gradients, zip(bad_grads, good_vars))
- good_grads = [tf.constant([0.01]), tf.constant([0.02])]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
- opt.apply_gradients, zip(good_grads, bad_vars))
- opt.apply_gradients(zip(good_grads, good_vars))
-
def testSparseBasic(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([[1.0], [2.0]], dtype=dtype)
var1 = tf.Variable([[3.0], [4.0]], dtype=dtype)
@@ -145,7 +110,7 @@ class AdagradOptimizerTest(tf.test.TestCase):
np.array([[3.0], [3.715679168701172]]), var1.eval())
def testSparseStability(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
shape = [1, 6]
var0 = tf.Variable(
@@ -175,7 +140,7 @@ class AdagradOptimizerTest(tf.test.TestCase):
0.0144573, -0.01029443]]), var0.eval())
def testSharing(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index ce6ec99f84..df97b1be50 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -36,7 +36,7 @@ def adam_update_numpy(param, g_t, t, m, v, alpha=0.001, beta1=0.9, beta2=0.999,
class AdamOptimizerTest(tf.test.TestCase):
def testSparse(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@@ -79,7 +79,7 @@ class AdamOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(var1_np, var1.eval())
def testBasic(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@@ -116,7 +116,7 @@ class AdamOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(var1_np, var1.eval())
def testTensorLearningRate(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
@@ -152,41 +152,8 @@ class AdamOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
- def testFloat64(self):
- with self.test_session():
- opt = tf.train.AdamOptimizer()
-
- # compute_gradients.
- values = [1.0, 3.0]
- good_vars = [tf.Variable([v]) for v in values]
- bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
- opt.compute_gradients, bad_loss, good_vars)
- bad_vars = [
- tf.Variable(np.array([v], np.float64), name="bad_var")
- for v in values]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
- opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
- bad_vars)
- opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
-
- # apply_gradients.
- bad_grads = [
- tf.constant([0.1], dtype=np.float64, name="bad_grad"),
- tf.constant([0.01])]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
- opt.apply_gradients, zip(bad_grads, good_vars))
- good_grads = [tf.constant([0.01]), tf.constant([0.02])]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
- opt.apply_gradients, zip(good_grads, bad_vars))
- opt.apply_gradients(zip(good_grads, good_vars))
-
def testSharing(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index 5161963032..26e1ae7dee 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -25,7 +25,7 @@ import tensorflow as tf
class GradientDescentOptimizerTest(tf.test.TestCase):
def testBasic(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -46,7 +46,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase):
[3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
def testTensorLearningRate(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -67,43 +67,8 @@ class GradientDescentOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(
[3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
- def testFloat64(self):
- with self.test_session():
- opt = tf.train.GradientDescentOptimizer(3.0)
-
- # compute_gradients.
- values = [1.0, 3.0]
- good_vars = [tf.Variable([v]) for v in values]
- bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
- opt.compute_gradients, bad_loss, good_vars)
- bad_vars = [
- tf.Variable(np.array([v], np.float64), name="bad_var")
- for v in values
- ]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
- opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
- bad_vars)
- opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
-
- # apply_gradients.
- bad_grads = [
- tf.constant([0.1], dtype=np.float64, name="bad_grad"),
- tf.constant([0.01])
- ]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
- opt.apply_gradients, zip(bad_grads, good_vars))
- good_grads = [tf.constant([0.01]), tf.constant([0.02])]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
- opt.apply_gradients, zip(good_grads, bad_vars))
- opt.apply_gradients(zip(good_grads, good_vars))
-
def testGradWrtRef(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
opt = tf.train.GradientDescentOptimizer(3.0)
values = [1.0, 3.0]
@@ -114,7 +79,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType([1.0], grad.eval())
def testWithGlobalStep(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
global_step = tf.Variable(0, trainable=False)
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
@@ -138,7 +103,7 @@ class GradientDescentOptimizerTest(tf.test.TestCase):
self.assertAllCloseAccordingToType(1, global_step.eval())
def testSparseBasic(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([[1.0], [2.0]], dtype=dtype)
var1 = tf.Variable([[3.0], [4.0]], dtype=dtype)
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index ab48d34782..203399eab3 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import control_flow_ops
def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
@@ -84,3 +85,62 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
if staircase:
p = math_ops.floor(p)
return math_ops.mul(learning_rate, math_ops.pow(decay_rate, p), name=name)
+
+
+def piecewise_constant(x, boundaries, values, name=None):
+ """ Piecewise constant from boundaries and interval values.
+
+ Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5
+ for steps 100001 to 110000, and 0.1 for any additional steps.
+
+ ```python
+ global_step = tf.Variable(0, trainable=False)
+ boundaries = [100000, 110000]
+ values = [1.0, 0.5, 0.1]
+ learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
+
+ # Later, whenever we perform an optimization step, we increment global_step.
+ ```
+
+ Args:
+ x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
+ `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
+ boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
+ increasing entries, and with all elements having the same type as `x`.
+ values: A list of `Tensor`s or float`s or `int`s that specifies the values
+ for the intervals defined by `boundaries`. It should have one more element
+ than `boundaries`, and all elements should have the same type.
+ name: A string. Optional name of the operation. Defaults to
+ 'PiecewiseConstant'.
+
+ Returns:
+ A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`,
+ `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
+ and values[-1] when `x > boundaries[-1]`.
+ """
+
+ with ops.op_scope([x, boundaries, values, name],
+ name, 'PiecewiseConstant') as name:
+ x = ops.convert_to_tensor(x)
+ # Avoid explicit conversion to x's dtype. This could result in faulty
+ # comparisons, for example if floats are converted to integers.
+ boundaries = ops.convert_n_to_tensor(boundaries)
+ if not all(b.dtype == x.dtype for b in boundaries):
+ raise ValueError('boundaries must have the same dtype as x.')
+ # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
+ values = ops.convert_n_to_tensor(values)
+ if not all(v.dtype == values[0].dtype for v in values):
+ raise ValueError('values must have elements all with the same dtype.')
+
+ pred_fn_pairs = {}
+ pred_fn_pairs[x <= boundaries[0]] = lambda: values[0]
+ pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1]
+ for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
+ # Need to bind v here; can do this with lambda v=v: ...
+ pred = (x > low) & (x <= high)
+ pred_fn_pairs[pred] = lambda v=v: v
+
+ # The default isn't needed here because our conditions are mutually
+ # exclusive and exhaustive, but tf.case requires it.
+ default = lambda: values[0]
+ return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 21ea03826c..6fabd58fe3 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -72,6 +72,41 @@ class LRDecayTest(test_util.TensorFlowTestCase):
expected = .1 * 0.96**(100 // 3)
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ def testPiecewiseConstant(self):
+ with self.test_session():
+ x = variables.Variable(-999)
+ assign_100 = x.assign(100)
+ assign_105 = x.assign(105)
+ assign_110 = x.assign(110)
+ assign_120 = x.assign(120)
+ assign_999 = x.assign(999)
+ pc = learning_rate_decay.piecewise_constant(x, [100, 110, 120],
+ [1.0, 0.1, 0.01, 0.001])
+
+ variables.initialize_all_variables().run()
+ self.assertAllClose(pc.eval(), 1.0, 1e-6)
+ assign_100.op.run()
+ self.assertAllClose(pc.eval(), 1.0, 1e-6)
+ assign_105.op.run()
+ self.assertAllClose(pc.eval(), 0.1, 1e-6)
+ assign_110.op.run()
+ self.assertAllClose(pc.eval(), 0.1, 1e-6)
+ assign_120.op.run()
+ self.assertAllClose(pc.eval(), 0.01, 1e-6)
+ assign_999.op.run()
+ self.assertAllClose(pc.eval(), 0.001, 1e-6)
+
+ def testPiecewiseConstantEdgeCases(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ x_int = variables.Variable(0, dtype=variables.dtypes.int32)
+ boundaries, values = [-1.0, 1.0], [1, 2, 3]
+ pc = learning_rate_decay.piecewise_constant(x_int, boundaries, values)
+ with self.assertRaises(ValueError):
+ x = variables.Variable(0.0)
+ boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
+ pc = learning_rate_decay.piecewise_constant(x, boundaries, values)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 72e0eed4c4..88468f56a8 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -26,7 +26,7 @@ import tensorflow as tf
class MomentumOptimizerTest(tf.test.TestCase):
def testBasic(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -81,7 +81,7 @@ class MomentumOptimizerTest(tf.test.TestCase):
var1.eval())
def testTensorLearningRateAndMomentum(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -136,39 +136,6 @@ class MomentumOptimizerTest(tf.test.TestCase):
3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
var1.eval())
- def testFloat64(self):
- with self.test_session():
- opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
-
- # compute_gradients.
- values = [1.0, 3.0]
- good_vars = [tf.Variable([v]) for v in values]
- bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
- opt.compute_gradients, bad_loss, good_vars)
- bad_vars = [
- tf.Variable(np.array([v], np.float64), name="bad_var")
- for v in values]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
- opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
- bad_vars)
- opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
-
- # apply_gradients.
- bad_grads = [
- tf.constant([0.1], dtype=np.float64, name="bad_grad"),
- tf.constant([0.01])]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
- opt.apply_gradients, zip(bad_grads, good_vars))
- good_grads = [tf.constant([0.01]), tf.constant([0.02])]
- self.assertRaisesRegexp(
- ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
- opt.apply_gradients, zip(good_grads, bad_vars))
- opt.apply_gradients(zip(good_grads, good_vars))
-
def _dbParamsMom01(self):
"""Return dist-belief momentum values.
@@ -222,7 +189,7 @@ class MomentumOptimizerTest(tf.test.TestCase):
self.assertAllClose(np.array(db_out[i]), var0.eval())
def testSparse(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable(tf.zeros([4, 2], dtype=dtype))
var1 = tf.Variable(tf.constant(1.0, dtype, [4, 2]))
@@ -290,7 +257,7 @@ class MomentumOptimizerTest(tf.test.TestCase):
var1.eval()[2])
def testSharing(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index a8b41be923..623c76e18b 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -376,7 +376,7 @@ class Optimizer(object):
Returns:
Valid types for loss, variables and gradients.
"""
- return set([dtypes.float16, dtypes.float32])
+ return set([dtypes.float16, dtypes.float32, dtypes.float64])
def _create_slots(self, var_list):
"""Create all slots needed by the variables.
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
index 54d400a51c..f87a207c60 100644
--- a/tensorflow/python/training/optimizer_test.py
+++ b/tensorflow/python/training/optimizer_test.py
@@ -23,7 +23,7 @@ import tensorflow as tf
class OptimizerTest(tf.test.TestCase):
def testBasic(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -43,7 +43,7 @@ class OptimizerTest(tf.test.TestCase):
self.assertAllClose([-6., -5.], var1.eval())
def testAggregationMethod(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -67,7 +67,7 @@ class OptimizerTest(tf.test.TestCase):
self.assertAllClose([-6., -5.], var1.eval())
def testPrecomputedGradient(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
@@ -92,7 +92,7 @@ class OptimizerTest(tf.test.TestCase):
[3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)], var1.eval())
def testNoVariables(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype, trainable=False)
var1 = tf.Variable([3.0, 4.0], dtype=dtype, trainable=False)
@@ -102,7 +102,7 @@ class OptimizerTest(tf.test.TestCase):
sgd_op.minimize(cost)
def testNoGradients(self):
- for dtype in [tf.half, tf.float32]:
+ for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session():
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
index bf7faef209..0f7d0eaeb3 100644
--- a/tensorflow/stream_executor/dso_loader.cc
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -138,9 +138,9 @@ string GetCudnnVersion() { return ""; }
static std::vector<string>* CreatePrimordialRpaths() {
auto rpaths = new std::vector<string>;
#if defined(__APPLE__)
- rpaths->push_back("driver/driver_sh.runfiles/third_party/gpus/cuda/lib");
+ rpaths->push_back("driver/driver_sh.runfiles/org_tensorflow/third_party/gpus/cuda/lib");
#else
- rpaths->push_back("driver/driver_sh.runfiles/third_party/gpus/cuda/lib64");
+ rpaths->push_back("driver/driver_sh.runfiles/org_tensorflow/third_party/gpus/cuda/lib64");
#endif
return rpaths;
}
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 1e73e00a3f..4eb5619ecd 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -234,7 +234,7 @@ def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[],
# TODO(opensource): we need to enable this to work around the hidden symbol
# __cudaRegisterFatBinary error. Need more investigations.
def tf_cc_test(name, deps, linkstatic=0, tags=[], data=[], size="medium",
- suffix="", args=None):
+ suffix="", args=None, linkopts=[]):
name = name.replace(".cc", "")
native.cc_test(name="%s%s" % (name.replace("/", "_"), suffix),
size=size,
@@ -243,7 +243,7 @@ def tf_cc_test(name, deps, linkstatic=0, tags=[], data=[], size="medium",
copts=tf_copts(),
data=data,
deps=deps,
- linkopts=["-lpthread", "-lm"],
+ linkopts=["-lpthread", "-lm"] + linkopts,
linkstatic=linkstatic,
tags=tags,)
@@ -254,13 +254,15 @@ def tf_cc_test_gpu(name, deps, linkstatic=0, tags=[], data=[], size="medium",
tf_cc_test(name, deps, linkstatic=linkstatic, tags=tags, data=data,
size=size, suffix=suffix, args=args)
-def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium",linkstatic=0,args=[]):
+def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium", linkstatic=0,
+ args=[], linkopts=[]):
tf_cc_test(name=name,
deps=deps,
tags=tags + ["manual"],
data=data,
size=size,
linkstatic=linkstatic,
+ linkopts=linkopts,
args=args)
tf_cc_test(name=name,
suffix="_gpu",
@@ -269,21 +271,26 @@ def tf_cuda_cc_test(name, deps, tags=[], data=[], size="medium",linkstatic=0,arg
tags=tags + tf_cuda_tests_tags(),
data=data,
size=size,
+ linkopts=linkopts,
args=args)
# Create a cc_test for each of the tensorflow tests listed in "tests"
-def tf_cc_tests(tests, deps, linkstatic=0, tags=[], size="medium", args=None):
+def tf_cc_tests(tests, deps, linkstatic=0, tags=[], size="medium", args=None,
+ linkopts=[]):
for t in tests:
- tf_cc_test(t, deps, linkstatic, tags=tags, size=size, args=args)
+ tf_cc_test(t, deps, linkstatic, tags=tags, size=size, args=args,
+ linkopts=linkopts)
def tf_cc_tests_gpu(tests, deps, linkstatic=0, tags=[], size="medium", args=None):
tf_cc_tests(tests, deps, linkstatic, tags=tags, size=size, args=args)
-def tf_cuda_cc_tests(tests, deps, tags=[], size="medium", linkstatic=0, args=None):
+def tf_cuda_cc_tests(tests, deps, tags=[], size="medium", linkstatic=0,
+ args=None, linkopts=[]):
for t in tests:
- tf_cuda_cc_test(t, deps, tags=tags, size=size, linkstatic=linkstatic, args=args)
+ tf_cuda_cc_test(t, deps, tags=tags, size=size, linkstatic=linkstatic,
+ args=args, linkopts=linkopts)
def _cuda_copts():
"""Gets the appropriate set of copts for (maybe) CUDA compilation.
diff --git a/tensorflow/tools/docs/gen_docs_test.sh b/tensorflow/tools/docs/gen_docs_test.sh
index 9375784dc2..7023638322 100755
--- a/tensorflow/tools/docs/gen_docs_test.sh
+++ b/tensorflow/tools/docs/gen_docs_test.sh
@@ -16,7 +16,12 @@
set -eux
-TFDIR=$TEST_SRCDIR/tensorflow
+if [ -d $TEST_SRCDIR/org_tensorflow ]; then
+ TFDIR=$TEST_SRCDIR/org_tensorflow/tensorflow
+else
+ # Support 0.2.1- runfiles.
+ TFDIR=$TEST_SRCDIR/tensorflow
+fi
DOXYGEN=doxygen
DOXYGEN_CONFIG="tf-doxy_for_md-config"
TMP_DIR=/tmp/tensorflow-docs
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 1ae6926b67..7f123937e8 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -32,17 +32,38 @@ function main() {
echo "Could not find bazel-bin. Did you run from the root of the build tree?"
exit 1
fi
- cp -R \
- bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/{tensorflow,external} \
- ${TMPDIR}
+
+ if [ ! -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow ]; then
+ # Really old (0.2.1-) runfiles, without workspace name.
+ cp -R \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/{tensorflow,external} \
+ "${TMPDIR}"
+ RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles
+ else
+ if [ -d bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/external ]; then
+ # Old-style runfiles structure (--legacy_external_runfiles).
+ cp -R \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/{tensorflow,external} \
+ "${TMPDIR}"
+ else
+ # New-style runfiles structure (--nolegacy_external_runfiles).
+ cp -R \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow/tensorflow \
+ "${TMPDIR}"
+ mkdir "${TMPDIR}/external"
+ # Note: this makes an extra copy of org_tensorflow.
+ cp -R \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles \
+ "${TMPDIR}/external"
+ fi
+ RUNFILES=bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/org_tensorflow
+ fi
+
# protobuf pip package doesn't ship with header files. Copy the headers
# over so user defined ops can be compiled.
rsync --include "*/" --include "*.h" --exclude "*" --prune-empty-dirs -a \
- bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/google \
- ${TMPDIR}
- rsync -a \
- bazel-bin/tensorflow/tools/pip_package/build_pip_package.runfiles/third_party/eigen3 \
- ${TMPDIR}/third_party
+ $RUNFILES/google ${TMPDIR}
+ rsync -a $RUNFILES/third_party/eigen3 ${TMPDIR}/third_party
cp tensorflow/tools/pip_package/MANIFEST.in ${TMPDIR}
cp tensorflow/tools/pip_package/README ${TMPDIR}
diff --git a/util/python/python_config.sh b/util/python/python_config.sh
index 83e3856690..7554765003 100755
--- a/util/python/python_config.sh
+++ b/util/python/python_config.sh
@@ -16,11 +16,16 @@
set -e -o errexit
-# Prefix expected paths with ./ locally and external/reponame/ for remote repos.
-# TODO(kchodorow): remove once runfiles paths are fixed, see
-# https://github.com/bazelbuild/bazel/issues/848.
-script_path=$(dirname $(dirname $(dirname "$0")))
-script_path=${script_path:-.}
+if [ -d "../org_tensorflow" ]; then
+ script_path="../org_tensorflow"
+else
+ # Prefix expected paths with ./ locally and external/reponame/ for remote repos.
+ # TODO(kchodorow): remove once runfiles paths are fixed, see
+ # https://github.com/bazelbuild/bazel/issues/848.
+ script_path=$(dirname $(dirname $(dirname "$0")))
+ script_path=${script_path:-.}
+fi
+
EXPECTED_PATHS="$script_path/util/python/python_include"\
" $script_path/util/python/python_lib"\
" $script_path/third_party/py/numpy/numpy_include"