aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc27
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h51
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc8
-rw-r--r--tensorflow/core/kernels/scatter_op.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt24
-rw-r--r--tensorflow/models/rnn/ptb/ptb_word_lm.py12
-rw-r--r--tensorflow/models/rnn/translate/translate.py4
-rw-r--r--tensorflow/python/__init__.py10
-rw-r--r--tensorflow/python/framework/framework_lib.py2
-rw-r--r--tensorflow/python/framework/importer_test.py3
-rw-r--r--tensorflow/python/framework/ops.py40
-rw-r--r--tensorflow/python/framework/ops_test.py33
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py11
-rw-r--r--tensorflow/python/kernel_tests/linear_test.py2
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/reverse_sequence_op_test.py1
-rw-r--r--tensorflow/python/kernel_tests/rnn_cell_test.py18
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py11
-rw-r--r--tensorflow/python/kernel_tests/seq2seq_test.py10
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py16
-rw-r--r--tensorflow/python/ops/control_flow_ops.py9
-rw-r--r--tensorflow/python/ops/variables.py54
-rw-r--r--tensorflow/python/training/adam.py8
-rw-r--r--tensorflow/python/training/coordinator.py4
-rw-r--r--tensorflow/python/training/moving_averages_test.py86
-rw-r--r--tensorflow/python/util/compat.py15
-rw-r--r--tensorflow/tensorboard/gulpfile.js17
-rw-r--r--tensorflow/tensorboard/tests.html31
-rw-r--r--tools/bazel.rc.template4
-rwxr-xr-xutil/python/python_config.sh12
31 files changed, 346 insertions, 184 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 01872b4f54..ace5d8f5c0 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -40,6 +40,7 @@ filegroup(
py_library(
name = "tensorflow_py",
srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = ["//tensorflow/python"],
)
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
index 1821289f4b..962848ad17 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -40,13 +40,13 @@ EventMgr::~EventMgr() {
delete e;
}
while (!used_events_.empty()) {
- delete used_events_[0].event;
- delete used_events_[0].mem;
- if (used_events_[0].bufrec.buf) {
- used_events_[0].bufrec.alloc->DeallocateRaw(used_events_[0].bufrec.buf);
+ InUse* ue = &used_events_[0];
+ delete ue->event;
+ delete ue->mem;
+ if (ue->bufrec.buf) {
+ ue->bufrec.alloc->DeallocateRaw(ue->bufrec.buf);
}
- if (used_events_[0].func != nullptr)
- threadpool_.Schedule(used_events_[0].func);
+ if (ue->func != nullptr) threadpool_.Schedule(ue->func);
used_events_.pop_front();
}
}
@@ -60,10 +60,12 @@ EventMgr::~EventMgr() {
void EventMgr::PollLoop() {
while (!stop_polling_.HasBeenNotified()) {
Env::Default()->SleepForMicroseconds(1 * 1000);
+ ToFreeVector to_free;
{
mutex_lock l(mu_);
- PollEvents(true);
+ PollEvents(true, &to_free);
}
+ FreeMemory(to_free);
}
polling_stopped_.Notify();
}
@@ -103,7 +105,8 @@ void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) {
// GPU memory use to spike needlessly. An alternative strategy would
// be to throttle new Op execution until the pending event queue
// clears.
-void EventMgr::PollEvents(bool is_dedicated_poller) {
+void EventMgr::PollEvents(bool is_dedicated_poller,
+ gtl::InlinedVector<InUse, 4>* to_free) {
VLOG(2) << "PollEvents free_events_ " << free_events_.size()
<< " used_events_ " << used_events_.size();
// Sweep the remaining events in order. If this is the dedicated
@@ -123,11 +126,9 @@ void EventMgr::PollEvents(bool is_dedicated_poller) {
if (!is_dedicated_poller) return; // quit processing queue
break;
case gpu::Event::Status::kComplete:
- delete iu.mem;
- if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
- // The function must be called in another thread, outside of
- // the mutex held here.
- if (iu.func != nullptr) threadpool_.Schedule(iu.func);
+ // Make a copy of the InUse record so we can free it after releasing
+ // the lock
+ to_free->push_back(iu);
free_events_.push_back(iu.event);
// Mark this InUse record as completed.
iu.event = nullptr;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
index 5fe9fd782d..443664beef 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -18,8 +18,10 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/tensor.h"
@@ -47,9 +49,13 @@ class EventMgr {
// currently enqueued on *stream have completed.
inline void ThenDeleteTensors(perftools::gputools::Stream* stream,
std::vector<Tensor>* tensors) {
- mutex_lock l(mu_);
- QueueTensors(stream, tensors);
- PollEvents(false);
+ ToFreeVector to_free;
+ {
+ mutex_lock l(mu_);
+ QueueTensors(stream, tensors);
+ PollEvents(false, &to_free);
+ }
+ FreeMemory(to_free);
}
struct BufRec {
@@ -61,16 +67,24 @@ class EventMgr {
// on it as soon as all events currently enqueued on *stream have completed.
inline void ThenDeleteBuffer(perftools::gputools::Stream* stream,
BufRec bufrec) {
- mutex_lock l(mu_);
- QueueBuffer(stream, bufrec);
- PollEvents(false);
+ ToFreeVector to_free;
+ {
+ mutex_lock l(mu_);
+ QueueBuffer(stream, bufrec);
+ PollEvents(false, &to_free);
+ }
+ FreeMemory(to_free);
}
inline void ThenExecute(perftools::gputools::Stream* stream,
std::function<void()> func) {
- mutex_lock l(mu_);
- QueueFunc(stream, func);
- PollEvents(false);
+ ToFreeVector to_free;
+ {
+ mutex_lock l(mu_);
+ QueueFunc(stream, func);
+ PollEvents(false, &to_free);
+ }
+ FreeMemory(to_free);
}
private:
@@ -85,10 +99,22 @@ class EventMgr {
std::function<void()> func;
};
+ typedef gtl::InlinedVector<InUse, 4> ToFreeVector;
+
+ void FreeMemory(const ToFreeVector& to_free) {
+ for (const auto& iu : to_free) {
+ delete iu.mem;
+ if (iu.bufrec.buf) iu.bufrec.alloc->DeallocateRaw(iu.bufrec.buf);
+ // The function must be called in another thread.
+ if (iu.func != nullptr) threadpool_.Schedule(iu.func);
+ }
+ }
+
// Stream-enqueue an unused Event and save with it a collection of
// Tensors and/or a BufRec to be deleted only after the Event
// records.
void QueueInUse(perftools::gputools::Stream* stream, InUse in_use)
+
EXCLUSIVE_LOCKS_REQUIRED(mu_);
void QueueTensors(perftools::gputools::Stream* stream,
@@ -109,8 +135,11 @@ class EventMgr {
// This function should be called at roughly the same tempo as
// QueueTensors() to check whether pending events have recorded,
- // and then retire them.
- void PollEvents(bool is_dedicated_poller) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ // and then retire them. It appends InUse elements that need cleanup
+ // to "*to_free". The caller should call FreeMemory(to_free)
+ // when this returns.
+ void PollEvents(bool is_dedicated_poller, ToFreeVector* to_free)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
// An internal polling loop that runs at a low frequency to clear
// straggler Events.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
index 6956ead643..90d26a34cd 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -47,8 +47,12 @@ class TEST_EventMgrHelper {
}
void PollEvents(bool is_dedicated_poller) {
- mutex_lock l(em_->mu_);
- em_->PollEvents(is_dedicated_poller);
+ EventMgr::ToFreeVector to_free;
+ {
+ mutex_lock l(em_->mu_);
+ em_->PollEvents(is_dedicated_poller, &to_free);
+ }
+ em_->FreeMemory(to_free);
}
private:
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index e47b07a534..59315876aa 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -140,6 +140,8 @@ class ScatterUpdateOp : public OpKernel {
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT32);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_INT64);
+REGISTER_SCATTER_UPDATE_INT32(bool)
+REGISTER_SCATTER_UPDATE_INT64(bool)
#undef REGISTER_SCATTER_UPDATE_INT64
#undef REGISTER_SCATTER_UPDATE_INT32
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 59a0ee62ee..871fe17638 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -2039,7 +2039,7 @@ op {
type: "type"
}
summary: "Partitions `data` into `num_partitions` tensors using indices from `partitions`."
- description: "For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]`\nbecomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i`\nare placed in `outputs[i]` in lexicographic order of `js`, and the first\ndimension of `outputs[i]` is the number of entries in `partitions` equal to `i`.\nIn detail,\n\n outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:]\n\n outputs[i] = pack([data[js, ...] for js if partitions[js] == i])\n\n`data.shape` must start with `partitions.shape`.\n\nFor example:\n\n # Scalar partitions\n partitions = 1\n num_partitions = 2\n data = [10, 20]\n outputs[0] = [] # Empty with shape [0, 2]\n outputs[1] = [[10, 20]]\n\n # Vector partitions\n partitions = [0, 0, 1, 1, 0]\n num_partitions = 2\n data = [10, 20, 30, 40, 50]\n outputs[0] = [10, 20, 50]\n outputs[1] = [30, 40]\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/DynamicPartition.png\" alt>\n</div>"
+ description: "For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]`\nbecomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i`\nare placed in `outputs[i]` in lexicographic order of `js`, and the first\ndimension of `outputs[i]` is the number of entries in `partitions` equal to `i`.\nIn detail,\n\n outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:]\n\n outputs[i] = pack([data[js, ...] for js if partitions[js] == i])\n\n`data.shape` must start with `partitions.shape`.\n\nFor example:\n\n # Scalar partitions\n partitions = 1\n num_partitions = 2\n data = [10, 20]\n outputs[0] = [] # Empty with shape [0, 2]\n outputs[1] = [[10, 20]]\n\n # Vector partitions\n partitions = [0, 0, 1, 1, 0]\n num_partitions = 2\n data = [10, 20, 30, 40, 50]\n outputs[0] = [10, 20, 50]\n outputs[1] = [30, 40]\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/DynamicPartition.png\" alt>\n</div>"
}
op {
name: "DynamicStitch"
@@ -2068,7 +2068,7 @@ op {
type: "type"
}
summary: "Interleave the values from the `data` tensors into a single tensor."
- description: "Builds a merged tensor such that\n\n merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]\n\nFor example, if each `indices[m]` is scalar or vector, we have\n\n # Scalar indices\n merged[indices[m], ...] = data[m][...]\n\n # Vector indices\n merged[indices[m][i], ...] = data[m][i, ...]\n\nEach `data[i].shape` must start with the corresponding `indices[i].shape`,\nand the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we\nmust have `data[i].shape = indices[i].shape + constant`. In terms of this\n`constant`, the output shape is\n\n merged.shape = [max(indices)] + constant\n\nValues are merged in order, so if an index appears in both `indices[m][i]` and\n`indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the\nmerged result.\n\nFor example:\n\n indices[0] = 6\n indices[1] = [4, 1]\n indices[2] = [[5, 2], [0, 3]]\n data[0] = [61, 62]\n data[1] = [[41, 42], [11, 12]]\n data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]\n merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],\n [51, 52], [61, 62]]\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/DynamicStitch.png\" alt>\n</div>"
+ description: "Builds a merged tensor such that\n\n merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]\n\nFor example, if each `indices[m]` is scalar or vector, we have\n\n # Scalar indices\n merged[indices[m], ...] = data[m][...]\n\n # Vector indices\n merged[indices[m][i], ...] = data[m][i, ...]\n\nEach `data[i].shape` must start with the corresponding `indices[i].shape`,\nand the rest of `data[i].shape` must be constant w.r.t. `i`. That is, we\nmust have `data[i].shape = indices[i].shape + constant`. In terms of this\n`constant`, the output shape is\n\n merged.shape = [max(indices)] + constant\n\nValues are merged in order, so if an index appears in both `indices[m][i]` and\n`indices[n][j]` for `(m,i) < (n,j)` the slice `data[n][j]` will appear in the\nmerged result.\n\nFor example:\n\n indices[0] = 6\n indices[1] = [4, 1]\n indices[2] = [[5, 2], [0, 3]]\n data[0] = [61, 62]\n data[1] = [[41, 42], [11, 12]]\n data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]\n merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],\n [51, 52], [61, 62]]\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/DynamicStitch.png\" alt>\n</div>"
}
op {
name: "EditDistance"
@@ -2784,7 +2784,7 @@ op {
}
}
summary: "Gather slices from `params` according to `indices`."
- description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/Gather.png\" alt>\n</div>"
+ description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/Gather.png\" alt>\n</div>"
}
op {
name: "Greater"
@@ -6182,7 +6182,7 @@ op {
description: "If True, the addition will be protected by a lock;\notherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Adds sparse updates to a variable reference."
- description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] += updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] += updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nDuplicate entries are handled correctly: if multiple `indices` reference\nthe same location, their contributions add.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/ScatterAdd.png\" alt>\n</div>"
+ description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] += updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] += updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nDuplicate entries are handled correctly: if multiple `indices` reference\nthe same location, their contributions add.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterAdd.png\" alt>\n</div>"
}
op {
name: "ScatterSub"
@@ -6246,7 +6246,7 @@ op {
description: "If True, the subtraction will be protected by a lock;\notherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Subtracts sparse updates to a variable reference."
- description: " # Scalar indices\n ref[indices, ...] -= updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] -= updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nDuplicate entries are handled correctly: if multiple `indices` reference\nthe same location, their (negated) contributions add.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/ScatterSub.png\" alt>\n</div>"
+ description: " # Scalar indices\n ref[indices, ...] -= updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] -= updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nDuplicate entries are handled correctly: if multiple `indices` reference\nthe same location, their (negated) contributions add.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterSub.png\" alt>\n</div>"
}
op {
name: "ScatterUpdate"
@@ -6295,7 +6295,7 @@ op {
description: "If True, the assignment will be protected by a lock;\notherwise the behavior is undefined, but may exhibit less contention."
}
summary: "Applies sparse updates to a variable reference."
- description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] = updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] = updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nIf `indices` contains duplicate entries, lexicographically later entries\noverride earlier entries.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/ScatterUpdate.png\" alt>\n</div>"
+ description: "This operation computes\n\n # Scalar indices\n ref[indices, ...] = updates[...]\n\n # Vector indices (for each i)\n ref[indices[i], ...] = updates[i, ...]\n\n # High rank indices (for each i, ..., j)\n ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]\n\nThis operation outputs `ref` after the update is done.\nThis makes it easier to chain operations that need to use the reset value.\n\nIf `indices` contains duplicate entries, lexicographically later entries\noverride earlier entries.\n\nRequires `updates.shape = indices.shape + ref.shape[1:]`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/ScatterUpdate.png\" alt>\n</div>"
}
op {
name: "SegmentMax"
@@ -6339,7 +6339,7 @@ op {
}
}
summary: "Computes the maximum along segments of a tensor."
- description: "Read [the section on Segmentation](../../api_docs/python/math_ops.md#segmentation)\nfor an explanation of segments.\n\nComputes a tensor such that\n\\\\(output_i = \\max_j(data_j)\\\\) where `max` is over `j` such\nthat `segment_ids[j] == i`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/SegmentMax.png\" alt>\n</div>"
+ description: "Read [the section on Segmentation](../../api_docs/python/math_ops.md#segmentation)\nfor an explanation of segments.\n\nComputes a tensor such that\n\\\\(output_i = \\max_j(data_j)\\\\) where `max` is over `j` such\nthat `segment_ids[j] == i`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/SegmentMax.png\" alt>\n</div>"
}
op {
name: "SegmentMean"
@@ -6383,7 +6383,7 @@ op {
}
}
summary: "Computes the mean along segments of a tensor."
- description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\frac{\\sum_j data_j}{N}\\\\) where `mean` is\nover `j` such that `segment_ids[j] == i` and `N` is the total number of\nvalues summed.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/SegmentMean.png\" alt>\n</div>"
+ description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\frac{\\sum_j data_j}{N}\\\\) where `mean` is\nover `j` such that `segment_ids[j] == i` and `N` is the total number of\nvalues summed.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/SegmentMean.png\" alt>\n</div>"
}
op {
name: "SegmentMin"
@@ -6427,7 +6427,7 @@ op {
}
}
summary: "Computes the minimum along segments of a tensor."
- description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\min_j(data_j)\\\\) where `min` is over `j` such\nthat `segment_ids[j] == i`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/SegmentMin.png\" alt>\n</div>"
+ description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\min_j(data_j)\\\\) where `min` is over `j` such\nthat `segment_ids[j] == i`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/SegmentMin.png\" alt>\n</div>"
}
op {
name: "SegmentProd"
@@ -6471,7 +6471,7 @@ op {
}
}
summary: "Computes the product along segments of a tensor."
- description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\prod_j data_j\\\\) where the product is over `j` such\nthat `segment_ids[j] == i`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/SegmentProd.png\" alt>\n</div>"
+ description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\prod_j data_j\\\\) where the product is over `j` such\nthat `segment_ids[j] == i`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/SegmentProd.png\" alt>\n</div>"
}
op {
name: "SegmentSum"
@@ -6515,7 +6515,7 @@ op {
}
}
summary: "Computes the sum along segments of a tensor."
- description: "Read [the section on Segmentation](../../api_docs/python/math_ops.md#segmentation)\nfor an explanation of segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/SegmentSum.png\" alt>\n</div>"
+ description: "Read [the section on Segmentation](../../api_docs/python/math_ops.md#segmentation)\nfor an explanation of segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/SegmentSum.png\" alt>\n</div>"
}
op {
name: "Select"
@@ -8321,7 +8321,7 @@ op {
}
}
summary: "Computes the sum along segments of a tensor."
- description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../images/UnsortedSegmentSum.png\" alt>\n</div>"
+ description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
}
op {
name: "Variable"
diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py
index a9e8f8ddf3..3380a4fc92 100644
--- a/tensorflow/models/rnn/ptb/ptb_word_lm.py
+++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py
@@ -106,12 +106,10 @@ class PTBModel(object):
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [vocab_size, size])
- inputs = tf.split(
- 1, num_steps, tf.nn.embedding_lookup(embedding, self._input_data))
- inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
+ inputs = tf.nn.embedding_lookup(embedding, self._input_data)
if is_training and config.keep_prob < 1:
- inputs = [tf.nn.dropout(input_, config.keep_prob) for input_ in inputs]
+ inputs = tf.nn.dropout(inputs, config.keep_prob)
# Simplified version of tensorflow.models.rnn.rnn.py's rnn().
# This builds an unrolled LSTM for tutorial purposes only.
@@ -120,14 +118,16 @@ class PTBModel(object):
# The alternative version of the code below is:
#
# from tensorflow.models.rnn import rnn
+ # inputs = [tf.squeeze(input_, [1])
+ # for input_ in tf.split(1, num_steps, inputs)]
# outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state)
outputs = []
states = []
state = self._initial_state
with tf.variable_scope("RNN"):
- for time_step, input_ in enumerate(inputs):
+ for time_step in range(num_steps):
if time_step > 0: tf.get_variable_scope().reuse_variables()
- (cell_output, state) = cell(input_, state)
+ (cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)
states.append(state)
diff --git a/tensorflow/models/rnn/translate/translate.py b/tensorflow/models/rnn/translate/translate.py
index c10eeefcb6..7e1e616fd7 100644
--- a/tensorflow/models/rnn/translate/translate.py
+++ b/tensorflow/models/rnn/translate/translate.py
@@ -128,7 +128,7 @@ def create_model(session, forward_only):
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
- session.run(tf.variables.initialize_all_variables())
+ session.run(tf.initialize_all_variables())
return model
@@ -254,7 +254,7 @@ def self_test():
# Create model with vocabularies of 10, 2 small buckets, 2 layers of 32.
model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2,
5.0, 32, 0.3, 0.99, num_samples=8)
- sess.run(tf.variables.initialize_all_variables())
+ sess.run(tf.initialize_all_variables())
# Fake data set for both the (3, 3) and (6, 6) bucket.
data_set = ([([1, 1], [2, 2]), ([3, 3], [4]), ([5], [6])],
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 0e0c4081df..718ab5cd93 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -28,6 +28,7 @@ import tensorflow as tf
"""
+import inspect
import traceback
try:
@@ -47,6 +48,7 @@ from tensorflow.core.util.event_pb2 import *
# Framework
from tensorflow.python.framework.framework_lib import *
+from tensorflow.python.framework import errors
# Session
from tensorflow.python.client.client_lib import *
@@ -71,3 +73,11 @@ from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.platform import logging
from tensorflow.python.platform import test
+
+# Don't export modules except for the few we really want
+_whitelist = set([app, compat, errors, flags, image, logging, nn,
+ python_io, test, train, user_ops])
+# TODO(b/25561952): tf.ops and tf.tensor_util are DEPRECATED. Please avoid.
+_whitelist.update([ops, tensor_util]) # pylint: disable=undefined-variable
+__all__ = [name for name, x in locals().items() if not name.startswith('_') and
+ (not inspect.ismodule(x) or x in _whitelist)]
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
index e85a72e0e1..84163ca1c0 100644
--- a/tensorflow/python/framework/framework_lib.py
+++ b/tensorflow/python/framework/framework_lib.py
@@ -33,6 +33,7 @@
@@name_scope
@@control_dependencies
@@convert_to_tensor
+@@convert_to_tensor_or_indexed_slices
@@get_default_graph
@@import_graph_def
@@ -75,6 +76,7 @@ from tensorflow.python.framework.ops import GraphKeys
from tensorflow.python.framework.ops import add_to_collection
from tensorflow.python.framework.ops import get_collection
from tensorflow.python.framework.ops import convert_to_tensor
+from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices
from tensorflow.python.framework.random_seed import get_seed
from tensorflow.python.framework.random_seed import set_random_seed
from tensorflow.python.framework.importer import import_graph_def
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 188ec2edcc..154c550ed5 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import tensorflow.python.platform
+import numpy as np
import tensorflow as tf
from google.protobuf import text_format
@@ -604,7 +605,7 @@ class ImportGraphDefTest(tf.test.TestCase):
# Adding a 150M entries float32 tensor should blow through the warning,
# but not the hard limit.
input_shape = [150, 1024, 1024]
- tensor_input = tf.np.random.rand(*input_shape).astype(tf.np.float32)
+ tensor_input = np.random.rand(*input_shape).astype(np.float32)
t = tf.constant(tensor_input, shape=input_shape)
g = tf.identity(t)
g.eval()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 352c73c0f7..d3527c693e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2344,17 +2344,25 @@ class Graph(object):
class _ControlDependenciesController(object):
"""Context manager for `control_dependencies()`."""
- def __init__(self, graph, control_inputs):
+ def __init__(self, graph, control_inputs, new_stack):
self._graph = graph
self._control_inputs = control_inputs
+ self._new_stack = new_stack
self._seen_nodes = set()
+ self._old_stack = None
# pylint: disable=protected-access
def __enter__(self):
+ if self._new_stack:
+ self._old_stack = self._graph._control_dependencies_stack
+ self._graph._control_dependencies_stack = []
self._graph._push_control_dependencies_controller(self)
def __exit__(self, unused_type, unused_value, unused_traceback):
self._graph._pop_control_dependencies_controller(self)
+ if self._new_stack:
+ self._graph._control_dependencies_stack = self._old_stack
+
# pylint: enable=protected-access
@property
@@ -2445,9 +2453,21 @@ class Graph(object):
```python
with g.control_dependencies([a, b]):
- # Ops declared here run after `a` and `b`.
+ # Ops constructed here run after `a` and `b`.
with g.control_dependencies([c, d]):
- # Ops declared here run after `a`, `b`, `c`, and `d`.
+ # Ops constructed here run after `a`, `b`, `c`, and `d`.
+ ```
+
+ You can pass None to clear the control dependencies:
+
+ ```python
+ with g.control_dependencies([a, b]):
+ # Ops constructed here run after `a` and `b`.
+ with g.control_dependencies(None):
+ # Ops constructed here run normally, not waiting for either `a` or `b`.
+ with g.control_dependencies([c, d]):
+ # Ops constructed here run after `c` and `d`, also not waiting
+ # for either `a` or `b`.
```
*N.B.* The control dependencies context applies *only* to ops that
@@ -2473,9 +2493,10 @@ class Graph(object):
```
Args:
- control_inputs: A list of `Operation` or `Tensor` objects, which
+ control_inputs: A list of `Operation` or `Tensor` objects which
must be executed or computed before running the operations
- defined in the context.
+ defined in the context. Can also be `None` to clear the control
+ dependencies.
Returns:
A context manager that specifies control dependencies for all
@@ -2485,6 +2506,8 @@ class Graph(object):
TypeError: If `control_inputs` is not a list of `Operation` or
`Tensor` objects.
"""
+ if control_inputs is None:
+ return self._ControlDependenciesController(self, [], True)
# First convert the inputs to ops, and deduplicate them.
# NOTE(mrry): Other than deduplication, we do not currently track direct
# or indirect dependencies between control_inputs, which may result in
@@ -2500,7 +2523,7 @@ class Graph(object):
if c not in current:
control_ops.append(c)
current.add(c)
- return self._ControlDependenciesController(self, control_ops)
+ return self._ControlDependenciesController(self, control_ops, False)
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
@@ -2670,9 +2693,10 @@ def control_dependencies(control_inputs):
for more details.
Args:
- control_inputs: A list of `Operation` or `Tensor` objects, which
+ control_inputs: A list of `Operation` or `Tensor` objects which
must be executed or computed before running the operations
- defined in the context.
+ defined in the context. Can also be `None` to clear the control
+ dependencies.
Returns:
A context manager that specifies control dependencies for all
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index b6dab94102..8eafddca32 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -681,6 +681,39 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
[a_1.op, a_2.op, a_3.op, a_4.op], b_1.op.control_inputs)
self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
+ def testClear(self):
+ g = ops.Graph()
+ a_1 = _apply_op(g, "const", [], [dtypes.float32])
+ a_2 = _apply_op(g, "const", [], [dtypes.float32])
+ a_3 = _apply_op(g, "const", [], [dtypes.float32])
+ a_4 = _apply_op(g, "const", [], [dtypes.float32])
+
+ with g.control_dependencies([a_1]):
+ with g.control_dependencies([a_2]):
+ with g.control_dependencies(None):
+ with g.control_dependencies([a_3]):
+ with g.control_dependencies([a_4]):
+ # deps [a_3, a_4]
+ b_3_4 = _apply_op(g, "const", [], [dtypes.float32])
+ # deps = [a_3]
+ b_3 = _apply_op(g, "const", [], [dtypes.float32])
+ # deps back to None
+ b_none = _apply_op(g, "const", [], [dtypes.float32])
+ # deps back to [a_1, a_2]
+ b_1_2 = _apply_op(g, "const", [], [dtypes.float32])
+ # deps back to [a_1]
+ b_1 = _apply_op(g, "const", [], [dtypes.float32])
+ with g.control_dependencies(None):
+ # deps are None again
+ b_none2 = _apply_op(g, "const", [], [dtypes.float32])
+
+ self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
+ self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
+ self.assertItemsEqual([], b_none.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
+ self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
+ self.assertItemsEqual([], b_none2.op.control_inputs)
+
def testComplex(self):
g = ops.Graph()
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index b70ec134ab..2a3acd6b76 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1325,5 +1325,16 @@ class TupleTest(tf.test.TestCase):
self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
v1.eval())
+ def testAcceptTensorsAsControlInputs(self):
+ with self.test_session():
+ var = tf.Variable(0)
+ assign = tf.assign(var, 1)
+ t, = tf.tuple([tf.constant(0)], control_inputs=[assign])
+
+ # Should trigger the assign.
+ t.eval()
+
+ self.assertEquals(1, var.eval())
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/linear_test.py b/tensorflow/python/kernel_tests/linear_test.py
index fdb4541114..dbaa332287 100644
--- a/tensorflow/python/kernel_tests/linear_test.py
+++ b/tensorflow/python/kernel_tests/linear_test.py
@@ -31,7 +31,7 @@ class LinearTest(tf.test.TestCase):
with tf.variable_scope("root", initializer=tf.constant_initializer(1.0)):
x = tf.zeros([1, 2])
l = tf.nn.rnn_cell.linear([x], 2, False)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([l], {x.name: np.array([[1., 2.]])})
self.assertAllClose(res[0], [[3.0, 3.0]])
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 331c62edf2..22ab1716ca 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -488,8 +488,8 @@ class ParseSequenceExampleTest(tf.test.TestCase):
}),
feature_lists=feature_lists({
"repeated_feature_2_frames": feature_list([
- bytes_feature(["a", "b", "c"]),
- bytes_feature(["a", "d", "e"])]),
+ bytes_feature([b"a", b"b", b"c"]),
+ bytes_feature([b"a", b"d", b"e"])]),
"repeated_feature_3_frames": feature_list([
int64_feature([3, 4, 5, 6, 7]),
int64_feature([-1, 0, 0, 0, 0]),
diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
index f2bc964109..cf75c95b25 100644
--- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
+++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import tensorflow.python.platform
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index fefe4b078d..c3a4de1b55 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -37,7 +37,7 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([1, 2])
m = tf.zeros([1, 2])
g, _ = rnn_cell.BasicRNNCell(2)(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([g], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
@@ -48,7 +48,7 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([1, 2])
m = tf.zeros([1, 2])
g, _ = rnn_cell.GRUCell(2)(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([g], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
# Smoke test
@@ -60,7 +60,7 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([1, 2])
m = tf.zeros([1, 8])
g, out_m = rnn_cell.MultiRNNCell([rnn_cell.BasicLSTMCell(2)] * 2)(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]),
m.name: 0.1 * np.ones([1, 8])})
self.assertEqual(len(res), 2)
@@ -84,7 +84,7 @@ class RNNCellTest(tf.test.TestCase):
m = tf.zeros([batch_size, state_size])
output, state = rnn_cell.LSTMCell(
num_units=num_units, input_size=input_size, num_proj=num_proj)(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([output, state],
{x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
m.name: 0.1 * np.ones((batch_size, state_size))})
@@ -107,7 +107,7 @@ class RNNCellTest(tf.test.TestCase):
m = tf.zeros([1, 3])
cell = rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(3), 2)
g, new_m = cell(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1]])})
self.assertEqual(res[1].shape, (1, 3))
@@ -121,7 +121,7 @@ class RNNCellTest(tf.test.TestCase):
m = tf.zeros([1, 3])
cell = rnn_cell.InputProjectionWrapper(rnn_cell.GRUCell(3), 2)
g, new_m = cell(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1]])})
self.assertEqual(res[1].shape, (1, 3))
@@ -136,7 +136,7 @@ class RNNCellTest(tf.test.TestCase):
keep = tf.zeros([]) + 1
g, new_m = rnn_cell.DropoutWrapper(rnn_cell.GRUCell(3),
keep, keep)(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1., 1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1]])})
self.assertEqual(res[1].shape, (1, 3))
@@ -149,7 +149,7 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([1, 1], dtype=tf.int32)
m = tf.zeros([1, 2])
g, new_m = rnn_cell.EmbeddingWrapper(rnn_cell.GRUCell(2), 3)(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run([g, new_m], {x.name: np.array([[1]]),
m.name: np.array([[0.1, 0.1]])})
self.assertEqual(res[1].shape, (1, 2))
@@ -162,7 +162,7 @@ class RNNCellTest(tf.test.TestCase):
x = tf.zeros([1, 2])
m = tf.zeros([1, 4])
_, ml = rnn_cell.MultiRNNCell([rnn_cell.GRUCell(2)] * 2)(x, m)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run(ml, {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1, 0.1, 0.1]])})
# The numbers in results were not calculated, this is just a smoke test.
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index af541a96c1..c4c248186c 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -63,6 +63,17 @@ class ScatterTest(tf.test.TestCase):
ref[indices] -= updates
self._VariableRankTest(sub, tf.scatter_sub)
+ def testBooleanScatterUpdate(self):
+ with self.test_session() as session:
+ var = tf.Variable([True, False])
+ update0 = tf.scatter_update(var, 1, True)
+ update1 = tf.scatter_update(var, tf.constant(0, dtype=tf.int64), False)
+ var.initializer.run()
+
+ session.run([update0, update1])
+
+ self.assertAllEqual([False, True], var.eval())
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/seq2seq_test.py b/tensorflow/python/kernel_tests/seq2seq_test.py
index 5ee2845780..1582d8d2ff 100644
--- a/tensorflow/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/python/kernel_tests/seq2seq_test.py
@@ -110,7 +110,7 @@ class Seq2SeqTest(tf.test.TestCase):
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
dec, mem = tf.nn.seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp, cell, 2, 5)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 5))
@@ -125,7 +125,7 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("proj_seq2seq"):
dec, _ = tf.nn.seq2seq.embedding_rnn_seq2seq(
enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b))
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 2))
@@ -156,7 +156,7 @@ class Seq2SeqTest(tf.test.TestCase):
cell = tf.nn.rnn_cell.BasicLSTMCell(2)
dec, mem = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp, cell, 5)
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 5))
@@ -171,7 +171,7 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("proj_seq2seq"):
dec, _ = tf.nn.seq2seq.embedding_tied_rnn_seq2seq(
enc_inp, dec_inp, cell, 5, output_projection=(w, b))
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 2))
@@ -281,7 +281,7 @@ class Seq2SeqTest(tf.test.TestCase):
with tf.variable_scope("proj_seq2seq"):
dec, _ = tf.nn.seq2seq.embedding_attention_seq2seq(
enc_inp, dec_inp, cell, 2, 5, output_projection=(w, b))
- sess.run([tf.variables.initialize_all_variables()])
+ sess.run([tf.initialize_all_variables()])
res = sess.run(dec)
self.assertEqual(len(res), 3)
self.assertEqual(res[0].shape, (2, 2))
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 913959ae93..4b38bfb7e7 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -132,6 +132,22 @@ class VariablesTestCase(tf.test.TestCase):
def testCountUpToInt64(self):
self._countUpToTest(tf.int64)
+ def testControlDepsNone(self):
+ with self.test_session():
+ c = tf.constant(1.0)
+ with tf.control_dependencies([c]):
+ # d get the control dep.
+ d = tf.constant(2.0)
+ # variables do not.
+ var_x = tf.Variable(2.0)
+ # initialized_value do not either.
+ inited_x = var_x.initialized_value()
+ self.assertEqual([c.op], d.op.control_inputs)
+ self.assertEqual([], var_x.initializer.control_inputs)
+ self.assertEqual([], var_x.value().op.control_inputs)
+ self.assertEqual([], var_x.ref().op.control_inputs)
+ self.assertEqual([var_x.initializer], inited_x.op.control_inputs)
+
def testUseVariableAsTensor(self):
with self.test_session():
var_x = tf.Variable(2.0)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index b2660c210a..779ba1e131 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1252,12 +1252,19 @@ def tuple(tensors, name=None, control_inputs=None):
Raises:
ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
+ TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
+ objects.
"""
with ops.op_scope(tensors, name, "tuple") as name:
gating_ops = [t.op for t in tensors if t]
if control_inputs:
- gating_ops += control_inputs
+ for c in control_inputs:
+ if isinstance(c, ops.Tensor):
+ c = c.op
+ elif not isinstance(c, ops.Operation):
+ raise TypeError("Control input must be Operation or Tensor: %s" % c)
+ gating_ops.append(c)
# Note that in order to ensure ordering in the pbtxt, we must take care to
# ensure the order here.
gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops.
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 3840971d76..35e52b878a 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -187,30 +187,31 @@ class Variable(object):
# modify the value of the variable, not the list.
collections = collections + [ops.GraphKeys.TRAINABLE_VARIABLES]
# pylint: enable=g-no-augmented-assignment
- with ops.op_scope([initial_value], name, "Variable") as name:
- self._initial_value = ops.convert_to_tensor(initial_value,
- name="initial_value")
- if not self._initial_value.get_shape().is_fully_defined():
- if validate_shape:
- raise ValueError(
- "initial_value must have a shape specified: %s"
- % self._initial_value)
- self._variable = state_ops.variable_op(
- [], self._initial_value.dtype.base_dtype, set_shape=False,
- name=name)
- with ops.device(self._variable.device):
- self._initializer_op = state_ops.assign(
- self._variable, self._initial_value, validate_shape=False).op
- self._snapshot = array_ops.identity(self._variable, name="read")
- else:
- self._variable = state_ops.variable_op(
- self._initial_value.get_shape(),
- self._initial_value.dtype.base_dtype,
- name=name)
- with ops.device(self._variable.device):
- self._initializer_op = state_ops.assign(
- self._variable, self._initial_value).op
- self._snapshot = array_ops.identity(self._variable, name="read")
+ with ops.control_dependencies(None):
+ with ops.op_scope([initial_value], name, "Variable") as name:
+ self._initial_value = ops.convert_to_tensor(initial_value,
+ name="initial_value")
+ if not self._initial_value.get_shape().is_fully_defined():
+ if validate_shape:
+ raise ValueError(
+ "initial_value must have a shape specified: %s"
+ % self._initial_value)
+ self._variable = state_ops.variable_op(
+ [], self._initial_value.dtype.base_dtype, set_shape=False,
+ name=name)
+ with ops.device(self._variable.device):
+ self._initializer_op = state_ops.assign(
+ self._variable, self._initial_value, validate_shape=False).op
+ self._snapshot = array_ops.identity(self._variable, name="read")
+ else:
+ self._variable = state_ops.variable_op(
+ self._initial_value.get_shape(),
+ self._initial_value.dtype.base_dtype,
+ name=name)
+ with ops.device(self._variable.device):
+ self._initializer_op = state_ops.assign(
+ self._variable, self._initial_value).op
+ self._snapshot = array_ops.identity(self._variable, name="read")
for key in collections:
ops.add_to_collection(key, self)
self._save_slice_info = None
@@ -317,8 +318,9 @@ class Variable(object):
A `Tensor` holding the value of this variable after its initializer
has run.
"""
- return control_flow_ops.with_dependencies(
- [self._initializer_op], self._variable)
+ with ops.control_dependencies(None):
+ with ops.control_dependencies([self._initializer_op]):
+ return array_ops.identity(self._variable)
def assign(self, value, use_locking=False):
"""Assigns a new value to the variable.
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 6729394083..55079b6c4c 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -103,8 +103,12 @@ class AdamOptimizer(optimizer.Optimizer):
# variable.
if self._beta1_power is None:
with ops.device(var_list[0].device):
- self._beta1_power = variables.Variable(self._beta1, name="beta1_power")
- self._beta2_power = variables.Variable(self._beta2, name="beta2_power")
+ self._beta1_power = variables.Variable(self._beta1,
+ name="beta1_power",
+ trainable=False)
+ self._beta2_power = variables.Variable(self._beta2,
+ name="beta2_power",
+ trainable=False)
# Create slots for the first and second moments.
for v in var_list:
self._zeros_slot(v, "m", self._name)
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index 805d00a441..0ee4012bdc 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -136,11 +136,11 @@ class Coordinator(object):
if ex and self._exc_info_to_raise is None:
if isinstance(ex, tuple):
logging.info("Error reported to Coordinator: %s",
- compat.as_str(unicode(ex[1])))
+ compat.as_str_any(ex[1]))
self._exc_info_to_raise = ex
else:
logging.info("Error reported to Coordinator: %s",
- compat.as_str(unicode(ex)))
+ compat.as_str_any(ex))
self._exc_info_to_raise = sys.exc_info()
self._stop_event.set()
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index a2ad3a51b0..11c4a27deb 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -20,26 +20,20 @@ from __future__ import print_function
import tensorflow.python.platform
-from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import constant_op
+import tensorflow as tf
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import googletest
from tensorflow.python.training import moving_averages
-class MovingAveragesTest(test_util.TensorFlowTestCase):
+class MovingAveragesTest(tf.test.TestCase):
def testAssignMovingAverage(self):
with self.test_session():
- var = variables.Variable([10.0, 11.0])
- val = constant_op.constant([1.0, 2.0], dtypes.float32)
+ var = tf.Variable([10.0, 11.0])
+ val = tf.constant([1.0, 2.0], tf.float32)
decay = 0.25
assign = moving_averages.assign_moving_average(var, val, decay)
- variables.initialize_all_variables().run()
+ tf.initialize_all_variables().run()
self.assertAllClose([10.0, 11.0], var.eval())
assign.op.run()
self.assertAllClose([10.0 * 0.25 + 1.0 * (1.0 - 0.25),
@@ -49,16 +43,16 @@ class MovingAveragesTest(test_util.TensorFlowTestCase):
def _Repeat(value, dim):
if dim == 1:
return value
- return [value for _ in xrange(dim)]
+ return [value] * dim
-class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
+class ExponentialMovingAverageTest(tf.test.TestCase):
def _CheckDecay(self, ema, actual_decay, dim):
tens = _Repeat(10.0, dim)
thirties = _Repeat(30.0, dim)
- var0 = variables.Variable(tens, name="v0")
- var1 = variables.Variable(thirties, name="v1")
- variables.initialize_all_variables().run()
+ var0 = tf.Variable(tens, name="v0")
+ var1 = tf.Variable(thirties, name="v1")
+ tf.initialize_all_variables().run()
# Note that tensor2 is not a Variable but just a plain Tensor resulting
# from the sum operation.
tensor2 = var0 + var1
@@ -67,10 +61,10 @@ class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
avg1 = ema.average(var1)
avg2 = ema.average(tensor2)
- self.assertFalse(avg0 in variables.trainable_variables())
- self.assertFalse(avg1 in variables.trainable_variables())
- self.assertFalse(avg2 in variables.trainable_variables())
- variables.initialize_all_variables().run()
+ self.assertFalse(avg0 in tf.trainable_variables())
+ self.assertFalse(avg1 in tf.trainable_variables())
+ self.assertFalse(avg2 in tf.trainable_variables())
+ tf.initialize_all_variables().run()
self.assertEqual("v0/ExponentialMovingAverage:0", avg0.name)
self.assertEqual("v1/ExponentialMovingAverage:0", avg1.name)
@@ -114,31 +108,55 @@ class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
def testAverageVariablesNoNumUpdates_Scalar(self):
with self.test_session():
- ema = moving_averages.ExponentialMovingAverage(0.25)
+ ema = tf.train.ExponentialMovingAverage(0.25)
self._CheckDecay(ema, actual_decay=0.25, dim=1)
def testAverageVariablesNoNumUpdates_Vector(self):
with self.test_session():
- ema = moving_averages.ExponentialMovingAverage(0.25)
+ ema = tf.train.ExponentialMovingAverage(0.25)
self._CheckDecay(ema, actual_decay=0.25, dim=5)
def testAverageVariablesNumUpdates_Scalar(self):
with self.test_session():
# With num_updates 1, the decay applied is 0.1818
- ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
+ ema = tf.train.ExponentialMovingAverage(0.25, num_updates=1)
self._CheckDecay(ema, actual_decay=0.181818, dim=1)
def testAverageVariablesNumUpdates_Vector(self):
with self.test_session():
# With num_updates 1, the decay applied is 0.1818
- ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
+ ema = tf.train.ExponentialMovingAverage(0.25, num_updates=1)
self._CheckDecay(ema, actual_decay=0.181818, dim=5)
+ def testAverageVariablesWithControlDeps(self):
+ with self.test_session() as sess:
+ v0 = tf.Variable(0, name="v0")
+ add_to_v0 = v0.assign_add(1)
+ v1 = tf.Variable([10.0], name="v1")
+ assign_to_v1 = v1.assign([20.0])
+ ema = tf.train.ExponentialMovingAverage(0.25)
+ with tf.control_dependencies([add_to_v0]):
+ ema_op = ema.apply([v1])
+ # the moving average of v1 should not have any control inputs
+ v1_avg = ema.average(v1)
+ self.assertEqual([], v1_avg.initializer.control_inputs)
+ self.assertEqual([], v1_avg.value().op.control_inputs)
+ self.assertEqual([], v1_avg.ref().op.control_inputs)
+ # We should be able to initialize v1_avg before v0.
+ sess.run(v1_avg.initializer)
+ sess.run(v0.initializer)
+ self.assertEqual([10.0], sess.run(v1_avg))
+ # running ema_op should add to v0 (in addition to updating v1_avg)
+ sess.run(assign_to_v1)
+ sess.run(ema_op)
+ self.assertEqual(1, sess.run(v0))
+ self.assertEqual([17.5], sess.run(v1_avg))
+
def testAverageVariablesNames(self):
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(30.0, name="v1")
+ v0 = tf.Variable(10.0, name="v0")
+ v1 = tf.Variable(30.0, name="v1")
tensor2 = v0 + v1
- ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
+ ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
self.assertEqual("v0/foo_avg", ema.average_name(v0))
self.assertEqual("v1/foo_avg", ema.average_name(v1))
self.assertEqual("add/foo_avg", ema.average_name(tensor2))
@@ -148,13 +166,13 @@ class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
def testAverageVariablesDeviceAssignment(self):
- with ops.device("dev_v0"):
- v0 = variables.Variable(10.0, name="v0")
- with ops.device("dev_v1"):
- v1 = state_ops.variable_op(shape=[1], dtype=dtypes.float32, name="v1")
+ with tf.device("dev_v0"):
+ v0 = tf.Variable(10.0, name="v0")
+ with tf.device("dev_v1"):
+ v1 = state_ops.variable_op(shape=[1], dtype=tf.float32, name="v1")
tensor2 = v0 + v1
- ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
- with ops.device("default"):
+ ema = tf.train.ExponentialMovingAverage(0.25, name="foo_avg")
+ with tf.device("default"):
ema.apply([v0, v1, tensor2])
self.assertEqual("dev_v0", ema.average(v0).device)
self.assertEqual("dev_v1", ema.average(v1).device)
@@ -162,4 +180,4 @@ class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
if __name__ == "__main__":
- googletest.main()
+ tf.test.main()
diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py
index 3154527bb6..0b936a20a3 100644
--- a/tensorflow/python/util/compat.py
+++ b/tensorflow/python/util/compat.py
@@ -70,6 +70,21 @@ else:
as_str = as_text
+def as_str_any(value):
+ """Converts to `str` as `str(value)`, but use `as_str` for `bytes`.
+
+ Args:
+ value: A object that can be converted to `str`.
+
+ Returns:
+ A `str` object.
+ """
+ if isinstance(value, bytes):
+ return as_str(value)
+ else:
+ return str(value)
+
+
# Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we
# need to check them specifically. The same goes from Real and Complex.
integral_types = (numbers.Integral, np.integer)
diff --git a/tensorflow/tensorboard/gulpfile.js b/tensorflow/tensorboard/gulpfile.js
index e01af56c9c..e6546297dd 100644
--- a/tensorflow/tensorboard/gulpfile.js
+++ b/tensorflow/tensorboard/gulpfile.js
@@ -64,26 +64,11 @@ gulp.task('compile.all', function() {
.pipe(ts(tsProject))
.on('error', onError);
return merge([
- // Send concatenated component code to build/component
- tsResult.js
- .pipe(isComponent)
- .pipe(concat('components.js'))
- .pipe(gulp.dest('build')),
-
// Duplicate all component code to live next to the ts file
// (makes polymer imports very clean)
tsResult.js
.pipe(isComponent)
- .pipe(gulp.dest('.')),
-
- tsResult.js
- .pipe(isApp)
- .pipe(gulp.dest('.')),
-
- // Create a unified defintions file at build/all.d.ts
- tsResult.dts
- .pipe(concat('all.d.ts'))
- .pipe(gulp.dest('build')),
+ .pipe(gulp.dest('.'))
]);
});
diff --git a/tensorflow/tensorboard/tests.html b/tensorflow/tensorboard/tests.html
deleted file mode 100644
index 31773f705c..0000000000
--- a/tensorflow/tensorboard/tests.html
+++ /dev/null
@@ -1,31 +0,0 @@
-<!DOCTYPE html>
-<html>
- <head>
- <title>Mocha</title>
- <meta http-equiv="Content-Type" content="text/html; charset=UTF-8">
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
- <link rel="stylesheet" href="node_modules/mocha/mocha.css" />
- </head>
- <body>
- <div id="mocha"></div>
- <script src="node_modules/chai/chai.js"></script>
- <script src="node_modules/mocha/mocha.js"></script>
- <script>mocha.setup('bdd')</script>
- <script>Polymer = function() {}
- // hack hack - can't get polymer to run in phantomjs, so mock it out
- </script>
- <script src="bower_components/d3/d3.js"></script>
- <script src="bower_components/svg-typewriter/svgtypewriter.js"></script>
- <script src="bower_components/plottable/plottable.js"></script>
- <script src="build/components.js"></script>
- <script src="build/test.js"></script>
- <script>
- if (window.mochaPhantomJS) {
- mochaPhantomJS.run();
- } else {
- mocha.run();
- }
- </script>
- </body>
-</html>
-
diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template
new file mode 100644
index 0000000000..0a97daa4a8
--- /dev/null
+++ b/tools/bazel.rc.template
@@ -0,0 +1,4 @@
+build:cuda --crosstool_top=//third_party/gpus/crosstool
+
+build --force_python=py$PYTHON_MAJOR_VERSION
+build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY
diff --git a/util/python/python_config.sh b/util/python/python_config.sh
index 27b20949e8..dae157766b 100755
--- a/util/python/python_config.sh
+++ b/util/python/python_config.sh
@@ -45,6 +45,12 @@ function setup_python {
exit 1
fi
+ local python_major_version=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import sys; print(sys.version_info[0]);')
+ if [ "$python_major_version" == "" ]; then
+ echo -e "\n\nERROR: Problem getting python version. Is $PYTHON_BIN_PATH the correct python binary?"
+ exit 1
+ fi
+
local python_include=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; from distutils import sysconfig; print(sysconfig.get_python_inc());')
if [ "$python_include" == "" ]; then
echo -e "\n\nERROR: Problem getting python include path. Is distutils installed?"
@@ -70,6 +76,12 @@ function setup_python {
ln -s "${python_include}" util/python/python_include
ln -s "${python_lib}" util/python/python_lib
ln -s "${numpy_include}" third_party/py/numpy/numpy_include
+
+ # Write tools/bazel.rc
+ echo "# Autogenerated by configure: DO NOT EDIT" > tools/bazel.rc
+ sed -e "s/\$PYTHON_MAJOR_VERSION/$python_major_version/g" \
+ -e "s[\$PYTHON_BINARY[$PYTHON_BIN_PATH[g" \
+ tools/bazel.rc.template >> tools/bazel.rc
}
function check_python {