aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--README.md37
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/kernels/resize_area_op.cc7
-rw-r--r--tensorflow/core/kernels/resize_bicubic_op.cc7
-rw-r--r--tensorflow/core/kernels/resize_bilinear_op.cc7
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op.cc7
-rw-r--r--tensorflow/core/ops/image_ops.cc8
-rw-r--r--tensorflow/core/ops/ops.pbtxt8
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py22
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md2
-rw-r--r--tensorflow/g3doc/resources/dims_types.md10
-rw-r--r--tensorflow/g3doc/tutorials/deep_cnn/index.md2
-rw-r--r--tensorflow/python/framework/ops.py2
-rw-r--r--tensorflow/python/ops/image_ops_test.py139
-rw-r--r--tensorflow/python/ops/template.py2
-rw-r--r--tensorflow/python/training/coordinator.py8
-rw-r--r--tensorflow/stream_executor/BUILD1
17 files changed, 124 insertions, 146 deletions
diff --git a/README.md b/README.md
index 5f57cd9c9e..52704e7010 100644
--- a/README.md
+++ b/README.md
@@ -16,11 +16,8 @@ organization for the purposes of conducting machine learning and deep neural
networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
-
-**Note: Currently we do not accept pull requests on github -- see
-[CONTRIBUTING.md](CONTRIBUTING.md) for information on how to contribute code
-changes to TensorFlow through
-[tensorflow.googlesource.com](https://tensorflow.googlesource.com/tensorflow)**
+**If you'd like to contribute to tensorflow, be sure to review the [contribution
+guidelines](CONTRIBUTING.md).**
**We use [github issues](https://github.com/tensorflow/tensorflow/issues) for
tracking requests and bugs, but please see
@@ -29,35 +26,7 @@ and discussion.**
# Download and Setup
-To install the CPU version of TensorFlow using a binary package, see the
-instructions below. For more detailed installation instructions, including
-installing from source, GPU-enabled support, etc., see
-[here](tensorflow/g3doc/get_started/os_setup.md).
-
-## Binary Installation
-
-The TensorFlow Python API supports Python 2.7 and Python 3.3+.
-
-The simplest way to install TensorFlow is using
-[pip](https://pypi.python.org/pypi/pip) for both Linux and Mac.
-
-For the GPU-enabled version, or if you encounter installation errors, or for
-more detailed installation instructions, see
-[here](tensorflow/g3doc/get_started/os_setup.md#detailed_install).
-
-### Ubuntu/Linux 64-bit
-
-```bash
-# For CPU-only version
-$ pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
-```
-
-### Mac OS X
-
-```bash
-# Only CPU-version is available at the moment.
-$ pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
-```
+See [install instructions](tensorflow/g3doc/get_started/os_setup.md).
### Try your first TensorFlow program
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index d2b5e65f12..33444cd45d 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -349,6 +349,7 @@ tf_gpu_kernel_library(
visibility = ["//visibility:public"],
deps = [
":cuda",
+ ":framework",
"//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc
index 11a6cb1da7..126e50ad73 100644
--- a/tensorflow/core/kernels/resize_area_op.cc
+++ b/tensorflow/core/kernels/resize_area_op.cc
@@ -144,11 +144,8 @@ class ResizeAreaOp : public OpKernel {
.HostMemory("size"), \
ResizeAreaOp<CPUDevice, T>);
-REGISTER_KERNEL(uint8);
-REGISTER_KERNEL(int8);
-REGISTER_KERNEL(int32);
-REGISTER_KERNEL(float);
-REGISTER_KERNEL(double);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
+
#undef REGISTER_KERNEL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc
index ea09c3be0c..83b5af31c5 100644
--- a/tensorflow/core/kernels/resize_bicubic_op.cc
+++ b/tensorflow/core/kernels/resize_bicubic_op.cc
@@ -131,11 +131,8 @@ class ResizeBicubicOp : public OpKernel {
.HostMemory("size"), \
ResizeBicubicOp<CPUDevice, T>);
-REGISTER_KERNEL(uint8);
-REGISTER_KERNEL(int8);
-REGISTER_KERNEL(int32);
-REGISTER_KERNEL(float);
-REGISTER_KERNEL(double);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
+
#undef REGISTER_KERNEL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc
index 1688be42f5..897e3a0f12 100644
--- a/tensorflow/core/kernels/resize_bilinear_op.cc
+++ b/tensorflow/core/kernels/resize_bilinear_op.cc
@@ -215,11 +215,8 @@ class ResizeBilinearOpGrad : public OpKernel {
.HostMemory("size"), \
ResizeBilinearOp<CPUDevice, T>);
-REGISTER_KERNEL(uint8);
-REGISTER_KERNEL(int8);
-REGISTER_KERNEL(int32);
-REGISTER_KERNEL(float);
-REGISTER_KERNEL(double);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
+
#undef REGISTER_KERNEL
REGISTER_KERNEL_BUILDER(Name("ResizeBilinearGrad")
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
index c4eea44044..7f88241a36 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
@@ -178,11 +178,8 @@ class ResizeNearestNeighborOpGrad : public OpKernel {
.HostMemory("size"), \
ResizeNearestNeighborOpGrad<CPUDevice, T>);
-REGISTER_KERNEL(uint8);
-REGISTER_KERNEL(int8);
-REGISTER_KERNEL(int32);
-REGISTER_KERNEL(float);
-REGISTER_KERNEL(double);
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
+
#undef REGISTER_KERNEL
} // namespace tensorflow
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 51909fe6fc..10464133f9 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -22,7 +22,7 @@ REGISTER_OP("ResizeArea")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: float")
- .Attr("T: {uint8, int8, int32, float, double}")
+ .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Doc(R"doc(
Resize `images` to `size` using area interpolation.
@@ -40,7 +40,7 @@ REGISTER_OP("ResizeBicubic")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: float")
- .Attr("T: {uint8, int8, int32, float, double}")
+ .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Doc(R"doc(
Resize `images` to `size` using bicubic interpolation.
@@ -58,7 +58,7 @@ REGISTER_OP("ResizeBilinear")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: float")
- .Attr("T: {uint8, int8, int32, float, double}")
+ .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Doc(R"doc(
Resize `images` to `size` using bilinear interpolation.
@@ -93,7 +93,7 @@ REGISTER_OP("ResizeNearestNeighbor")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: T")
- .Attr("T: {uint8, int8, int32, float, double}")
+ .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Doc(R"doc(
Resize `images` to `size` using nearest neighbor interpolation.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 761c261dda..b90d6b2ddc 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -5949,7 +5949,9 @@ op {
list {
type: DT_UINT8
type: DT_INT8
+ type: DT_INT16
type: DT_INT32
+ type: DT_INT64
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -5982,7 +5984,9 @@ op {
list {
type: DT_UINT8
type: DT_INT8
+ type: DT_INT16
type: DT_INT32
+ type: DT_INT64
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -6015,7 +6019,9 @@ op {
list {
type: DT_UINT8
type: DT_INT8
+ type: DT_INT16
type: DT_INT32
+ type: DT_INT64
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -6077,7 +6083,9 @@ op {
list {
type: DT_UINT8
type: DT_INT8
+ type: DT_INT16
type: DT_INT32
+ type: DT_INT64
type: DT_FLOAT
type: DT_DOUBLE
}
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index e04e86a100..714d9dccf4 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -141,16 +141,18 @@ with graph.as_default():
train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
valid_dataset = tf.constant(valid_examples, dtype=tf.int32)
- # Construct the variables.
- embeddings = tf.Variable(
- tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
- nce_weights = tf.Variable(
- tf.truncated_normal([vocabulary_size, embedding_size],
- stddev=1.0 / math.sqrt(embedding_size)))
- nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
-
- # Look up embeddings for inputs.
- embed = tf.nn.embedding_lookup(embeddings, train_inputs)
+ # Ops and variables pinned to the CPU because of missing GPU implementation
+ with tf.device('/cpu:0'):
+ # Look up embeddings for inputs.
+ embeddings = tf.Variable(
+ tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
+ embed = tf.nn.embedding_lookup(embeddings, train_inputs)
+
+ # Construct the variables for the NCE loss
+ nce_weights = tf.Variable(
+ tf.truncated_normal([vocabulary_size, embedding_size],
+ stddev=1.0 / math.sqrt(embedding_size)))
+ nce_biases = tf.Variable(tf.zeros([vocabulary_size]))
# Compute the average NCE loss for the batch.
# tf.nce_loss automatically draws a new sample of the negative labels each
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index 0d572c942e..79d485cc60 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1408,7 +1408,7 @@ and Python scalars. For example:
```python
import numpy as np
-array = np.random.rand((32, 100, 100))
+array = np.random.rand(32, 100, 100)
def my_func(arg):
arg = tf.convert_to_tensor(arg, dtype=tf.float32)
diff --git a/tensorflow/g3doc/resources/dims_types.md b/tensorflow/g3doc/resources/dims_types.md
index 875eba6863..c636f8a34a 100644
--- a/tensorflow/g3doc/resources/dims_types.md
+++ b/tensorflow/g3doc/resources/dims_types.md
@@ -54,14 +54,14 @@ Data type | Python type | Description
--- | --- | ---
`DT_FLOAT` | `tf.float32` | 32 bits floating point.
`DT_DOUBLE` | `tf.float64` | 64 bits floating point.
-`DT_INT64` | `tf.int64` | 64 bits signed integer.
-`DT_INT32` | `tf.int32` | 32 bits signed integer.
-`DT_INT16` | `tf.int16` | 16 bits signed integer.
`DT_INT8` | `tf.int8` | 8 bits signed integer.
+`DT_INT16` | `tf.int16` | 16 bits signed integer.
+`DT_INT32` | `tf.int32` | 32 bits signed integer.
+`DT_INT64` | `tf.int64` | 64 bits signed integer.
`DT_UINT8` | `tf.uint8` | 8 bits unsigned integer.
`DT_STRING` | `tf.string` | Variable length byte arrays. Each element of a Tensor is a byte array.
`DT_BOOL` | `tf.bool` | Boolean.
`DT_COMPLEX64` | `tf.complex64` | Complex number made of two 32 bits floating points: real and imaginary parts.
-`DT_QINT32` | `tf.qint32` | 32 bits signed integer used in quantized Ops.
`DT_QINT8` | `tf.qint8` | 8 bits signed integer used in quantized Ops.
-`DT_QUINT8` | `tf.quint8` | 8 bits unsigned integer used in quantized Ops.
+`DT_QINT32` | `tf.qint32` | 32 bits signed integer used in quantized Ops.
+`DT_QUINT8` | `tf.quint8` | 8 bits unsigned integer used in quantized Ops. \ No newline at end of file
diff --git a/tensorflow/g3doc/tutorials/deep_cnn/index.md b/tensorflow/g3doc/tutorials/deep_cnn/index.md
index 66614d402f..00d4383f20 100644
--- a/tensorflow/g3doc/tutorials/deep_cnn/index.md
+++ b/tensorflow/g3doc/tutorials/deep_cnn/index.md
@@ -126,7 +126,7 @@ artificially increase the data set size:
* [Randomly flip](../../api_docs/python/image.md#random_flip_left_right) the image from left to right.
* Randomly distort the [image brightness](../../api_docs/python/image.md#random_brightness).
-* Randomly distort the [image contrast](../../api_docs/python/image.md#tf_image_random_contrast).
+* Randomly distort the [image contrast](../../api_docs/python/image.md#random_contrast).
Please see the [Images](../../api_docs/python/image.md) page for the list of
available distortions. We also attach an
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index ac12551528..d08f98c61c 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -486,7 +486,7 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
```python
import numpy as np
- array = np.random.rand((32, 100, 100))
+ array = np.random.rand(32, 100, 100)
def my_func(arg):
arg = tf.convert_to_tensor(arg, dtype=tf.float32)
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 3410689864..ad7612d8ab 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -564,49 +564,56 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
image_ops.ResizeMethod.BICUBIC,
image_ops.ResizeMethod.AREA]
+ TYPES = [np.uint8, np.int8, np.int16, np.int32, np.int64,
+ np.float, np.double]
+
def testNoOp(self):
img_shape = [1, 6, 4, 1]
single_shape = [6, 4, 1]
- data = [128, 128, 64, 64,
- 128, 128, 64, 64,
- 64, 64, 128, 128,
- 64, 64, 128, 128,
+ # This test is also conducted with int8, so 127 is the maximum
+ # value that can be used.
+ data = [127, 127, 64, 64,
+ 127, 127, 64, 64,
+ 64, 64, 127, 127,
+ 64, 64, 127, 127,
50, 50, 100, 100,
50, 50, 100, 100]
- img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
-
target_height = 6
target_width = 4
- for opt in self.OPTIONS:
- with self.test_session() as sess:
- image = constant_op.constant(img_np, shape=img_shape)
- y = image_ops.resize_images(image, target_height, target_width, opt)
- yshape = array_ops.shape(y)
- resized, newshape = sess.run([y, yshape])
- self.assertAllEqual(img_shape, newshape)
- self.assertAllClose(resized, img_np, atol=1e-5)
+ for nptype in self.TYPES:
+ img_np = np.array(data, dtype=nptype).reshape(img_shape)
- # Resizing with a single image must leave the shape unchanged also.
- with self.test_session():
- img_single = img_np.reshape(single_shape)
- image = constant_op.constant(img_single, shape=single_shape)
- y = image_ops.resize_images(image, target_height, target_width,
- self.OPTIONS[0])
- yshape = array_ops.shape(y)
- newshape = yshape.eval()
- self.assertAllEqual(single_shape, newshape)
+ for opt in self.OPTIONS:
+ with self.test_session() as sess:
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ yshape = array_ops.shape(y)
+ resized, newshape = sess.run([y, yshape])
+ self.assertAllEqual(img_shape, newshape)
+ self.assertAllClose(resized, img_np, atol=1e-5)
- def testResizeDown(self):
+ # Resizing with a single image must leave the shape unchanged also.
+ with self.test_session():
+ img_single = img_np.reshape(single_shape)
+ image = constant_op.constant(img_single, shape=single_shape)
+ y = image_ops.resize_images(image, target_height, target_width,
+ self.OPTIONS[0])
+ yshape = array_ops.shape(y)
+ newshape = yshape.eval()
+ self.assertAllEqual(single_shape, newshape)
- data = [128, 128, 64, 64,
- 128, 128, 64, 64,
- 64, 64, 128, 128,
- 64, 64, 128, 128,
+ def testResizeDown(self):
+ # This test is also conducted with int8, so 127 is the maximum
+ # value that can be used.
+ data = [127, 127, 64, 64,
+ 127, 127, 64, 64,
+ 64, 64, 127, 127,
+ 64, 64, 127, 127,
50, 50, 100, 100,
50, 50, 100, 100]
- expected_data = [128, 64,
- 64, 128,
+ expected_data = [127, 64,
+ 64, 127,
50, 100]
target_height = 3
target_width = 2
@@ -617,59 +624,61 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
[target_height, target_width, 1]]
for target_shape, img_shape in zip(target_shapes, img_shapes):
- img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
- for opt in self.OPTIONS:
- with self.test_session():
- image = constant_op.constant(img_np, shape=img_shape)
- y = image_ops.resize_images(image, target_height, target_width, opt)
- expected = np.array(expected_data).reshape(target_shape)
- resized = y.eval()
- self.assertAllClose(resized, expected, atol=1e-5)
+ for nptype in self.TYPES:
+ img_np = np.array(data, dtype=nptype).reshape(img_shape)
+
+ for opt in self.OPTIONS:
+ with self.test_session():
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ expected = np.array(expected_data).reshape(target_shape)
+ resized = y.eval()
+ self.assertAllClose(resized, expected, atol=1e-5)
def testResizeUp(self):
img_shape = [1, 3, 2, 1]
- data = [128, 64,
- 64, 128,
+ data = [64, 32,
+ 32, 64,
50, 100]
- img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
-
target_height = 6
target_width = 4
expected_data = {}
expected_data[image_ops.ResizeMethod.BILINEAR] = [
- 128.0, 96.0, 64.0, 64.0,
- 96.0, 96.0, 96.0, 96.0,
- 64.0, 96.0, 128.0, 128.0,
- 57.0, 85.5, 114.0, 114.0,
+ 64.0, 48.0, 32.0, 32.0,
+ 48.0, 48.0, 48.0, 48.0,
+ 32.0, 48.0, 64.0, 64.0,
+ 41.0, 61.5, 82.0, 82.0,
50.0, 75.0, 100.0, 100.0,
50.0, 75.0, 100.0, 100.0]
expected_data[image_ops.ResizeMethod.NEAREST_NEIGHBOR] = [
- 128.0, 128.0, 64.0, 64.0,
- 128.0, 128.0, 64.0, 64.0,
- 64.0, 64.0, 128.0, 128.0,
- 64.0, 64.0, 128.0, 128.0,
+ 64.0, 64.0, 32.0, 32.0,
+ 64.0, 64.0, 32.0, 32.0,
+ 32.0, 32.0, 64.0, 64.0,
+ 32.0, 32.0, 64.0, 64.0,
50.0, 50.0, 100.0, 100.0,
50.0, 50.0, 100.0, 100.0]
expected_data[image_ops.ResizeMethod.AREA] = [
- 128.0, 128.0, 64.0, 64.0,
- 128.0, 128.0, 64.0, 64.0,
- 64.0, 64.0, 128.0, 128.0,
- 64.0, 64.0, 128.0, 128.0,
+ 64.0, 64.0, 32.0, 32.0,
+ 64.0, 64.0, 32.0, 32.0,
+ 32.0, 32.0, 64.0, 64.0,
+ 32.0, 32.0, 64.0, 64.0,
50.0, 50.0, 100.0, 100.0,
50.0, 50.0, 100.0, 100.0]
- for opt in [
- image_ops.ResizeMethod.BILINEAR,
- image_ops.ResizeMethod.NEAREST_NEIGHBOR,
- image_ops.ResizeMethod.AREA]:
- with self.test_session():
- image = constant_op.constant(img_np, shape=img_shape)
- y = image_ops.resize_images(image, target_height, target_width, opt)
- resized = y.eval()
- expected = np.array(expected_data[opt]).reshape(
- [1, target_height, target_width, 1])
- self.assertAllClose(resized, expected, atol=1e-05)
+ for nptype in self.TYPES:
+ for opt in [
+ image_ops.ResizeMethod.BILINEAR,
+ image_ops.ResizeMethod.NEAREST_NEIGHBOR,
+ image_ops.ResizeMethod.AREA]:
+ with self.test_session():
+ img_np = np.array(data, dtype=nptype).reshape(img_shape)
+ image = constant_op.constant(img_np, shape=img_shape)
+ y = image_ops.resize_images(image, target_height, target_width, opt)
+ resized = y.eval()
+ expected = np.array(expected_data[opt]).reshape(
+ [1, target_height, target_width, 1])
+ self.assertAllClose(resized, expected, atol=1e-05)
def testResizeUpBicubic(self):
img_shape = [1, 6, 6, 1]
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 60cff4b97b..78bff9f9db 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -187,7 +187,7 @@ class Template(object):
"meant tf.get_variable: %s",
variables[vars_at_start:])
return result
- except Exception, exc:
+ except Exception as exc:
# Reraise the exception, but append the original definition to the
# trace.
args = exc.args
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index 5030c63370..efd6f2a807 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -70,7 +70,7 @@ class Coordinator(object):
try:
while not coord.should_stop():
...do some work...
- except Exception, e:
+ except Exception as e:
coord.request_stop(e)
```
@@ -85,7 +85,7 @@ class Coordinator(object):
...start thread N...(coord, ...)
# Wait for all the threads to terminate.
coord.join(threads)
- except Exception, e:
+ except Exception as e:
...exception that was passed to coord.request_stop()
```
@@ -188,7 +188,7 @@ class Coordinator(object):
```python
try:
...body...
- exception Exception, ex:
+ exception Exception as ex:
coord.request_stop(ex)
```
@@ -198,7 +198,7 @@ class Coordinator(object):
# pylint: disable=broad-except
try:
yield
- except Exception, ex:
+ except Exception as ex:
self.request_stop(ex)
# pylint: enable=broad-except
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
index cb69c36b75..90186371c3 100644
--- a/tensorflow/stream_executor/BUILD
+++ b/tensorflow/stream_executor/BUILD
@@ -19,6 +19,7 @@ cc_library(
),
hdrs = glob([
"*.h",
+ "cuda/*.h",
"lib/*.h",
"platform/**/*.h",
]),