aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-05 14:05:27 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-05 14:05:27 -0800
commit1c579361cd1e088dd5e05a394b1561a73e3667ba (patch)
treeec464b9ac18113dc052744b6714eebbc7c6cc34d
parent208350a6092f9faa473daf8b6eb6a80e9f9518f1 (diff)
Added 'logging' import to control_flow_ops which is used in the file but not imported.
Change: 110842260
-rw-r--r--WORKSPACE4
-rw-r--r--eigen.BUILD2
-rw-r--r--tensorflow/core/BUILD8
-rw-r--r--tensorflow/core/client/tensor_c_api.cc27
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc26
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc89
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc42
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h33
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc120
-rw-r--r--tensorflow/core/example/example.proto19
-rw-r--r--tensorflow/core/example/feature.proto2
-rw-r--r--tensorflow/core/framework/attr_value_util.cc2
-rw-r--r--tensorflow/core/framework/config.proto18
-rw-r--r--tensorflow/core/framework/function.cc42
-rw-r--r--tensorflow/core/framework/graph.proto1
-rw-r--r--tensorflow/core/framework/load_library.cc76
-rw-r--r--tensorflow/core/framework/op.cc47
-rw-r--r--tensorflow/core/framework/op.h14
-rw-r--r--tensorflow/core/framework/op_def_builder.cc2
-rw-r--r--tensorflow/core/framework/op_kernel.cc18
-rw-r--r--tensorflow/core/framework/op_kernel.h7
-rw-r--r--tensorflow/core/framework/rendezvous.cc35
-rw-r--r--tensorflow/core/framework/tensor_reference.h11
-rw-r--r--tensorflow/core/graph/algorithm.cc12
-rw-r--r--tensorflow/core/graph/graph_constructor.cc13
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc27
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_erf.cc23
-rw-r--r--tensorflow/core/kernels/cwise_op_erfc.cc23
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_erf.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_erfc.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_lgamma.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_lgamma.cc23
-rw-r--r--tensorflow/core/kernels/cwise_ops.h9
-rw-r--r--tensorflow/core/kernels/example_parsing_ops.cc102
-rw-r--r--tensorflow/core/kernels/queue_base.cc14
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h68
-rw-r--r--tensorflow/core/kernels/reshape_op.h3
-rw-r--r--tensorflow/core/kernels/sparse_to_dense_op.cc18
-rw-r--r--tensorflow/core/kernels/summary_op.cc12
-rw-r--r--tensorflow/core/kernels/tile_ops.cc11
-rw-r--r--tensorflow/core/kernels/transpose_op.cc55
-rw-r--r--tensorflow/core/kernels/transpose_op.h6
-rw-r--r--tensorflow/core/kernels/transpose_op_gpu.cu.cc10
-rw-r--r--tensorflow/core/lib/core/errors.h6
-rw-r--r--tensorflow/core/lib/io/inputbuffer.cc5
-rw-r--r--tensorflow/core/lib/io/inputbuffer_test.cc26
-rw-r--r--tensorflow/core/lib/strings/regexp.h33
-rw-r--r--tensorflow/core/lib/strings/str_util.h25
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc1
-rw-r--r--tensorflow/core/ops/function_ops.cc70
-rw-r--r--tensorflow/core/ops/math_ops.cc18
-rw-r--r--tensorflow/core/ops/ops.pbtxt202
-rw-r--r--tensorflow/core/ops/parsing_ops.cc34
-rw-r--r--tensorflow/core/ops/script_ops.cc37
-rw-r--r--tensorflow/core/ops/sparse_ops.cc9
-rw-r--r--tensorflow/core/ops/summary_ops.cc6
-rw-r--r--tensorflow/core/platform/env.cc42
-rw-r--r--tensorflow/core/platform/env_test.cc3
-rw-r--r--tensorflow/core/platform/load_library.cc44
-rw-r--r--tensorflow/core/platform/load_library.h33
-rw-r--r--tensorflow/core/platform/posix/env.cc12
-rw-r--r--tensorflow/core/platform/regexp.h12
-rw-r--r--tensorflow/core/platform/tracing.h9
-rw-r--r--tensorflow/core/public/env.h28
-rw-r--r--tensorflow/core/public/tensor_c_api.h41
-rw-r--r--tensorflow/core/public/version.h2
-rw-r--r--tensorflow/core/util/bcast.h3
-rw-r--r--tensorflow/examples/tutorials/mnist/BUILD2
-rw-r--r--tensorflow/examples/tutorials/mnist/input_data.py42
-rw-r--r--tensorflow/g3doc/api_docs/cc/ClassEnv.md28
-rw-r--r--tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md28
-rw-r--r--tensorflow/g3doc/api_docs/cc/ClassSession.md2
-rw-r--r--tensorflow/g3doc/api_docs/cc/StructTF_Buffer.md24
-rw-r--r--tensorflow/g3doc/api_docs/cc/index.md2
-rw-r--r--tensorflow/g3doc/api_docs/index.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/client.md3
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md31
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md8
-rw-r--r--tensorflow/g3doc/api_docs/python/math_ops.md57
-rw-r--r--tensorflow/g3doc/api_docs/python/script_ops.md46
-rw-r--r--tensorflow/g3doc/api_docs/python/sparse_ops.md24
-rw-r--r--tensorflow/g3doc/api_docs/python/state_ops.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/train.md219
-rw-r--r--tensorflow/g3doc/extras/README.txt3
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/index.md6
-rw-r--r--tensorflow/g3doc/how_tos/index.md2
-rw-r--r--tensorflow/g3doc/resources/index.md7
-rw-r--r--tensorflow/g3doc/resources/leftnav_files1
-rw-r--r--tensorflow/g3doc/resources/versions.md143
-rw-r--r--tensorflow/g3doc/tutorials/index.md2
-rw-r--r--tensorflow/models/embedding/word2vec_kernels.cc2
-rw-r--r--tensorflow/models/image/cifar10/BUILD1
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py151
-rw-r--r--tensorflow/models/image/cifar10/cifar10_input.py156
-rw-r--r--tensorflow/python/BUILD59
-rw-r--r--tensorflow/python/client/session.py5
-rw-r--r--tensorflow/python/client/tf_session.i11
-rw-r--r--tensorflow/python/framework/framework_lib.py4
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py6
-rw-r--r--tensorflow/python/framework/importer_test.py6
-rw-r--r--tensorflow/python/framework/load_library.py74
-rw-r--r--tensorflow/python/framework/ops.py33
-rw-r--r--tensorflow/python/framework/python_op_gen.cc92
-rw-r--r--tensorflow/python/framework/python_op_gen.h11
-rw-r--r--tensorflow/python/framework/python_op_gen.i24
-rw-r--r--tensorflow/python/framework/test_util.py2
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py9
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py7
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py24
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py27
-rw-r--r--tensorflow/python/kernel_tests/gradient_checker_test.py8
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py77
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py84
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py22
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py14
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py (renamed from tensorflow/python/ops/sparse_ops_test.py)15
-rw-r--r--tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py22
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py4
-rw-r--r--tensorflow/python/lib/core/py_func.cc338
-rw-r--r--tensorflow/python/lib/core/py_func.h47
-rw-r--r--tensorflow/python/lib/core/py_func.i29
-rw-r--r--tensorflow/python/ops/array_grad.py18
-rw-r--r--tensorflow/python/ops/array_ops.py2
-rw-r--r--tensorflow/python/ops/control_flow_ops.py8
-rw-r--r--tensorflow/python/ops/gradients_test.py16
-rw-r--r--tensorflow/python/ops/math_grad.py24
-rw-r--r--tensorflow/python/ops/math_ops.py57
-rw-r--r--tensorflow/python/ops/nn.py3
-rw-r--r--tensorflow/python/ops/op_def_library.py11
-rw-r--r--tensorflow/python/ops/parsing_ops.py63
-rw-r--r--tensorflow/python/ops/script_ops.py135
-rw-r--r--tensorflow/python/ops/sparse_ops.py37
-rw-r--r--tensorflow/python/ops/standard_ops.py1
-rw-r--r--tensorflow/python/ops/summary_ops.py4
-rw-r--r--tensorflow/python/platform/default/_flags.py2
-rw-r--r--tensorflow/python/platform/default/flags_test.py3
-rw-r--r--tensorflow/python/tensorflow.i3
-rw-r--r--tensorflow/python/training/coordinator.py94
-rw-r--r--tensorflow/python/training/optimizer.py2
-rw-r--r--tensorflow/python/training/saver.py3
-rw-r--r--tensorflow/python/training/training.py1
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc10
-rw-r--r--tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts8
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/common.ts13
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts12
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts17
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts347
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/render.ts196
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts19
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts37
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts4
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts62
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts38
-rw-r--r--tensorflow/tensorboard/components/tf-graph-common/lib/template.ts40
-rw-r--r--tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html2
-rw-r--r--tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html8
-rw-r--r--tensorflow/tensorboard/components/tf-graph/tf-graph.html4
-rw-r--r--tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html66
-rw-r--r--tensorflow/tensorboard/dist/index.html1
-rw-r--r--tensorflow/tensorboard/dist/tf-tensorboard.html2581
-rw-r--r--tensorflow/tensorboard/gulpfile.js12
-rw-r--r--tensorflow/tensorboard/package.json26
-rw-r--r--tensorflow/tensorboard/tensorboard_handler.py53
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues2
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor2
170 files changed, 5924 insertions, 2045 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 876919fd1e..99c0da542a 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -21,8 +21,8 @@ new_http_archive(
new_http_archive(
name = "eigen_archive",
- url = "https://bitbucket.org/eigen/eigen/get/3.3-beta1.tar.gz",
- sha256 = "2d6533e86ed6b54d30ae1d6c10808533b335d1c570c5e4c58ce2f03da99c134b",
+ url = "https://bitbucket.org/eigen/eigen/get/a0661a2.tar.gz",
+ sha256 = "d4d13995a0b3a2d80189f83d28647eb35819a478522149c15a761d91f53579b1",
build_file = "eigen.BUILD",
)
diff --git a/eigen.BUILD b/eigen.BUILD
index c33fe7186e..5c6127e6a9 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"])
-archive_dir = "eigen-eigen-ce5a455b34c0"
+archive_dir = "eigen-eigen-a0661a2bb165"
cc_library(
name = "eigen",
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 33444cd45d..b9f253740e 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -111,6 +111,7 @@ cc_library(
"public/tensor_shape.h",
],
copts = tf_copts(),
+ linkopts = ["-ldl"],
visibility = [
":friends",
"//tensorflow:internal",
@@ -171,8 +172,12 @@ tf_cuda_library(
hdrs = glob([
"public/**/*.h",
"util/device_name_utils.h",
- ]),
+ ]) + [
+ "framework/op.h",
+ "framework/op_kernel.h",
+ ],
copts = tf_copts(),
+ linkopts = ["-ldl"],
visibility = ["//visibility:public"],
deps = [
":lib",
@@ -422,6 +427,7 @@ tf_gen_op_libs(
"no_op",
"parsing_ops",
"random_ops",
+ "script_ops",
"sendrecv_ops",
"sparse_ops",
"state_ops",
diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc
index 87bdc664fb..bc6c88f8c0 100644
--- a/tensorflow/core/client/tensor_c_api.cc
+++ b/tensorflow/core/client/tensor_c_api.cc
@@ -52,6 +52,11 @@ struct TF_Status {
Status status;
};
+struct TF_Library {
+ void* lib_handle;
+ TF_Buffer op_list;
+};
+
TF_Status* TF_NewStatus() { return new TF_Status; }
void TF_DeleteStatus(TF_Status* s) { delete s; }
@@ -304,6 +309,10 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, const TensorShape& shape) {
[](void*, size_t, void*) {}, nullptr);
}
+// Helpers for loading a TensorFlow plugin (a .so file).
+Status LoadLibrary(const char* library_filename, void** result,
+ const void** buf, size_t* len);
+
} // namespace tensorflow
extern "C" {
@@ -382,4 +391,22 @@ void TF_Run(TF_Session* s,
}
}
+const void* TF_BufferData(TF_Buffer* buffer) { return buffer->data; }
+
+size_t TF_BufferLength(TF_Buffer* buffer) { return buffer->length; }
+
+TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
+ TF_Library* lib_handle = new TF_Library;
+ status->status = tensorflow::LoadLibrary(
+ library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
+ &lib_handle->op_list.length);
+ if (!status->status.ok()) {
+ delete lib_handle;
+ return nullptr;
+ }
+ return lib_handle;
+}
+
+TF_Buffer TF_GetOpList(TF_Library* lib_handle) { return lib_handle->op_list; }
+
} // end extern "C"
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 94818e3938..26b2948166 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h"
#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -454,6 +455,31 @@ Status DirectSession::CreateGraphs(
std::unique_ptr<FunctionLibraryDefinition> fdefs;
std::unique_ptr<Graph> graph;
GraphConstructorOptions opts;
+ if (options_.config.has_graph_options()) {
+ opts.optimizer_do_cse = !options_.config.graph_options()
+ .skip_common_subexpression_elimination();
+ } else {
+ opts.optimizer_do_cse = true;
+ }
+
+ if (opts.optimizer_do_cse) {
+ // Prevent CSE from eliminating nodes that will be required during
+ // RewriteGraphForExecution, below.
+ std::unordered_set<StringPiece, StringPiece::Hasher> no_cse_nodes;
+ for (const string& feed : feeds) {
+ no_cse_nodes.insert(ParseTensorName(feed).first);
+ }
+ for (const string& fetch : fetches) {
+ no_cse_nodes.insert(ParseTensorName(fetch).first);
+ }
+ for (const string& target_node : target_nodes) {
+ no_cse_nodes.insert(target_node);
+ }
+ opts.cse_consider_function = [no_cse_nodes](const Node* n) {
+ return n->type_string() != "Const" && !no_cse_nodes.count(n->name());
+ };
+ }
+
{
mutex_lock l(graph_def_lock_);
fdefs.reset(new FunctionLibraryDefinition(graph_def_.library()));
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 9b6c2de473..c0376e29fa 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -185,7 +185,7 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
LOG(ERROR) << "Failed to get StreamExecutor for device " << gpu_id_;
return;
}
- em_.reset(new EventMgr(executor));
+ em_.reset(new EventMgr(executor, options.config.gpu_options()));
if (FLAGS_brain_gpu_max_streams < 1) {
LOG(FATAL) << "Invalid value for brain_gpu_max_streams.";
@@ -262,9 +262,14 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
gpu::Stream* stream = gpu_device_context->stream();
const auto stream_id = gpu_device_context->stream_id();
- VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op "
- << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
- << stream_id << "]";
+ const bool vlog_1 = VLOG_IS_ON(1);
+ const bool vlog_2 = vlog_1 && VLOG_IS_ON(2);
+
+ if (vlog_1) {
+ VLOG(1) << "GpuDevice::Compute " << op_kernel->name() << " op "
+ << op_kernel->def().op() << " on GPU" << gpu_id_ << " stream["
+ << stream_id << "]";
+ }
// NOTE(tucker): We need to discriminate between Eigen GPU
// operations and all others. If an operation is Eigen
@@ -292,7 +297,7 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
OP_REQUIRES(context, idc != nullptr,
errors::Internal("Input device context ", i,
" was not set properly."));
- if (VLOG_IS_ON(2)) {
+ if (vlog_2) {
const void* base;
size_t len;
if (context->has_input(i)) {
@@ -316,35 +321,36 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
}
gpu::cuda::ScopedActivateExecutorContext scoped_activation{
stream->parent(), gpu::cuda::MultiOpActivation::kYes};
- // Keep a copy of the inputs before Compute runs, in case they get
- // deleted. TODO(misard) this will be fixed when the tracking is
- // done right.
- EventMgr::TensorReferenceVector* tensor_refs = nullptr;
- if (!FLAGS_brain_gpu_sync_every_op) {
+
+ if (FLAGS_brain_gpu_sync_every_op) {
+ op_kernel->Compute(context);
+ if (context->status().ok()) {
+ // Note: GPUUtil::Sync() only syncs the default stream.
+ // We need to either sync the stream used by this op, or
+ // all streams. Given that this flag is typically used for
+ // debugging it makes more sense to sync all GPU activity.
+ context->SetStatus(GPUUtil::SyncAll(this));
+ }
+ } else {
+ // Keep a copy of the inputs before Compute runs, in case they get
+ // deleted. TODO(misard) this will be fixed when the tracking is
+ // done right.
+ EventMgr::TensorReferenceVector tensor_refs;
const int N_inputs = context->num_inputs();
- tensor_refs = new EventMgr::TensorReferenceVector;
- tensor_refs->reserve(N_inputs + context->num_outputs());
+ tensor_refs.reserve(N_inputs + context->num_outputs());
for (int ii = 0; ii < N_inputs; ++ii) {
if (context->has_input(ii)) {
if (IsRefType(context->input_dtype(ii))) {
Tensor in = context->mutable_input(ii, false);
- tensor_refs->push_back(TensorReference(in));
+ tensor_refs.push_back(TensorReference(in));
} else {
const Tensor& in = context->input(ii);
- tensor_refs->push_back(TensorReference(in));
+ tensor_refs.push_back(TensorReference(in));
}
}
}
- }
- op_kernel->Compute(context);
- if (context->status().ok()) {
- if (FLAGS_brain_gpu_sync_every_op) {
- // Note: GPUUtil::Sync() only syncs the default stream.
- // We need to either sync the stream used by this op, or
- // all streams. Given that this flag is typically used for
- // debugging it makes more sense to sync all GPU activity.
- context->SetStatus(GPUUtil::SyncAll(this));
- } else {
+ op_kernel->Compute(context);
+ if (context->status().ok()) {
// The GPU kernel has been queued, but may not complete for some
// time. As soon as this function completes, the caller will
// discard its refs on the inputs, outputs and any scratch
@@ -352,21 +358,19 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
// held until the kernel completes.
for (int ii = 0; ii < context->num_temps(); ++ii) {
Tensor* temp = context->temp(ii);
- VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp);
- tensor_refs->push_back(TensorReference(*temp));
+ if (vlog_2) {
+ VLOG(2) << "Saving ref to temp Tensor @ " << DMAHelper::base(temp);
+ }
+ tensor_refs.push_back(TensorReference(*temp));
}
for (int ii = 0; ii < context->num_outputs(); ++ii) {
Tensor* temp = context->mutable_output(ii);
if (nullptr != temp) {
- tensor_refs->push_back(TensorReference(*temp));
+ tensor_refs.push_back(TensorReference(*temp));
}
}
em_->ThenDeleteTensors(stream, tensor_refs);
}
- } else {
- if (!FLAGS_brain_gpu_sync_every_op) {
- delete tensor_refs;
- }
}
}
}
@@ -431,28 +435,29 @@ namespace {
class ConcretePerOpGpuDevice : public PerOpGpuDevice {
public:
explicit ConcretePerOpGpuDevice(gpu::Stream* stream,
- EigenAllocator* allocator)
- : device_(stream, allocator), allocator_(allocator) {}
- ~ConcretePerOpGpuDevice() { delete allocator_; }
+ Allocator* base_allocator,
+ ::tensorflow::EventMgr* em)
+ : allocator_(stream, base_allocator, em), device_(stream, &allocator_) {}
const Eigen::GpuDevice& device() const override { return device_; }
private:
+ EigenAllocator allocator_;
Eigen::GpuDevice device_;
- EigenAllocator* allocator_;
};
#else
class ConcretePerOpGpuDevice : public PerOpGpuDevice {
public:
- explicit ConcretePerOpGpuDevice(EigenCudaStreamDevice* stream_device)
- : device_(stream_device), stream_device_(stream_device) {}
- ~ConcretePerOpGpuDevice() { delete stream_device_; }
+ explicit ConcretePerOpGpuDevice(const cudaStream_t* cuda_stream, int gpu_id,
+ Allocator* base_allocator)
+ : stream_device_(cuda_stream, gpu_id, base_allocator),
+ device_(&stream_device_) {}
const Eigen::GpuDevice& device() const override { return device_; }
private:
+ EigenCudaStreamDevice stream_device_;
Eigen::GpuDevice device_;
- EigenCudaStreamDevice* stream_device_;
};
#endif
} // namespace
@@ -460,13 +465,11 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
const PerOpGpuDevice* BaseGPUDevice::NewDevice(int stream_id,
Allocator* allocator) {
#if defined(__GCUDACC__) || defined(__GCUDACC_HOST__)
- auto ea = new EigenAllocator(streams_[stream_id], allocator, em_.get());
- return new ConcretePerOpGpuDevice(streams_[stream_id], ea);
+ return new ConcretePerOpGpuDevice(streams_[stream_id], allocator, em_.get());
#else
const cudaStream_t* cuda_stream = reinterpret_cast<const cudaStream_t*>(
streams_[stream_id]->implementation()->CudaStreamMemberHack());
- auto es = new EigenCudaStreamDevice(cuda_stream, gpu_id_, allocator);
- return new ConcretePerOpGpuDevice(es);
+ return new ConcretePerOpGpuDevice(cuda_stream, gpu_id_, allocator);
#endif
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
index 962848ad17..6dd7c3c235 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -17,13 +17,20 @@ limitations under the License.
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/stream.h"
+#include "tensorflow/core/framework/config.pb.h"
namespace gpu = ::perftools::gputools;
namespace tensorflow {
-EventMgr::EventMgr(gpu::StreamExecutor* se)
+EventMgr::EventMgr(gpu::StreamExecutor* se, const GPUOptions& gpu_options)
: exec_(se),
+ deferred_bytes_threshold_(gpu_options.deferred_deletion_bytes()
+ ? gpu_options.deferred_deletion_bytes()
+ : 8 * 1048576),
+ accumulated_stream_(nullptr),
+ accumulated_tensors_(new TensorReferenceVector),
+ accumulated_tensor_bytes_(0),
// threadpool_ has 1 thread for the polling loop, and one to execute
// event callback functions. Maybe we should have more?
threadpool_(Env::Default(), "GPU_Event_Manager", 2) {
@@ -39,6 +46,10 @@ EventMgr::~EventMgr() {
for (auto& e : free_events_) {
delete e;
}
+ for (auto& t : *(accumulated_tensors_)) {
+ t.Unref();
+ }
+ delete accumulated_tensors_;
while (!used_events_.empty()) {
InUse* ue = &used_events_[0];
delete ue->event;
@@ -51,6 +62,35 @@ EventMgr::~EventMgr() {
}
}
+void EventMgr::ThenDeleteTensors(perftools::gputools::Stream* stream,
+ const TensorReferenceVector& tensors) {
+ mutex_lock l(mu_);
+ // TODO(jeff): We currently keep one accumulated_tensors_ object.
+ // If we start to use multiple streams heavily, we might want to keep
+ // separate vectors/byte counters per stream
+ if (!accumulated_tensors_->empty() && stream != accumulated_stream_) {
+ FlushAccumulatedTensors();
+ }
+ accumulated_stream_ = stream;
+ for (auto t : tensors) {
+ // accumulated_tensors_ takes over ownership of the reference to "t"
+ accumulated_tensors_->push_back(t);
+ accumulated_tensor_bytes_ += t.TotalBytes();
+ }
+ if (accumulated_tensor_bytes_ >= deferred_bytes_threshold_) {
+ FlushAccumulatedTensors();
+ }
+}
+
+void EventMgr::FlushAccumulatedTensors() {
+ DCHECK(!accumulated_tensors_->empty());
+ DCHECK(accumulated_stream_ != nullptr);
+ QueueTensors(accumulated_stream_, accumulated_tensors_);
+ accumulated_tensors_ = new TensorReferenceVector;
+ accumulated_tensor_bytes_ = 0;
+ accumulated_stream_ = nullptr;
+}
+
// This polling loop runs at a relatively low frequency. Most calls to
// PollEvents() should come directly from Compute() via
// ThenDeleteTensors(). This function's purpose is to ensure that
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
index 3faee71614..09d785d792 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -37,29 +37,24 @@ class StreamExecutor;
namespace tensorflow {
+class GPUOptions;
+
// An object to keep track of pending Events in the StreamExecutor streams
// and associated Tensors that cannot safely be deleted until the associated
// Events are recorded.
class EventMgr {
public:
- explicit EventMgr(perftools::gputools::StreamExecutor* se);
+ EventMgr(perftools::gputools::StreamExecutor* se,
+ const GPUOptions& gpu_options);
~EventMgr();
typedef gtl::InlinedVector<TensorReference, 4> TensorReferenceVector;
- // Takes ownership of *tensors and deletes it as soon as all events
- // currently enqueued on *stream have completed.
- inline void ThenDeleteTensors(perftools::gputools::Stream* stream,
- TensorReferenceVector* tensors) {
- ToFreeVector to_free;
- {
- mutex_lock l(mu_);
- QueueTensors(stream, tensors);
- PollEvents(false, &to_free);
- }
- FreeMemory(to_free);
- }
+ // Releases the references on the elements of "tensors" as soon as
+ // all events currently enqueued on "stream" have completed.
+ void ThenDeleteTensors(perftools::gputools::Stream* stream,
+ const TensorReferenceVector& tensors);
struct BufRec {
Allocator* alloc;
@@ -92,8 +87,11 @@ class EventMgr {
private:
friend class TEST_EventMgrHelper;
+ perftools::gputools::StreamExecutor* const exec_;
+ const int64 deferred_bytes_threshold_;
mutex mu_;
- perftools::gputools::StreamExecutor* exec_;
+
+ void FlushAccumulatedTensors() EXCLUSIVE_LOCKS_REQUIRED(mu_);
struct InUse {
perftools::gputools::Event* event;
@@ -122,7 +120,6 @@ class EventMgr {
// Tensors and/or a BufRec to be deleted only after the Event
// records.
void QueueInUse(perftools::gputools::Stream* stream, InUse in_use)
-
EXCLUSIVE_LOCKS_REQUIRED(mu_);
void QueueTensors(perftools::gputools::Stream* stream,
@@ -156,6 +153,12 @@ class EventMgr {
// A stack of unused events
std::vector<perftools::gputools::Event*> free_events_ GUARDED_BY(mu_);
+ // Buffered list of tensors waiting to have an event queued for deletion
+ perftools::gputools::Stream* accumulated_stream_ GUARDED_BY(mu_);
+ TensorReferenceVector* accumulated_tensors_ GUARDED_BY(mu_);
+ // Sum of the TotalBytes() of the tensors in "accumulated_tensors_"
+ int64 accumulated_tensor_bytes_ GUARDED_BY(mu_);
+
// A FIFO queue of InUse events and associated tensors.
std::deque<InUse> used_events_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
index 910093a069..57c1554678 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -17,10 +17,12 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
+#include <atomic>
#include "tensorflow/stream_executor/multi_platform_manager.h"
#include "tensorflow/stream_executor/stream_executor.h"
#include <gtest/gtest.h>
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/framework/config.pb.h"
namespace gpu = ::perftools::gputools;
@@ -59,11 +61,32 @@ class TEST_EventMgrHelper {
EventMgr* em_;
};
+static std::atomic_int_fast64_t live_tensor_bytes(0);
+
+// A TensorBuffer that counts live memory usage for testing
+class TestTensorBuffer : public TensorBuffer {
+ public:
+ TestTensorBuffer(size_t bytes) : bytes_(bytes) {
+ live_tensor_bytes += bytes_;
+ }
+ ~TestTensorBuffer() { live_tensor_bytes -= bytes_; }
+
+ size_t size() const override { return bytes_; }
+
+ // Not used in this test
+ void* data() const override { return nullptr; }
+ TensorBuffer* root_buffer() override { return nullptr; }
+ void FillAllocationDescription(AllocationDescription* arg) const override {}
+
+ private:
+ size_t bytes_;
+};
+
namespace {
TEST(EventMgr, Empty) {
auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
- EventMgr em(stream_exec);
+ EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(0, th.free_size());
@@ -74,7 +97,7 @@ TEST(EventMgr, Empty) {
// the max simultaneously pending, we should not allocate any more.
TEST(EventMgr, DelayedPolling) {
auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
- EventMgr em(stream_exec);
+ EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
EventMgr::TensorReferenceVector* v = nullptr;
@@ -103,22 +126,87 @@ TEST(EventMgr, DelayedPolling) {
}
}
-// Immediate polling should require only one event to be allocated.
-TEST(EventMgr, ImmediatePolling) {
+static void AddTensorReference(EventMgr::TensorReferenceVector* v, int64 size) {
+ TestTensorBuffer* buf = new TestTensorBuffer(size);
+ v->push_back(TensorReference(buf));
+ buf->Unref();
+}
+
+TEST(EventMgr, FlushLargeTensorImmediately) {
auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
- EventMgr em(stream_exec);
+ EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
- EXPECT_EQ(0, th.queue_size());
- EXPECT_EQ(0, th.free_size());
- EventMgr::TensorReferenceVector* v = nullptr;
+ EXPECT_EQ(0, live_tensor_bytes);
std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
- v = new EventMgr::TensorReferenceVector;
+ EventMgr::TensorReferenceVector v;
+ AddTensorReference(&v, 100 * 1048576);
em.ThenDeleteTensors(stream.get(), v);
- EXPECT_EQ(0, th.queue_size());
- EXPECT_EQ(1, th.free_size());
+ th.PollEvents(false); // Ensure things get registered to be freed by Poll
+ EXPECT_EQ(0, live_tensor_bytes);
+ }
+}
+
+TEST(EventMgr, ManySmallTensorsFlushedImmediately) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec, GPUOptions());
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, live_tensor_bytes);
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ EventMgr::TensorReferenceVector v;
+ for (int i = 0; i < 1000; i++) {
+ AddTensorReference(&v, 100 * 1024);
+ }
+ em.ThenDeleteTensors(stream.get(), v);
+ th.PollEvents(false); // Ensure things get registered to be freed by Poll
+ EXPECT_EQ(0, live_tensor_bytes);
+ }
+}
+
+TEST(EventMgr, StreamSwitchingFlushesImmediately) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec, GPUOptions());
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, live_tensor_bytes);
+ std::unique_ptr<gpu::Stream> stream1(new gpu::Stream(stream_exec));
+ std::unique_ptr<gpu::Stream> stream2(new gpu::Stream(stream_exec));
+ stream1->Init();
+ stream2->Init();
+ EventMgr::TensorReferenceVector v1;
+ AddTensorReference(&v1, 1024);
+ em.ThenDeleteTensors(stream1.get(), v1);
+
+ EventMgr::TensorReferenceVector v2;
+ AddTensorReference(&v2, 1024);
+ int64 initial_live_bytes = live_tensor_bytes;
+ em.ThenDeleteTensors(stream2.get(), v2);
+ th.PollEvents(false); // Ensure things get registered to be freed by Poll
+ // Different stream should cause first tensor to get deleted
+ EXPECT_GT(initial_live_bytes, live_tensor_bytes);
+}
+
+TEST(EventMgr, ManySmallTensorsSeperateCallsFlushed) {
+ auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
+ EventMgr em(stream_exec, GPUOptions());
+ TEST_EventMgrHelper th(&em);
+ EXPECT_EQ(0, live_tensor_bytes);
+ std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ CHECK(stream.get());
+ stream->Init();
+ for (int i = 0; i < 5; ++i) {
+ for (int i = 0; i < 1000; i++) {
+ EventMgr::TensorReferenceVector v;
+ AddTensorReference(&v, 100 * 1024);
+ em.ThenDeleteTensors(stream.get(), v);
+ }
+ th.PollEvents(false); // Ensure things get registered to be freed by Poll
+ // Some of the tensors at least should be flushed
+ EXPECT_GT(1000 * 100 * 1024, live_tensor_bytes);
}
}
@@ -126,16 +214,15 @@ TEST(EventMgr, ImmediatePolling) {
// should clear the queue.
TEST(EventMgr, LongDelayedPolling) {
auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
- EventMgr em(stream_exec);
+ EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(0, th.free_size());
- EventMgr::TensorReferenceVector* v = nullptr;
std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
- v = new EventMgr::TensorReferenceVector;
+ EventMgr::TensorReferenceVector* v = new EventMgr::TensorReferenceVector;
th.QueueTensors(stream.get(), v);
EXPECT_EQ(1 + i, th.queue_size());
EXPECT_EQ(0, th.free_size());
@@ -149,16 +236,15 @@ TEST(EventMgr, LongDelayedPolling) {
// down gracefully.
TEST(EventMgr, NonEmptyShutdown) {
auto stream_exec = GPUMachineManager()->ExecutorForDevice(0).ValueOrDie();
- EventMgr em(stream_exec);
+ EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(0, th.free_size());
- EventMgr::TensorReferenceVector* v = nullptr;
std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
- v = new EventMgr::TensorReferenceVector;
+ EventMgr::TensorReferenceVector* v = new EventMgr::TensorReferenceVector;
th.QueueTensors(stream.get(), v);
EXPECT_EQ(1 + i, th.queue_size());
EXPECT_EQ(0, th.free_size());
diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto
index f4d946dcf0..d2e9f24563 100644
--- a/tensorflow/core/example/example.proto
+++ b/tensorflow/core/example/example.proto
@@ -4,6 +4,8 @@ syntax = "proto3";
import "tensorflow/core/example/feature.proto";
// option cc_enable_arenas = true;
+option java_multiple_files = true;
+option java_package = "org.tensorflow.example";
package tensorflow;
@@ -163,12 +165,13 @@ message Example {
// an empty list (zero length).
// - If a FeatureList L exists, it may be empty (zero length).
// - If a FeatureList L is non-empty, all features within the FeatureList
-// must have data type T, and all features within the FeatureList must
-// have the same size.
+// must have data type T.
+// - If a FeatureList L is non-empty, it is up to the parser configuration
+// to determine if all features within the FeatureList must
+// have the same size. The same holds for this FeatureList across multiple
+// examples.
// - If a FeatureList L exists in one example with data type T,
// it must be of type T in all other examples when present.
-// - If a FeatureList L exists in one example having features' sizes all S,
-// these sizes must be S in all other examples when present.
//
// Examples of conformant and non-conformant examples' FeatureLists:
//
@@ -186,7 +189,8 @@ message Example {
// feature: { int64_list: { value: [ 5 ] } } }
// } }
//
-// Non-conformant FeatureLists (mismatched sizes):
+// Conditionally conformant FeatureLists, the parser configuration determines
+// if the feature sizes must match:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
@@ -244,7 +248,8 @@ message Example {
// feature: { int64_list: { value: [ 2 ] } } }
// } }
//
-// Non-conformant pair of SequenceExample (mismatched sizes)
+// Conditionally conformant pair of SequenceExample; the parser configuration
+// determines if the feature sizes must match:
// feature_lists: { feature_list: {
// key: "movie_ratings"
// value: { feature: { float_list: { value: [ 4.5 ] } }
@@ -253,7 +258,7 @@ message Example {
// and:
// feature_lists: { feature_list: {
// key: "movie_ratings"
-// value: { feature: { float_list: { value: [ 4.0, 5.0 ] } }
+// value: { feature: { float_list: { value: [ 4.0 ] } }
// feature: { float_list: { value: [ 5.0, 3.0 ] } }
// } }
diff --git a/tensorflow/core/example/feature.proto b/tensorflow/core/example/feature.proto
index 52d5fac441..130e142503 100644
--- a/tensorflow/core/example/feature.proto
+++ b/tensorflow/core/example/feature.proto
@@ -55,6 +55,8 @@
syntax = "proto3";
// option cc_enable_arenas = true;
+option java_multiple_files = true;
+option java_package = "org.tensorflow.example";
package tensorflow;
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index ccdcf35b34..004f65fe62 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/regexp.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto
index 3f5d01fb8d..167dd632ac 100644
--- a/tensorflow/core/framework/config.proto
+++ b/tensorflow/core/framework/config.proto
@@ -19,6 +19,21 @@ message GPUOptions {
// "BFC": A "Best-fit with coalescing" algorithm, simplified from a
// version of dlmalloc.
string allocator_type = 2;
+
+ // Delay deletion of up to this many bytes to reduce the number of
+ // interactions with gpu driver code. If 0, the system chooses
+ // a reasonable default (several MBs).
+ int64 deferred_deletion_bytes = 3;
+};
+
+message GraphOptions {
+ // If true, do not attempt to optimize the graph using common
+ // subexpression elimination.
+ bool skip_common_subexpression_elimination = 1;
+
+ // If true, use control flow to schedule the activation of Recv nodes.
+ // (Currently ignored.)
+ bool enable_recv_scheduling = 2;
};
// Session configuration parameters.
@@ -75,4 +90,7 @@ message ConfigProto {
// Whether device placements should be logged.
bool log_device_placement = 8;
+
+ // Options that apply to all graphs.
+ GraphOptions graph_options = 10;
};
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 14ffeca6e4..fc9f1d0324 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -25,48 +25,6 @@ limitations under the License.
namespace tensorflow {
-REGISTER_OP("_Arg")
- .Output("output: T")
- .Attr("T: type")
- .Attr("index: int >= 0")
- .Doc(R"doc(
-A graph node which represents an argument to a function.
-
-output: The argument.
-index: This argument is the index-th argument of the function.
-)doc");
-
-REGISTER_OP("_Retval")
- .Input("input: T")
- .Attr("T: type")
- .Attr("index: int >= 0")
- .Doc(R"doc(
-A graph node which represents a return value of a function.
-
-input: The return value.
-index: This return value is the index-th return value of the function.
-)doc");
-
-REGISTER_OP("_ListToArray")
- .Input("input: Tin")
- .Output("output: N * T")
- .Attr("Tin: list(type)")
- .Attr("T: type")
- .Attr("N: int >= 1")
- .Doc(R"doc(
-Converts a list of tensors to an array of tensors.
-)doc");
-
-REGISTER_OP("_ArrayToList")
- .Input("input: N * T")
- .Output("output: out_types")
- .Attr("T: type")
- .Attr("N: int >= 1")
- .Attr("out_types: list(type)")
- .Doc(R"doc(
-Converts an array of tensors to a list of tensors.
-)doc");
-
namespace {
// Extracts the actual type from "attr_values" based on its definition
diff --git a/tensorflow/core/framework/graph.proto b/tensorflow/core/framework/graph.proto
index 8bf4fd5e5f..d18dd81912 100644
--- a/tensorflow/core/framework/graph.proto
+++ b/tensorflow/core/framework/graph.proto
@@ -21,6 +21,7 @@ message GraphDef {
// 0. Graphs created before GraphDef versioning
// 1. First real version (2dec2015)
// 2. adjust_contrast only takes float, doesn't perform clamping (11dec2015)
+ // 3. Remove TileGrad, since it was equivalent to reduce_sum (30dec2015)
//
// The GraphDef version is distinct from the TensorFlow version.
// Each released version of TensorFlow will support a range of
diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc
new file mode 100644
index 0000000000..0d6b8563b0
--- /dev/null
+++ b/tensorflow/core/framework/load_library.cc
@@ -0,0 +1,76 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <dlfcn.h>
+#include <memory>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+namespace {
+
+template <typename R, typename... Args>
+Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
+ R (**symbol)(Args...)) {
+ Env* env = Env::Default();
+ void* symbol_ptr;
+ Status status = env->GetSymbolFromLibrary(handle, symbol_name, &symbol_ptr);
+ *symbol = reinterpret_cast<R (*)(Args...)>(symbol_ptr);
+ return status;
+}
+
+} // namespace
+
+// Load a dynamic library and register the ops and kernels defined in that file.
+// Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList" to be
+// defined in the library.
+// On success, returns the handle to library in result, copies the serialized
+// OpList of OpDefs registered in the library to *buf and the length to *len,
+// and returns OK from the function. Otherwise return nullptr in result
+// and an error status from the function, leaving buf and len untouched.
+Status LoadLibrary(const char* library_filename, void** result,
+ const void** buf, size_t* len) {
+ Env* env = Env::Default();
+ void* lib;
+ TF_RETURN_IF_ERROR(env->LoadLibrary(library_filename, &lib));
+
+ typedef void (*FuncType)(void*);
+ FuncType RegisterOps, RegisterKernels, GetOpList;
+ TF_RETURN_IF_ERROR(GetSymbolFromLibrary(lib, "RegisterOps", &RegisterOps));
+ TF_RETURN_IF_ERROR(
+ GetSymbolFromLibrary(lib, "RegisterKernels", &RegisterKernels));
+ TF_RETURN_IF_ERROR(GetSymbolFromLibrary(lib, "GetOpList", &GetOpList));
+
+ *buf = nullptr;
+ *len = 0;
+
+ RegisterOps(OpRegistry::Global());
+ RegisterKernels(GlobalKernelRegistry());
+ string str;
+ GetOpList(&str);
+ char* str_buf = reinterpret_cast<char*>(operator new(str.length()));
+ strncpy(str_buf, str.data(), str.length());
+ *buf = str_buf;
+ *len = str.length();
+
+ *result = lib;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index bd246cda80..06e913ec4e 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -33,14 +33,13 @@ OpRegistryInterface::~OpRegistryInterface() {}
OpRegistry::OpRegistry() : initialized_(false) {}
-void OpRegistry::Register(std::function<OpDef(void)> func) {
+void OpRegistry::Register(const OpDef& op_def) {
mutex_lock lock(mu_);
if (initialized_) {
- OpDef def = func();
- TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: "
- << SummarizeOpDef(def);
+ TF_QCHECK_OK(RegisterAlreadyLocked(op_def)) << "Attempting to register: "
+ << SummarizeOpDef(op_def);
} else {
- deferred_.push_back(func);
+ deferred_.push_back(op_def);
}
}
@@ -75,6 +74,14 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name,
return op_def;
}
+void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
+ mutex_lock lock(mu_);
+ CallDeferred();
+ for (auto p : registry_) {
+ op_defs->push_back(*p.second);
+ }
+}
+
void OpRegistry::Export(bool include_internal, OpList* ops) const {
mutex_lock lock(mu_);
CallDeferred();
@@ -107,10 +114,9 @@ string OpRegistry::DebugString(bool include_internal) const {
bool OpRegistry::CallDeferred() const {
if (initialized_) return false;
initialized_ = true;
- for (const auto& fn : deferred_) {
- OpDef def = fn();
- TF_QCHECK_OK(RegisterAlreadyLocked(def)) << "Attempting to register: "
- << SummarizeOpDef(def);
+ for (const auto& op_def : deferred_) {
+ TF_QCHECK_OK(RegisterAlreadyLocked(op_def)) << "Attempting to register: "
+ << SummarizeOpDef(op_def);
}
deferred_.clear();
return true;
@@ -136,12 +142,25 @@ OpRegistry* OpRegistry::Global() {
namespace register_op {
OpDefBuilderReceiver::OpDefBuilderReceiver(const OpDefBuilder& builder) {
- OpRegistry::Global()->Register([builder]() {
- OpDef op_def;
- TF_QCHECK_OK(builder.Finalize(&op_def));
- return op_def;
- });
+ OpDef op_def;
+ builder.Finalize(&op_def);
+ OpRegistry::Global()->Register(op_def);
}
} // namespace register_op
+extern "C" void RegisterOps(void* registry_ptr) {
+ OpRegistry* op_registry = static_cast<OpRegistry*>(registry_ptr);
+ std::vector<OpDef> op_defs;
+ OpRegistry::Global()->GetRegisteredOps(&op_defs);
+ for (auto const& op_def : op_defs) {
+ op_registry->Register(op_def);
+ }
+}
+
+extern "C" void GetOpList(void* str) {
+ OpList op_list;
+ OpRegistry::Global()->Export(true, &op_list);
+ op_list.SerializeToString(reinterpret_cast<string*>(str));
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 6e20a0fb4a..2a6a34f28e 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -16,7 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_FRAMEWORK_OP_H_
#define TENSORFLOW_FRAMEWORK_OP_H_
-#include <functional>
#include <unordered_map>
#include "tensorflow/core/framework/op_def.pb.h"
@@ -65,7 +64,7 @@ class OpRegistry : public OpRegistryInterface {
// we defer calling func() until the first call to LookUp() or
// Export() (if one of those has already been called, func() is
// called immediately).
- void Register(std::function<OpDef(void)> func);
+ void Register(const OpDef& op_def);
const OpDef* LookUp(const string& op_type_name,
Status* status) const override;
@@ -81,6 +80,9 @@ class OpRegistry : public OpRegistryInterface {
// A singleton available at startup.
static OpRegistry* Global();
+ // Get all registered ops.
+ void GetRegisteredOps(std::vector<OpDef>* op_defs);
+
private:
// Ensures that all the functions in deferred_ get called, their OpDef's
// registered, and returns with deferred_ empty. Returns true the first
@@ -94,11 +96,17 @@ class OpRegistry : public OpRegistryInterface {
mutable mutex mu_;
// Functions in deferred_ may only be called with mu_ held.
- mutable std::vector<std::function<OpDef(void)>> deferred_ GUARDED_BY(mu_);
+ mutable std::vector<OpDef> deferred_ GUARDED_BY(mu_);
mutable std::unordered_map<string, OpDef*> registry_ GUARDED_BY(mu_);
mutable bool initialized_ GUARDED_BY(mu_);
};
+// Treats 'registry_ptr' as a pointer to OpRegistry, and calls
+// registry_ptr->Register(op_def) for each op_def that has been registered with
+// the current library's global op registry (obtained by calling
+// OpRegistry::Global().
+extern "C" void RegisterOps(void* registry_ptr);
+
// Support for defining the OpDef (specifying the semantics of the Op and how
// it should be created) and registering it in the OpRegistry::Global()
// registry. Usage:
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 767d2a0466..a59ddb3e9b 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/regexp.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 6947ebfc7f..b984966148 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -442,17 +442,27 @@ struct KernelRegistration {
// KernelDef.
typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;
-static KernelRegistry* GlobalKernelRegistry() {
+void* GlobalKernelRegistry() {
static KernelRegistry* global_kernel_registry = new KernelRegistry;
return global_kernel_registry;
}
+static KernelRegistry* GlobalKernelRegistryTyped() {
+ return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
+}
+
static string Key(const string& op_type, DeviceType device_type,
const string& label) {
return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":",
label);
}
+extern "C" void RegisterKernels(void* registry_ptr) {
+ KernelRegistry* kernel_registry = static_cast<KernelRegistry*>(registry_ptr);
+ kernel_registry->insert(GlobalKernelRegistryTyped()->begin(),
+ GlobalKernelRegistryTyped()->end());
+}
+
namespace kernel_factory {
OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def,
@@ -460,7 +470,7 @@ OpKernelRegistrar::OpKernelRegistrar(const KernelDef* kernel_def,
const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label());
- GlobalKernelRegistry()->insert(
+ GlobalKernelRegistryTyped()->insert(
std::make_pair(key, KernelRegistration(*kernel_def, factory)));
delete kernel_def;
}
@@ -533,7 +543,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
string label; // Label defaults to empty if not found in NodeDef.
GetNodeAttr(node_def, "_kernel", &label);
const string key = Key(node_def.op(), device_type, label);
- auto regs = GlobalKernelRegistry()->equal_range(key);
+ auto regs = GlobalKernelRegistryTyped()->equal_range(key);
for (auto iter = regs.first; iter != regs.second; ++iter) {
// If there is a kernel registered for the op and device_type,
// check that the attrs match.
@@ -730,7 +740,7 @@ bool FindArgInOp(const string& arg_name,
Status ValidateKernelRegistrations(const OpRegistryInterface* op_registry) {
Status unused_status;
- for (const auto& key_registration : *GlobalKernelRegistry()) {
+ for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
const KernelDef& kernel_def(key_registration.second.def);
const OpDef* op_def = op_registry->LookUp(kernel_def.op(), &unused_status);
if (op_def == nullptr) {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index a68a170fde..dedd600b05 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -966,6 +966,13 @@ typedef ::tensorflow::KernelDefBuilder Name;
+[](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
+void* GlobalKernelRegistry();
+
+// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k'
+// registered with the current library's global kernel registry (obtained by
+// calling GlobalKernelRegistry()), inserts 'k' into registry_ptr.
+extern "C" void RegisterKernels(void* registry_ptr);
+
namespace kernel_factory {
class OpKernelRegistrar {
diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc
index 9a597a1042..8176b91a5e 100644
--- a/tensorflow/core/framework/rendezvous.cc
+++ b/tensorflow/core/framework/rendezvous.cc
@@ -43,18 +43,39 @@ string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
":", frame_iter.iter_id);
}
+// Return the prefix of "*s" up to the next occurrence of "delim", or
+// the whole remaining string if "delim" is not found. "*s" is advanced
+// past the string returned plus the delimiter (if found).
+static StringPiece ConsumeNextPart(StringPiece* s, char delim) {
+ for (int offset = 0; offset < s->size(); offset++) {
+ if ((*s)[offset] == delim) {
+ StringPiece result(s->data(), offset);
+ s->remove_prefix(offset + 1); // +1: remove delim, as well
+ return result;
+ }
+ }
+ // No delimiter found: return rest of string
+ StringPiece result(s->data(), s->size());
+ s->remove_prefix(s->size());
+ return result;
+}
+
/* static */
Status Rendezvous::ParseKey(const string& key, ParsedKey* out) {
- // TODO(zhifengc): This code is not fast enough.
- std::vector<string> parts = str_util::Split(key, ';');
- if (parts.size() == 5 &&
+ StringPiece s(key);
+ StringPiece parts[5];
+ for (int i = 0; i < 5; i++) {
+ parts[i] = ConsumeNextPart(&s, ';');
+ }
+ if (s.empty() && // Consumed the whole string
+ !parts[4].empty() && // Exactly five parts
DeviceNameUtils::ParseFullName(parts[0], &out->src) &&
- strings::StringToFp(parts[1], &out->src_incarnation) &&
+ strings::StringToFp(parts[1].ToString(), &out->src_incarnation) &&
DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
!parts[3].empty()) {
- out->src_device = parts[0];
- out->dst_device = parts[2];
- out->edge_name = parts[3];
+ out->src_device.assign(parts[0].data(), parts[0].size());
+ out->dst_device.assign(parts[2].data(), parts[2].size());
+ out->edge_name.assign(parts[3].data(), parts[3].size());
return Status::OK();
}
return errors::InvalidArgument("Invalid rendezvous key: ", key);
diff --git a/tensorflow/core/framework/tensor_reference.h b/tensorflow/core/framework/tensor_reference.h
index 88853130a0..e700bb4b6d 100644
--- a/tensorflow/core/framework/tensor_reference.h
+++ b/tensorflow/core/framework/tensor_reference.h
@@ -38,6 +38,17 @@ class TensorReference {
if (buf_) buf_->Unref();
}
+ // Return an estimate of the total bytes being kept alive by this reference.
+ size_t TotalBytes() const {
+ // We add 128 as a baseline to account for per-Tensor metadata
+ return 128 + (buf_ ? buf_->size() : 0);
+ }
+
+ // A constructor used only for tests
+ explicit TensorReference(TensorBuffer* test_buffer) : buf_(test_buffer) {
+ if (buf_) buf_->Ref();
+ }
+
private:
TensorBuffer* buf_;
};
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc
index aaaf226bbd..31a470e4ea 100644
--- a/tensorflow/core/graph/algorithm.cc
+++ b/tensorflow/core/graph/algorithm.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "tensorflow/core/platform/logging.h"
+
namespace tensorflow {
void DFS(const Graph& g, std::function<void(Node*)> enter,
@@ -78,14 +80,18 @@ void PruneForReverseReachability(Graph* g,
// nodes, and accumulating the visited nodes.
std::deque<const Node*> queue;
for (const Node* n : nodes) {
- queue.push_back(n);
+ if (visited.insert(n).second) {
+ VLOG(2) << "Reverse reach init: " << n->name();
+ queue.push_back(n);
+ }
}
while (!queue.empty()) {
const Node* n = queue.front();
queue.pop_front();
- if (visited.insert(n).second) {
- for (const Node* in : n->in_nodes()) {
+ for (const Node* in : n->in_nodes()) {
+ if (visited.insert(in).second) {
queue.push_back(in);
+ VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
}
}
}
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 4459d0b54b..e74033bd98 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -53,7 +53,7 @@ class GraphConstructor {
*status = errors::InvalidArgument(
"GraphDef version ", version, " is ", low ? "no longer" : "not yet",
" supported: TensorFlow ", TF_VERSION_STRING, " needs ",
- TF_GRAPH_DEF_VERSION_MAX, " <= version <= ", TF_GRAPH_DEF_VERSION_MIN,
+ TF_GRAPH_DEF_VERSION_MIN, " <= version <= ", TF_GRAPH_DEF_VERSION_MAX,
". ",
low ? "Please regenerate your graph." : "Please upgrade TensorFlow.");
return;
@@ -150,8 +150,8 @@ void GraphConstructor::BuildNodeIndex() {
SetNodeError(node_def, "Node name contains invalid characters");
return;
}
- if (!name_index_.insert(std::make_pair(StringPiece(node_def.name()),
- NodeInfo(n)))
+ if (!name_index_
+ .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n)))
.second) {
SetNodeError(node_def, "Node name is not unique");
return;
@@ -346,8 +346,8 @@ void GraphConstructor::Convert() {
if (opts_.optimizer_do_cse) {
if (!back_edges.empty()) {
- LOG(WARNING) << "Not doing CSE. We need to figure out how to handle "
- << "loops in the CSE phase.";
+ VLOG(1) << "Not doing CSE. We need to figure out how to handle "
+ << "loops in the CSE phase.";
} else {
VLOG(1) << "Starting CSE: graph of " << CountNodes(g_) << " nodes";
OptimizeCSE(g_, opts_.cse_consider_function);
@@ -392,6 +392,9 @@ void CopyGraph(const Graph& src, Graph* dest) {
CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
}
+ // Copy GraphDef version
+ dest->set_version(src.version());
+
// Copy the nodes
std::unordered_map<Node*, Node*>
node_map; // "Node in src" -> "Node in *dest"
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 7706a3d0c6..657dfa4e90 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -180,16 +180,20 @@ TEST_F(GraphConstructorTest, VersionGraph) {
TEST_F(GraphConstructorTest, LowVersion) {
ExpectError(strings::StrCat("version: ", -1),
- R"(^GraphDef version -1 is no longer supported: TensorFlow \S+ )"
- R"(needs \d+ <= version <= \d+\. )"
- R"(Please regenerate your graph\.$)");
+ strings::StrCat(R"(^GraphDef version -1 is no longer supported: )"
+ R"(TensorFlow \S+ needs )",
+ TF_GRAPH_DEF_VERSION_MIN, " <= version <= ",
+ TF_GRAPH_DEF_VERSION_MAX,
+ R"(. Please regenerate your graph\.$)"));
}
TEST_F(GraphConstructorTest, HighVersion) {
ExpectError(strings::StrCat("version: ", TF_GRAPH_DEF_VERSION_MAX + 1),
- R"(^GraphDef version \d+ is not yet supported: TensorFlow \S+ )"
- R"(needs \d+ <= version <= \d+\. )"
- R"(Please upgrade TensorFlow\.$)");
+ strings::StrCat(R"(^GraphDef version \d+ is not yet supported: )"
+ R"(TensorFlow \S+ needs )",
+ TF_GRAPH_DEF_VERSION_MIN, " <= version <= ",
+ TF_GRAPH_DEF_VERSION_MAX,
+ R"(. Please upgrade TensorFlow\.$)"));
}
TEST_F(GraphConstructorTest, SimpleModel) {
@@ -231,5 +235,16 @@ TEST_F(GraphConstructorTest, Error_ControlEdgeBeforeRealInput) {
"Node 't2': Control dependencies must come after regular dependencies");
}
+TEST_F(GraphConstructorTest, CopyGraph) {
+ const int version = TF_GRAPH_DEF_VERSION - 1;
+
+ Graph src(OpRegistry::Global());
+ src.set_version(version);
+
+ Graph dst(OpRegistry::Global());
+ CopyGraph(src, &dst);
+ EXPECT_EQ(dst.version(), version);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc
index 18f7cb083d..b3fdba055c 100644
--- a/tensorflow/core/kernels/adjust_contrast_op.cc
+++ b/tensorflow/core/kernels/adjust_contrast_op.cc
@@ -38,7 +38,7 @@ template <typename Device, typename T>
class AdjustContrastOp : public OpKernel {
public:
explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_DEPRECATED(context, 2);
+ OP_DEPRECATED(context, 2, "Use AdjustContrastv2 instead");
}
void Compute(OpKernelContext* context) override {
diff --git a/tensorflow/core/kernels/cwise_op_erf.cc b/tensorflow/core/kernels/cwise_op_erf.cc
new file mode 100644
index 0000000000..02f6b4b8d1
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_erf.cc
@@ -0,0 +1,23 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "Erf", functor::erf, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Erf", functor::erf, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_erfc.cc b/tensorflow/core/kernels/cwise_op_erfc.cc
new file mode 100644
index 0000000000..65862d4082
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_erfc.cc
@@ -0,0 +1,23 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "Erfc", functor::erfc, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Erfc", functor::erfc, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_erf.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_erf.cu.cc
new file mode 100644
index 0000000000..a1e31a1b2f
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_erf.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(erf, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_erfc.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_erfc.cu.cc
new file mode 100644
index 0000000000..260463c8bf
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_erfc.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(erfc, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_lgamma.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_lgamma.cu.cc
new file mode 100644
index 0000000000..8105ac1694
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_lgamma.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY2(lgamma, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_lgamma.cc b/tensorflow/core/kernels/cwise_op_lgamma.cc
new file mode 100644
index 0000000000..6985e5c6ba
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_lgamma.cc
@@ -0,0 +1,23 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(UnaryOp, CPU, "Lgamma", functor::lgamma, float, double);
+#if GOOGLE_CUDA
+REGISTER2(UnaryOp, GPU, "Lgamma", functor::lgamma, float, double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 7f42a4ca2b..1aec41fdc0 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -342,6 +342,15 @@ template <typename T>
struct tanh : base<T, Eigen::internal::scalar_tanh_op<T> > {};
template <typename T>
+struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T> > {};
+
+template <typename T>
+struct erf : base<T, Eigen::internal::scalar_erf_op<T> > {};
+
+template <typename T>
+struct erfc : base<T, Eigen::internal::scalar_erfc_op<T> > {};
+
+template <typename T>
struct sigmoid : base<T, Eigen::internal::scalar_sigmoid_op<T> > {};
template <typename T>
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
index 599b05525c..e17c0da061 100644
--- a/tensorflow/core/kernels/example_parsing_ops.cc
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -469,9 +469,13 @@ class SingleSequenceExampleParserOp : public OpKernel {
ctx, ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Ncontext_sparse", &num_context_sparse_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tcontext_dense", &context_dense_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_list_sparse_types",
+ &feature_list_sparse_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_list_dense_types",
&feature_list_dense_types_));
OP_REQUIRES_OK(
+ ctx, ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse_));
+ OP_REQUIRES_OK(
ctx, ctx->GetAttr("context_dense_shapes", &context_dense_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("feature_list_dense_shapes",
&feature_list_dense_shapes_));
@@ -488,6 +492,11 @@ class SingleSequenceExampleParserOp : public OpKernel {
context_dense_shapes_.size(),
errors::InvalidArgument(
"len(context_dense_keys) != len(context_dense_shapes"));
+ OP_REQUIRES(
+ ctx, static_cast<size_t>(num_feature_list_sparse_) ==
+ feature_list_sparse_types_.size(),
+ errors::InvalidArgument(
+ "len(feature_list_sparse_keys) != len(feature_list_sparse_types"));
OP_REQUIRES(ctx, static_cast<size_t>(num_feature_list_dense_) ==
feature_list_dense_types_.size(),
errors::InvalidArgument("len(feature_list_dense_keys) != "
@@ -501,6 +510,9 @@ class SingleSequenceExampleParserOp : public OpKernel {
for (const DataType& type : feature_list_dense_types_) {
OP_REQUIRES_OK(ctx, CheckValidType(type));
}
+ for (const DataType& type : feature_list_sparse_types_) {
+ OP_REQUIRES_OK(ctx, CheckValidType(type));
+ }
}
void Compute(OpKernelContext* ctx) override {
@@ -510,6 +522,7 @@ class SingleSequenceExampleParserOp : public OpKernel {
OpInputList context_sparse_keys;
OpInputList context_dense_defaults;
OpInputList feature_list_dense_keys;
+ OpInputList feature_list_sparse_keys;
const Tensor* feature_list_dense_missing_assumed_empty;
OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
@@ -522,16 +535,20 @@ class SingleSequenceExampleParserOp : public OpKernel {
&feature_list_dense_keys));
OP_REQUIRES_OK(
ctx, ctx->input_list("context_sparse_keys", &context_sparse_keys));
+ OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys",
+ &feature_list_sparse_keys));
OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
&context_dense_defaults));
std::vector<string> context_dense_keys_t(num_context_dense_);
std::vector<string> context_sparse_keys_t(num_context_sparse_);
std::vector<string> feature_list_dense_keys_t(num_feature_list_dense_);
+ std::vector<string> feature_list_sparse_keys_t(num_feature_list_sparse_);
std::unordered_set<string> feature_list_dense_missing_assumed_empty_set;
CHECK_EQ(context_dense_keys.size(), num_context_dense_);
CHECK_EQ(context_sparse_keys.size(), num_context_sparse_);
CHECK_EQ(feature_list_dense_keys.size(), num_feature_list_dense_);
+ CHECK_EQ(feature_list_sparse_keys.size(), num_feature_list_sparse_);
for (int di = 0; di < num_context_dense_; ++di) {
OP_REQUIRES(ctx,
TensorShapeUtils::IsScalar(context_dense_keys[di].shape()),
@@ -560,6 +577,16 @@ class SingleSequenceExampleParserOp : public OpKernel {
feature_list_dense_keys_t[di] =
feature_list_dense_keys[di].scalar<string>()();
}
+ for (int di = 0; di < num_feature_list_sparse_; ++di) {
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()),
+ errors::InvalidArgument(
+ "Expected feature_list_sparse_keys[", di,
+ "] to be a vector, got shape: ",
+ feature_list_sparse_keys[di].shape().ShortDebugString()));
+ feature_list_sparse_keys_t[di] =
+ feature_list_sparse_keys[di].scalar<string>()();
+ }
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(
feature_list_dense_missing_assumed_empty->shape()),
errors::InvalidArgument(
@@ -622,6 +649,9 @@ class SingleSequenceExampleParserOp : public OpKernel {
OpOutputList context_sparse_values;
OpOutputList context_sparse_shapes;
OpOutputList context_dense_values;
+ OpOutputList feature_list_sparse_indices;
+ OpOutputList feature_list_sparse_values;
+ OpOutputList feature_list_sparse_shapes;
OpOutputList feature_list_dense_values;
OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
@@ -632,6 +662,14 @@ class SingleSequenceExampleParserOp : public OpKernel {
ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
OP_REQUIRES_OK(
ctx, ctx->output_list("context_dense_values", &context_dense_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
+ &context_sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
+ &feature_list_sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
+ &feature_list_sparse_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
+ &feature_list_sparse_shapes));
OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
&feature_list_dense_values));
@@ -784,16 +822,78 @@ class SingleSequenceExampleParserOp : public OpKernel {
feature_list_dense_values[d]));
}
}
+
+ // Feature List Sparse -----------------------------------------------------
+ for (int d = 0; d < num_feature_list_sparse_; ++d) {
+ const string& key = feature_list_sparse_keys_t[d];
+ const DataType& dtype = feature_list_sparse_types_[d];
+
+ const auto& feature_list_found = feature_list_dict.find(key);
+ bool feature_list_has_data = // Found key
+ (feature_list_found != feature_list_dict.end());
+
+ std::vector<Tensor> sparse_values_tmp;
+ int64 feature_list_size = 0;
+ if (feature_list_has_data) {
+ const FeatureList& fl = feature_list_found->second;
+ feature_list_size = fl.feature_size();
+ for (int64 t = 0; t < feature_list_size; ++t) {
+ const Feature& f = fl.feature(t);
+ bool types_match;
+ OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
+ OP_REQUIRES(
+ ctx, types_match,
+ errors::InvalidArgument(
+ "Name: ", name, ", Feature List: ", key, ", Index: ", t,
+ ". Data types don't match. ", "Expected type: ",
+ DataTypeString(dtype), " Feature is: ", f.DebugString()));
+ sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f));
+ }
+ } else {
+ sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0})));
+ }
+
+ int64 total_num_features = 0;
+ int64 max_num_features = 0;
+ for (int t = 0; t < feature_list_size; ++t) {
+ const Tensor& v = sparse_values_tmp[t];
+ const int64 num_elements = v.shape().num_elements();
+ total_num_features += num_elements;
+ max_num_features = std::max(max_num_features, num_elements);
+ }
+
+ TensorShape indices_shape({total_num_features, 2});
+ TensorShape values_shape({total_num_features});
+ Tensor* sp_indices_d = nullptr;
+ Tensor* sp_values_d = nullptr;
+ Tensor* sp_shape_d = nullptr;
+ feature_list_sparse_indices.allocate(d, indices_shape, &sp_indices_d);
+ feature_list_sparse_values.allocate(d, values_shape, &sp_values_d);
+ feature_list_sparse_shapes.allocate(d, TensorShape({2}), &sp_shape_d);
+ auto shape_t = sp_shape_d->vec<int64>();
+ shape_t(0) = feature_list_size;
+ shape_t(1) = max_num_features;
+
+ int64 offset = 0;
+
+ for (int t = 0; t < feature_list_size; ++t) {
+ const int64 num_elements = CopyIntoSparseTensor(
+ sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d);
+ offset += num_elements;
+ }
+ }
}
protected:
int64 num_context_sparse_;
int64 num_context_dense_;
+ int64 num_feature_list_sparse_;
int64 num_feature_list_dense_;
std::vector<DataType> context_sparse_types_;
std::vector<DataType> context_dense_types_;
- std::vector<DataType> feature_list_dense_types_;
std::vector<TensorShape> context_dense_shapes_;
+ std::vector<DataType> feature_list_sparse_types_;
+ std::vector<DataType> feature_list_dense_types_;
std::vector<TensorShape> feature_list_dense_shapes_;
};
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc
index 470a7a3eb6..9153e5a31b 100644
--- a/tensorflow/core/kernels/queue_base.cc
+++ b/tensorflow/core/kernels/queue_base.cc
@@ -345,7 +345,14 @@ Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
HANDLE_TYPE(DT_INT16);
HANDLE_TYPE(DT_INT8);
HANDLE_TYPE(DT_STRING);
+ HANDLE_TYPE(DT_COMPLEX64);
HANDLE_TYPE(DT_INT64);
+ HANDLE_TYPE(DT_BOOL);
+ HANDLE_TYPE(DT_QINT8);
+ HANDLE_TYPE(DT_QUINT8);
+ HANDLE_TYPE(DT_QINT32);
+ HANDLE_TYPE(DT_QINT16);
+ HANDLE_TYPE(DT_QUINT16);
#undef HANDLE_TYPE
return errors::Unimplemented("Unhandled data type: ", parent.dtype());
}
@@ -365,7 +372,14 @@ Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
HANDLE_TYPE(DT_INT16);
HANDLE_TYPE(DT_INT8);
HANDLE_TYPE(DT_STRING);
+ HANDLE_TYPE(DT_COMPLEX64);
HANDLE_TYPE(DT_INT64);
+ HANDLE_TYPE(DT_BOOL);
+ HANDLE_TYPE(DT_QINT8);
+ HANDLE_TYPE(DT_QUINT8);
+ HANDLE_TYPE(DT_QINT32);
+ HANDLE_TYPE(DT_QINT16);
+ HANDLE_TYPE(DT_QUINT16);
#undef HANDLE_TYPE
return errors::Unimplemented("Unhandled data type: ", element.dtype());
}
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index e8db7106ef..44911c9d36 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -24,6 +24,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/reduction_ops.h"
+#include "tensorflow/core/kernels/transpose_op.h"
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -31,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/status.h"
#include "tensorflow/core/public/tensor.h"
@@ -76,7 +78,7 @@ class ReductionHelper {
Status Simplify(const Tensor& data, const Tensor& axis,
const bool keep_dims) {
// bitmap[i] indicates whether to reduce data along i-th axis.
- std::vector<bool> bitmap(data.dims(), false);
+ gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
auto axis_vec = axis.flat<int32>();
for (int64 i = 0; i < axis.NumElements(); ++i) {
const int32 index = axis_vec(i);
@@ -194,11 +196,43 @@ class ReductionHelper {
return data.shaped<T, N>(data_reshape_);
}
+ // Shape of shuffled input
+ const gtl::ArraySlice<int64> data_reshape() const { return data_reshape_; }
+
+ // Shape with all reduction dimensions at the end
+ TensorShape shuffled_shape() {
+ const int dims = data_reshape_.size();
+ TensorShape shape;
+ for (int i = reduce_first_axis_; i < dims; i += 2) {
+ shape.AddDim(data_reshape_[i]);
+ }
+ for (int i = !reduce_first_axis_; i < dims; i += 2) {
+ shape.AddDim(data_reshape_[i]);
+ }
+ return shape;
+ }
+
+ // Permutation of reduced dims needed to put reduction dimensions at the end
+ gtl::InlinedVector<int32, 8> permutation() {
+ const int dims = data_reshape_.size();
+ const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
+ gtl::InlinedVector<int32, 8> perm(dims);
+ for (int i = 0; i < unreduced_dims; i++) {
+ perm[i] = 2 * i + reduce_first_axis_;
+ }
+ for (int i = unreduced_dims; i < dims; i++) {
+ perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
+ }
+ return perm;
+ }
+
private:
bool reduce_first_axis_; // True if need to reduce the 0-th dimension.
- std::vector<int64> data_reshape_; // Reshape the data before reduction.
- std::vector<int64> out_shape_; // The final output shape.
- std::vector<int64> out_reshape_; // Reshape the output for reduction.
+ gtl::InlinedVector<int64, 4>
+ data_reshape_; // Reshape the data before reduction.
+ gtl::InlinedVector<int64, 4> out_shape_; // The final output shape.
+ gtl::InlinedVector<int64, 4>
+ out_reshape_; // Reshape the output for reduction.
};
} // end namespace
@@ -252,6 +286,9 @@ class ReductionOp : public OpKernel {
const Device& d = ctx->eigen_device<Device>();
Reducer reducer;
+ if (tmp_out.NumElements() == 0) {
+ // Nothing to do, fall through to final reshaping.
+ }
if ((helper.ndims() == 1) && helper.reduce_first_axis()) {
// Reduce to a scalar.
Functor::Reduce(d, helper.out<T, 0>(&tmp_out), helper.in<T, 1>(data),
@@ -274,15 +311,20 @@ class ReductionOp : public OpKernel {
Functor::Reduce(d, helper.out<T, 2>(&tmp_out), helper.in<T, 3>(data),
constants.kOne, reducer);
} else {
- // TODO(zhifengc): We can implement reduction for arbitrary rank
- // tensor and arbitrary reduction axes by iterating the reduction
- // multiple times. This may also be accomplished in the graph
- // construction.
- ctx->SetStatus(
- errors::Unimplemented("Reducing ", data.shape().ShortDebugString(),
- " axes [", axes.SummarizeValue(10), "] to ",
- tmp_out.shape().ShortDebugString()));
- return;
+ // If we don't hit one of the cases above, transpose the data so that
+ // all reduced dimensions are last and reuse the 2-D -> 1-D case.
+ Tensor shuffled;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<T>::value,
+ helper.shuffled_shape(), &shuffled));
+ TransposeTensor<Device, T>(d, data, helper.data_reshape(),
+ helper.permutation(), &shuffled);
+ const int64 unreduced = tmp_out.NumElements();
+ const int64 reduced = shuffled.NumElements() / unreduced;
+ const Tensor& const_shuffled = shuffled;
+ Functor::Reduce(d, tmp_out.flat<T>(),
+ const_shuffled.shaped<T, 2>({unreduced, reduced}),
+ constants.kOne, reducer);
}
// Set the real output using the contents of the reduction but the
diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h
index 8f908109ed..f1260746bf 100644
--- a/tensorflow/core/kernels/reshape_op.h
+++ b/tensorflow/core/kernels/reshape_op.h
@@ -39,9 +39,6 @@ class ReshapeOp : public OpKernel {
errors::InvalidArgument("sizes input must be 1-D, not shape ",
sizes.shape().ShortDebugString()));
const int64 num_dims = sizes.NumElements();
- OP_REQUIRES(
- context, num_dims <= 8,
- errors::InvalidArgument(num_dims, " > max 8 output dims supported"));
// Compute the output shape. Determine product of specified
// dimensions, and find the index of the unspecified one.
diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc
index 3de5132049..7759dbdc0f 100644
--- a/tensorflow/core/kernels/sparse_to_dense_op.cc
+++ b/tensorflow/core/kernels/sparse_to_dense_op.cc
@@ -41,7 +41,10 @@ namespace tensorflow {
template <typename T, typename Index>
class SparseToDense : public OpKernel {
public:
- explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {}
+ explicit SparseToDense(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("validate_indices", &validate_indices_));
+ }
void Compute(OpKernelContext* c) override {
// sparse_indices
@@ -111,17 +114,28 @@ class SparseToDense : public OpKernel {
sparse_values_b = sparse_values;
}
+ // Assume SparseTensor is lexicographically sorted.
gtl::InlinedVector<int64, 8> order(output->shape().dims());
- std::iota(order.begin(), order.end(), 0); // Assume order is correct
+ std::iota(order.begin(), order.end(), 0);
sparse::SparseTensor st(indices_shaped, sparse_values_b, output->shape(),
order);
+ if (validate_indices_) {
+ OP_REQUIRES(c, st.IndicesValid(),
+ errors::InvalidArgument("Indices are not valid: not "
+ "lexicographically sorted or "
+ "containing repeats."));
+ }
+
output->flat<T>().setConstant(default_value.scalar<T>()());
OP_REQUIRES(c, st.template ToDense<T>(output, false /* initialize */),
errors::InvalidArgument(
"Indices are not valid (out of bounds). Shape: ",
output->shape().DebugString()));
}
+
+ private:
+ bool validate_indices_;
};
#define REGISTER_KERNELS(type, index_type) \
diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc
index 972889f878..4031e90857 100644
--- a/tensorflow/core/kernels/summary_op.cc
+++ b/tensorflow/core/kernels/summary_op.cc
@@ -61,15 +61,6 @@ class SummaryScalarOp : public OpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("ScalarSummary")
- .Device(DEVICE_CPU)
- .TypeConstraint<float>("T"),
- SummaryScalarOp<float>);
-REGISTER_KERNEL_BUILDER(Name("ScalarSummary")
- .Device(DEVICE_CPU)
- .TypeConstraint<double>("T"),
- SummaryScalarOp<double>);
-
template <typename T>
class SummaryHistoOp : public OpKernel {
public:
@@ -108,6 +99,9 @@ class SummaryHistoOp : public OpKernel {
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER( \
+ Name("ScalarSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ SummaryScalarOp<T>); \
+ REGISTER_KERNEL_BUILDER( \
Name("HistogramSummary").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
SummaryHistoOp<T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER)
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index d4c6ecb6c7..971cef6ddb 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -65,14 +65,17 @@ class TileOp : public OpKernel {
TensorShape output_shape;
for (int i = 0; i < input_dims; ++i) {
OP_REQUIRES(
- context, multiples_array[i] > 0,
- errors::InvalidArgument("Expected multiples[", i, "] > 0, but got ",
+ context, multiples_array[i] >= 0,
+ errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ",
multiples_array[i]));
output_shape.AddDim(input.dim_size(i) * multiples_array[i]);
}
Tensor* result = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &result));
+ // If there's no output, there's nothing to do.
+ if (output_shape.num_elements() == 0) return;
+
#define HANDLE_DIM(DT, NDIM) \
if (context->input(0).dtype() == DT && input_dims == NDIM) { \
HandleCase<DT, NDIM>(context, multiples_array, result); \
@@ -180,7 +183,9 @@ HANDLE_CASE_DIM(GPUDevice, DT_INT64);
template <typename Device>
class TileGradientOp : public OpKernel {
public:
- explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {}
+ explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_DEPRECATED(context, 3, "TileGrad has been replaced with reduce_sum");
+ }
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index ad312b8e7b..7d9c2a90e5 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -97,8 +97,8 @@ void TransposeOp<Device, T>::Compute(OpKernelContext* context) {
perm.shape().DebugString()));
auto Vperm = perm.vec<int32>();
const int dims = input.dims();
- static const int kMinDims = 1;
- static const int kMaxDims = 8;
+ static const int kMinDims = 0;
+ static const int kMaxDims = 10;
OP_REQUIRES(context, kMinDims <= dims && dims <= kMaxDims,
errors::Unimplemented("Transposing a tensor of rank ", dims,
" is not implemented."));
@@ -125,20 +125,35 @@ void TransposeOp<Device, T>::Compute(OpKernelContext* context) {
str_util::Join(permutation, ","), "}."));
}
+ // 0-D and 1-D transposes do nothing
+ if (dims <= 1) {
+ context->set_output(0, input);
+ return;
+ }
+
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output));
+ TransposeTensor<Device, T>(context->eigen_device<Device>(), input,
+ input.shape().dim_sizes(), permutation, output);
+}
+
+template <typename Device, typename T>
+void TransposeTensor(const Device& device, const Tensor& input,
+ const gtl::ArraySlice<int64> input_shape,
+ gtl::ArraySlice<int32> permutation, Tensor* output) {
+ const int dims = input_shape.size();
+ CHECK(permutation.size() == dims);
if (input.NumElements() == 0) {
return;
}
switch (dims) {
-#define EXPAND_DIM(N) \
- case N: { \
- functor::TransposeFunctor<Device, T, N> func; \
- func(context->eigen_device<Device>(), output->tensor<T, N>(), \
- input.tensor<T, N>(), permutation.data()); \
- break; \
+#define EXPAND_DIM(N) \
+ case N: { \
+ functor::TransposeFunctor<Device, T, N> func; \
+ func(device, output->tensor<T, N>(), input.shaped<T, N>(input_shape), \
+ permutation.data()); \
+ break; \
}
- EXPAND_DIM(1);
EXPAND_DIM(2);
EXPAND_DIM(3);
EXPAND_DIM(4);
@@ -146,6 +161,8 @@ void TransposeOp<Device, T>::Compute(OpKernelContext* context) {
EXPAND_DIM(6);
EXPAND_DIM(7);
EXPAND_DIM(8);
+ EXPAND_DIM(9);
+ EXPAND_DIM(10);
default:
LOG(FATAL) << "Unexpected dims: " << dims;
}
@@ -179,13 +196,16 @@ struct TransposeFunctor<CPUDevice, T, NDIMS> {
} // namespace functor
-#define REGISTER(D, T) \
- template class TransposeOp<D##Device, T>; \
- REGISTER_KERNEL_BUILDER(Name("Transpose") \
- .Device(DEVICE_##D) \
- .TypeConstraint<T>("T") \
- .HostMemory("perm"), \
- TransposeOp<D##Device, T>)
+#define REGISTER(D, T) \
+ template class TransposeOp<D##Device, T>; \
+ REGISTER_KERNEL_BUILDER(Name("Transpose") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("perm"), \
+ TransposeOp<D##Device, T>); \
+ template void TransposeTensor<D##Device, T>( \
+ const D##Device&, const Tensor&, const gtl::ArraySlice<int64>, \
+ gtl::ArraySlice<int32>, Tensor*);
REGISTER(CPU, float);
REGISTER(CPU, double);
REGISTER(CPU, complex64);
@@ -195,6 +215,7 @@ REGISTER(CPU, int16);
REGISTER(CPU, int32);
REGISTER(CPU, int64);
REGISTER(CPU, string);
+REGISTER(CPU, bool);
#if GOOGLE_CUDA
REGISTER(GPU, uint8);
REGISTER(GPU, int8);
@@ -203,6 +224,8 @@ REGISTER(GPU, int32);
REGISTER(GPU, int64);
REGISTER(GPU, float);
REGISTER(GPU, double);
+REGISTER(GPU, complex64);
+REGISTER(GPU, bool);
#endif
#undef REGISTER
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h
index f4d36fea54..15cd1b6488 100644
--- a/tensorflow/core/kernels/transpose_op.h
+++ b/tensorflow/core/kernels/transpose_op.h
@@ -29,6 +29,12 @@ class TransposeOp : public OpKernel {
void Compute(OpKernelContext* context) override;
};
+// Exposed for use in reduction ops
+template <typename Device, typename T>
+void TransposeTensor(const Device& device, const Tensor& input,
+ const gtl::ArraySlice<int64> input_shape,
+ gtl::ArraySlice<int32> permutation, Tensor* output);
+
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
diff --git a/tensorflow/core/kernels/transpose_op_gpu.cu.cc b/tensorflow/core/kernels/transpose_op_gpu.cu.cc
index c2f720a121..b8d664b95f 100644
--- a/tensorflow/core/kernels/transpose_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/transpose_op_gpu.cu.cc
@@ -17,8 +17,9 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/kernels/transpose_op_functor.h"
+#include "tensorflow/core/platform/port.h"
namespace tensorflow {
namespace functor {
@@ -34,14 +35,15 @@ struct TransposeFunctor<Eigen::GpuDevice, T, NDIMS> {
#define DEFINE(T, N) template struct TransposeFunctor<Eigen::GpuDevice, T, N>;
#define DEFINE_DIM(T) \
- DEFINE(T, 1); \
DEFINE(T, 2); \
DEFINE(T, 3); \
DEFINE(T, 4); \
DEFINE(T, 5); \
DEFINE(T, 6); \
DEFINE(T, 7); \
- DEFINE(T, 8);
+ DEFINE(T, 8); \
+ DEFINE(T, 9); \
+ DEFINE(T, 10);
DEFINE_DIM(uint8);
DEFINE_DIM(int8);
DEFINE_DIM(int16);
@@ -49,6 +51,8 @@ DEFINE_DIM(int32);
DEFINE_DIM(int64);
DEFINE_DIM(float);
DEFINE_DIM(double);
+DEFINE_DIM(complex64);
+DEFINE_DIM(bool);
#undef DEFINE_DIM
#undef DEFINE
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index b4a42c45f8..50fc40e18d 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -103,19 +103,19 @@ using ::tensorflow::error::OK;
// }
// Declares an op deprecated, and illegal starting at GraphDef version VERSION
-#define OP_DEPRECATED(CTX, VERSION) \
+#define OP_DEPRECATED(CTX, VERSION, NOTE) \
if ((CTX)->graph_def_version() >= (VERSION)) { \
::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \
"Op ", (CTX)->op_def().name(), \
" is not available in GraphDef version ", (CTX)->graph_def_version(), \
- ". It has been removed in version ", (VERSION), ".")); \
+ ". It has been removed in version ", (VERSION), ". ", (NOTE), ".")); \
VLOG(1) << _s; \
(CTX)->SetStatus(_s); \
return; \
} else { \
LOG(WARNING) << "Op is deprecated." \
<< " It will cease to work in GraphDef version " << (VERSION) \
- << "."; \
+ << ". " << (NOTE) << "."; \
}
#define OP_REQUIRES(CTX, EXP, STATUS) \
diff --git a/tensorflow/core/lib/io/inputbuffer.cc b/tensorflow/core/lib/io/inputbuffer.cc
index 55514c0242..722f1da48c 100644
--- a/tensorflow/core/lib/io/inputbuffer.cc
+++ b/tensorflow/core/lib/io/inputbuffer.cc
@@ -61,7 +61,10 @@ Status InputBuffer::ReadLine(string* result) {
// We don't append the '\n' to *result
return Status::OK();
}
- *result += c;
+ // We don't append '\r' to *result
+ if (c != '\r') {
+ *result += c;
+ }
}
if (errors::IsOutOfRange(s) && !result->empty()) {
return Status::OK();
diff --git a/tensorflow/core/lib/io/inputbuffer_test.cc b/tensorflow/core/lib/io/inputbuffer_test.cc
index 5e4888b727..d424336e06 100644
--- a/tensorflow/core/lib/io/inputbuffer_test.cc
+++ b/tensorflow/core/lib/io/inputbuffer_test.cc
@@ -116,6 +116,32 @@ TEST(InputBuffer, ReadLine_EmptyLines) {
}
}
+TEST(InputBuffer, ReadLine_CRLF) {
+ Env* env = Env::Default();
+ string fname = testing::TmpDir() + "/inputbuffer_test";
+ WriteStringToFile(env, fname, "line one\r\n\r\n\r\nline two\r\nline three");
+
+ for (auto buf_size : BufferSizes()) {
+ RandomAccessFile* file;
+ TF_CHECK_OK(env->NewRandomAccessFile(fname, &file));
+ string line;
+ io::InputBuffer in(file, buf_size);
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line one");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line two");
+ TF_CHECK_OK(in.ReadLine(&line));
+ EXPECT_EQ(line, "line three");
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ // A second call should also return end of file
+ EXPECT_TRUE(errors::IsOutOfRange(in.ReadLine(&line)));
+ }
+}
+
TEST(InputBuffer, ReadNBytes) {
Env* env = Env::Default();
string fname = testing::TmpDir() + "/inputbuffer_test";
diff --git a/tensorflow/core/lib/strings/regexp.h b/tensorflow/core/lib/strings/regexp.h
new file mode 100644
index 0000000000..aaf58f8139
--- /dev/null
+++ b/tensorflow/core/lib/strings/regexp.h
@@ -0,0 +1,33 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_REGEXP_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_REGEXP_H_
+
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+// Conversion to/from the appropriate StringPiece type for using in RE2
+inline RegexpStringPiece ToRegexpStringPiece(tensorflow::StringPiece sp) {
+ return RegexpStringPiece(sp.data(), sp.size());
+}
+inline tensorflow::StringPiece FromRegexpStringPiece(RegexpStringPiece sp) {
+ return tensorflow::StringPiece(sp.data(), sp.size());
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_LIB_STRINGS_REGEXP_H_
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
index cd972e2371..134a49e631 100644
--- a/tensorflow/core/lib/strings/str_util.h
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -81,9 +81,7 @@ void TitlecaseString(string* s, StringPiece delimiters);
// Join functionality
template <typename T>
-string Join(const std::vector<T>& s, const char* sep);
-template <typename T>
-string Join(const gtl::ArraySlice<T>& s, const char* sep);
+string Join(const T& s, const char* sep);
struct AllowEmpty {
bool operator()(StringPiece sp) const { return true; }
@@ -110,31 +108,16 @@ bool SplitAndParseAsInts(StringPiece text, char delim,
// ------------------------------------------------------------------
// Implementation details below
-namespace internal {
template <typename T>
-string JoinHelper(typename gtl::ArraySlice<T>::const_iterator begin,
- typename gtl::ArraySlice<T>::const_iterator end,
- const char* sep) {
+string Join(const T& s, const char* sep) {
string result;
bool first = true;
- for (typename gtl::ArraySlice<T>::const_iterator it = begin; it != end;
- ++it) {
- tensorflow::strings::StrAppend(&result, (first ? "" : sep), *it);
+ for (const auto& x : s) {
+ tensorflow::strings::StrAppend(&result, (first ? "" : sep), x);
first = false;
}
return result;
}
-} // namespace internal
-
-template <typename T>
-string Join(const std::vector<T>& s, const char* sep) {
- return Join<T>(gtl::ArraySlice<T>(s), sep);
-}
-
-template <typename T>
-string Join(const gtl::ArraySlice<T>& s, const char* sep) {
- return internal::JoinHelper<T>(s.begin(), s.end(), sep);
-}
inline std::vector<string> Split(StringPiece text, char delim) {
return Split(text, delim, AllowEmpty());
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index c9d782f1c5..0bad9feb7a 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -388,6 +388,7 @@ REGISTER_OP("HashTable")
.Attr("shared_name: string = ''")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
+ .SetIsStateful()
.Doc(R"doc(
Creates a non-initialized hash table.
diff --git a/tensorflow/core/ops/function_ops.cc b/tensorflow/core/ops/function_ops.cc
new file mode 100644
index 0000000000..3842c025b3
--- /dev/null
+++ b/tensorflow/core/ops/function_ops.cc
@@ -0,0 +1,70 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/function.h"
+
+#include <unordered_set>
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+REGISTER_OP("_Arg")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("index: int >= 0")
+ .Doc(R"doc(
+A graph node which represents an argument to a function.
+
+output: The argument.
+index: This argument is the index-th argument of the function.
+)doc");
+
+REGISTER_OP("_Retval")
+ .Input("input: T")
+ .Attr("T: type")
+ .Attr("index: int >= 0")
+ .Doc(R"doc(
+A graph node which represents a return value of a function.
+
+input: The return value.
+index: This return value is the index-th return value of the function.
+)doc");
+
+REGISTER_OP("_ListToArray")
+ .Input("input: Tin")
+ .Output("output: N * T")
+ .Attr("Tin: list(type)")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Doc(R"doc(
+Converts a list of tensors to an array of tensors.
+)doc");
+
+REGISTER_OP("_ArrayToList")
+ .Input("input: N * T")
+ .Output("output: out_types")
+ .Attr("T: type")
+ .Attr("N: int >= 1")
+ .Attr("out_types: list(type)")
+ .Doc(R"doc(
+Converts an array of tensors to a list of tensors.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 3f359334b0..88e2b34d6a 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -188,6 +188,24 @@ REGISTER_OP("Tanh")
Computes hyperbolic tangent of `x` element-wise.
)doc");
+REGISTER_OP("Lgamma")
+ .UNARY()
+ .Doc(R"doc(
+Computes the log of the absolute value of Gamma of `x` element-wise.
+)doc");
+
+REGISTER_OP("Erf")
+ .UNARY()
+ .Doc(R"doc(
+Computes the Gauss error function of `x` element-wise.
+)doc");
+
+REGISTER_OP("Erfc")
+ .UNARY()
+ .Doc(R"doc(
+Computes the complementary error function of `x` element-wise.
+)doc");
+
REGISTER_OP("Sigmoid")
.UNARY()
.Doc(R"doc(
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index b90d6b2ddc..56f70f9420 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -2413,6 +2413,56 @@ op {
is_commutative: true
}
op {
+ name: "Erf"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_INT64
+ }
+ }
+ }
+ summary: "Computes the Gauss error function of `x` element-wise."
+}
+op {
+ name: "Erfc"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_INT64
+ }
+ }
+ }
+ summary: "Computes the complementary error function of `x` element-wise."
+}
+op {
name: "Exit"
input_arg {
name: "data"
@@ -2949,6 +2999,7 @@ op {
}
summary: "Creates a non-initialized hash table."
description: "This op creates a hash table, specifying the type of its keys and values.\nBefore using the table you will have to initialize it. After initialization the\ntable will be immutable."
+ is_stateful: true
}
op {
name: "HistogramSummary"
@@ -3554,6 +3605,31 @@ op {
summary: "Returns the truth value of (x <= y) element-wise."
}
op {
+ name: "Lgamma"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_INT64
+ }
+ }
+ }
+ summary: "Computes the log of the absolute value of Gamma of `x` element-wise."
+}
+op {
name: "LinSpace"
input_arg {
name: "start"
@@ -4731,6 +4807,12 @@ op {
number_attr: "Ncontext_dense"
}
input_arg {
+ name: "feature_list_sparse_keys"
+ description: "A list of Nfeature_list_sparse string Tensors\n(scalars). The keys expected in the FeatureLists associated with sparse\nvalues."
+ type: DT_STRING
+ number_attr: "Nfeature_list_sparse"
+ }
+ input_arg {
name: "feature_list_dense_keys"
description: "A list of Nfeature_list_dense string Tensors (scalars).\nThe keys expected in the SequenceExamples\' feature_lists associated\nwith lists of dense values."
type: DT_STRING
@@ -4765,27 +4847,62 @@ op {
type_list_attr: "Tcontext_dense"
}
output_arg {
+ name: "feature_list_sparse_indices"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_sparse_values"
+ type_list_attr: "feature_list_sparse_types"
+ }
+ output_arg {
+ name: "feature_list_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
name: "feature_list_dense_values"
type_list_attr: "feature_list_dense_types"
}
attr {
name: "Ncontext_sparse"
type: "int"
+ default_value {
+ i: 0
+ }
has_minimum: true
}
attr {
name: "Ncontext_dense"
type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
has_minimum: true
}
attr {
name: "Nfeature_list_dense"
type: "int"
+ default_value {
+ i: 0
+ }
has_minimum: true
}
attr {
name: "context_sparse_types"
type: "list(type)"
+ default_value {
+ list {
+ }
+ }
description: "A list of Ncontext_sparse types; the data types of data in\neach context Feature given in context_sparse_keys.\nCurrently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),\nDT_INT64 (Int64List), and DT_STRING (BytesList)."
has_minimum: true
allowed_values {
@@ -4799,6 +4916,10 @@ op {
attr {
name: "Tcontext_dense"
type: "list(type)"
+ default_value {
+ list {
+ }
+ }
has_minimum: true
allowed_values {
list {
@@ -4811,6 +4932,10 @@ op {
attr {
name: "feature_list_dense_types"
type: "list(type)"
+ default_value {
+ list {
+ }
+ }
has_minimum: true
allowed_values {
list {
@@ -4823,12 +4948,37 @@ op {
attr {
name: "context_dense_shapes"
type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
description: "A list of Ncontext_dense shapes; the shapes of data in\neach context Feature given in context_dense_keys.\nThe number of elements in the Feature corresponding to context_dense_key[j]\nmust always equal context_dense_shapes[j].NumEntries().\nThe shape of context_dense_values[j] will match context_dense_shapes[j]."
has_minimum: true
}
attr {
+ name: "feature_list_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ description: "A list of Nfeature_list_sparse types; the data types\nof data in each FeatureList given in feature_list_sparse_keys.\nCurrently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),\nDT_INT64 (Int64List), and DT_STRING (BytesList)."
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
name: "feature_list_dense_shapes"
type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
description: "A list of Nfeature_list_dense shapes; the shapes of\ndata in each FeatureList given in feature_list_dense_keys.\nThe shape of each Feature in the FeatureList corresponding to\nfeature_list_dense_key[j] must always equal\nfeature_list_dense_shapes[j].NumEntries()."
has_minimum: true
}
@@ -4986,6 +5136,39 @@ op {
description: "Reduces `input` along the dimensions given in `reduction_indices`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_indices`. If `keep_dims` is true, the reduced dimensions are\nretained with length 1."
}
op {
+ name: "PyFunc"
+ input_arg {
+ name: "input"
+ description: "List of Tensors that will provide input to the Op."
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ description: "The outputs from the Op."
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "token"
+ type: "string"
+ description: "A token representing a registered python function in this address space."
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ description: "Data types of the inputs to the op."
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ description: "Data types of the outputs from the op.\nThe length of the list specifies the number of outputs."
+ has_minimum: true
+ minimum: 1
+ }
+ summary: "Invokes a python function to compute func(input)->output."
+}
+op {
name: "QueueClose"
input_arg {
name: "handle"
@@ -6354,12 +6537,12 @@ op {
name: "ScalarSummary"
input_arg {
name: "tags"
- description: "1-D. Tags for the summary."
+ description: "Tags for the summary."
type: DT_STRING
}
input_arg {
name: "values"
- description: "1-D, same size as `tags. Values for the summary."
+ description: "Same shape as `tags. Values for the summary."
type_attr: "T"
}
output_arg {
@@ -6374,6 +6557,11 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
}
}
}
@@ -7806,6 +7994,14 @@ op {
type_attr: "T"
}
attr {
+ name: "validate_indices"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ description: "If true, indices are checked to make sure they are sorted in\nlexicographic order and that there are no repeats."
+ }
+ attr {
name: "T"
type: "type"
}
@@ -7820,7 +8016,7 @@ op {
}
}
summary: "Converts a sparse representation into a dense tensor."
- description: "Builds an array `dense` with shape `output_shape` such that\n\n```prettyprint\n# If sparse_indices is scalar\ndense[i] = (i == sparse_indices ? sparse_values : default_value)\n\n# If sparse_indices is a vector, then for each i\ndense[sparse_indices[i]] = sparse_values[i]\n\n# If sparse_indices is an n by d matrix, then for each i in [0, n)\ndense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]\n```\n\nAll other values in `dense` are set to `default_value`. If `sparse_values` is a\nscalar, all sparse indices are set to this single value."
+ description: "Builds an array `dense` with shape `output_shape` such that\n\n```prettyprint\n# If sparse_indices is scalar\ndense[i] = (i == sparse_indices ? sparse_values : default_value)\n\n# If sparse_indices is a vector, then for each i\ndense[sparse_indices[i]] = sparse_values[i]\n\n# If sparse_indices is an n by d matrix, then for each i in [0, n)\ndense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]\n```\n\nAll other values in `dense` are set to `default_value`. If `sparse_values` is a\nscalar, all sparse indices are set to this single value.\n\nIndices should be sorted in lexicographic order, and indices must not\ncontain any repeats. If `validate_indices` is true, these properties\nare checked during execution."
}
op {
name: "Split"
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index eb40783206..150e36ffad 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -87,6 +87,7 @@ REGISTER_OP("ParseSingleSequenceExample")
.Input("feature_list_dense_missing_assumed_empty: string")
.Input("context_sparse_keys: Ncontext_sparse * string")
.Input("context_dense_keys: Ncontext_dense * string")
+ .Input("feature_list_sparse_keys: Nfeature_list_sparse * string")
.Input("feature_list_dense_keys: Nfeature_list_dense * string")
.Input("context_dense_defaults: Tcontext_dense")
.Input("debug_name: string")
@@ -94,16 +95,24 @@ REGISTER_OP("ParseSingleSequenceExample")
.Output("context_sparse_values: context_sparse_types")
.Output("context_sparse_shapes: Ncontext_sparse * int64")
.Output("context_dense_values: Tcontext_dense")
+ .Output("feature_list_sparse_indices: Nfeature_list_sparse * int64")
+ .Output("feature_list_sparse_values: feature_list_sparse_types")
+ .Output("feature_list_sparse_shapes: Nfeature_list_sparse * int64")
.Output("feature_list_dense_values: feature_list_dense_types")
- .Attr("Ncontext_sparse: int >= 0") // Infer from context_sparse_keys
- .Attr("Ncontext_dense: int >= 0") // Infer from context_dense_keys
- .Attr(
- "Nfeature_list_dense: int >= 0") // Infer from feature_list_dense_keys
- .Attr("context_sparse_types: list({float,int64,string}) >= 0")
- .Attr("Tcontext_dense: list({float,int64,string}) >= 0")
- .Attr("feature_list_dense_types: list({float,int64,string}) >= 0")
- .Attr("context_dense_shapes: list(shape) >= 0")
- .Attr("feature_list_dense_shapes: list(shape) >= 0")
+ // Infer from context_sparse_keys
+ .Attr("Ncontext_sparse: int >= 0 = 0")
+ // Infer from context_dense_keys
+ .Attr("Ncontext_dense: int >= 0 = 0")
+ // Infer from feature_list_sparse_keys
+ .Attr("Nfeature_list_sparse: int >= 0 = 0")
+ // Infer from feature_list_dense_keys
+ .Attr("Nfeature_list_dense: int >= 0 = 0")
+ .Attr("context_sparse_types: list({float,int64,string}) >= 0 = []")
+ .Attr("Tcontext_dense: list({float,int64,string}) >= 0 = []")
+ .Attr("feature_list_dense_types: list({float,int64,string}) >= 0 = []")
+ .Attr("context_dense_shapes: list(shape) >= 0 = []")
+ .Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []")
+ .Attr("feature_list_dense_shapes: list(shape) >= 0 = []")
.Doc(R"doc(
Transforms a scalar brain.SequenceExample proto (as strings) into typed tensors.
@@ -148,6 +157,13 @@ context_sparse_types: A list of Ncontext_sparse types; the data types of data in
each context Feature given in context_sparse_keys.
Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
DT_INT64 (Int64List), and DT_STRING (BytesList).
+feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors
+ (scalars). The keys expected in the FeatureLists associated with sparse
+ values.
+feature_list_sparse_types: A list of Nfeature_list_sparse types; the data types
+ of data in each FeatureList given in feature_list_sparse_keys.
+ Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+ DT_INT64 (Int64List), and DT_STRING (BytesList).
)doc");
REGISTER_OP("DecodeCSV")
diff --git a/tensorflow/core/ops/script_ops.cc b/tensorflow/core/ops/script_ops.cc
new file mode 100644
index 0000000000..7b6d6d7c81
--- /dev/null
+++ b/tensorflow/core/ops/script_ops.cc
@@ -0,0 +1,37 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("PyFunc")
+ .Input("input: Tin")
+ .Output("output: Tout")
+ .Attr("token: string")
+ .Attr("Tin: list(type)")
+ .Attr("Tout: list(type)")
+ .Doc(R"doc(
+Invokes a python function to compute func(input)->output.
+
+token: A token representing a registered python function in this address space.
+input: List of Tensors that will provide input to the Op.
+output: The outputs from the Op.
+Tin: Data types of the inputs to the op.
+Tout: Data types of the outputs from the op.
+ The length of the list specifies the number of outputs.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 57913c8a30..ea8bf95ab3 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -114,8 +114,9 @@ REGISTER_OP("SparseToDense")
.Input("output_shape: Tindices")
.Input("sparse_values: T")
.Input("default_value: T")
- .Output("dense: T")
+ .Attr("validate_indices: bool = true")
.Attr("T: type")
+ .Output("dense: T")
.Attr("Tindices: {int32, int64}")
.Doc(R"doc(
Converts a sparse representation into a dense tensor.
@@ -136,6 +137,10 @@ dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
All other values in `dense` are set to `default_value`. If `sparse_values` is a
scalar, all sparse indices are set to this single value.
+Indices should be sorted in lexicographic order, and indices must not
+contain any repeats. If `validate_indices` is true, these properties
+are checked during execution.
+
sparse_indices: 0-D, 1-D, or 2-D. `sparse_indices[i]` contains the complete
index where `sparse_values[i]` will be placed.
output_shape: 1-D. Shape of the dense output tensor.
@@ -143,6 +148,8 @@ sparse_values: 1-D. Values corresponding to each row of `sparse_indices`,
or a scalar value to be used for all sparse indices.
default_value: Scalar value to set for indices not specified in
`sparse_indices`.
+validate_indices: If true, indices are checked to make sure they are sorted in
+ lexicographic order and that there are no repeats.
dense: Dense output tensor of shape `output_shape`.
)doc");
diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc
index 63fa4a8b5c..33a7a614ab 100644
--- a/tensorflow/core/ops/summary_ops.cc
+++ b/tensorflow/core/ops/summary_ops.cc
@@ -24,15 +24,15 @@ REGISTER_OP("ScalarSummary")
.Input("tags: string")
.Input("values: T")
.Output("summary: string")
- .Attr("T: {float, double}")
+ .Attr("T: realnumbertype")
.Doc(R"doc(
Outputs a `Summary` protocol buffer with scalar values.
The input `tags` and `values` must have the same shape. The generated summary
has a summary value for each tag-value pair in `tags` and `values`.
-tags: 1-D. Tags for the summary.
-values: 1-D, same size as `tags. Values for the summary.
+tags: Tags for the summary.
+values: Same shape as `tags. Values for the summary.
summary: Scalar. Serialized `Summary` protocol buffer.
)doc");
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 9497b48726..d24276c547 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/public/env.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -30,32 +31,31 @@ Thread::~Thread() {}
EnvWrapper::~EnvWrapper() {}
Status ReadFileToString(Env* env, const string& fname, string* data) {
- data->clear();
+ uint64 file_size;
+ Status s = env->GetFileSize(fname, &file_size);
+ if (!s.ok()) {
+ return s;
+ }
RandomAccessFile* file;
- Status s = env->NewRandomAccessFile(fname, &file);
+ s = env->NewRandomAccessFile(fname, &file);
if (!s.ok()) {
return s;
}
- int64 offset = 0;
- static const int kBufferSize = 8192;
- char* space = new char[kBufferSize];
- while (true) {
- StringPiece fragment;
- s = file->Read(offset, kBufferSize, &fragment, space);
- if (!s.ok()) {
- if (errors::IsOutOfRange(s)) { // No more bytes, but not an error
- s = Status::OK();
- data->append(fragment.data(), fragment.size());
- }
- break;
- }
- offset += fragment.size();
- data->append(fragment.data(), fragment.size());
- if (fragment.empty()) {
- break;
- }
+ gtl::STLStringResizeUninitialized(data, file_size);
+ char* p = gtl::string_as_array(data);
+ StringPiece result;
+ s = file->Read(0, file_size, &result, p);
+ if (!s.ok()) {
+ data->clear();
+ } else if (result.size() != file_size) {
+ s = errors::Aborted("File ", fname, " changed while reading: ", file_size,
+ " vs. ", result.size());
+ data->clear();
+ } else if (result.data() == p) {
+ // Data is already in the correct location
+ } else {
+ memmove(p, result.data(), result.size());
}
- delete[] space;
delete file;
return s;
}
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 678d427242..ca4cdc8d66 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -27,7 +27,8 @@ struct EnvTest {};
TEST(EnvTest, ReadFileToString) {
Env* env = Env::Default();
const string dir = testing::TmpDir();
- for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000}) {
+ for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000, (1 << 20) - 1,
+ 1 << 20, (1 << 20) + 1}) {
const string filename = io::JoinPath(dir, strings::StrCat("file", length));
// Write a file with the given length
diff --git a/tensorflow/core/platform/load_library.cc b/tensorflow/core/platform/load_library.cc
new file mode 100644
index 0000000000..aff2562d95
--- /dev/null
+++ b/tensorflow/core/platform/load_library.cc
@@ -0,0 +1,44 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <dlfcn.h>
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace internal {
+
+Status LoadLibrary(const char* library_filename, void** handle) {
+ *handle = dlopen(library_filename, RTLD_NOW | RTLD_LOCAL);
+ if (!*handle) {
+ return errors::NotFound("Unable to find library ", library_filename);
+ }
+ return Status::OK();
+}
+
+Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
+ void** symbol) {
+ *symbol = dlsym(handle, symbol_name);
+ if (!*symbol) {
+ return errors::NotFound("Unable to find symbol ", symbol_name,
+ " in library");
+ }
+ return Status::OK();
+}
+
+} // namespace internal
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/load_library.h b/tensorflow/core/platform/load_library.h
new file mode 100644
index 0000000000..eb546acc55
--- /dev/null
+++ b/tensorflow/core/platform/load_library.h
@@ -0,0 +1,33 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
+#define TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
+
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+namespace internal {
+
+Status LoadLibrary(const char* library_filename, void** handle);
+Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
+ void** symbol);
+
+} // namespace internal
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc
index 2c8daf98a5..164d11a81f 100644
--- a/tensorflow/core/platform/posix/env.cc
+++ b/tensorflow/core/platform/posix/env.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include <thread>
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/platform/load_library.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/env.h"
@@ -397,9 +398,20 @@ class PosixEnv : public Env {
// TODO(mrry): Replace with a non-blocking timer mechanism and threadpool.
CHECK(false) << "PosixEnv::SchedClosureAfter not implemented.";
}
+
+ Status LoadLibrary(const char* library_filename, void** handle) override {
+ return tensorflow::internal::LoadLibrary(library_filename, handle);
+ }
+
+ Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
+ void** symbol) override {
+ return tensorflow::internal::GetSymbolFromLibrary(handle, symbol_name,
+ symbol);
+ }
};
} // namespace
+
#if defined(PLATFORM_POSIX) || defined(__ANDROID__)
Env* Env::Default() {
static Env* default_env = new PosixEnv;
diff --git a/tensorflow/core/platform/regexp.h b/tensorflow/core/platform/regexp.h
index 8432f47289..52fb475062 100644
--- a/tensorflow/core/platform/regexp.h
+++ b/tensorflow/core/platform/regexp.h
@@ -33,16 +33,4 @@ typedef re2::StringPiece RegexpStringPiece;
#endif
-namespace tensorflow {
-
-// Conversion to/from the appropriate StringPiece type for using in RE2
-inline RegexpStringPiece ToRegexpStringPiece(tensorflow::StringPiece sp) {
- return RegexpStringPiece(sp.data(), sp.size());
-}
-inline tensorflow::StringPiece FromRegexpStringPiece(RegexpStringPiece sp) {
- return tensorflow::StringPiece(sp.data(), sp.size());
-}
-
-} // namespace tensorflow
-
#endif // TENSORFLOW_PLATFORM_REGEXP_H_
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index 23d70fc3eb..0ce7bd379d 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -152,6 +152,9 @@ class Tracing::Engine {
Engine() {}
virtual ~Engine();
+ // Returns true if Tracing is currently enabled.
+ virtual bool IsEnabled() const = 0;
+
// Represents an active annotation.
class Annotation {
public:
@@ -225,7 +228,7 @@ class Tracing::TraceMe {
inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name) {
auto e = Tracing::engine();
- if (e) {
+ if (e && e->IsEnabled()) {
annotation_.reset(e->PushAnnotation(name));
}
}
@@ -233,7 +236,7 @@ inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name) {
inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name_part1,
StringPiece name_part2) {
auto e = Tracing::engine();
- if (e) {
+ if (e && e->IsEnabled()) {
annotation_.reset(
e->PushAnnotation(strings::StrCat(name_part1, ":", name_part2)));
}
@@ -241,7 +244,7 @@ inline Tracing::ScopedAnnotation::ScopedAnnotation(StringPiece name_part1,
inline Tracing::TraceMe::TraceMe(StringPiece name) {
auto e = Tracing::engine();
- if (e) {
+ if (e && e->IsEnabled()) {
tracer_.reset(e->StartTracing(name));
}
}
diff --git a/tensorflow/core/public/env.h b/tensorflow/core/public/env.h
index ac34a02c89..e40fe8974f 100644
--- a/tensorflow/core/public/env.h
+++ b/tensorflow/core/public/env.h
@@ -145,6 +145,27 @@ class Env {
// NOTE(mrry): This closure must not block.
virtual void SchedClosureAfter(int micros, std::function<void()> closure) = 0;
+ // \brief Load a dynamic library.
+ //
+ // Pass "library_filename" to a platform-specific mechanism for dynamically
+ // loading a library. The rules for determining the exact location of the
+ // library are platform-specific and are not documented here.
+ //
+ // On success, returns a handle to the library in "*handle" and returns
+ // OK from the function.
+ // Otherwise returns nullptr in "*handle" and an error status from the
+ // function.
+ virtual Status LoadLibrary(const char* library_filename, void** handle) = 0;
+
+ // \brief Get a pointer to a symbol from a dynamic library.
+ //
+ // "handle" should be a pointer returned from a previous call to LoadLibrary.
+ // On success, store a pointer to the located symbol in "*symbol" and return
+ // OK from the function. Otherwise, returns nullptr in "*symbol" and an error
+ // status from the function.
+ virtual Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
+ void** symbol) = 0;
+
private:
/// No copying allowed
Env(const Env&);
@@ -251,6 +272,13 @@ class EnvWrapper : public Env {
void SchedClosureAfter(int micros, std::function<void()> closure) override {
target_->SchedClosureAfter(micros, closure);
}
+ Status LoadLibrary(const char* library_filename, void** handle) override {
+ return target_->LoadLibrary(library_filename, handle);
+ }
+ Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
+ void** symbol) override {
+ return target_->GetSymbolFromLibrary(handle, symbol_name, symbol);
+ }
private:
Env* target_;
diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h
index 5d90e80342..22219d1413 100644
--- a/tensorflow/core/public/tensor_c_api.h
+++ b/tensorflow/core/public/tensor_c_api.h
@@ -117,6 +117,21 @@ typedef enum {
// else an error code with an associated error message.
typedef struct TF_Status TF_Status;
+// --------------------------------------------------------------------------
+// TF_Buffer holds a pointer to a block of data and its associated length.
+// Typically, the data consists of a serialized protocol buffer, but other data
+// may also be held in a buffer.
+//
+// TF_Buffer itself does not do any memory management of the pointed-to block.
+typedef struct {
+ const void* data;
+ size_t length;
+} TF_Buffer;
+
+// --------------------------------------------------------------------------
+// TF_Library holds information about dynamically loaded TensorFlow plugins.
+typedef struct TF_Library TF_Library;
+
// Return a new status object.
extern TF_Status* TF_NewStatus();
@@ -253,6 +268,32 @@ extern void TF_Run(TF_Session*,
// Output status
TF_Status*);
+// --------------------------------------------------------------------------
+// Load plugins containing custom ops and kernels
+
+// Load the library specified by library_filename and register the ops and
+// kernels present in that library.
+//
+// Pass "library_filename" to a platform-specific mechanism for dynamically
+// loading a library. The rules for determining the exact location of the
+// library are platform-specific and are not documented here.
+// Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
+// defined in the library.
+//
+// On success, place OK in status and return the newly created library handle.
+// The caller owns the library handle.
+//
+// On failure, place an error status in status and return nullptr.
+extern TF_Library* TF_LoadLibrary(const char* library_filename,
+ TF_Status* status);
+
+// Get the OpList of OpDefs defined in the library pointed by lib_handle.
+//
+// Returns a TF_Buffer. The memory pointed to by the result is owned by
+// lib_handle. The data in the buffer will be the serialized OpList proto for
+// ops defined in the library.
+extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 0884538d52..5863a3d782 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -38,7 +38,7 @@ limitations under the License.
// Supported GraphDef versions (see graph.proto).
#define TF_GRAPH_DEF_VERSION_MIN 0
-#define TF_GRAPH_DEF_VERSION_MAX 1
+#define TF_GRAPH_DEF_VERSION_MAX 3
#define TF_GRAPH_DEF_VERSION TF_GRAPH_DEF_VERSION_MAX
#endif // THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h
index 9681dc4c18..19aee104dd 100644
--- a/tensorflow/core/util/bcast.h
+++ b/tensorflow/core/util/bcast.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@@ -71,7 +72,7 @@ class BCast {
// element is the outer-most dimension and the last element is the
// inner-most dimension. Note that we do not use TensorShape since
// it's more convenient to manipulate Vec directly for this module.
- typedef std::vector<int64> Vec;
+ typedef gtl::InlinedVector<int64, 4> Vec;
BCast(const Vec& x, const Vec& y);
~BCast() {}
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index 4fc90730d5..55eff77960 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -22,7 +22,7 @@ py_library(
name = "input_data",
srcs = ["input_data.py"],
srcs_version = "PY2AND3",
- visibility = ["//tensorflow:__subpackages__"],
+ visibility = ["//tensorflow:internal"],
deps = ["//tensorflow:tensorflow_py"],
)
diff --git a/tensorflow/examples/tutorials/mnist/input_data.py b/tensorflow/examples/tutorials/mnist/input_data.py
index ae3727c82e..bca3bcf008 100644
--- a/tensorflow/examples/tutorials/mnist/input_data.py
+++ b/tensorflow/examples/tutorials/mnist/input_data.py
@@ -21,9 +21,12 @@ from __future__ import print_function
import gzip
import os
+import tensorflow.python.platform
+
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
@@ -91,9 +94,18 @@ def extract_labels(filename, one_hot=False):
class DataSet(object):
- def __init__(self, images, labels, fake_data=False, one_hot=False):
- """Construct a DataSet. one_hot arg is used only if fake_data is true."""
-
+ def __init__(self, images, labels, fake_data=False, one_hot=False,
+ dtype=tf.float32):
+ """Construct a DataSet.
+
+ one_hot arg is used only if fake_data is true. `dtype` can be either
+ `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
+ `[0, 1]`.
+ """
+ dtype = tf.as_dtype(dtype).base_dtype
+ if dtype not in (tf.uint8, tf.float32):
+ raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
+ dtype)
if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
@@ -108,9 +120,10 @@ class DataSet(object):
assert images.shape[3] == 1
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2])
- # Convert from [0, 255] -> [0.0, 1.0].
- images = images.astype(numpy.float32)
- images = numpy.multiply(images, 1.0 / 255.0)
+ if dtype == tf.float32:
+ # Convert from [0, 255] -> [0.0, 1.0].
+ images = images.astype(numpy.float32)
+ images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
@@ -160,15 +173,17 @@ class DataSet(object):
return self._images[start:end], self._labels[start:end]
-def read_data_sets(train_dir, fake_data=False, one_hot=False):
+def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
class DataSets(object):
pass
data_sets = DataSets()
if fake_data:
- data_sets.train = DataSet([], [], fake_data=True, one_hot=one_hot)
- data_sets.validation = DataSet([], [], fake_data=True, one_hot=one_hot)
- data_sets.test = DataSet([], [], fake_data=True, one_hot=one_hot)
+ def fake():
+ return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
+ data_sets.train = fake()
+ data_sets.validation = fake()
+ data_sets.test = fake()
return data_sets
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
@@ -194,8 +209,9 @@ def read_data_sets(train_dir, fake_data=False, one_hot=False):
train_images = train_images[VALIDATION_SIZE:]
train_labels = train_labels[VALIDATION_SIZE:]
- data_sets.train = DataSet(train_images, train_labels)
- data_sets.validation = DataSet(validation_images, validation_labels)
- data_sets.test = DataSet(test_images, test_labels)
+ data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
+ data_sets.validation = DataSet(validation_images, validation_labels,
+ dtype=dtype)
+ data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
return data_sets
diff --git a/tensorflow/g3doc/api_docs/cc/ClassEnv.md b/tensorflow/g3doc/api_docs/cc/ClassEnv.md
index 38bb94ac63..ca50ca1563 100644
--- a/tensorflow/g3doc/api_docs/cc/ClassEnv.md
+++ b/tensorflow/g3doc/api_docs/cc/ClassEnv.md
@@ -36,6 +36,10 @@ All Env implementations are safe for concurrent access from multiple threads wit
* Sleeps/delays the thread for the prescribed number of micro-seconds.
* [`virtual Thread* tensorflow::Env::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) TF_MUST_USE_RESULT=0`](#virtual_Thread_tensorflow_Env_StartThread)
* Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name".
+* [`virtual void tensorflow::Env::SchedClosure(std::function< void()> closure)=0`](#virtual_void_tensorflow_Env_SchedClosure)
+* [`virtual void tensorflow::Env::SchedClosureAfter(int micros, std::function< void()> closure)=0`](#virtual_void_tensorflow_Env_SchedClosureAfter)
+* [`virtual Status tensorflow::Env::LoadLibrary(const char *library_filename, void **handle)=0`](#virtual_Status_tensorflow_Env_LoadLibrary)
+* [`virtual Status tensorflow::Env::GetSymbolFromLibrary(void *handle, const char *symbol_name, void **symbol)=0`](#virtual_Status_tensorflow_Env_GetSymbolFromLibrary)
* [`static Env* tensorflow::Env::Default()`](#static_Env_tensorflow_Env_Default)
* Returns a default environment suitable for the current operating system.
@@ -137,6 +141,30 @@ Returns a new thread that is running fn() and is identified (for debugging/perfo
Caller takes ownership of the result and must delete it eventually (the deletion will block until fn() stops running).
+#### `virtual void tensorflow::Env::SchedClosure(std::function< void()> closure)=0` {#virtual_void_tensorflow_Env_SchedClosure}
+
+
+
+
+
+#### `virtual void tensorflow::Env::SchedClosureAfter(int micros, std::function< void()> closure)=0` {#virtual_void_tensorflow_Env_SchedClosureAfter}
+
+
+
+
+
+#### `virtual Status tensorflow::Env::LoadLibrary(const char *library_filename, void **handle)=0` {#virtual_Status_tensorflow_Env_LoadLibrary}
+
+
+
+
+
+#### `virtual Status tensorflow::Env::GetSymbolFromLibrary(void *handle, const char *symbol_name, void **symbol)=0` {#virtual_Status_tensorflow_Env_GetSymbolFromLibrary}
+
+
+
+
+
#### `static Env* tensorflow::Env::Default()` {#static_Env_tensorflow_Env_Default}
Returns a default environment suitable for the current operating system.
diff --git a/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md b/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md
index 9ed2a97016..bdfb1af1d4 100644
--- a/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md
+++ b/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md
@@ -37,6 +37,10 @@ May be useful to clients who wish to override just part of the functionality of
* Sleeps/delays the thread for the prescribed number of micro-seconds.
* [`Thread* tensorflow::EnvWrapper::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) override`](#Thread_tensorflow_EnvWrapper_StartThread)
* Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name".
+* [`void tensorflow::EnvWrapper::SchedClosure(std::function< void()> closure) override`](#void_tensorflow_EnvWrapper_SchedClosure)
+* [`void tensorflow::EnvWrapper::SchedClosureAfter(int micros, std::function< void()> closure) override`](#void_tensorflow_EnvWrapper_SchedClosureAfter)
+* [`Status tensorflow::EnvWrapper::LoadLibrary(const char *library_filename, void **handle) override`](#Status_tensorflow_EnvWrapper_LoadLibrary)
+* [`Status tensorflow::EnvWrapper::GetSymbolFromLibrary(void *handle, const char *symbol_name, void **symbol) override`](#Status_tensorflow_EnvWrapper_GetSymbolFromLibrary)
##Member Details
@@ -141,3 +145,27 @@ Sleeps/delays the thread for the prescribed number of micro-seconds.
Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name".
Caller takes ownership of the result and must delete it eventually (the deletion will block until fn() stops running).
+
+#### `void tensorflow::EnvWrapper::SchedClosure(std::function< void()> closure) override` {#void_tensorflow_EnvWrapper_SchedClosure}
+
+
+
+
+
+#### `void tensorflow::EnvWrapper::SchedClosureAfter(int micros, std::function< void()> closure) override` {#void_tensorflow_EnvWrapper_SchedClosureAfter}
+
+
+
+
+
+#### `Status tensorflow::EnvWrapper::LoadLibrary(const char *library_filename, void **handle) override` {#Status_tensorflow_EnvWrapper_LoadLibrary}
+
+
+
+
+
+#### `Status tensorflow::EnvWrapper::GetSymbolFromLibrary(void *handle, const char *symbol_name, void **symbol) override` {#Status_tensorflow_EnvWrapper_GetSymbolFromLibrary}
+
+
+
+
diff --git a/tensorflow/g3doc/api_docs/cc/ClassSession.md b/tensorflow/g3doc/api_docs/cc/ClassSession.md
index a0fe3a4a30..ffe51ca310 100644
--- a/tensorflow/g3doc/api_docs/cc/ClassSession.md
+++ b/tensorflow/g3doc/api_docs/cc/ClassSession.md
@@ -33,7 +33,7 @@ if (output_tensor(0) > 0.5) { ... }
// Close the session to release the resources associated with
// this session.
-session->Close()
+session->Close();
```
diff --git a/tensorflow/g3doc/api_docs/cc/StructTF_Buffer.md b/tensorflow/g3doc/api_docs/cc/StructTF_Buffer.md
new file mode 100644
index 0000000000..e0e46a8794
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/cc/StructTF_Buffer.md
@@ -0,0 +1,24 @@
+# Struct `TF_Buffer`
+
+
+
+
+
+##Member Summary
+
+* [`const void* TF_Buffer::data`](#const_void_TF_Buffer_data)
+* [`size_t TF_Buffer::length`](#size_t_TF_Buffer_length)
+
+##Member Details
+
+#### `const void* TF_Buffer::data` {#const_void_TF_Buffer_data}
+
+
+
+
+
+#### `size_t TF_Buffer::length` {#size_t_TF_Buffer_length}
+
+
+
+
diff --git a/tensorflow/g3doc/api_docs/cc/index.md b/tensorflow/g3doc/api_docs/cc/index.md
index 2bb24375cb..97abde341e 100644
--- a/tensorflow/g3doc/api_docs/cc/index.md
+++ b/tensorflow/g3doc/api_docs/cc/index.md
@@ -46,6 +46,7 @@ write the graph to a file.
* [tensorflow::TensorShape](ClassTensorShape.md)
* [tensorflow::TensorShapeDim](StructTensorShapeDim.md)
* [tensorflow::TensorShapeUtils](ClassTensorShapeUtils.md)
+* [TF_Buffer](StructTF_Buffer.md)
## Thread
@@ -68,6 +69,7 @@ write the graph to a file.
<!-- ClassTensorShape.md -->
<!-- StructTensorShapeDim.md -->
<!-- ClassTensorShapeUtils.md -->
+<!-- StructTF_Buffer.md -->
<!-- ClassThread.md -->
<!-- StructThreadOptions.md -->
-->
diff --git a/tensorflow/g3doc/api_docs/index.md b/tensorflow/g3doc/api_docs/index.md
index 7e41a44d7f..f58624cf51 100644
--- a/tensorflow/g3doc/api_docs/index.md
+++ b/tensorflow/g3doc/api_docs/index.md
@@ -1,4 +1,4 @@
-# Overview
+# API Documentation
TensorFlow has APIs available in several languages both for constructing and
executing a TensorFlow graph. The Python API is at present the most complete
diff --git a/tensorflow/g3doc/api_docs/python/client.md b/tensorflow/g3doc/api_docs/python/client.md
index 4f243e58b3..4feeb7556a 100644
--- a/tensorflow/g3doc/api_docs/python/client.md
+++ b/tensorflow/g3doc/api_docs/python/client.md
@@ -277,7 +277,7 @@ with tf.Session():
- - -
-#### `tf.InteractiveSession.__init__(target='', graph=None)` {#InteractiveSession.__init__}
+#### `tf.InteractiveSession.__init__(target='', graph=None, config=None)` {#InteractiveSession.__init__}
Creates a new interactive TensorFlow session.
@@ -296,6 +296,7 @@ the session constructor.
Defaults to using an in-process engine. At present, no value
other than the empty string is supported.
* <b>`graph`</b>: (Optional.) The `Graph` to be launched (described above).
+* <b>`config`</b>: (Optional) `ConfigProto` proto used to configure the session.
- - -
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index 6fd8e44dc1..9ee01c727d 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1339,7 +1339,7 @@ Converts the given `type_value` to a `DType`.
Wrapper for `Graph.device()` using the default graph.
See
-[`Graph.name_scope()`](../../api_docs/python/framework.md#Graph.name_scope)
+[`Graph.device()`](../../api_docs/python/framework.md#Graph.device)
for more details.
##### Args:
@@ -1544,6 +1544,35 @@ protocol buffer, and extract individual objects in the `GraphDef` as
it refers to an unknown tensor).
+- - -
+
+### `tf.load_op_library(library_filename)` {#load_op_library}
+
+Loads a TensorFlow plugin, containing custom ops and kernels.
+
+Pass "library_filename" to a platform-specific mechanism for dynamically
+loading a library. The rules for determining the exact location of the
+library are platform-specific and are not documented here.
+Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
+defined in the library.
+
+##### Args:
+
+
+* <b>`library_filename`</b>: Path to the plugin.
+ Relative or absolute filesystem path to a dynamic library file.
+
+##### Returns:
+
+ A python module containing the Python wrappers for Ops defined in
+ the plugin.
+
+##### Raises:
+
+
+* <b>`RuntimeError`</b>: when unable to load the library or get the python wrappers.
+
+
## Graph collections
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index a2165c834d..efd38ed148 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -17,6 +17,7 @@
* [`Graph`](../../api_docs/python/framework.md#Graph)
* [`GraphKeys`](../../api_docs/python/framework.md#GraphKeys)
* [`import_graph_def`](../../api_docs/python/framework.md#import_graph_def)
+ * [`load_op_library`](../../api_docs/python/framework.md#load_op_library)
* [`name_scope`](../../api_docs/python/framework.md#name_scope)
* [`NoGradient`](../../api_docs/python/framework.md#NoGradient)
* [`op_scope`](../../api_docs/python/framework.md#op_scope)
@@ -124,6 +125,8 @@
* [`diag`](../../api_docs/python/math_ops.md#diag)
* [`div`](../../api_docs/python/math_ops.md#div)
* [`edit_distance`](../../api_docs/python/math_ops.md#edit_distance)
+ * [`erf`](../../api_docs/python/math_ops.md#erf)
+ * [`erfc`](../../api_docs/python/math_ops.md#erfc)
* [`exp`](../../api_docs/python/math_ops.md#exp)
* [`fft2d`](../../api_docs/python/math_ops.md#fft2d)
* [`floor`](../../api_docs/python/math_ops.md#floor)
@@ -132,6 +135,7 @@
* [`imag`](../../api_docs/python/math_ops.md#imag)
* [`inv`](../../api_docs/python/math_ops.md#inv)
* [`invert_permutation`](../../api_docs/python/math_ops.md#invert_permutation)
+ * [`lgamma`](../../api_docs/python/math_ops.md#lgamma)
* [`listdiff`](../../api_docs/python/math_ops.md#listdiff)
* [`log`](../../api_docs/python/math_ops.md#log)
* [`matmul`](../../api_docs/python/math_ops.md#matmul)
@@ -355,6 +359,7 @@
* [`gradients`](../../api_docs/python/train.md#gradients)
* [`histogram_summary`](../../api_docs/python/train.md#histogram_summary)
* [`image_summary`](../../api_docs/python/train.md#image_summary)
+ * [`LooperThread`](../../api_docs/python/train.md#LooperThread)
* [`merge_all_summaries`](../../api_docs/python/train.md#merge_all_summaries)
* [`merge_summary`](../../api_docs/python/train.md#merge_summary)
* [`MomentumOptimizer`](../../api_docs/python/train.md#MomentumOptimizer)
@@ -369,3 +374,6 @@
* [`write_graph`](../../api_docs/python/train.md#write_graph)
* [`zero_fraction`](../../api_docs/python/train.md#zero_fraction)
+* **[Wraps python functions](../../api_docs/python/script_ops.md)**:
+ * [`py_func`](../../api_docs/python/script_ops.md#py_func)
+
diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md
index 878bae23d4..66a9c20e04 100644
--- a/tensorflow/g3doc/api_docs/python/math_ops.md
+++ b/tensorflow/g3doc/api_docs/python/math_ops.md
@@ -527,6 +527,63 @@ Computes sin of x element-wise.
A `Tensor`. Has the same type as `x`.
+- - -
+
+### `tf.lgamma(x, name=None)` {#lgamma}
+
+Computes `ln(|gamma(x)|)` element-wise.
+
+##### Args:
+
+
+* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `int64`,
+ or `qint32`.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+
+
+- - -
+
+### `tf.erf(x, name=None)` {#erf}
+
+Computes Gauss error function of `x` element-wise.
+
+##### Args:
+
+
+* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `int64`,
+ or `qint32`.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+
+
+- - -
+
+### `tf.erfc(x, name=None)` {#erfc}
+
+Computes complementary error function of `x` element-wise.
+
+##### Args:
+
+
+* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `int64`,
+ or `qint32`.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+
+
## Matrix Math Functions
diff --git a/tensorflow/g3doc/api_docs/python/script_ops.md b/tensorflow/g3doc/api_docs/python/script_ops.md
new file mode 100644
index 0000000000..aa1ff4c9c5
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/script_ops.md
@@ -0,0 +1,46 @@
+<!-- This file is machine generated: DO NOT EDIT! -->
+
+# Wraps python functions
+
+Note: Functions taking `Tensor` arguments can also take anything accepted by
+[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor).
+
+[TOC]
+
+## Script Language Operators.
+
+TensorFlow provides allows you to wrap python/numpy functions as
+TensorFlow operators.
+
+## Other Functions and Classes
+- - -
+
+### `tf.py_func(func, inp, Tout, name=None)` {#py_func}
+
+Wraps a python function and uses it as a tensorflow op.
+
+Given a python function `func`, which takes numpy arrays as its
+inputs and returns numpy arrays as its outputs. E.g.,
+
+ def my_func(x):
+ return np.sinh(x)
+ inp = tf.placeholder(..., tf.float32)
+ y = py_func(my_func, [inp], [tf.float32])
+
+The above snippet constructs a tf graph which invokes a numpy
+sinh(x) as an op in the graph.
+
+##### Args:
+
+
+* <b>`func`</b>: A python function.
+* <b>`inp`</b>: A list of `Tensor`.
+* <b>`Tout`</b>: A list of tensorflow data types indicating what `func`
+ returns.
+* <b>`name`</b>: A name for the operation (optional).
+
+##### Returns:
+
+ A list of `Tensor` which `func` computes.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/sparse_ops.md b/tensorflow/g3doc/api_docs/python/sparse_ops.md
index cdcb5e2c0f..afaae8facf 100644
--- a/tensorflow/g3doc/api_docs/python/sparse_ops.md
+++ b/tensorflow/g3doc/api_docs/python/sparse_ops.md
@@ -157,7 +157,7 @@ Alias for field number 1
- - -
-### `tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0, name=None)` {#sparse_to_dense}
+### `tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0, validate_indices=True, name=None)` {#sparse_to_dense}
Converts a sparse representation into a dense tensor.
@@ -177,6 +177,10 @@ dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
All other values in `dense` are set to `default_value`. If `sparse_values`
is a scalar, all sparse indices are set to this single value.
+Indices should be sorted in lexicographic order, and indices must not
+contain any repeats. If `validate_indices` is True, these properties
+are checked during execution.
+
##### Args:
@@ -189,6 +193,8 @@ is a scalar, all sparse indices are set to this single value.
`sparse_indices`, or a scalar value to be used for all sparse indices.
* <b>`default_value`</b>: A 0-D `Tensor` of the same type as `sparse_values`. Value
to set for indices not specified in `sparse_indices`. Defaults to zero.
+* <b>`validate_indices`</b>: A boolean value. If True, indices are checked to make
+ sure they are sorted in lexicographic order and that there are no repeats.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
@@ -199,7 +205,7 @@ is a scalar, all sparse indices are set to this single value.
- - -
-### `tf.sparse_tensor_to_dense(sp_input, default_value=0, name=None)` {#sparse_tensor_to_dense}
+### `tf.sparse_tensor_to_dense(sp_input, default_value=0, validate_indices=True, name=None)` {#sparse_tensor_to_dense}
Converts a `SparseTensor` into a dense tensor.
@@ -218,12 +224,17 @@ string tensor with values:
[x x x x x]
[c x x x x]]
+Indices must be without repeats. This is only
+tested if validate_indices is True.
+
##### Args:
* <b>`sp_input`</b>: The input `SparseTensor`.
* <b>`default_value`</b>: Scalar value to set for indices not specified in
`sp_input`. Defaults to zero.
+* <b>`validate_indices`</b>: A boolean value. If `True`, indices are checked to make
+ sure they are sorted in lexicographic order and that there are no repeats.
* <b>`name`</b>: A name prefix for the returned tensors (optional).
##### Returns:
@@ -257,15 +268,18 @@ For example, if `sp_input.shape = [2, 3, 4]` with non-empty values:
[0, 0, 0]: 0
[0, 1, 0]: 10
[1, 0, 3]: 103
- [1, 1, 2]: 112
- [1, 1, 3]: 113
+ [1, 1, 2]: 150
+ [1, 1, 3]: 149
+ [1, 1, 4]: 150
[1, 2, 1]: 121
and `vocab_size = 200`, then the output will be a `[2, 3, 200]` dense bool
tensor with False everywhere except at positions
- (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 112), (1, 1, 113), (1, 2, 121).
+ (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 149), (1, 1, 150),
+ (1, 2, 121).
+Note that repeats are allowed in the input SparseTensor.
This op is useful for converting `SparseTensor`s into dense formats for
compatibility with ops that expect dense tensors.
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index a1b7244c56..7a18323685 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -784,7 +784,7 @@ Sets the list of old checkpoint filenames.
##### Raises:
-* <b>`AssertionError`</b>: If the list of checkpoint filenames has already been set.
+* <b>`AssertionError`</b>: If last_checkpoints is not a list.
- - -
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index 2a9687535c..4bfbd3007f 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -32,7 +32,7 @@ opt = GradientDescentOptimizer(learning_rate=0.1)
# Add Ops to the graph to minimize a cost by updating a list of variables.
# "cost" is a Tensor, and the list of variables contains tf.Variable
# objects.
-opt_op = opt.minimize(cost, <list of variables>)
+opt_op = opt.minimize(cost, var_list=<list of variables>)
```
In the training program you will just have to run the returned Op.
@@ -1471,8 +1471,8 @@ summary has a summary value for each tag-value pair in `tags` and `values`.
##### Args:
-* <b>`tags`</b>: A 1-D `string` `Tensor`. Tags for the summaries.
-* <b>`values`</b>: A 1-D `float32` or `float64` Tensor. Values for the summaries.
+* <b>`tags`</b>: A `string` `Tensor`. Tags for the summaries.
+* <b>`values`</b>: A real numeric Tensor. Values for the summaries.
* <b>`collections`</b>: Optional list of graph collections keys. The new summary op is
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
* <b>`name`</b>: A name for the operation (optional).
@@ -1874,3 +1874,216 @@ tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
* <b>`as_text`</b>: If `True`, writes the graph as an ASCII proto.
+
+## Other Functions and Classes
+- - -
+
+### `class tf.train.LooperThread` {#LooperThread}
+
+A thread that runs code repeatedly, optionally on a timer.
+
+This thread class is intended to be used with a `Coordinator`. It repeatedly
+runs code specified either as `target` and `args` or by the `run_loop()`
+method.
+
+Before each run the thread checks if the coordinator has requested stop. In
+that case the looper thread terminates immediately.
+
+If the code being run raises an exception, that exception is reported to the
+coordinator and the thread terminates. The coordinator will then request all
+the other threads it coordinates to stop.
+
+You typically pass looper threads to the supervisor `Join()` method.
+- - -
+
+#### `tf.train.LooperThread.__init__(coord, timer_interval_secs, target=None, args=None)` {#LooperThread.__init__}
+
+Create a LooperThread.
+
+##### Args:
+
+
+* <b>`coord`</b>: a Coordinator.
+* <b>`timer_interval_secs`</b>: Time boundaries at which to call Run(), or None
+ if it should be called back to back.
+* <b>`target`</b>: Optional callable object that will be executed in the thread.
+* <b>`args`</b>: Optional arguments to pass to `target` when calling it.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: If one of the arguments is invalid.
+
+
+- - -
+
+#### `tf.train.LooperThread.daemon` {#LooperThread.daemon}
+
+A boolean value indicating whether this thread is a daemon thread (True) or not (False).
+
+This must be set before start() is called, otherwise RuntimeError is
+raised. Its initial value is inherited from the creating thread; the
+main thread is not a daemon thread and therefore all threads created in
+the main thread default to daemon = False.
+
+The entire Python program exits when no alive non-daemon threads are
+left.
+
+
+- - -
+
+#### `tf.train.LooperThread.getName()` {#LooperThread.getName}
+
+
+
+
+- - -
+
+#### `tf.train.LooperThread.ident` {#LooperThread.ident}
+
+Thread identifier of this thread or None if it has not been started.
+
+This is a nonzero integer. See the thread.get_ident() function. Thread
+identifiers may be recycled when a thread exits and another thread is
+created. The identifier is available even after the thread has exited.
+
+
+- - -
+
+#### `tf.train.LooperThread.isAlive()` {#LooperThread.isAlive}
+
+Return whether the thread is alive.
+
+This method returns True just before the run() method starts until just
+after the run() method terminates. The module function enumerate()
+returns a list of all alive threads.
+
+
+- - -
+
+#### `tf.train.LooperThread.isDaemon()` {#LooperThread.isDaemon}
+
+
+
+
+- - -
+
+#### `tf.train.LooperThread.is_alive()` {#LooperThread.is_alive}
+
+Return whether the thread is alive.
+
+This method returns True just before the run() method starts until just
+after the run() method terminates. The module function enumerate()
+returns a list of all alive threads.
+
+
+- - -
+
+#### `tf.train.LooperThread.join(timeout=None)` {#LooperThread.join}
+
+Wait until the thread terminates.
+
+This blocks the calling thread until the thread whose join() method is
+called terminates -- either normally or through an unhandled exception
+or until the optional timeout occurs.
+
+When the timeout argument is present and not None, it should be a
+floating point number specifying a timeout for the operation in seconds
+(or fractions thereof). As join() always returns None, you must call
+isAlive() after join() to decide whether a timeout happened -- if the
+thread is still alive, the join() call timed out.
+
+When the timeout argument is not present or None, the operation will
+block until the thread terminates.
+
+A thread can be join()ed many times.
+
+join() raises a RuntimeError if an attempt is made to join the current
+thread as that would cause a deadlock. It is also an error to join() a
+thread before it has been started and attempts to do so raises the same
+exception.
+
+
+- - -
+
+#### `tf.train.LooperThread.loop(coord, timer_interval_secs, target, args=None)` {#LooperThread.loop}
+
+Start a LooperThread that calls a function periodically.
+
+If `timer_interval_secs` is None the thread calls `target(args)`
+repeatedly. Otherwise `target(args)` is called every `timer_interval_secs`
+seconds. The thread terminates when a stop of the coordinator is
+requested.
+
+##### Args:
+
+
+* <b>`coord`</b>: A Coordinator.
+* <b>`timer_interval_secs`</b>: Number. Time boundaries at which to call `target`.
+* <b>`target`</b>: A callable object.
+* <b>`args`</b>: Optional arguments to pass to `target` when calling it.
+
+##### Returns:
+
+ The started thread.
+
+
+- - -
+
+#### `tf.train.LooperThread.name` {#LooperThread.name}
+
+A string used for identification purposes only.
+
+It has no semantics. Multiple threads may be given the same name. The
+initial name is set by the constructor.
+
+
+- - -
+
+#### `tf.train.LooperThread.run()` {#LooperThread.run}
+
+
+
+
+- - -
+
+#### `tf.train.LooperThread.run_loop()` {#LooperThread.run_loop}
+
+Called at 'timer_interval_secs' boundaries.
+
+
+- - -
+
+#### `tf.train.LooperThread.setDaemon(daemonic)` {#LooperThread.setDaemon}
+
+
+
+
+- - -
+
+#### `tf.train.LooperThread.setName(name)` {#LooperThread.setName}
+
+
+
+
+- - -
+
+#### `tf.train.LooperThread.start()` {#LooperThread.start}
+
+Start the thread's activity.
+
+It must be called at most once per thread object. It arranges for the
+object's run() method to be invoked in a separate thread of control.
+
+This method will raise a RuntimeError if called more than once on the
+same thread object.
+
+
+- - -
+
+#### `tf.train.LooperThread.start_loop()` {#LooperThread.start_loop}
+
+Called when the thread starts.
+
+
+
diff --git a/tensorflow/g3doc/extras/README.txt b/tensorflow/g3doc/extras/README.txt
index 2c9682d2fb..765809a762 100644
--- a/tensorflow/g3doc/extras/README.txt
+++ b/tensorflow/g3doc/extras/README.txt
@@ -1,2 +1,3 @@
This directory holds extra files we'd like to be able
-to link to and serve from within tensorflow.org
+to link to and serve from within tensorflow.org.
+They are excluded from versioning. \ No newline at end of file
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md
index fe943fac6c..4b2e623f00 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/index.md
+++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md
@@ -844,7 +844,8 @@ For more details, see
In general, changes to specifications must be backwards-compatible: changing the
specification of an Op must not break prior serialized `GraphDef` protocol
-buffers constructed from older specfications.
+buffers constructed from older specfications. The details of `GraphDef`
+compatibility are [described here](../../resources/versions.md#graphs).
There are several ways to preserve backwards-compatibility.
@@ -897,7 +898,8 @@ generated Python code may change in a way that isn't compatible with old
callers. The Python API may be kept compatible by careful changes in a
hand-written Python wrapper, by keeping the old signature except possibly adding
new optional arguments to the end. Generally incompatible changes may only be
-made when TensorFlow's changes major versions.
+made when TensorFlow's changes major versions, and must conform to the
+[`GraphDef` version semantics](../../resources/versions.md#graphs).
## GPU Support {#mult-archs}
diff --git a/tensorflow/g3doc/how_tos/index.md b/tensorflow/g3doc/how_tos/index.md
index 748ecfd398..c9ab79aa2a 100644
--- a/tensorflow/g3doc/how_tos/index.md
+++ b/tensorflow/g3doc/how_tos/index.md
@@ -1,4 +1,4 @@
-# Overview
+# How-Tos
## Variables: Creation, Initializing, Saving, and Restoring
diff --git a/tensorflow/g3doc/resources/index.md b/tensorflow/g3doc/resources/index.md
index f53ae3a18c..d19093871d 100644
--- a/tensorflow/g3doc/resources/index.md
+++ b/tensorflow/g3doc/resources/index.md
@@ -12,7 +12,7 @@ implementation can be found in our white paper:
If you use TensorFlow in your research and would like to cite the TensorFlow
system, we suggest you cite the paper above.
-You can use this [BibTeX entry](../resources/bib.md). As the project progresses, we
+You can use this [BibTeX entry](bib.md). As the project progresses, we
may update the suggested citation with new papers.
Please only use the TensorFlow name and marks when accurately referencing this
@@ -55,3 +55,8 @@ https://github.com/tensorflow/tensorflow/issues) on GitHub.
If you need help with using TensorFlow, please do not use the issue
tracker for that. Instead, direct your questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).
+## Versioning
+
+TensorFlow uses [Semantic Versioning 2.0](http://semver.org). For details on
+the versioning of our public API and binary compatibility, see the [versioning
+document](versions.md).
diff --git a/tensorflow/g3doc/resources/leftnav_files b/tensorflow/g3doc/resources/leftnav_files
index 2e1940b5d4..b0df3a8368 100644
--- a/tensorflow/g3doc/resources/leftnav_files
+++ b/tensorflow/g3doc/resources/leftnav_files
@@ -3,3 +3,4 @@ uses.md
faq.md
glossary.md
dims_types.md
+versions.md
diff --git a/tensorflow/g3doc/resources/versions.md b/tensorflow/g3doc/resources/versions.md
new file mode 100644
index 0000000000..a16bd8f549
--- /dev/null
+++ b/tensorflow/g3doc/resources/versions.md
@@ -0,0 +1,143 @@
+# TensorFlow Version Semantics
+
+## Semantic Versioning 2.0
+
+Once we reach version 1.0, TensorFlow will follow Semantic Versioning 2.0
+(semver). For details, see <http://semver.org>.  Each release version of
+TensorFlow has the form `MAJOR.MINOR.PATCH`.  Changes to the each number have
+the following meaning:
+
+* **MAJOR**:  Backwards incompatible changes.  Code and data that worked with
+ a previous major release will not necessarily work with a new release.
+ However, in some cases existing TensorFlow data (graphs, checkpoints, and
+ other protobufs) may be migratable to the newer release; see below for details
+ on data compatibility.
+
+* **MINOR**: Backwards compatible features, speed improvements, etc.  Code and
+ data that worked with a previous minor release *and* which depends only the
+ public API will continue to work unchanged.  For details on what is and is
+ not the public API, see below.
+
+* **PATCH**: Backwards compatible bug fixes.
+
+Before 1.0, semver allows backwards incompatible changes at any time.  However,
+to support users now, we will use the format `0.MAJOR.MINOR` (shifted one step
+to the right).  Thus 0.5.0 to 0.6.0 may be backwards incompatible, but 0.6.0 to
+0.6.1 will include only backwards compatible features and bug fixes.
+
+At some point (especially as we approach 1.0) we will likely use prerelease
+versions such as X.Y.Z-alpha.1, but we do not yet have specific plans (beyond
+the restrictions of semver).
+
+
+## Public API
+
+Only the public API of TensorFlow is backwards compatible across minor and patch
+versions.  The public API consists of
+
+* The documented [C++ and Python APIs](../api_docs).
+
+* The following protocol buffer files:
+ [`attr_value`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/attr_value.proto),
+ [`config`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/config.proto),
+ [`event`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/event.proto),
+ [`graph`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto),
+ [`op_def`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/op_def.proto),
+ [`reader_base`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/reader_base.proto),
+ [`summary`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto),
+ [`tensor`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto),
+ [`tensor_shape`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor_shape.proto),
+ and [`types`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto).
+
+The public C++ API is exposed through the header files in
+[`tensorflow/core/public`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/public).
+The public Python API is unfortunately **not** everything available through the
+tensorflow python module and its submodules, since we do not yet use `__all__`
+everywhere ([#421](https://github.com/tensorflow/tensorflow/issues/421)).
+ Please refer to the documentation to determine whether a given Python feature
+is part of the public API. For now, the protocol buffers are defined in
+[`tensorflow/core/framework/*.proto`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/framework)
+([#484](https://github.com/tensorflow/tensorflow/issues/484)).
+
+
+## Details That Are Not Public
+
+The following are specifically **not** part of the public API: they are allowed
+to change without notice across minor releases and even patch releases if bug
+fixes require it:
+
+* **Details of composite ops:**  Many public functions in Python expand to
+ several primitive ops in the graph, and these details will be part of any
+ graphs saved to disk as GraphDefs.  These details are allowed to change for
+ minor releases. In particular, regressions tests that check for exact
+ matching between graphs are likely to break across minor releases, even though
+ the behavior of the graph should be unchanged and existing checkpoints will
+ still work.
+
+* **Floating point numerical details:** The specific floating point values
+ computed by ops may change at any time: users should rely only on approximate
+ accuracy and numerical stability, not on the specific bits computed.  Changes
+ to numerical formulas in minor and patch releases should result in comparable
+ or improved accuracy, with the caveat that in machine learning improved
+ accuracy of specific formulas may result in worse accuracy for the overall
+ system.
+
+* **Random numbers:** The specific random numbers computed by the [random
+ ops](../api_docs/python/constant_op.html#random-tensors) may change at any
+ time: users should rely only on approximately correct distributions and
+ statistical strength, not the specific bits computed.  However, we will make
+ changes to random bits rarely and ideally never for patch releases, and all
+ such intended changes will be documented.
+
+
+## Compatibility for Graphs and Checkpoints {#graphs}
+
+Many users of TensorFlow will be saving graphs and trained models to disk for
+later evaluation or more training, often changing versions of TensorFlow in the
+process.  First, following semver, any graph or checkpoint written out with one
+version of TensorFlow can be loaded and evaluated with a later version of
+TensorFlow with the same major release.  However, we will endeavour to preserve
+backwards compatibility even across major releases when possible, so that the
+serialized files are usable over long periods of time.
+
+There are two main classes of saved TensorFlow data: graphs and checkpoints.
+Graphs describe the data flow graphs of ops to be run during training and
+inference, and checkpoints contain the saved tensor values of variables in a
+graph.
+
+Graphs are serialized via the `GraphDef` protocol buffer.  To facilitate (rare)
+backwards incompatible changes to graphs, each `GraphDef` has an integer version
+separate from the TensorFlow version.  The semantics are:
+
+* Each version of TensorFlow supports an interval of `GraphDef` versions.  This
+ interval with be constant across patch releases, and will only grow across
+ minor releases.  Dropping support for a `GraphDef` version will only occur
+ for a major release of TensorFlow.
+
+* Newly created graphs use the newest `GraphDef` version.
+
+* If a given version of TensorFlow supports the `GraphDef` version of a graph,
+ it will load and evaluate with the same behavior as when it was written out
+ (except for floating point numerical details and random numbers), regardless
+ of the major version of TensorFlow.  In particular, all checkpoint files will
+ be compatible.
+
+* If the `GraphDef` upper bound is increased to X in a (minor) release, there
+ will be at least six months before the lower bound is increased to X.
+
+For example (numbers and versions hypothetical), TensorFlow 1.2 might support
+`GraphDef` versions 4 to 7.  TensorFlow 1.3 could add `GraphDef` version 8 and
+support versions 4 to 8.  At least six months later, TensorFlow 2.0.0 could drop
+support for versions 4 to 7, leaving version 8 only.
+
+Finally, when support for a `GraphDef` version is dropped, we will attempt to
+provide tools for automatically converting graphs to a newer supported
+`GraphDef` version.
+
+
+## C++ API Compatibility
+
+Only patch releases will be binary compatible at the C++ level.  That is, minor
+releases are backwards compatible in terms of behavior but may require a
+recompile for downstream C++ code.  As always, backwards compatibility is only
+provided for the public C++ API.
diff --git a/tensorflow/g3doc/tutorials/index.md b/tensorflow/g3doc/tutorials/index.md
index 8fbdcfcd21..98e1d60fbc 100644
--- a/tensorflow/g3doc/tutorials/index.md
+++ b/tensorflow/g3doc/tutorials/index.md
@@ -1,4 +1,4 @@
-# Overview
+# Tutorials
## MNIST For ML Beginners
diff --git a/tensorflow/models/embedding/word2vec_kernels.cc b/tensorflow/models/embedding/word2vec_kernels.cc
index f579ce138c..58f5f15d5c 100644
--- a/tensorflow/models/embedding/word2vec_kernels.cc
+++ b/tensorflow/models/embedding/word2vec_kernels.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/distribution_sampler.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
-#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/lib/strings/regexp.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/guarded_philox_random.h"
diff --git a/tensorflow/models/image/cifar10/BUILD b/tensorflow/models/image/cifar10/BUILD
index 25dce65f28..87e11bab62 100644
--- a/tensorflow/models/image/cifar10/BUILD
+++ b/tensorflow/models/image/cifar10/BUILD
@@ -9,6 +9,7 @@ py_library(
name = "cifar10_input",
srcs = ["cifar10_input.py"],
srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow:tensorflow_py",
],
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index b9b89473e8..32234db496 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -43,11 +43,9 @@ import tarfile
import tensorflow.python.platform
from six.moves import urllib
-from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.models.image.cifar10 import cifar10_input
-from tensorflow.python.platform import gfile
FLAGS = tf.app.flags.FLAGS
@@ -57,15 +55,12 @@ tf.app.flags.DEFINE_integer('batch_size', 128,
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
-# Process images of this size. Note that this differs from the original CIFAR
-# image size of 32 x 32. If one alters this number, then the entire model
-# architecture will change and any model would need to be retrained.
-IMAGE_SIZE = 24
-
# Global constants describing the CIFAR-10 data set.
-NUM_CLASSES = 10
-NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
-NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
+IMAGE_SIZE = cifar10_input.IMAGE_SIZE
+NUM_CLASSES = cifar10_input.NUM_CLASSES
+NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
+NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
+
# Constants describing the training process.
MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
@@ -139,91 +134,21 @@ def _variable_with_weight_decay(name, shape, stddev, wd):
return var
-def _generate_image_and_label_batch(image, label, min_queue_examples):
- """Construct a queued batch of images and labels.
-
- Args:
- image: 3-D Tensor of [IMAGE_SIZE, IMAGE_SIZE, 3] of type.float32.
- label: 1-D Tensor of type.int32
- min_queue_examples: int32, minimum number of samples to retain
- in the queue that provides of batches of examples.
-
- Returns:
- images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
- labels: Labels. 1D tensor of [batch_size] size.
- """
- # Create a queue that shuffles the examples, and then
- # read 'FLAGS.batch_size' images + labels from the example queue.
- num_preprocess_threads = 16
- images, label_batch = tf.train.shuffle_batch(
- [image, label],
- batch_size=FLAGS.batch_size,
- num_threads=num_preprocess_threads,
- capacity=min_queue_examples + 3 * FLAGS.batch_size,
- min_after_dequeue=min_queue_examples)
-
- # Display the training images in the visualizer.
- tf.image_summary('images', images)
-
- return images, tf.reshape(label_batch, [FLAGS.batch_size])
-
-
def distorted_inputs():
"""Construct distorted input for CIFAR training using the Reader ops.
- Raises:
- ValueError: if no data_dir
-
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
- """
- filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
- 'data_batch_%d.bin' % i)
- for i in xrange(1, 6)]
- for f in filenames:
- if not gfile.Exists(f):
- raise ValueError('Failed to find file: ' + f)
-
- # Create a queue that produces the filenames to read.
- filename_queue = tf.train.string_input_producer(filenames)
- # Read examples from files in the filename queue.
- read_input = cifar10_input.read_cifar10(filename_queue)
- reshaped_image = tf.cast(read_input.uint8image, tf.float32)
-
- height = IMAGE_SIZE
- width = IMAGE_SIZE
-
- # Image processing for training the network. Note the many random
- # distortions applied to the image.
-
- # Randomly crop a [height, width] section of the image.
- distorted_image = tf.image.random_crop(reshaped_image, [height, width])
-
- # Randomly flip the image horizontally.
- distorted_image = tf.image.random_flip_left_right(distorted_image)
-
- # Because these operations are not commutative, consider randomizing
- # randomize the order their operation.
- distorted_image = tf.image.random_brightness(distorted_image,
- max_delta=63)
- distorted_image = tf.image.random_contrast(distorted_image,
- lower=0.2, upper=1.8)
-
- # Subtract off the mean and divide by the variance of the pixels.
- float_image = tf.image.per_image_whitening(distorted_image)
-
- # Ensure that the random shuffling has good mixing properties.
- min_fraction_of_examples_in_queue = 0.4
- min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
- min_fraction_of_examples_in_queue)
- print ('Filling queue with %d CIFAR images before starting to train. '
- 'This will take a few minutes.' % min_queue_examples)
-
- # Generate a batch of images and labels by building up a queue of examples.
- return _generate_image_and_label_batch(float_image, read_input.label,
- min_queue_examples)
+ Raises:
+ ValueError: If no data_dir
+ """
+ if not FLAGS.data_dir:
+ raise ValueError('Please supply a data_dir')
+ data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
+ return cifar10_input.distorted_inputs(data_dir=data_dir,
+ batch_size=FLAGS.batch_size)
def inputs(eval_data):
@@ -232,56 +157,18 @@ def inputs(eval_data):
Args:
eval_data: bool, indicating if one should use the train or eval data set.
- Raises:
- ValueError: if no data_dir
-
Returns:
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
+
+ Raises:
+ ValueError: If no data_dir
"""
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
-
- if not eval_data:
- filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
- 'data_batch_%d.bin' % i)
- for i in xrange(1, 6)]
- num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
- else:
- filenames = [os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin',
- 'test_batch.bin')]
- num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
-
- for f in filenames:
- if not gfile.Exists(f):
- raise ValueError('Failed to find file: ' + f)
-
- # Create a queue that produces the filenames to read.
- filename_queue = tf.train.string_input_producer(filenames)
-
- # Read examples from files in the filename queue.
- read_input = cifar10_input.read_cifar10(filename_queue)
- reshaped_image = tf.cast(read_input.uint8image, tf.float32)
-
- height = IMAGE_SIZE
- width = IMAGE_SIZE
-
- # Image processing for evaluation.
- # Crop the central [height, width] of the image.
- resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
- width, height)
-
- # Subtract off the mean and divide by the variance of the pixels.
- float_image = tf.image.per_image_whitening(resized_image)
-
- # Ensure that the random shuffling has good mixing properties.
- min_fraction_of_examples_in_queue = 0.4
- min_queue_examples = int(num_examples_per_epoch *
- min_fraction_of_examples_in_queue)
-
- # Generate a batch of images and labels by building up a queue of examples.
- return _generate_image_and_label_batch(float_image, read_input.label,
- min_queue_examples)
+ data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
+ return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir,
+ batch_size=FLAGS.batch_size)
def inference(images):
diff --git a/tensorflow/models/image/cifar10/cifar10_input.py b/tensorflow/models/image/cifar10/cifar10_input.py
index ac73c493a3..ffe8facd27 100644
--- a/tensorflow/models/image/cifar10/cifar10_input.py
+++ b/tensorflow/models/image/cifar10/cifar10_input.py
@@ -19,9 +19,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
import tensorflow.python.platform
+from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+from tensorflow.python.platform import gfile
+
+# Process images of this size. Note that this differs from the original CIFAR
+# image size of 32 x 32. If one alters this number, then the entire model
+# architecture will change and any model would need to be retrained.
+IMAGE_SIZE = 24
+
+# Global constants describing the CIFAR-10 data set.
+NUM_CLASSES = 10
+NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
+NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
+
def read_cifar10(filename_queue):
"""Reads and parses examples from CIFAR10 data files.
@@ -82,3 +97,144 @@ def read_cifar10(filename_queue):
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
+
+
+def _generate_image_and_label_batch(image, label, min_queue_examples,
+ batch_size):
+ """Construct a queued batch of images and labels.
+
+ Args:
+ image: 3-D Tensor of [height, width, 3] of type.float32.
+ label: 1-D Tensor of type.int32
+ min_queue_examples: int32, minimum number of samples to retain
+ in the queue that provides of batches of examples.
+ batch_size: Number of images per batch.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, height, width, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ # Create a queue that shuffles the examples, and then
+ # read 'batch_size' images + labels from the example queue.
+ num_preprocess_threads = 16
+ images, label_batch = tf.train.shuffle_batch(
+ [image, label],
+ batch_size=batch_size,
+ num_threads=num_preprocess_threads,
+ capacity=min_queue_examples + 3 * batch_size,
+ min_after_dequeue=min_queue_examples)
+
+ # Display the training images in the visualizer.
+ tf.image_summary('images', images)
+
+ return images, tf.reshape(label_batch, [batch_size])
+
+
+def distorted_inputs(data_dir, batch_size):
+ """Construct distorted input for CIFAR training using the Reader ops.
+
+ Args:
+ data_dir: Path to the CIFAR-10 data directory.
+ batch_size: Number of images per batch.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
+ for i in xrange(1, 6)]
+ for f in filenames:
+ if not gfile.Exists(f):
+ raise ValueError('Failed to find file: ' + f)
+
+ # Create a queue that produces the filenames to read.
+ filename_queue = tf.train.string_input_producer(filenames)
+
+ # Read examples from files in the filename queue.
+ read_input = read_cifar10(filename_queue)
+ reshaped_image = tf.cast(read_input.uint8image, tf.float32)
+
+ height = IMAGE_SIZE
+ width = IMAGE_SIZE
+
+ # Image processing for training the network. Note the many random
+ # distortions applied to the image.
+
+ # Randomly crop a [height, width] section of the image.
+ distorted_image = tf.image.random_crop(reshaped_image, [height, width])
+
+ # Randomly flip the image horizontally.
+ distorted_image = tf.image.random_flip_left_right(distorted_image)
+
+ # Because these operations are not commutative, consider randomizing
+ # randomize the order their operation.
+ distorted_image = tf.image.random_brightness(distorted_image,
+ max_delta=63)
+ distorted_image = tf.image.random_contrast(distorted_image,
+ lower=0.2, upper=1.8)
+
+ # Subtract off the mean and divide by the variance of the pixels.
+ float_image = tf.image.per_image_whitening(distorted_image)
+
+ # Ensure that the random shuffling has good mixing properties.
+ min_fraction_of_examples_in_queue = 0.4
+ min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
+ min_fraction_of_examples_in_queue)
+ print ('Filling queue with %d CIFAR images before starting to train. '
+ 'This will take a few minutes.' % min_queue_examples)
+
+ # Generate a batch of images and labels by building up a queue of examples.
+ return _generate_image_and_label_batch(float_image, read_input.label,
+ min_queue_examples, batch_size)
+
+
+def inputs(eval_data, data_dir, batch_size):
+ """Construct input for CIFAR evaluation using the Reader ops.
+
+ Args:
+ eval_data: bool, indicating if one should use the train or eval data set.
+ data_dir: Path to the CIFAR-10 data directory.
+ batch_size: Number of images per batch.
+
+ Returns:
+ images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
+ labels: Labels. 1D tensor of [batch_size] size.
+ """
+ if not eval_data:
+ filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
+ for i in xrange(1, 6)]
+ num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
+ else:
+ filenames = [os.path.join(data_dir, 'test_batch.bin')]
+ num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
+
+ for f in filenames:
+ if not gfile.Exists(f):
+ raise ValueError('Failed to find file: ' + f)
+
+ # Create a queue that produces the filenames to read.
+ filename_queue = tf.train.string_input_producer(filenames)
+
+ # Read examples from files in the filename queue.
+ read_input = read_cifar10(filename_queue)
+ reshaped_image = tf.cast(read_input.uint8image, tf.float32)
+
+ height = IMAGE_SIZE
+ width = IMAGE_SIZE
+
+ # Image processing for evaluation.
+ # Crop the central [height, width] of the image.
+ resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
+ width, height)
+
+ # Subtract off the mean and divide by the variance of the pixels.
+ float_image = tf.image.per_image_whitening(resized_image)
+
+ # Ensure that the random shuffling has good mixing properties.
+ min_fraction_of_examples_in_queue = 0.4
+ min_queue_examples = int(num_examples_per_epoch *
+ min_fraction_of_examples_in_queue)
+
+ # Generate a batch of images and labels by building up a queue of examples.
+ return _generate_image_and_label_batch(float_image, read_input.label,
+ min_queue_examples, batch_size)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 62aa3ee0c5..ee2e769ede 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -69,6 +69,20 @@ py_tests(
)
cc_library(
+ name = "py_func_lib",
+ srcs = ["lib/core/py_func.cc"],
+ hdrs = [
+ "lib/core/py_func.h",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/py/numpy:headers",
+ "//util/python:python_headers",
+ ],
+)
+
+cc_library(
name = "py_record_reader_lib",
srcs = [
"lib/io/py_record_reader.cc",
@@ -107,17 +121,28 @@ py_test(
)
cc_library(
- name = "python_op_gen_main",
+ name = "python_op_gen",
srcs = [
"framework/python_op_gen.cc",
"framework/python_op_gen.h",
- "framework/python_op_gen_main.cc",
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_cc",
],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "python_op_gen_main",
+ srcs = [
+ "framework/python_op_gen_main.cc",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":python_op_gen",
+ ],
)
# What is needed for tf_gen_op_wrapper_py.
@@ -154,6 +179,7 @@ py_library(
"framework/importer.py",
"framework/random_seed.py",
"framework/tensor_util.py",
+ "framework/load_library.py",
# TODO(josh11b): Move this to the framework directory
"ops/common_shapes.py",
],
@@ -483,6 +509,9 @@ tf_gen_op_wrapper_py(
"MatMul",
"Sigmoid",
"Tanh",
+ "Lgamma",
+ "Erf",
+ "Erfc",
],
require_shape_functions = True,
)
@@ -531,6 +560,14 @@ tf_gen_op_wrapper_py(
)
tf_gen_op_wrapper_py(
+ name = "script_ops",
+ hidden = [
+ "PyFunc",
+ ],
+ require_shape_functions = True,
+)
+
+tf_gen_op_wrapper_py(
name = "state_ops",
hidden = [
"Variable",
@@ -631,6 +668,7 @@ py_library(
"ops/random_ops.py",
"ops/rnn.py",
"ops/rnn_cell.py",
+ "ops/script_ops.py",
"ops/seq2seq.py",
"ops/sparse_grad.py",
"ops/sparse_ops.py",
@@ -658,6 +696,7 @@ py_library(
":nn_ops",
":parsing_ops",
":random_ops",
+ ":script_ops",
":sparse_ops",
":string_ops",
":summary_ops",
@@ -710,8 +749,18 @@ tf_proto_library_py(
name = "protos_all",
srcs = glob(
["**/*.proto"],
- exclude = ["util/protobuf/compare_test.proto"],
+ exclude = [
+ "util/protobuf/compare_test.proto",
+ "training/saver.proto",
+ ],
),
+ deps = [":public_protos_py"],
+)
+
+tf_proto_library_py(
+ name = "public_protos",
+ srcs = ["training/saver.proto"],
+ visibility = ["//visibility:public"],
)
tf_proto_library_py(
@@ -785,6 +834,8 @@ tf_py_wrap_cc(
swig_includes = [
"client/events_writer.i",
"client/tf_session.i",
+ "framework/python_op_gen.i",
+ "lib/core/py_func.i",
"lib/core/status.i",
"lib/core/status_helper.i",
"lib/core/strings.i",
@@ -795,8 +846,10 @@ tf_py_wrap_cc(
"util/port.i",
],
deps = [
+ ":py_func_lib",
":py_record_reader_lib",
":py_record_writer_lib",
+ ":python_op_gen",
":tf_session_helper",
"//util/python:python_headers",
],
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 2afbea6e63..5c8dfc74a7 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -594,7 +594,7 @@ class InteractiveSession(BaseSession):
@@close
"""
- def __init__(self, target='', graph=None):
+ def __init__(self, target='', graph=None, config=None):
"""Creates a new interactive TensorFlow session.
If no `graph` argument is specified when constructing the session,
@@ -610,8 +610,9 @@ class InteractiveSession(BaseSession):
Defaults to using an in-process engine. At present, no value
other than the empty string is supported.
graph: (Optional.) The `Graph` to be launched (described above).
+ config: (Optional) `ConfigProto` proto used to configure the session.
"""
- super(InteractiveSession, self).__init__(target, graph)
+ super(InteractiveSession, self).__init__(target, graph, config)
self._default_session = self.as_default()
self._default_session.__enter__()
self._explicit_graph = graph
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 2d6a73eb9e..dcb64e8a9e 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -200,7 +200,13 @@ tensorflow::ImportNumpy();
// END TYPEMAPS FOR tensorflow::TF_Run_wrapper()
////////////////////////////////////////////////////////////////////////////////
-
+// Typemaps for TF_GetOpList.
+// The wrapped function TF_GetOpList returns a TF_Buffer pointer. This typemap
+// creates a Python string from the TF_Buffer and returns it.
+%typemap(out) TF_Buffer TF_GetOpList {
+ $result = PyString_FromStringAndSize(
+ reinterpret_cast<const char*>($1.data), $1.length);
+}
// Include the functions from tensor_c_api.h, except TF_Run.
%ignoreall
@@ -219,6 +225,9 @@ tensorflow::ImportNumpy();
%unignore TF_CloseSession;
%unignore TF_DeleteSession;
%unignore TF_ExtendGraph;
+%unignore TF_NewLibrary;
+%unignore TF_LoadLibrary;
+%unignore TF_GetOpList;
%include "tensorflow/core/public/tensor_c_api.h"
%ignoreall
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
index 84163ca1c0..9d4f2c54f9 100644
--- a/tensorflow/python/framework/framework_lib.py
+++ b/tensorflow/python/framework/framework_lib.py
@@ -36,6 +36,7 @@
@@convert_to_tensor_or_indexed_slices
@@get_default_graph
@@import_graph_def
+@@load_op_library
## Graph collections
@@ -89,3 +90,6 @@ from tensorflow.python.framework.tensor_shape import Dimension
from tensorflow.python.framework.tensor_shape import TensorShape
from tensorflow.python.framework.dtypes import *
+
+# Load a TensorFlow plugin
+from tensorflow.python.framework.load_library import *
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index a20f706f38..e998310a69 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -105,17 +105,19 @@ def all_libraries(module_to_name, members, documented):
"rnn", "state_saving_rnn", "bidirectional_rnn",
"dynamic_rnn", "seq2seq", "rnn_cell"],
prefix=PREFIX_TEXT),
- library('client', "Running Graphs", client_lib),
+ library("client", "Running Graphs", client_lib),
library("train", "Training", tf.train,
exclude_symbols=["Feature", "Features", "BytesList", "FloatList",
"Int64List", "Example", "InferenceExample",
"FeatureList", "FeatureLists",
"RankingExample", "SequenceExample"]),
+ library("script_ops", "Wraps python functions", prefix=PREFIX_TEXT)
]
_hidden_symbols = ["Event", "Summary", "xrange",
"HistogramProto", "ConfigProto", "NodeDef", "GraphDef",
- "GPUOptions", "SessionInterface", "BaseSession"]
+ "GPUOptions", "GraphOptions", "SessionInterface",
+ "BaseSession"]
def main(unused_argv):
if not FLAGS.out_dir:
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index efc977aeb5..14b990fffa 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -622,14 +622,16 @@ class ImportGraphDefTest(tf.test.TestCase):
def testVersionLow(self):
with tf.Graph().as_default():
pat = (r"^GraphDef version -1 is no longer supported: TensorFlow \S+ "
- r"needs \d+ <= version <= \d+. Please regenerate your graph.$")
+ r"needs %d <= version <= %d. Please regenerate your graph.$" %
+ (tf.GRAPH_DEF_VERSION_MIN, tf.GRAPH_DEF_VERSION_MAX))
with self.assertRaisesRegexp(ValueError, pat):
tf.import_graph_def(self._MakeGraphDef("", version=-1))
def testVersionHigh(self):
with tf.Graph().as_default():
pat = (r"^GraphDef version \d+ is not yet supported: TensorFlow \S+ "
- r"needs \d+ <= version <= \d+. Please upgrade TensorFlow.$")
+ r"needs %d <= version <= %d. Please upgrade TensorFlow.$" %
+ (tf.GRAPH_DEF_VERSION_MIN, tf.GRAPH_DEF_VERSION_MAX))
with self.assertRaisesRegexp(ValueError, pat):
tf.import_graph_def(self._MakeGraphDef("", version=1 << 30))
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
new file mode 100644
index 0000000000..9436638022
--- /dev/null
+++ b/tensorflow/python/framework/load_library.py
@@ -0,0 +1,74 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Function for loading TensorFlow plugins."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import hashlib
+import imp
+import sys
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.python import pywrap_tensorflow as py_tf
+from tensorflow.python.util import compat
+
+
+def load_op_library(library_filename):
+ """Loads a TensorFlow plugin, containing custom ops and kernels.
+
+ Pass "library_filename" to a platform-specific mechanism for dynamically
+ loading a library. The rules for determining the exact location of the
+ library are platform-specific and are not documented here.
+ Expects the symbols "RegisterOps", "RegisterKernels", and "GetOpList", to be
+ defined in the library.
+
+ Args:
+ library_filename: Path to the plugin.
+ Relative or absolute filesystem path to a dynamic library file.
+
+ Returns:
+ A python module containing the Python wrappers for Ops defined in
+ the plugin.
+
+ Raises:
+ RuntimeError: when unable to load the library or get the python wrappers.
+ """
+ status = py_tf.TF_NewStatus()
+
+ lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
+ try:
+ if py_tf.TF_GetCode(status) != 0:
+ raise RuntimeError(compat.as_text(py_tf.TF_Message(status)))
+ finally:
+ py_tf.TF_DeleteStatus(status)
+
+ op_list_str = py_tf.TF_GetOpList(lib_handle)
+ op_list = op_def_pb2.OpList()
+ op_list.ParseFromString(op_list_str)
+ wrappers = py_tf.GetPythonWrappers(op_list_str, len(op_list_str))
+
+ # Get a unique name for the module.
+ module_name = hashlib.md5(wrappers).hexdigest()
+ module = imp.new_module(module_name)
+ # pylint: disable=exec-used
+ exec(wrappers, module.__dict__)
+ # Stash away the library handle for making calls into the dynamic library.
+ module.LIB_HANDLE = lib_handle
+ # OpDefs of the list of ops defined in the library.
+ module.OP_LIST = op_list
+ sys.modules[module_name] = module
+ return module
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index c511b2ea28..390a293c95 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -542,6 +542,35 @@ def convert_to_tensor(value, dtype=None, name=None, as_ref=False):
% (error_prefix, value, type(value)))
+def convert_n_to_tensor(values, dtype=None, name=None, as_ref=False):
+ """Converts `values` to a list of `Tensor` objects.
+
+ Args:
+ values: A list of objects that can be consumed by `tf.convert_to_tensor()`.
+ dtype: (Optional.) The required `DType` of the returned `Tensor` objects.
+ name: (Optional.) A name prefix to used when a new `Tensor` is
+ created, in which case element `i` will be given the name `name
+ + '_' + i`.
+ as_ref: True if the caller wants the results as ref tensors.
+
+ Returns:
+ A list of `Tensor` and/or `IndexedSlices` objects.
+
+ Raises:
+ TypeError: If no conversion function is registered for an element in
+ `values`.
+ RuntimeError: If a registered conversion function returns an invalid
+ value.
+ """
+ if not isinstance(values, collections.Sequence):
+ raise TypeError("values must be a list.")
+ ret = []
+ for i, value in enumerate(values):
+ n = None if name is None else "%s_%d" % (name, i)
+ ret.append(convert_to_tensor(value, dtype=dtype, name=n, as_ref=as_ref))
+ return ret
+
+
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
as_ref=False):
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
@@ -2218,7 +2247,7 @@ class Graph(object):
"""
try:
old_stack = self._name_stack
- if not name: # Both for name=None nad name="" we re-set to empty scope.
+ if not name: # Both for name=None and name="" we re-set to empty scope.
new_stack = (None, None)
elif name and name[-1] == "/":
new_stack = (name[:-1], name[:-1])
@@ -2734,7 +2763,7 @@ def device(dev):
"""Wrapper for `Graph.device()` using the default graph.
See
- [`Graph.name_scope()`](../../api_docs/python/framework.md#Graph.name_scope)
+ [`Graph.device()`](../../api_docs/python/framework.md#Graph.device)
for more details.
Args:
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index ae28319a38..898c4acf18 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/python/framework/python_op_gen.h"
#include <stdio.h>
+#include <sstream>
#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -28,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
@@ -252,18 +254,19 @@ string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
}
}
-void PrintReturns(const OpDef& op_def,
- const std::vector<string>& output_type_string) {
+static string GetReturns(const OpDef& op_def,
+ const std::vector<string>& output_type_string) {
+ string result;
DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
const int num_outs = op_def.output_arg_size();
- printf("\n Returns:\n");
+ strings::Appendf(&result, "\n Returns:\n");
if (num_outs == 0) {
- printf(" The created Operation.\n");
+ strings::Appendf(&result, " The created Operation.\n");
} else {
if (num_outs == 1) {
StringPiece description = op_def.output_arg(0).description();
if (ConsumeEquals(&description)) { // Skip the generated type info.
- printf("%s", Indent(4, 4, description).c_str());
+ strings::Appendf(&result, "%s", Indent(4, 4, description).c_str());
} else {
// Special case of one output, don't use the name of the output unless
// there is no description.
@@ -282,7 +285,7 @@ void PrintReturns(const OpDef& op_def,
} else if (!description.empty()) {
AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
}
- printf("%s", Indent(4, 4, desc).c_str());
+ strings::Appendf(&result, "%s", Indent(4, 4, desc).c_str());
}
} else {
std::vector<string> out_names(num_outs);
@@ -293,8 +296,8 @@ void PrintReturns(const OpDef& op_def,
out_names[i] = strings::StrCat("output", i);
}
}
- printf(" A tuple of `Tensor` objects (%s).\n",
- str_util::Join(out_names, ", ").c_str());
+ strings::Appendf(&result, " A tuple of `Tensor` objects (%s).\n",
+ str_util::Join(out_names, ", ").c_str());
for (int i = 0; i < num_outs; ++i) {
string desc = strings::StrCat(out_names[i], ": ");
StringPiece description = op_def.output_arg(i).description();
@@ -317,10 +320,16 @@ void PrintReturns(const OpDef& op_def,
strings::StrAppend(&desc, type);
}
}
- printf("%s", Indent(4, 6, desc).c_str());
+ strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str());
}
}
}
+ return result;
+}
+
+void PrintReturns(const OpDef& op_def,
+ const std::vector<string>& output_type_string) {
+ printf("%s", GetReturns(op_def, output_type_string).c_str());
}
string StringToPython(const string& str) {
@@ -400,8 +409,8 @@ string AttrValueToPython(const string& type, const AttrValue& value) {
}
}
-// Requires: ValidateOpDef(op_def).ok()
-void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
+static string GetPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
+ string result;
// Map from attr name to the first input arg it is inferred from.
std::unordered_map<string, string> inferred_attrs;
// This has all the input args followed by those attrs that don't have
@@ -472,7 +481,8 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
const string def_suffix =
strings::StrCat(parameters, has_args ? ", " : "", "name=None):");
- printf("%s\n", WordWrap(def_prefix, def_suffix, kRightMargin).c_str());
+ strings::Appendf(&result, "%s\n",
+ WordWrap(def_prefix, def_suffix, kRightMargin).c_str());
// Format the Op's descriptions so that it can be a Python docstring.
string comment;
@@ -485,10 +495,7 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
}
}
- printf(R"( r"""%s
- Args:
-)",
- comment.c_str());
+ strings::Appendf(&result, " r\"\"\"%s\n Args:\n", comment.c_str());
// Inputs
for (int i = 0; i < op_def.input_arg_size(); ++i) {
@@ -504,7 +511,7 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
if (!description.empty()) {
AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
}
- printf("%s", Indent(4, 6, desc).c_str());
+ strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str());
}
// Attrs
@@ -569,10 +576,10 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
AppendWithinWidth(&desc, attr.description(),
kRightMargin - 4 /* indent */);
}
- printf("%s", Indent(4, 6, desc).c_str());
+ strings::Appendf(&result, "%s", Indent(4, 6, desc).c_str());
}
- printf(" name: A name for the operation (optional).\n");
+ strings::Appendf(&result, " name: A name for the operation (optional).\n");
std::vector<string> output_type_string;
output_type_string.reserve(op_def.output_arg_size());
@@ -580,7 +587,7 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
output_type_string.push_back(
ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true));
}
- PrintReturns(op_def, output_type_string);
+ strings::StrAppend(&result, GetReturns(op_def, output_type_string));
string return_prefix = strings::StrCat(" return _op_def_lib.apply_op(");
string return_args = strings::StrCat("\"", op_def.name(), "\", ");
@@ -589,13 +596,12 @@ void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
}
strings::StrAppend(&return_args, "name=name)");
- printf(R"( """
-%s
-)",
- // Wrap the arguments, and indent to the (.
- WordWrap(return_prefix, return_args, kRightMargin).c_str());
+ strings::Appendf(&result, " \"\"\"\n%s\n",
+ // Wrap the arguments, and indent to the (.
+ WordWrap(return_prefix, return_args, kRightMargin).c_str());
- printf("\n\n");
+ strings::Appendf(&result, "\n\n");
+ return result;
}
void GenerateLowerCaseOpName(const string& str, string* result) {
@@ -616,11 +622,12 @@ void GenerateLowerCaseOpName(const string& str, string* result) {
} // namespace
-void PrintPythonOps(const OpList& ops, const string& hidden_ops,
+string GetPythonOps(const OpList& ops, const string& hidden_ops,
bool require_shapes) {
+ string result;
// Header
// TODO(josh11b): Mention the library for which wrappers are being generated.
- printf(R"("""Python wrappers around Brain.
+ strings::Appendf(&result, R"("""Python wrappers around Brain.
This file is MACHINE GENERATED! Do not edit.
"""
@@ -662,10 +669,12 @@ from tensorflow.python.ops import op_def_library
continue;
}
- PrintPythonOp(op_def, is_hidden, lower_case_name);
+ strings::StrAppend(&result,
+ GetPythonOp(op_def, is_hidden, lower_case_name));
if (!require_shapes) {
- printf("ops.RegisterShape(\"%s\")(None)\n", op_def.name().c_str());
+ strings::Appendf(&result, "ops.RegisterShape(\"%s\")(None)\n",
+ op_def.name().c_str());
}
auto added = out->Add();
@@ -673,7 +682,7 @@ from tensorflow.python.ops import op_def_library
RemoveDescriptionsFromOpDef(added);
}
- printf(R"(def _InitOpDefLibrary():
+ strings::Appendf(&result, R"(def _InitOpDefLibrary():
op_list = op_def_pb2.OpList()
text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
op_def_registry.register_op_list(op_list)
@@ -687,7 +696,26 @@ _InitOpDefLibrary.op_list_ascii = """%s"""
_op_def_lib = _InitOpDefLibrary()
)",
- cleaned_ops.DebugString().c_str());
+ cleaned_ops.DebugString().c_str());
+ return result;
+}
+
+void PrintPythonOps(const OpList& ops, const string& hidden_ops,
+ bool require_shapes) {
+ printf("%s", GetPythonOps(ops, hidden_ops, require_shapes).c_str());
+}
+
+string GetAllPythonOps(const char* hidden, bool require_shapes) {
+ OpList ops;
+ OpRegistry::Global()->Export(false, &ops);
+ return GetPythonOps(ops, hidden, require_shapes);
+}
+
+string GetPythonWrappers(const char* buf, size_t len) {
+ string op_list_str(buf, len);
+ OpList ops;
+ ops.ParseFromString(op_list_str);
+ return GetPythonOps(ops, "", false);
}
} // namespace tensorflow
diff --git a/tensorflow/python/framework/python_op_gen.h b/tensorflow/python/framework/python_op_gen.h
index b998f2247b..faac3caab5 100644
--- a/tensorflow/python/framework/python_op_gen.h
+++ b/tensorflow/python/framework/python_op_gen.h
@@ -22,10 +22,19 @@ limitations under the License.
namespace tensorflow {
-// Result is printed to stdout. hidden_ops should be a comma-separated
+// hidden_ops should be a comma-separated
// list of Op names that should get a leading _ in the output.
+// The Print* version prints the output to stdout, Get* version returns the
+// output as a string.
void PrintPythonOps(const OpList& ops, const string& hidden_ops,
bool require_shapes);
+string GetPythonOps(const OpList& ops, const string& hidden_ops,
+ bool require_shapes);
+
+// Get the python wrappers for a list of ops in a OpList.
+// buf should be a pointer to a buffer containing the binary encoded OpList
+// proto, and len should be the length of that buffer.
+string GetPythonWrappers(const char* buf, size_t len);
} // namespace tensorflow
diff --git a/tensorflow/python/framework/python_op_gen.i b/tensorflow/python/framework/python_op_gen.i
new file mode 100644
index 0000000000..08f53f101b
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen.i
@@ -0,0 +1,24 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/python/framework/python_op_gen.h"
+%}
+
+%ignoreall;
+%unignore tensorflow::GetPythonWrappers;
+%include "tensorflow/python/framework/python_op_gen.h"
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 2b65e483de..0f4eb744f1 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -171,6 +171,8 @@ class TensorFlowTestCase(googletest.TestCase):
A Session object that should be used as a context manager to surround
the graph building and execution code in a test case.
"""
+ if self.id().endswith(".test_session"):
+ self.skipTest("Not a test.")
def prepare_config(config):
if config is None:
config = config_pb2.ConfigProto()
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index 0ea573932b..ab0676d9ec 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -364,5 +364,14 @@ class ConcatOpTest(tf.test.TestCase):
err = tf.test.compute_gradient_error(xs, x_shapes, output, output_shape)
self.assertLess(err, 1e-11)
+ def testConcatTuple(self):
+ c1 = np.random.rand(4, 4)
+ c2 = np.random.rand(4, 4)
+ with self.test_session():
+ concat_list_t = tf.concat(0, [c1, c2])
+ concat_tuple_t = tf.concat(0, (c1, c2))
+ self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 58302a683d..6de4c905b1 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -330,6 +330,13 @@ class ControlFlowTest(tf.test.TestCase):
result = exit_i.eval()
self.assertAllEqual(10, result)
+ def testCondBool(self):
+ values = tf.constant(10)
+ fn1 = lambda: tf.add(values, 1)
+ fn2 = lambda: tf.sub(values, 1)
+ with self.assertRaisesRegexp(TypeError, "must not be a Python bool"):
+ _ = control_flow_ops.cond(False, fn1, fn2)
+
def testCondIndexedSlices(self):
with self.test_session():
values = tf.constant(10)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index a823250d51..8f2720f1cf 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import math
+
import tensorflow.python.platform
import numpy as np
@@ -55,6 +57,11 @@ class UnaryOpTest(tf.test.TestCase):
tf_cpu = y.eval()
self.assertShapeEqual(np_ans, y)
self.assertAllClose(np_ans, tf_cpu)
+
+ # TODO(ebrevdo): add gradient for lgamma (digamma) and remove lgamma here.
+ if tf_func in (tf.lgamma,):
+ return # Return early
+
if x.dtype == np.float32:
s = list(np.shape(x))
jacob_t, jacob_n = tf.test.compute_gradient(inx,
@@ -94,6 +101,17 @@ class UnaryOpTest(tf.test.TestCase):
def _sigmoid(self, x):
return 1.0 / (1.0 + np.exp(-x))
+ def _replace_domain_error_with_inf(self, fn):
+ def func(x):
+ try:
+ return fn(x)
+ except ValueError, e:
+ if "domain error" in e.message:
+ return np.inf * np.ones_like(x)
+ else:
+ raise e
+ return func
+
def testFloatBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
y = (x + .5).astype(np.float32) # no zero
@@ -113,6 +131,12 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBoth(y, np.sign, tf.sign)
self._compareBoth(x, np.sin, tf.sin)
self._compareBoth(x, np.cos, tf.cos)
+ self._compareBoth(
+ x,
+ np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ tf.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), tf.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), tf.erfc)
def testFloatTanhEdge(self):
x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index f02e16a4ae..db8d4ba5c4 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -1124,6 +1124,33 @@ class FIFOQueueTest(tf.test.TestCase):
thread.join()
self.assertAllEqual(elem, results)
+ def testDtypes(self):
+ with self.test_session() as sess:
+ dtypes = [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8,
+ tf.int64, tf.bool, tf.complex64]
+ shape = (32, 4, 128)
+ q = tf.FIFOQueue(32, dtypes, [shape[1:]] * len(dtypes))
+
+ input_tuple = []
+ for dtype in dtypes:
+ np_dtype = dtype.as_numpy_dtype
+ np_array = np.random.randint(-10, 10, shape)
+ if dtype == tf.bool:
+ np_array = np_array > 0
+ elif dtype == tf.complex64:
+ np_array = np.sqrt(np_array.astype(np_dtype))
+ else:
+ np_array = np_array.astype(np_dtype)
+ input_tuple.append(np_array)
+
+ q.enqueue_many(input_tuple).run()
+
+ output_tuple_t = q.dequeue_many(32)
+ output_tuple = sess.run(output_tuple_t)
+
+ for (input_elem, output_elem) in zip(input_tuple, output_tuple):
+ self.assertAllEqual(input_elem, output_elem)
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/kernel_tests/gradient_checker_test.py b/tensorflow/python/kernel_tests/gradient_checker_test.py
index 2ded0375a8..bcaaa8cc4e 100644
--- a/tensorflow/python/kernel_tests/gradient_checker_test.py
+++ b/tensorflow/python/kernel_tests/gradient_checker_test.py
@@ -27,6 +27,7 @@ import tensorflow as tf
class GradientCheckerTest(tf.test.TestCase):
def testAddSimple(self):
+ np.random.seed(1) # Fix seed to avoid flakiness
with self.test_session(use_gpu=False):
# a test case for Add operation
size = (2, 3)
@@ -40,6 +41,7 @@ class GradientCheckerTest(tf.test.TestCase):
assert error < 1e-4
def testAddSimpleGPU(self):
+ np.random.seed(2) # Fix seed to avoid flakiness
with self.test_session(use_gpu=True):
# a test case for Add operation
size = (2, 3)
@@ -53,6 +55,7 @@ class GradientCheckerTest(tf.test.TestCase):
assert error < 1e-4
def testAddCustomized(self):
+ np.random.seed(3) # Fix seed to avoid flakiness
with self.test_session():
# a test case for Add operation
size = (2, 3)
@@ -74,6 +77,7 @@ class GradientCheckerTest(tf.test.TestCase):
assert error < 1e-10
def testGather(self):
+ np.random.seed(4) # Fix seed to avoid flakiness
with self.test_session():
p_shape = (4, 2)
p_size = 8
@@ -89,6 +93,7 @@ class GradientCheckerTest(tf.test.TestCase):
assert error < 1e-4
def testNestedGather(self):
+ np.random.seed(5) # Fix seed to avoid flakiness
with self.test_session():
p_shape = (8, 2)
p_size = 16
@@ -110,6 +115,9 @@ class GradientCheckerTest(tf.test.TestCase):
# Gradient checker for MNIST.
def BuildAndTestMiniMNIST(param_index, tag):
+ # Fix seed to avoid occasional flakiness
+ np.random.seed(6)
+
# Hyperparameters
batch = 3
inputs = 16
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 5a0ffce6b4..a470fb7274 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -642,6 +642,60 @@ class ParseSequenceExampleTest(tf.test.TestCase):
"feature_list_dense_defaults": {"d": None},
}, expected_feat_list_values=expected_feature_list_output)
+ def testSequenceExampleWithSparseAndDenseFeatureLists(self):
+ feature_list_dense_keys = ["a"]
+ feature_list_dense_types = [tf.int64]
+ feature_list_dense_shapes = [(2,)]
+
+ original = sequence_example(feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([3, 4]),
+ int64_feature([1, 0])]),
+ "st_a": feature_list([
+ float_feature([3.0, 4.0]),
+ float_feature([5.0]),
+ float_feature([])]),
+ "st_b": feature_list([
+ bytes_feature([b"a"]),
+ bytes_feature([]),
+ bytes_feature([]),
+ bytes_feature([b"b", b"c"])])}))
+
+ serialized = original.SerializeToString()
+
+ expected_st_a = (
+ np.array([[0, 0], [0, 1], [1, 0]], dtype=np.int64), # indices
+ np.array([3.0, 4.0, 5.0], dtype=np.float32), # values
+ np.array([3, 2], dtype=np.int64)) # shape: num_time = 3, max_feat = 2
+
+ expected_st_b = (
+ np.array([[0, 0], [3, 0], [3, 1]], dtype=np.int64), # indices
+ np.array(["a", "b", "c"], dtype=np.str), # values
+ np.array([4, 2], dtype=np.int64)) # shape: num_time = 4, max_feat = 2
+
+ expected_st_c = (
+ np.empty((0, 2), dtype=np.int64), # indices
+ np.empty((0,), dtype=np.int64), # values
+ np.array([0, 0], dtype=np.int64)) # shape: num_time = 0, max_feat = 0
+
+ expected_feature_list_output = {
+ "a": np.array([[3, 4], [1, 0]], dtype=np.int64),
+ "st_a": expected_st_a,
+ "st_b": expected_st_b,
+ "st_c": expected_st_c,
+ }
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_dense_types": feature_list_dense_types,
+ "feature_list_dense_keys": feature_list_dense_keys,
+ "feature_list_dense_shapes": feature_list_dense_shapes,
+ "feature_list_sparse_keys": ["st_a", "st_b", "st_c"],
+ "feature_list_sparse_types": [tf.float32, tf.string, tf.int64]
+ }, expected_feat_list_values=expected_feature_list_output)
+
def testSequenceExampleListWithInconsistentDataFails(self):
feature_list_dense_types = [tf.int64]
feature_list_dense_shapes = [(2,)]
@@ -687,6 +741,29 @@ class ParseSequenceExampleTest(tf.test.TestCase):
expected_err_re=("Feature list: a, Index: 0. Data types don't match. "
"Expected type: int64"))
+ def testSequenceExampleListWithWrongSparseDataTypeFails(self):
+ feature_list_sparse_types = [tf.int64]
+
+ original = sequence_example(feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([3, 4]),
+ int64_feature([1, 2, 3]),
+ float_feature([2])])
+ }))
+
+ serialized = original.SerializeToString()
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_sparse_types": feature_list_sparse_types,
+ "feature_list_sparse_keys": ["a"]
+ },
+ expected_err_re=(
+ "Name: in1, Feature List: a, Index: 2. Data types don't match. "
+ "Expected type: int64 Feature is: float_list"))
+
def testSequenceExampleListWithWrongShapeFails(self):
feature_list_dense_types = [tf.int64]
feature_list_dense_shapes = [(2,)]
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
new file mode 100644
index 0000000000..742402b8b7
--- /dev/null
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -0,0 +1,84 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_func op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.ops import script_ops
+
+
+class PyOpTest(tf.test.TestCase):
+
+ def testBasic(self):
+
+ def my_func(x, y):
+ return np.sinh(x) + np.cosh(y)
+
+ # scalar
+ with self.test_session():
+ x = tf.constant(1.0, tf.float32)
+ y = tf.constant(2.0, tf.float32)
+ z = tf.py_func(my_func, [x, y], [tf.float32])
+ self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32))
+
+ # array
+ with self.test_session():
+ x = tf.constant([1.0, 2.0], tf.float64)
+ y = tf.constant([2.0, 3.0], tf.float64)
+ z = tf.py_func(my_func, [x, y], [tf.float64])
+ self.assertAllEqual(
+ z[0].eval(),
+ my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
+
+ # a bit exotic type (complex64)
+ with self.test_session():
+ x = tf.constant(1+2j, tf.complex64)
+ y = tf.constant(3+4j, tf.complex64)
+ z, = tf.py_func(my_func, [x, y], [tf.complex64])
+ self.assertAllClose(z.eval(), my_func(1+2j, 3+4j))
+
+ # a bit excotic function (rfft)
+ with self.test_session():
+ x = tf.constant([1., 2., 3., 4.], tf.float32)
+ def rfft(x):
+ return np.fft.rfft(x).astype(np.complex64)
+ y, = tf.py_func(rfft, [x], [tf.complex64])
+ self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.]))
+
+ def testLarge(self):
+ with self.test_session() as sess:
+ x = tf.zeros([1000000], dtype=np.float32)
+ y = tf.py_func(lambda x: x + 1, [x], [tf.float32])
+ z = tf.py_func(lambda x: x * 2, [x], [tf.float32])
+ for _ in xrange(100):
+ sess.run([y[0].op, z[0].op])
+
+ def testCleanup(self):
+ for _ in range(1000):
+ g = tf.Graph()
+ with g.as_default():
+ c = tf.constant([1.], tf.float32)
+ _ = tf.py_func(lambda x: x + 1, [c], [tf.float32])
+ self.assertTrue(script_ops._py_funcs.size() < 100)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index b1188d0672..2882182a03 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -235,7 +235,7 @@ class TextLineReaderTest(tf.test.TestCase):
def _LineText(self, f, l):
return tf.compat.as_bytes("%d: %d" % (f, l))
- def _CreateFiles(self):
+ def _CreateFiles(self, crlf=False):
filenames = []
for i in range(self._num_files):
fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
@@ -246,11 +246,10 @@ class TextLineReaderTest(tf.test.TestCase):
# Always include a newline after the record unless it is
# at the end of the file, in which case we include it sometimes.
if j + 1 != self._num_lines or i == 0:
- f.write(b"\n")
+ f.write(b"\r\n" if crlf else b"\n")
return filenames
- def testOneEpoch(self):
- files = self._CreateFiles()
+ def _testOneEpoch(self, files):
with self.test_session() as sess:
reader = tf.TextLineReader(name="test_reader")
queue = tf.FIFOQueue(99, [tf.string], shapes=())
@@ -268,6 +267,12 @@ class TextLineReaderTest(tf.test.TestCase):
"\\(requested 1, current size 0\\)"):
k, v = sess.run([key, value])
+ def testOneEpochLF(self):
+ self._testOneEpoch(self._CreateFiles(crlf=False))
+
+ def testOneEpochCRLF(self):
+ self._testOneEpoch(self._CreateFiles(crlf=True))
+
def testSkipHeaderLines(self):
files = self._CreateFiles()
with self.test_session() as sess:
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index 7ff3851da7..3b79ae341b 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -174,6 +174,28 @@ class SumReductionTest(tf.test.TestCase):
def testGradient4(self):
self._compareGradient([2, 3, 4, 2], [], None)
+ def testHighRank(self):
+ # Do a bunch of random high dimensional reductions
+ np.random.seed(42)
+ for _ in range(20):
+ rank = np.random.randint(4, 10 + 1)
+ axes, = np.nonzero(np.random.randint(2, size=rank))
+ shape = tuple(np.random.randint(1, 3 + 1, size=rank))
+ data = np.random.randint(1024, size=shape)
+ self._compareAll(data, axes)
+ # Check some particular axis patterns
+ for rank in 4, 7, 10:
+ shape = tuple(np.random.randint(1, 3 + 1, size=rank))
+ data = np.random.randint(1024, size=shape)
+ for axes in ([], np.arange(rank), np.arange(0, rank, 2),
+ np.arange(1, rank, 2)):
+ self._compareAll(data, axes)
+
+ def testExpand(self):
+ # Reduce an empty tensor to a nonempty tensor
+ x = np.zeros((5, 0))
+ self._compareAll(x, [1])
+
class MeanReductionTest(tf.test.TestCase):
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 81be48990b..38ba890c74 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -227,15 +227,23 @@ class TileTest(tf.test.TestCase):
def testSimple(self):
with self.test_session():
- inp = np.random.rand(4, 1).astype("f")
- a = tf.constant([float(x) for x in inp.ravel(order="C")],
- shape=[4, 1], dtype=tf.float32)
+ inp = np.random.rand(4, 1).astype(np.float32)
+ a = tf.constant(inp)
tiled = tf.tile(a, [1, 4])
result = tiled.eval()
self.assertEqual(result.shape, (4, 4))
self.assertEqual([4, 4], tiled.get_shape())
self.assertTrue((result == np.tile(inp, (1, 4))).all())
+ def testEmpty(self):
+ with self.test_session():
+ inp = np.random.rand(2, 3).astype(np.float32)
+ a = tf.constant(inp)
+ tiled = tf.tile(a, [5, 0])
+ result = tiled.eval()
+ self.assertEqual(result.shape, (10, 0))
+ self.assertEqual([10, 0], tiled.get_shape())
+
def testTypes(self):
types_to_test = {
"bool": (tf.bool, bool),
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index c6e91dcd71..d6ee16c8e2 100644
--- a/tensorflow/python/ops/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -19,7 +19,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+# pylint: disable=unused-import, g-bad-import-order
import tensorflow.python.platform
+# pylint: enable=unused-import, g-bad-import-order
import numpy as np
@@ -46,13 +48,17 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase):
constant_op.constant(shape, dtypes.int64))
def _SparseTensor_2x3x4(self, dtype):
+ # Includes two entries with the form [1, 1, x] : 150.
ind = np.array([
[0, 0, 1],
- [0, 1, 0], [0, 1, 2],
+ [0, 1, 0],
+ [0, 1, 2],
[1, 0, 3],
- [1, 1, 1], [1, 1, 3],
+ [1, 1, 0],
+ [1, 1, 1],
+ [1, 1, 2],
[1, 2, 2]])
- val = np.array([1, 10, 12, 103, 111, 113, 122])
+ val = np.array([1, 10, 12, 103, 150, 149, 150, 122])
shape = np.array([2, 3, 4])
return ops.SparseTensor(
constant_op.constant(ind, dtypes.int64),
@@ -90,7 +96,8 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase):
expected_output = np.zeros((2, 3, 200), dtype=np.bool)
expected_trues = [(0, 0, 1), (0, 1, 10), (0, 1, 12),
- (1, 0, 103), (1, 1, 111), (1, 1, 113), (1, 2, 122)]
+ (1, 0, 103), (1, 1, 149), (1, 1, 150),
+ (1, 2, 122)]
for expected_true in expected_trues:
expected_output[expected_true] = True
diff --git a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
index ee9a697a0b..6ea1e6d8eb 100644
--- a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
+++ b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
@@ -25,9 +25,11 @@ import tensorflow as tf
def _SparseToDense(sparse_indices, output_size, sparse_values,
- default_value):
+ default_value, validate_indices=True):
return tf.sparse_to_dense(sparse_indices, output_size,
- sparse_values, default_value)
+ sparse_values,
+ default_value=default_value,
+ validate_indices=validate_indices)
class SparseToDenseTest(tf.test.TestCase):
@@ -107,10 +109,24 @@ class SparseToDenseTest(tf.test.TestCase):
def testBadDefault(self):
with self.test_session():
- dense = _SparseToDense([1, 3], [5], [1, 2], [1, 2])
+ dense = _SparseToDense([1, 3], [5], [1, 2], [0])
with self.assertRaisesOpError("default_value should be a scalar"):
dense.eval()
+ def testInvalidIndicesWithWithoutValidation(self):
+ with self.test_session():
+ dense = _SparseToDense(
+ sparse_indices=[[1], [1]], output_size=[5],
+ sparse_values=[-1.0, 1.0], default_value=0.0)
+ with self.assertRaisesOpError(
+ "not lexicographically sorted or containing repeats"):
+ dense.eval()
+ # Disable checks
+ dense_without_validation = _SparseToDense(
+ sparse_indices=[[1], [1]], output_size=[5],
+ sparse_values=[-1.0, 1.0], default_value=0.0, validate_indices=False)
+ dense_without_validation.eval()
+
def testShapeInferenceKnownShape(self):
with self.test_session(use_gpu=False):
indices = tf.placeholder(tf.int64)
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index c6af05ff22..c769987a47 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -186,8 +186,8 @@ class TransposeTest(tf.test.TestCase):
def testError(self):
with self.assertRaises(ValueError):
tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [[0, 1], [2, 3]])
- self._testError(np.arange(0., 2 ** 10).reshape([2] * 10),
- np.arange(10),
+ self._testError(np.arange(0., 2 ** 11).reshape([2] * 11),
+ np.arange(11),
"not implemented")
with self.assertRaises(IndexError):
tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 3])
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
new file mode 100644
index 0000000000..87c64014bf
--- /dev/null
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -0,0 +1,338 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/lib/core/py_func.h"
+
+#include <Python.h>
+#include "numpy/arrayobject.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace {
+
+static mutex mu;
+static bool initialized GUARDED_BY(mu) = false;
+static PyObject* py_trampoline GUARDED_BY(mu) = nullptr;
+
+// Returns the py_trampoline that is used to pass the control to the
+// python runtime.
+PyObject* GetPyTrampoline() {
+ mutex_lock l(mu);
+ return py_trampoline;
+}
+
+// Module initialization (mainly import numpy) if needed.
+void InitIfNeeded() {
+ mutex_lock l(mu);
+ if (!initialized) {
+ PyGILState_STATE py_threadstate;
+ py_threadstate = PyGILState_Ensure();
+ import_array();
+ PyGILState_Release(py_threadstate);
+ initialized = true;
+ }
+}
+
+// Returns a single-thread threadpool used to execute python
+// trampoline and the python function. It is single threaded because
+// GIL is needed running the trampoline.
+thread::ThreadPool* py_thread() {
+ static thread::ThreadPool* w =
+ new thread::ThreadPool(Env::Default(), "PyTrampoline", 1);
+ return w;
+}
+
+// Returns the corresponding numpy dtype in 'np' for tf data type
+// 'tf'. Returns an error if the type is not supported by this
+// module.
+Status TfDTypeToNpDType(const DataType& tf, int* np) {
+ switch (tf) {
+ case DT_FLOAT:
+ *np = NPY_FLOAT32;
+ break;
+ case DT_DOUBLE:
+ *np = NPY_FLOAT64;
+ break;
+ case DT_INT32:
+ *np = NPY_INT32;
+ break;
+ case DT_UINT8:
+ *np = NPY_UINT8;
+ break;
+ case DT_INT8:
+ *np = NPY_INT8;
+ break;
+ case DT_INT16:
+ *np = NPY_INT16;
+ break;
+ case DT_INT64:
+ *np = NPY_INT64;
+ break;
+ case DT_BOOL:
+ *np = NPY_BOOL;
+ break;
+ case DT_COMPLEX64:
+ *np = NPY_COMPLEX64;
+ break;
+ default:
+ return errors::Unimplemented("Unsupported tf type ", DataTypeString(tf));
+ }
+ return Status::OK();
+}
+
+// Creates a numpy array in 'ret' and copies the content of tensor 't'
+// into 'ret'.
+Status ConvertTensorToNdarray(const Tensor& t, PyObject** ret) {
+ int typenum;
+ TF_RETURN_IF_ERROR(TfDTypeToNpDType(t.dtype(), &typenum));
+ PyArray_Descr* descr = PyArray_DescrFromType(typenum);
+ CHECK(descr);
+ std::vector<npy_intp> dims;
+ for (int i = 0; i < t.dims(); ++i) {
+ dims.push_back(t.dim_size(i));
+ }
+ PyObject* obj = PyArray_Empty(dims.size(), dims.data(), descr, 0);
+ if (obj == nullptr) {
+ return errors::Internal("Failed to allocate np array: ",
+ t.shape().ShortDebugString());
+ }
+ PyArrayObject* np_array = reinterpret_cast<PyArrayObject*>(obj);
+ CHECK(DataTypeCanUseMemcpy(t.dtype()));
+ StringPiece p = t.tensor_data();
+ memcpy(np_array->data, p.data(), p.size());
+ *ret = PyArray_Return(np_array);
+ return Status::OK();
+}
+
+// A call to the registered python function.
+struct PyCall {
+ // Passed to python runtime to call the python function registered
+ // with this "token".
+ string token;
+
+ // Inputs and outputs of this function invokation.
+ std::vector<Tensor> ins;
+ std::vector<Tensor> out;
+};
+
+// Givens the 'call', prepares the token and inputs as a python tuple
+// that is appropriate for calling the trampoline.
+Status MakeArgTuple(PyCall* call, PyObject** tuple) {
+ int64 n = call->ins.size();
+ PyObject* lst = PyList_New(n);
+ CHECK(lst);
+ for (int64 i = 0; i < n; ++i) {
+ const Tensor& t = call->ins[i];
+ PyObject* a;
+ Status s = ConvertTensorToNdarray(t, &a);
+ if (!s.ok()) {
+ Py_DECREF(lst);
+ return s;
+ }
+ PyList_SetItem(lst, i, a);
+ }
+ *tuple = Py_BuildValue("(sN)", call->token.c_str(), lst);
+ CHECK(*tuple);
+ return Status::OK();
+}
+
+// Returns the corresponding tf dtype in 'tf' for numpy data type
+// 'np'. Returns an error if the type is not supported by this
+// module.
+Status NpDTypeToTfDType(const int np, DataType* tf) {
+ switch (np) {
+ case NPY_FLOAT32:
+ *tf = DT_FLOAT;
+ break;
+ case NPY_FLOAT64:
+ *tf = DT_DOUBLE;
+ break;
+ case NPY_INT32:
+ *tf = DT_INT32;
+ break;
+ case NPY_UINT8:
+ *tf = DT_UINT8;
+ break;
+ case NPY_INT8:
+ *tf = DT_INT8;
+ break;
+ case NPY_INT16:
+ *tf = DT_INT16;
+ break;
+ case NPY_INT64:
+ *tf = DT_INT64;
+ break;
+ case NPY_BOOL:
+ *tf = DT_BOOL;
+ break;
+ case NPY_COMPLEX64:
+ *tf = DT_COMPLEX64;
+ break;
+ default:
+ return errors::Unimplemented("Unsupported numpy type ", np);
+ }
+ return Status::OK();
+}
+
+// Given an numpy ndarray object 'obj', creates a corresponding tf
+// Tensor in '*ret'.
+Status ConvertNdarrayToTensor(PyObject* obj, Tensor* ret) {
+ PyArrayObject* a = reinterpret_cast<PyArrayObject*>(obj);
+ DataType dtype;
+ TF_RETURN_IF_ERROR(NpDTypeToTfDType(PyArray_TYPE(a), &dtype));
+ CHECK(DataTypeCanUseMemcpy(dtype));
+ TensorShape shape;
+ for (int i = 0; i < PyArray_NDIM(a); ++i) {
+ shape.AddDim(PyArray_SHAPE(a)[i]);
+ }
+ Tensor t(dtype, shape);
+ StringPiece p = t.tensor_data();
+ memcpy(const_cast<char*>(p.data()), a->data, p.size());
+ *ret = t;
+ return Status::OK();
+}
+
+// Calls the registered py function through the trampoline.
+Status DoCallPyFunc(PyCall* call) {
+ PyObject* trampoline = GetPyTrampoline();
+ if (trampoline == nullptr) {
+ return errors::InvalidArgument(
+ "Missing py trampoline. Most likely, it is a link error.");
+ }
+ // Prepare the argument.
+ PyObject* args = nullptr;
+ TF_RETURN_IF_ERROR(MakeArgTuple(call, &args));
+ CHECK(args);
+
+ // Invokes the trampoline.
+ PyObject* result = PyEval_CallObject(trampoline, args);
+ Py_DECREF(args);
+ if (result == nullptr) {
+ return errors::Internal("Failed to run py callback ", call->token,
+ ": see error log.");
+ }
+
+ // Process the return values and converts them to tf Tensors.
+ Status s;
+ if (PyList_Check(result)) {
+ // 'result' is a list.
+ call->out.clear();
+ for (int i = 0; i < PyList_Size(result); ++i) {
+ Tensor t;
+ s = ConvertNdarrayToTensor(PyList_GetItem(result, i), &t);
+ if (!s.ok()) {
+ break;
+ }
+ call->out.push_back(t);
+ }
+ } else if (PyArray_Check(result)) {
+ // 'result' is a single ndarray.
+ Tensor t;
+ s = ConvertNdarrayToTensor(result, &t);
+ if (s.ok()) {
+ call->out.push_back(t);
+ }
+ } else {
+ // 'result' is a plain python scalar. We convert it to an numpy
+ // scalar then convert it to a Tensor.
+ PyObject* scalar = PyArray_FromScalar(result, nullptr);
+ if (scalar == nullptr) {
+ s = errors::InvalidArgument(
+ call->token,
+ " returns a value which can't be converted into numpy scalar.");
+ } else {
+ Tensor t;
+ s = ConvertNdarrayToTensor(scalar, &t);
+ if (s.ok()) {
+ call->out.push_back(t);
+ }
+ Py_DECREF(scalar);
+ }
+ }
+ Py_DECREF(result);
+ return s;
+}
+
+// Calls the python function in a separate thread. Arranges to call
+// done() when the python function returns.
+void CallPyFunc(PyCall* call, std::function<void(Status)> done) {
+ InitIfNeeded();
+ py_thread()->Schedule([call, done]() {
+ PyGILState_STATE py_threadstate;
+ py_threadstate = PyGILState_Ensure();
+ Status s = DoCallPyFunc(call);
+ PyGILState_Release(py_threadstate);
+ done(s);
+ });
+}
+
+} // end namespace
+
+void InitializePyTrampoline(PyObject* trampoline) {
+ mutex_lock l(mu);
+ if (py_trampoline == nullptr) {
+ py_trampoline = trampoline;
+ Py_INCREF(py_trampoline);
+ } else {
+ LOG(WARNING) << "InitializeCallback should only be called once";
+ }
+}
+
+class PyFuncOp : public AsyncOpKernel {
+ public:
+ explicit PyFuncOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ PyCall* call = new PyCall;
+ call->token = token_;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ call->ins.push_back(ctx->input(i));
+ }
+ CallPyFunc(call, [this, ctx, call, done](Status s) {
+ std::unique_ptr<PyCall> delete_me(call);
+ OP_REQUIRES_OK_ASYNC(ctx, s, done);
+ OP_REQUIRES_ASYNC(
+ ctx, call->out.size() == ctx->num_outputs(),
+ errors::InvalidArgument(token_, " returns ", call->out.size(),
+ " values, but expects to see ",
+ ctx->num_outputs(), " values."),
+ done);
+ for (int i = 0; i < call->out.size(); ++i) {
+ const auto& t = call->out[i];
+ OP_REQUIRES_ASYNC(
+ ctx, t.dtype() == output_type(i),
+ errors::InvalidArgument(i, "-th value returned by ", token_, " is ",
+ DataTypeString(t.dtype()), ", but expects ",
+ DataTypeString(output_type(i))),
+ done);
+ ctx->set_output(i, t);
+ }
+ done();
+ });
+ }
+
+ private:
+ string token_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(PyFuncOp);
+};
+REGISTER_KERNEL_BUILDER(Name("PyFunc").Device(DEVICE_CPU), PyFuncOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/python/lib/core/py_func.h b/tensorflow/python/lib/core/py_func.h
new file mode 100644
index 0000000000..2de52fb492
--- /dev/null
+++ b/tensorflow/python/lib/core/py_func.h
@@ -0,0 +1,47 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
+
+#include "tensorflow/core/public/status.h"
+#include "tensorflow/core/public/tensor.h"
+
+#include <Python.h>
+
+namespace tensorflow {
+
+// Called by py code on initialization.
+//
+// "trampoline" must represent a python function which has the
+// following signature:
+// (string, list(ndarray)) -> ndarray | list(ndarray) | python scalar
+//
+// The trampoline takes two arguments, the first is a string token
+// used by the python frontend's dispatching logic; the second is a
+// list of numpy ndarrays.
+//
+// The trampoline can return a single numpy ndarray, a list of numpy
+// ndarrays, or a simply python scalar. The C++ runtime converts them,
+// if supported, back to Tensor objects.
+//
+// This is called by script_ops.py during its module initialization.
+//
+// TODO(zhifengc): Support distributed runtime.
+void InitializePyTrampoline(PyObject* trampoline);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_FUNC_H_
diff --git a/tensorflow/python/lib/core/py_func.i b/tensorflow/python/lib/core/py_func.i
new file mode 100644
index 0000000000..c85bbc1c55
--- /dev/null
+++ b/tensorflow/python/lib/core/py_func.i
@@ -0,0 +1,29 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/python/lib/core/py_func.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::InitializePyTrampoline;
+
+%include "tensorflow/python/lib/core/py_func.h"
+
+%unignoreall
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index e3526d32c4..d3b62f03c6 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -23,7 +23,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
-from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
@@ -231,7 +230,22 @@ ops.NoGradient("Size")
def _TileGrad(op, grad):
"""Sum reduces grad along the tiled dimensions."""
assert isinstance(grad, ops.Tensor)
- return [gen_array_ops._tile_grad(grad, op.inputs[1]), None]
+ input_shape = array_ops.shape(op.inputs[0])
+ # We interleave multiples and input_shape to get split_shape,
+ # reshape grad to split_shape, and reduce along all even
+ # dimensions (the tiled dimensions) to get the result
+ # with shape input_shape. For example
+ # input_shape = [20, 30, 40]
+ # multiples = [2, 3, 4]
+ # split_shape = [2, 20, 3, 30, 4, 40]
+ # axes = [0, 2, 4]
+ split_shape = array_ops.reshape(array_ops.transpose(
+ array_ops.pack([op.inputs[1], input_shape])), [-1])
+ axes = math_ops.range(0, array_ops.size(split_shape), 2)
+ input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
+ # Fix shape inference
+ input_grad.set_shape(op.inputs[0].get_shape())
+ return [input_grad, None]
ops.NoGradient("TileGrad")
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 613bdf49f0..0f36ed7e41 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -299,7 +299,7 @@ def concat(concat_dim, values, name="concat"):
Returns:
A `Tensor` resulting from concatenation of the input tensors.
"""
- if not isinstance(values, (list)):
+ if not isinstance(values, (list, tuple)):
values = [values]
# TODO(mrry): Change to return values?
if len(values) == 1: # Degenerate case of one tensor.
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index ce78458515..f3d17aa12d 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -69,6 +69,7 @@ from __future__ import print_function
import six
from six.moves import xrange # pylint: disable=redefined-builtin
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -82,6 +83,7 @@ from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.gen_control_flow_ops import *
+from tensorflow.python.platform import logging
# We override the 'tuple' for a control flow op, so we keep python's
@@ -630,6 +632,8 @@ def cond(pred, fn1, fn2, name=None):
raise TypeError("fn2 must be callable.")
# Add the Switch to the graph.
+ if isinstance(pred, bool):
+ raise TypeError("pred must not be a Python bool")
p_2, p_1 = switch(pred, pred)
pivot_1 = array_ops.identity(p_1, name="switch_t")
pivot_2 = array_ops.identity(p_2, name="switch_f")
@@ -1172,7 +1176,7 @@ def with_dependencies(dependencies, output_tensor, name=None):
TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
"""
with ops.op_scope(dependencies + [output_tensor], name,
- "control_dependency") as name:
+ "control_dependency") as name:
with ops.device(output_tensor.device
or ops.get_default_graph().get_default_device()):
with ops.control_dependencies(dependencies):
@@ -1237,6 +1241,7 @@ def group(*inputs, **kwargs):
# 2-level tree. The root node is the returned NoOp node.
# deps contains 1 NoOp node for each device.
deps = []
+
def device_key(dev):
"""A sort key that allows None to be compared to strings."""
return "" if dev is None else dev
@@ -1244,6 +1249,7 @@ def group(*inputs, **kwargs):
deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
return _GroupControlDeps(None, deps, name=name)
+
def tuple(tensors, name=None, control_inputs=None):
"""Group tensors together.
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 5afc8a779b..33d4ac2a0b 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -308,6 +308,22 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
c_dense = math_ops.mul(c_sparse, 1.0)
self.assertAllClose(np_val, c_dense.eval())
+ def testIndexedSlicesToTensorList(self):
+ with self.test_session():
+ numpy_list = []
+ dense_list = []
+ sparse_list = []
+ for _ in range(3):
+ np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
+ c = constant_op.constant(np_val)
+ c_sparse = math_ops._as_indexed_slices(c)
+ numpy_list.append(np_val)
+ dense_list.append(c)
+ sparse_list.append(c_sparse)
+ packed_dense = array_ops.pack(dense_list)
+ packed_sparse = array_ops.pack(sparse_list)
+ self.assertAllClose(packed_dense.eval(), packed_sparse.eval())
+
def testInt64Indices(self):
with self.test_session():
np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 55b20a1d10..aa5a2edf86 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -223,6 +225,28 @@ def _TanhGrad(op, grad):
return grad * (1 - math_ops.square(y))
+@ops.RegisterGradient("Erf")
+def _ErfGrad(op, grad):
+ """Returns grad * 2/sqrt(pi) * exp(-x**2)."""
+ x = op.inputs[0]
+ two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
+ return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))
+
+
+@ops.RegisterGradient("Erfc")
+def _ErfcGrad(op, grad):
+ """Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
+ x = op.inputs[0]
+ two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
+ return -grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))
+
+
+@ops.RegisterGradient("Lgamma")
+def _LgammaGrad(op, grad): # pylint: disable=unused-argument
+ # TODO(ebrevdo): implement digamma
+ raise NotImplementedError("grad(Lgamma) == Digamma is not implemented")
+
+
@ops.RegisterGradient("Sigmoid")
def _SigmoidGrad(op, grad):
"""Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index ec382da9b2..fa12adf8ce 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -49,6 +49,9 @@ mathematical functions to your graph.
@@minimum
@@cos
@@sin
+@@lgamma
+@@erf
+@@erfc
## Matrix Math Functions
@@ -1097,6 +1100,57 @@ def tanh(x, name=None):
return gen_math_ops._tanh(x, name=name)
+def lgamma(x, name=None):
+ """Computes `ln(|gamma(x)|)` element-wise.
+
+ Args:
+ x: A Tensor with type `float`, `double`, `int32`, `int64`,
+ or `qint32`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+ """
+ with ops.op_scope([x], name, "Lgamma") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ return gen_math_ops._lgamma(x, name=name)
+
+
+def erf(x, name=None):
+ """Computes Gauss error function of `x` element-wise.
+
+ Args:
+ x: A Tensor with type `float`, `double`, `int32`, `int64`,
+ or `qint32`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+ """
+ with ops.op_scope([x], name, "Erf") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ return gen_math_ops._erf(x, name=name)
+
+
+def erfc(x, name=None):
+ """Computes complementary error function of `x` element-wise.
+
+ Args:
+ x: A Tensor with type `float`, `double`, `int32`, `int64`,
+ or `qint32`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
+ the return type is `quint8`.
+ """
+ with ops.op_scope([x], name, "Erfc") as name:
+ x = ops.convert_to_tensor(x, name="x")
+ return gen_math_ops._erfc(x, name=name)
+
+
ops.RegisterShape("Abs")(common_shapes.unchanged_shape)
ops.RegisterShape("Ceil")(common_shapes.unchanged_shape)
ops.RegisterShape("Conj")(common_shapes.unchanged_shape)
@@ -1119,6 +1173,9 @@ ops.RegisterShape("Sqrt")(common_shapes.unchanged_shape)
ops.RegisterShape("Square")(common_shapes.unchanged_shape)
ops.RegisterShape("Sigmoid")(common_shapes.unchanged_shape)
ops.RegisterShape("Tanh")(common_shapes.unchanged_shape)
+ops.RegisterShape("Lgamma")(common_shapes.unchanged_shape)
+ops.RegisterShape("Erf")(common_shapes.unchanged_shape)
+ops.RegisterShape("Erfc")(common_shapes.unchanged_shape)
ops.RegisterShape("Cast")(common_shapes.unchanged_shape)
ops.RegisterShape("ComplexAbs")(common_shapes.unchanged_shape)
ops.RegisterShape("FFT2D")(common_shapes.unchanged_shape)
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 72adf9e498..6fecea8666 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -686,7 +686,8 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
if sampled_logits.dtype != acc_weights.dtype:
acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
sampled_logits += sparse_ops.sparse_to_dense(
- sparse_indices, sampled_logits_shape, acc_weights, 0.0)
+ sparse_indices, sampled_logits_shape, acc_weights,
+ default_value=0.0, validate_indices=False)
if subtract_log_q:
# Subtract log of Q(l), prior probability that l appears in sampled.
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
index 149bfe712a..94d874f067 100644
--- a/tensorflow/python/ops/op_def_library.py
+++ b/tensorflow/python/ops/op_def_library.py
@@ -376,16 +376,14 @@ class OpDefLibrary(object):
try:
if not input_arg.is_ref and dtype:
dtype = dtypes.as_dtype(dtype).base_dtype
- values = ops.convert_n_to_tensor_or_indexed_slices(
- values, name=input_arg.name,
- dtype=dtype if dtype else None,
+ values = ops.convert_n_to_tensor(
+ values, name=input_arg.name, dtype=dtype if dtype else None,
as_ref=input_arg.is_ref)
except (TypeError, ValueError):
assert dtype is not None, "Should not fail if dtype is None"
assert input_arg.number_attr, "Should be number_attr case"
# What types does the conversion function think values have?
- values = ops.convert_n_to_tensor_or_indexed_slices(
- values, as_ref=input_arg.is_ref)
+ values = ops.convert_n_to_tensor(values, as_ref=input_arg.is_ref)
observed = ", ".join(v.dtype.base_dtype.name for v in values)
prefix = (
@@ -659,8 +657,7 @@ class OpDefLibrary(object):
input_types=input_types, attrs=attr_protos,
op_def=op_def)
outputs = op.outputs
- return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs),
- output_structure)
+ return _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
else:
return g.create_op(op_type_name, inputs, output_types, name=scope,
input_types=input_types, attrs=attr_protos,
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 69cd7fbc56..c9cbfb1d7d 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -404,6 +404,8 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
context_dense_types=None,
context_dense_defaults=None,
context_dense_shapes=None,
+ feature_list_sparse_keys=None,
+ feature_list_sparse_types=None,
feature_list_dense_keys=None,
feature_list_dense_types=None,
feature_list_dense_shapes=None,
@@ -461,6 +463,12 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
map will be treated as empty (zero length) if not found in the
`FeatureList` map.
+ The key `feature_list_sparse_keys[j]` is mapped to a `SparseTensor` of type
+ `feature_list_sparse_types[j]`. This `SparseTensor` represents a ragged
+ vector. Its indices are `[time, index]`, where `time` is the FeatureList
+ entry `index` is the value's index in the list of values associated with that
+ time.
+
`debug_name` may contain a descriptive name for the corresponding serialized
proto. This may be useful for debugging purposes, but it has no effect on the
output. If not `None`, `debug_name` must be a scalar.
@@ -485,6 +493,12 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
The shape of the data for each context_dense feature referenced by
`context_dense_keys`. Required for any input tensors identified by
`context_dense_keys` whose shapes are anything other than `[]` or `[1]`.
+ feature_list_sparse_keys: A list of string keys in the `SequenceExample`'s
+ feature_lists. The results for these keys will be returned as
+ `SparseTensor` objects.
+ feature_list_sparse_types: A list of `DTypes`, same length as `sparse_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+ and `tf.string` (`BytesList`) are supported.
feature_list_dense_keys: A list of string keys in the `SequenceExample`'s
features_lists. The results for these keys will be returned as `Tensor`s.
feature_list_dense_types: A list of `DTypes`, same length as
@@ -528,6 +542,10 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
context_dense_shapes = (
[[]] * len(context_dense_keys)
if context_dense_shapes is None else context_dense_shapes)
+ feature_list_sparse_keys = (
+ [] if feature_list_sparse_keys is None else feature_list_sparse_keys)
+ feature_list_sparse_types = (
+ [] if feature_list_sparse_types is None else feature_list_sparse_types)
feature_list_dense_keys = (
[] if feature_list_dense_keys is None else feature_list_dense_keys)
feature_list_dense_types = (
@@ -545,6 +563,7 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
num_context_dense = len(context_dense_keys)
num_feature_list_dense = len(feature_list_dense_keys)
num_context_sparse = len(context_sparse_keys)
+ num_feature_list_sparse = len(feature_list_sparse_keys)
if len(context_dense_shapes) != num_context_dense:
raise ValueError(
@@ -567,15 +586,28 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
raise ValueError(
"len(context_sparse_types) != len(context_sparse_keys): %d vs. %d"
% (len(context_sparse_types), num_context_sparse))
- if num_context_dense + num_context_sparse + num_feature_list_dense == 0:
+ if len(feature_list_sparse_types) != num_feature_list_sparse:
+ raise ValueError(
+ "len(feature_list_sparse_types) != len(feature_list_sparse_keys): "
+ "%d vs. %d"
+ % (len(feature_list_sparse_types), num_feature_list_sparse))
+ if (num_context_dense + num_context_sparse
+ + num_feature_list_dense + num_feature_list_sparse) == 0:
raise ValueError(
"Must provide at least one context_sparse key, context_dense key, "
- "or feature_list_dense key")
+ ", feature_list_sparse key, or feature_list_dense key")
if not set(context_dense_keys).isdisjoint(set(context_sparse_keys)):
raise ValueError(
- "Context_Dense and context_sparse keys must not intersect; "
+ "context_dense and context_sparse keys must not intersect; "
"intersection: %s" %
set(context_dense_keys).intersection(set(context_sparse_keys)))
+ if not set(feature_list_dense_keys).isdisjoint(
+ set(feature_list_sparse_keys)):
+ raise ValueError(
+ "feature_list_dense and feature_list_sparse keys must not intersect; "
+ "intersection: %s" %
+ set(feature_list_dense_keys).intersection(
+ set(feature_list_sparse_keys)))
if not isinstance(feature_list_dense_defaults, dict):
raise TypeError("feature_list_dense_defaults must be a dict")
for k, v in feature_list_dense_defaults.items():
@@ -613,6 +645,8 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
context_sparse_types=context_sparse_types,
context_dense_keys=context_dense_keys,
context_dense_shapes=context_dense_shapes,
+ feature_list_sparse_keys=feature_list_sparse_keys,
+ feature_list_sparse_types=feature_list_sparse_types,
feature_list_dense_keys=feature_list_dense_keys,
feature_list_dense_types=feature_list_dense_types,
feature_list_dense_shapes=feature_list_dense_shapes,
@@ -622,7 +656,8 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
(context_sparse_indices, context_sparse_values,
context_sparse_shapes, context_dense_values,
- feature_list_dense_values) = outputs
+ feature_list_sparse_indices, feature_list_sparse_values,
+ feature_list_sparse_shapes, feature_list_dense_values) = outputs
context_sparse_tensors = [
ops.SparseTensor(ix, val, shape) for (ix, val, shape)
@@ -630,12 +665,18 @@ def parse_single_sequence_example(serialized, # pylint: disable=invalid-name
context_sparse_values,
context_sparse_shapes)]
+ feature_list_sparse_tensors = [
+ ops.SparseTensor(ix, val, shape) for (ix, val, shape)
+ in zip(feature_list_sparse_indices,
+ feature_list_sparse_values,
+ feature_list_sparse_shapes)]
+
context_output = dict(
zip(context_sparse_keys + context_dense_keys,
context_sparse_tensors + context_dense_values))
feature_list_output = dict(
- zip(feature_list_dense_keys,
- feature_list_dense_values))
+ zip(feature_list_sparse_keys + feature_list_dense_keys,
+ feature_list_sparse_tensors + feature_list_dense_values))
return (context_output, feature_list_output)
@@ -651,6 +692,7 @@ def _ParseSingleSequenceExampleShape(op):
num_context_dense = op.get_attr("Ncontext_dense")
num_feature_list_dense = op.get_attr("Nfeature_list_dense")
context_dense_shapes = op.get_attr("context_dense_shapes")
+ num_feature_list_sparse = op.get_attr("Nfeature_list_sparse")
feature_list_dense_shapes = op.get_attr("feature_list_dense_shapes")
context_sparse_index_shapes = [
tensor_shape.matrix(None, 1) for _ in range(num_context_sparse)]
@@ -661,6 +703,12 @@ def _ParseSingleSequenceExampleShape(op):
context_dense_shapes = [
tensor_shape.TensorShape(dense_shape)
for dense_shape in context_dense_shapes]
+ feature_list_sparse_index_shapes = [
+ tensor_shape.matrix(None, 2) for _ in range(num_feature_list_sparse)]
+ feature_list_sparse_value_shapes = [
+ tensor_shape.vector(None) for _ in range(num_feature_list_sparse)]
+ feature_list_sparse_shape_shapes = [
+ tensor_shape.vector(2) for _ in range(num_feature_list_sparse)]
feature_list_dense_shapes = [
tensor_shape.vector(None).concatenate(dense_shape)
for dense_shape in feature_list_dense_shapes]
@@ -668,7 +716,8 @@ def _ParseSingleSequenceExampleShape(op):
assert num_feature_list_dense == len(feature_list_dense_shapes)
return (context_sparse_index_shapes + context_sparse_value_shapes +
context_sparse_shape_shapes + context_dense_shapes +
- feature_list_dense_shapes)
+ feature_list_sparse_index_shapes + feature_list_sparse_value_shapes +
+ feature_list_sparse_shape_shapes + feature_list_dense_shapes)
ops.RegisterShape("StringToNumber")(
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
new file mode 100644
index 0000000000..caf376987f
--- /dev/null
+++ b/tensorflow/python/ops/script_ops.py
@@ -0,0 +1,135 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""## Script Language Operators.
+
+TensorFlow provides allows you to wrap python/numpy functions as
+TensorFlow operators.
+
+"""
+
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import common_shapes
+from tensorflow.python.ops import gen_script_ops
+
+
+class FuncRegistry(object):
+ """A helper class to keep track of registered py functions.
+
+ FuncRegistry keeps a map from unique tokens (string) to python
+ functions, which takes numpy arrays and outputs numpy arrays.
+ """
+
+ def __init__(self):
+ self._unique_id = 0
+ self._funcs = {}
+
+ def insert(self, func):
+ """Registers `func` and returns a unique token for this entry."""
+ token = self._next_unique_token()
+ self._funcs[token] = func
+ return token
+
+ def remove(self, token):
+ """Removes the registered function corresponding to `token`."""
+ self._funcs.pop(token, None)
+
+ def __call__(self, token, args):
+ """Calls the registered function for `token` with args."""
+ func = self._funcs[token]
+ if func is None:
+ raise ValueError("callback %s is not found" % token)
+ return func(*args)
+
+ def size(self):
+ """Returns how many functions are currently registered."""
+ return len(self._funcs)
+
+ def _next_unique_token(self):
+ """Returns a unique token."""
+ uid = self._unique_id
+ self._unique_id += 1
+ return "pyfunc_%d" % uid
+
+# Global registry for py functions.
+_py_funcs = FuncRegistry()
+
+pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
+
+
+class CleanupFunc(object):
+ """A helper class to remove a registered function from _py_funcs."""
+
+ def __init__(self, token):
+ self._token = token
+
+ def __del__(self):
+ _py_funcs.remove(self._token)
+
+
+def py_func(func, inp, Tout, name=None):
+ """Wraps a python function and uses it as a tensorflow op.
+
+ Given a python function `func`, which takes numpy arrays as its
+ inputs and returns numpy arrays as its outputs. E.g.,
+
+ def my_func(x):
+ return np.sinh(x)
+ inp = tf.placeholder(..., tf.float32)
+ y = py_func(my_func, [inp], [tf.float32])
+
+ The above snippet constructs a tf graph which invokes a numpy
+ sinh(x) as an op in the graph.
+
+ Args:
+ func: A python function.
+ inp: A list of `Tensor`.
+ Tout: A list of tensorflow data types indicating what `func`
+ returns.
+ name: A name for the operation (optional).
+
+ Returns:
+ A list of `Tensor` which `func` computes.
+ """
+ token = _py_funcs.insert(func)
+ # We tie the registered function's life-time with the current
+ # default graph. I.e., when the current graph is destroyed, we
+ # should remove its py funcs.
+ cleanup = CleanupFunc(token)
+ g = ops.get_default_graph()
+ # pylint: disable=protected-access
+ #
+ # TODO(zhifengc): Consider adding a Graph method to collect
+ # `cleanup` objects in one of its member.
+ if not hasattr(g, "_cleanup_py_funcs_used_in_graph"):
+ g._cleanup_py_funcs_used_in_graph = []
+
+ # When g is destroyed, elements in _cleanup_py_funcs_used_in_graph
+ # will be destroyed and their __del__ will remove the 'token' from
+ # the funcs registry.
+ g._cleanup_py_funcs_used_in_graph.append(cleanup)
+
+ return gen_script_ops._py_func(input=inp, token=token, Tout=Tout, name=name)
+ # pylint: enable=protected-access
+
+
+ops.RegisterShape("PyFunc")(common_shapes.unknown_shape)
+
+ops.NoGradient("PyFunc")
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 99e7e708f1..1a7ce8e40f 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -310,6 +310,7 @@ def sparse_to_dense(sparse_indices,
output_shape,
sparse_values,
default_value=0,
+ validate_indices=True,
name=None):
"""Converts a sparse representation into a dense tensor.
@@ -329,6 +330,10 @@ def sparse_to_dense(sparse_indices,
All other values in `dense` are set to `default_value`. If `sparse_values`
is a scalar, all sparse indices are set to this single value.
+ Indices should be sorted in lexicographic order, and indices must not
+ contain any repeats. If `validate_indices` is True, these properties
+ are checked during execution.
+
Args:
sparse_indices: A 0-D, 1-D, or 2-D `Tensor` of type `int32` or `int64`.
`sparse_indices[i]` contains the complete index where `sparse_values[i]`
@@ -339,6 +344,8 @@ def sparse_to_dense(sparse_indices,
`sparse_indices`, or a scalar value to be used for all sparse indices.
default_value: A 0-D `Tensor` of the same type as `sparse_values`. Value
to set for indices not specified in `sparse_indices`. Defaults to zero.
+ validate_indices: A boolean value. If True, indices are checked to make
+ sure they are sorted in lexicographic order and that there are no repeats.
name: A name for the operation (optional).
Returns:
@@ -348,11 +355,15 @@ def sparse_to_dense(sparse_indices,
return gen_sparse_ops._sparse_to_dense(sparse_indices,
output_shape,
sparse_values,
- default_value,
+ default_value=default_value,
+ validate_indices=validate_indices,
name=name)
-def sparse_tensor_to_dense(sp_input, default_value=0, name=None):
+def sparse_tensor_to_dense(sp_input,
+ default_value=0,
+ validate_indices=True,
+ name=None):
"""Converts a `SparseTensor` into a dense tensor.
This op is a convenience wrapper around `sparse_to_dense` for `SparseTensor`s.
@@ -370,10 +381,15 @@ def sparse_tensor_to_dense(sp_input, default_value=0, name=None):
[x x x x x]
[c x x x x]]
+ Indices must be without repeats. This is only
+ tested if validate_indices is True.
+
Args:
sp_input: The input `SparseTensor`.
default_value: Scalar value to set for indices not specified in
`sp_input`. Defaults to zero.
+ validate_indices: A boolean value. If `True`, indices are checked to make
+ sure they are sorted in lexicographic order and that there are no repeats.
name: A name prefix for the returned tensors (optional).
Returns:
@@ -390,7 +406,8 @@ def sparse_tensor_to_dense(sp_input, default_value=0, name=None):
return sparse_to_dense(sp_input.indices,
sp_input.shape,
sp_input.values,
- default_value,
+ default_value=default_value,
+ validate_indices=validate_indices,
name=name)
@@ -410,15 +427,18 @@ def sparse_to_indicator(sp_input, vocab_size, name=None):
[0, 0, 0]: 0
[0, 1, 0]: 10
[1, 0, 3]: 103
- [1, 1, 2]: 112
- [1, 1, 3]: 113
+ [1, 1, 2]: 150
+ [1, 1, 3]: 149
+ [1, 1, 4]: 150
[1, 2, 1]: 121
and `vocab_size = 200`, then the output will be a `[2, 3, 200]` dense bool
tensor with False everywhere except at positions
- (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 112), (1, 1, 113), (1, 2, 121).
+ (0, 0, 0), (0, 1, 10), (1, 0, 103), (1, 1, 149), (1, 1, 150),
+ (1, 2, 121).
+ Note that repeats are allowed in the input SparseTensor.
This op is useful for converting `SparseTensor`s into dense formats for
compatibility with ops that expect dense tensors.
@@ -460,7 +480,10 @@ def sparse_to_indicator(sp_input, vocab_size, name=None):
sp_new = ops.SparseTensor(new_indices, new_values, new_shape)
- return sparse_tensor_to_dense(sp_new, False, name=name)
+ # validate_indices may be False because we allow duplicates in new_indices:
+ # repeated indices are allowed when creating an indicator matrix.
+ return sparse_tensor_to_dense(
+ sp_new, default_value=False, validate_indices=False, name=name)
def sparse_retain(sp_input, to_retain):
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 2075e3c913..e2180737df 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -42,6 +42,7 @@ from tensorflow.python.ops.math_ops import *
from tensorflow.python.ops.numerics import *
from tensorflow.python.ops.parsing_ops import *
from tensorflow.python.ops.random_ops import *
+from tensorflow.python.ops.script_ops import py_func
from tensorflow.python.ops.sparse_ops import *
from tensorflow.python.ops.state_ops import assign
from tensorflow.python.ops.state_ops import assign_add
diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py
index 800ab7bc7e..dc6d2b17f2 100644
--- a/tensorflow/python/ops/summary_ops.py
+++ b/tensorflow/python/ops/summary_ops.py
@@ -165,8 +165,8 @@ def scalar_summary(tags, values, collections=None, name=None):
summary has a summary value for each tag-value pair in `tags` and `values`.
Args:
- tags: A 1-D `string` `Tensor`. Tags for the summaries.
- values: A 1-D `float32` or `float64` Tensor. Values for the summaries.
+ tags: A `string` `Tensor`. Tags for the summaries.
+ values: A real numeric Tensor. Values for the summaries.
collections: Optional list of graph collections keys. The new summary op is
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
name: A name for the operation (optional).
diff --git a/tensorflow/python/platform/default/_flags.py b/tensorflow/python/platform/default/_flags.py
index d7ae189c21..4e84623c79 100644
--- a/tensorflow/python/platform/default/_flags.py
+++ b/tensorflow/python/platform/default/_flags.py
@@ -32,7 +32,7 @@ class _FlagValues(object):
self.__dict__['__parsed'] = False
def _parse_flags(self):
- result = _global_parser.parse_args()
+ result, _ = _global_parser.parse_known_args()
for flag_name, val in vars(result).items():
self.__dict__['__flags'][flag_name] = val
self.__dict__['__parsed'] = True
diff --git a/tensorflow/python/platform/default/flags_test.py b/tensorflow/python/platform/default/flags_test.py
index 3868576c2f..e6cd57d5a9 100644
--- a/tensorflow/python/platform/default/flags_test.py
+++ b/tensorflow/python/platform/default/flags_test.py
@@ -86,7 +86,8 @@ class FlagsTest(googletest.TestCase):
if __name__ == "__main__":
# Test command lines
sys.argv.extend(["--bool_a", "--nobool_negation", "--bool_c=True",
- "--bool_d=False", "--bool_e=gibberish"])
+ "--bool_d=False", "--bool_e=gibberish", "--unknown_flag",
+ "and_argument"])
# googletest.main() tries to interpret the above flags, so use the
# direct functions instead.
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index ce9770b3f4..65ea4d2e17 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -19,6 +19,7 @@ limitations under the License.
%include "tensorflow/python/util/port.i"
+%include "tensorflow/python/lib/core/py_func.i"
%include "tensorflow/python/lib/core/status.i"
%include "tensorflow/python/lib/core/status_helper.i"
@@ -27,3 +28,5 @@ limitations under the License.
%include "tensorflow/python/client/events_writer.i"
%include "tensorflow/python/client/tf_session.i"
+
+%include "tensorflow/python/framework/python_op_gen.i"
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index efd6f2a807..e7510fe325 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -256,3 +256,97 @@ class Coordinator(object):
elif stragglers:
raise RuntimeError("Coordinator stopped with threads still running: %s",
" ".join(stragglers))
+
+
+# Threads for the standard services.
+class LooperThread(threading.Thread):
+ """A thread that runs code repeatedly, optionally on a timer.
+
+ This thread class is intended to be used with a `Coordinator`. It repeatedly
+ runs code specified either as `target` and `args` or by the `run_loop()`
+ method.
+
+ Before each run the thread checks if the coordinator has requested stop. In
+ that case the looper thread terminates immediately.
+
+ If the code being run raises an exception, that exception is reported to the
+ coordinator and the thread terminates. The coordinator will then request all
+ the other threads it coordinates to stop.
+
+ You typically pass looper threads to the supervisor `Join()` method.
+ """
+
+ def __init__(self, coord, timer_interval_secs, target=None, args=None):
+ """Create a LooperThread.
+
+ Args:
+ coord: a Coordinator.
+ timer_interval_secs: Time boundaries at which to call Run(), or None
+ if it should be called back to back.
+ target: Optional callable object that will be executed in the thread.
+ args: Optional arguments to pass to `target` when calling it.
+
+ Raises:
+ ValueError: If one of the arguments is invalid.
+ """
+ if not isinstance(coord, Coordinator):
+ raise ValueError("'coord' argument must be a Coordinator: %s" % coord)
+ super(LooperThread, self).__init__()
+ self.daemon = True
+ self._coord = coord
+ self._timer_interval_secs = timer_interval_secs
+ self._target = target
+ if self._target:
+ if args is None:
+ self._args = ()
+ else:
+ self._args = args
+ elif args:
+ raise ValueError("'args' argument require that you also pass 'target'")
+
+ @staticmethod
+ def loop(coord, timer_interval_secs, target, args=None):
+ """Start a LooperThread that calls a function periodically.
+
+ If `timer_interval_secs` is None the thread calls `target(args)`
+ repeatedly. Otherwise `target(args)` is called every `timer_interval_secs`
+ seconds. The thread terminates when a stop of the coordinator is
+ requested.
+
+ Args:
+ coord: A Coordinator.
+ timer_interval_secs: Number. Time boundaries at which to call `target`.
+ target: A callable object.
+ args: Optional arguments to pass to `target` when calling it.
+
+ Returns:
+ The started thread.
+ """
+ looper = LooperThread(coord, timer_interval_secs, target=target, args=args)
+ looper.start()
+ return looper
+
+ # pylint: disable=broad-except
+ def run(self):
+ with self._coord.stop_on_exception():
+ self.start_loop()
+ if self._timer_interval_secs is None:
+ # Call back-to-back.
+ while not self._coord.should_stop():
+ self.run_loop()
+ else:
+ # Next time at which to call run_loop(), starts as 'now'.
+ next_timer_time = time.time()
+ while not self._coord.wait_for_stop(next_timer_time - time.time()):
+ next_timer_time += self._timer_interval_secs
+ self.run_loop()
+ # pylint: enable=broad-except
+
+ def start_loop(self):
+ """Called when the thread starts."""
+ pass
+
+ def run_loop(self):
+ """Called at 'timer_interval_secs' boundaries."""
+ if self._target:
+ self._target(*self._args)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 6b70ddae3e..2091687e7c 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -44,7 +44,7 @@ class Optimizer(object):
# Add Ops to the graph to minimize a cost by updating a list of variables.
# "cost" is a Tensor, and the list of variables contains tf.Variable
# objects.
- opt_op = opt.minimize(cost, <list of variables>)
+ opt_op = opt.minimize(cost, var_list=<list of variables>)
```
In the training program you will just have to run the returned Op.
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 126f520328..34db5c3cd3 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -813,9 +813,8 @@ class Saver(object):
last_checkpoints: A list of checkpoint filenames.
Raises:
- AssertionError: If the list of checkpoint filenames has already been set.
+ AssertionError: If last_checkpoints is not a list.
"""
- assert not self._last_checkpoints
assert isinstance(last_checkpoints, list)
# We use a timestamp of +inf so that this checkpoint will never be
# deleted. This is both safe and backwards compatible to a previous
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index 737c980c6c..b9f5f9a54b 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -139,6 +139,7 @@ from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
# Utility classes for training.
from tensorflow.python.training.coordinator import Coordinator
+from tensorflow.python.training.coordinator import LooperThread
from tensorflow.python.training.queue_runner import *
# For the module level doc.
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index 6c94856772..c88dc88d29 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -319,7 +319,7 @@ void PopContextAndCheckNowNull(CUcontext expected) {
CUcontext popped;
CHECK_EQ(CUDA_SUCCESS, dynload::cuCtxPopCurrent_v2(&popped));
CHECK_EQ(expected, popped);
- CHECK(nullptr == CurrentContext());
+ DCHECK(nullptr == CurrentContext());
VLOG(3) << "popped context " << expected
<< " and current context is now null";
}
@@ -395,7 +395,7 @@ ScopedActivateContext::ScopedActivateContext(CUcontext context,
ScopedActivateContext::~ScopedActivateContext() {
if (tls_in_multi_op_activation.get()) {
- CHECK_EQ(context_, CurrentContext());
+ DCHECK_EQ(context_, CurrentContext());
if (FLAGS_gpuexec_cuda_sync_around_driver_calls) {
auto res = dynload::cuCtxSynchronize();
if (res != CUDA_SUCCESS) {
@@ -470,7 +470,7 @@ static port::Status InternalInit() {
LOG(ERROR) << "injecting CUDA init error; initialization will fail";
} else if (internal::CachedDsoLoader::GetLibcudaDsoHandle().ok()) {
// We only call cuInit if we can dynload libcuda.
-
+
res = dynload::cuInit(0 /* = flags */);
}
@@ -570,7 +570,7 @@ bool DeviceOptionsToContextFlags(DeviceOptions device_options, int *flags) {
{
// TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
// context creation: see http://b/13248943
-
+
res = dynload::cuCtxCreate_v2(context, flags, device);
}
if (res == CUDA_SUCCESS) {
@@ -737,7 +737,7 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
{
// TODO(leary) Need to see if NVIDIA can expunge the leakiness in their
// module loading: see http://b/13248943
-
+
res = dynload::cuModuleLoadDataEx(module, ptx_data, ARRAYSIZE(options),
options, option_values);
}
diff --git a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts
index b3511ce7b6..c3e23c760b 100644
--- a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts
+++ b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts
@@ -25,23 +25,23 @@ module TF {
function router(route: string): ((tag: string, run: string) => string) {
return function(tag: string, run: string): string {
- return "/" + route + "?tag=" + encodeURIComponent(tag)
+ return "/data/" + route + "?tag=" + encodeURIComponent(tag)
+ "&run=" + encodeURIComponent(run);
};
}
export function runsUrl() {
- return "/runs";
+ return "/data/runs";
}
export var scalarsUrl = router("scalars");
export var histogramsUrl = router("histograms");
export var compressedHistogramsUrl = router("compressedHistograms");
export var imagesUrl = router("images");
export function individualImageUrl(query: string) {
- return "/individualImage?" + query;
+ return "/data/individualImage?" + query;
}
export function graphUrl(run: string) {
- return "/graph?run=" + encodeURIComponent(run);
+ return "/data/graph?run=" + encodeURIComponent(run);
}
}
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts
index 17b753a2f9..66ed9fc84a 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts
@@ -18,17 +18,19 @@ limitations under the License.
declare module graphlib {
interface GraphOptions {
- name: string;
+ name?: string;
/**
* Direction for rank nodes. Can be TB, BT, LR, or RL, where T = top,
* B = bottom, L = left, and R = right.
*/
- rankdir: string;
- type: string|number;
+ rankdir?: string;
+ type?: string|number;
/** Number of pixels between each rank in the layout. */
ranksep?: number;
/** Number of pixels that separate nodes horizontally in the layout. */
nodesep?: number;
+ /** Number of pixels that separate edges horizontally in the layout */
+ edgesep?: number;
}
export interface EdgeObject {
@@ -58,7 +60,10 @@ declare module graphlib {
edges(): EdgeObject[];
outEdges(name: string): E[];
inEdges(name: string): E[];
- /** Returns those nodes in the graph that have no in-edges. Takes O(|V|) time. */
+ /**
+ * Returns those nodes in the graph that have no in-edges.
+ * Takes O(|V|) time.
+ */
sources(): string[];
/**
* Remove the node with the id v in the graph or do nothing if
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts
index 41f00c54f4..a9b2cf1934 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts
@@ -19,7 +19,6 @@ module tf.graph {
/** Delimiter used in node names to denote namespaces. */
export const NAMESPACE_DELIM = "/";
-const FULL_GRAPH_NAME = "fullGraph";
export const ROOT_NAME = "__root__";
// Separator between the source and the destination name of the edge.
@@ -315,8 +314,8 @@ class OpNodeImpl implements OpNode {
* @param rawNode The raw node.
* @param normalizedInputs An array of normalized
* inputs that denote the incoming edges to the current node. Each input
- * contains the normalized name of the source node, whether it has a number
- * part and whether it is a control dependency.
+ * contains the normalized name of the source node, whether it has a
+ * number part and whether it is a control dependency.
*/
constructor(rawNode: tf.TFNode, normalizedInputs: NormalizedInput[]) {
this.op = rawNode.op;
@@ -340,8 +339,8 @@ export function createMetanode(name: string, opt = {}): Metanode {
}
/**
- * Joins the information from the stats file (memory, compute time) with the graph
- * information.
+ * Joins the information from the stats file (memory, compute time) with the
+ * graph information.
*/
export function joinStatsInfoWithGraph(graph: SlimGraph,
statsJson: TFStats): void {
@@ -894,7 +893,8 @@ export function hasSimilarDegreeSequence(graph1: graphlib.Graph<any, any>,
/**
* Returns the hierarchical path of the current node, based on the node's name.
- * For example, if the name is 'a/b/c', the returned path is ['a', 'a/b', 'a/b/c'].
+ * For example, if the name is 'a/b/c', the returned path is
+ * ['a', 'a/b', 'a/b/c'].
*/
export function getHierarchicalPath(name: string,
seriesNames?: { [name: string]: string }): string[] {
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
index 1c8e1b2e18..504e47f4d5 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts
@@ -21,8 +21,6 @@ limitations under the License.
*/
module tf.graph.hierarchy {
-const LOG_PREFIX_MSG = "Graph hierarchy: ";
-
/**
* Class used as output for getPredecessors and getSuccessors methods
*/
@@ -469,7 +467,8 @@ function addNodes(h: Hierarchy, graph: SlimGraph) {
}
parent = child;
}
- // Assuming node name is 'a/b/c', assign the OpNode as a child of the metanode 'a/b'.
+ // Assuming node name is 'a/b/c', assign the OpNode as a child of the
+ // metanode 'a/b'.
h.setNode(node.name, node);
node.parentNode = parent;
parent.metagraph.setNode(node.name, node);
@@ -567,7 +566,8 @@ function addEdges(h: Hierarchy, graph: SlimGraph,
* @param hierarchy
* @param threshold If the series has this many nodes or more, then group them
* into a series.
- * @return A dictionary from node name to series node name that contains the node
+ * @return A dictionary from node name to series node name that contains the
+ * node.
*/
function groupSeries(metanode: Metanode, hierarchy: Hierarchy,
seriesNames: { [name: string]: string }, threshold: number) {
@@ -589,9 +589,6 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy,
if (nodeMemberNames.length < threshold) {
return;
}
- let firstMember = seriesNode.metagraph.node(nodeMemberNames[0]);
- let seriesType = firstMember.type;
-
hierarchy.setNode(seriesName, seriesNode); // add to the index
metagraph.setNode(seriesName, seriesNode);
_.each(nodeMemberNames, n => {
@@ -620,7 +617,8 @@ function groupSeries(metanode: Metanode, hierarchy: Hierarchy,
function clusterNodes(metagraph: graphlib.Graph<GroupNode|OpNode, Metaedge>):
{[clusterId: string]: string[]} {
let result: {[clusterId: string]: string[]} = {};
- return _.reduce(metagraph.nodes(), function(clusters: {[clusterId: string]: string[]}, n: string) {
+ return _.reduce(metagraph.nodes(),
+ (clusters: {[clusterId: string]: string[]}, n: string) => {
let child = metagraph.node(n);
if (child.type === NodeType.META) {
// skip metanodes
@@ -702,7 +700,8 @@ function detectSeries(clusters: {[clusterId: string]: string[]},
let seriesNodes = [seriesInfoArray[0]];
for (let index = 1; index < seriesInfoArray.length; index++) {
let nextNode = seriesInfoArray[index];
- if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + 1) {
+ if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId
+ + 1) {
seriesNodes.push(nextNode);
continue;
}
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
index 5a0559627f..b003e33177 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts
@@ -33,14 +33,19 @@ export const PARAMS = {
*
* See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
*/
- nodeSep: 110,
+ nodeSep: 5,
/**
* Dagre's ranksep param - number of pixels
* between each rank in the layout.
*
* See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
*/
- rankSep: 25
+ rankSep: 25,
+ /**
+ * Dagre's edgesep param - number of pixels that separate
+ * edges horizontally in the layout.
+ */
+ edgeSep: 5,
},
/** Graph parameter for metanode. */
series: {
@@ -50,7 +55,7 @@ export const PARAMS = {
*
* See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
*/
- nodeSep: 90,
+ nodeSep: 5,
/**
* Dagre's ranksep param - number of pixels
* between each rank in the layout.
@@ -58,6 +63,11 @@ export const PARAMS = {
* See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
*/
rankSep: 25,
+ /**
+ * Dagre's edgesep param - number of pixels that separate
+ * edges horizontally in the layout.
+ */
+ edgeSep: 5
},
/**
* Padding is used to correctly position the graph SVG inside of its parent
@@ -166,6 +176,10 @@ export const PARAMS = {
}
},
annotations: {
+ /** Maximum possible width of the bounding box for in annotations */
+ inboxWidth: 50,
+ /** Maximum possible width of the bounding box for out annotations */
+ outboxWidth: 50,
/** X-space between the shape and each annotation-node. */
xOffset: 10,
/** Y-space between each annotation-node. */
@@ -202,7 +216,7 @@ export const PARAMS = {
};
/** Calculate layout for a scene of a group node. */
-export function scene(renderNodeInfo: render.RenderGroupNodeInformation)
+export function layoutScene(renderNodeInfo: render.RenderGroupNodeInfo)
: void {
// Update layout, size, and annotations of its children nodes and edges.
if (renderNodeInfo.node.isGroupNode) {
@@ -218,9 +232,32 @@ export function scene(renderNodeInfo: render.RenderGroupNodeInformation)
};
/**
+ * Updates the total width of an unexpanded node which includes the size of its
+ * in and out annotations.
+ */
+function updateTotalWidthOfNode(renderInfo: render.RenderNodeInfo): void {
+ renderInfo.inboxWidth = renderInfo.inAnnotations.list.length > 0 ?
+ PARAMS.annotations.inboxWidth : 0;
+ renderInfo.outboxWidth = renderInfo.outAnnotations.list.length > 0 ?
+ PARAMS.annotations.outboxWidth : 0;
+ // Assign the width of the core box (the main shape of the node).
+ renderInfo.coreBox.width = renderInfo.width;
+ renderInfo.coreBox.height = renderInfo.height;
+ // TODO(jimbo): Account for font width rather than using a magic number.
+ let labelLength = renderInfo.node.name.length -
+ renderInfo.node.name.lastIndexOf(NAMESPACE_DELIM) - 1;
+ let charWidth = 3; // 3 pixels per character.
+ // Compute the total width of the node.
+ renderInfo.width = Math.max(renderInfo.coreBox.width +
+ renderInfo.inboxWidth + renderInfo.outboxWidth,
+ labelLength * charWidth);
+
+}
+
+/**
* Update layout, size, and annotations of its children nodes and edges.
*/
-function layoutChildren(renderNodeInfo: render.RenderGroupNodeInformation)
+function layoutChildren(renderNodeInfo: render.RenderGroupNodeInfo)
: void {
let children = renderNodeInfo.coreGraph.nodes().map(n => {
return renderNodeInfo.coreGraph.node(n);
@@ -238,25 +275,25 @@ function layoutChildren(renderNodeInfo: render.RenderGroupNodeInformation)
break;
case NodeType.META:
if (!childNodeInfo.expanded) {
- // set fixed width and scalable height based on cardinality
+ // Set fixed width and scalable height based on cardinality
_.extend(childNodeInfo, PARAMS.nodeSize.meta);
childNodeInfo.height =
PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality);
} else {
let childGroupNodeInfo =
- <render.RenderGroupNodeInformation>childNodeInfo;
- scene(childGroupNodeInfo); // Recursively layout its subscene.
+ <render.RenderGroupNodeInfo>childNodeInfo;
+ layoutScene(childGroupNodeInfo); // Recursively layout its subscene.
}
break;
case NodeType.SERIES:
if (childNodeInfo.expanded) {
_.extend(childNodeInfo, PARAMS.nodeSize.series.expanded);
let childGroupNodeInfo =
- <render.RenderGroupNodeInformation>childNodeInfo;
- scene(childGroupNodeInfo); // Recursively layout its subscene.
+ <render.RenderGroupNodeInfo>childNodeInfo;
+ layoutScene(childGroupNodeInfo); // Recursively layout its subscene.
} else {
let childGroupNodeInfo =
- <render.RenderGroupNodeInformation>childNodeInfo;
+ <render.RenderGroupNodeInfo>childNodeInfo;
let seriesParams =
childGroupNodeInfo.node.hasNonControlEdges ?
PARAMS.nodeSize.series.vertical :
@@ -267,7 +304,11 @@ function layoutChildren(renderNodeInfo: render.RenderGroupNodeInformation)
default:
throw Error("Unrecognized node type: " + childNodeInfo.node.type);
}
-
+ // Compute total width of un-expanded nodes. Width of expanded nodes
+ // has already been computed.
+ if (!childNodeInfo.expanded) {
+ updateTotalWidthOfNode(childNodeInfo);
+ }
// Layout each child's annotations
layoutAnnotation(childNodeInfo);
});
@@ -279,13 +320,14 @@ function layoutChildren(renderNodeInfo: render.RenderGroupNodeInformation)
* @param params layout parameters
* @return width and height of the core graph
*/
-function dagreLayout(graph: graphlib.Graph<any, any>, params)
- : {height: number, width: number} {
+function dagreLayout(
+ graph: graphlib.Graph<render.RenderNodeInfo, render.RenderMetaedgeInfo>,
+ params): {height: number, width: number} {
_.extend(graph.graph(), {
- nodeSep: params.nodeSep,
- rankSep: params.rankSep
- });
-
+ nodesep: params.nodeSep,
+ ranksep: params.rankSep,
+ edgesep: params.edgeSep
+ });
let bridgeNodeNames = [];
let nonBridgeNodeNames = [];
@@ -307,11 +349,8 @@ function dagreLayout(graph: graphlib.Graph<any, any>, params)
height: 0,
};
}
-
dagre.layout(graph);
- let graphLabel = graph.graph();
-
// Calculate the true bounding box of the graph by iterating over nodes and
// edges rather than accepting dagre's word for it. In particular, we should
// ignore the extra-wide bridge nodes and bridge edges, and allow for
@@ -323,33 +362,65 @@ function dagreLayout(graph: graphlib.Graph<any, any>, params)
_.each(nonBridgeNodeNames, nodeName => {
let nodeInfo = graph.node(nodeName);
let w = 0.5 * nodeInfo.width;
- let x1 = nodeInfo.x - w - nodeInfo.inboxWidth;
- let x2 = nodeInfo.x + w + nodeInfo.outboxWidth;
+ let x1 = nodeInfo.x - w;
+ let x2 = nodeInfo.x + w;
minX = x1 < minX ? x1 : minX;
maxX = x2 > maxX ? x2 : maxX;
- let labelLength =
- nodeName.length - nodeName.lastIndexOf(NAMESPACE_DELIM);
- // TODO(jimbo): Account for font width rather than using a magic number.
- let charWidth = 3; // 3 pixels per character.
- let lw = 0.5 * labelLength * charWidth;
- let lx1 = nodeInfo.x - lw;
- let lx2 = nodeInfo.x + lw;
- minX = lx1 < minX ? lx1 : minX;
- maxX = lx2 > maxX ? lx2 : maxX;
// TODO(jimbo): Account for the height of labels above op nodes here.
- let h = 0.5 * nodeInfo.outerHeight;
+ let h = 0.5 * nodeInfo.height;
let y1 = nodeInfo.y - h;
let y2 = nodeInfo.y + h;
minY = y1 < minY ? y1 : minY;
maxY = y2 > maxY ? y2 : maxY;
});
_.each(graph.edges(), edgeObj => {
- let renderMetaedgeInfo = graph.edge(edgeObj);
- if (renderMetaedgeInfo.structural) {
+ let edgeInfo = graph.edge(edgeObj);
+ if (edgeInfo.structural) {
return; // Skip structural edges from min/max calculations.
}
- _.each(renderMetaedgeInfo.points,
- (point: { x: number, y: number }) => {
+
+ // Since the node size passed to dagre includes the in and out
+ // annotations, the endpoints of the edge produced by dagre may not
+ // point to the actual node shape (rectangle, ellipse). We correct the
+ // end-points by finding the intersection of a line between the
+ // next-to-last (next-to-first) point and the destination (source)
+ // rectangle.
+ let sourceNode = graph.node(edgeInfo.metaedge.v);
+ let destNode = graph.node(edgeInfo.metaedge.w);
+
+ // Straight 3-points edges are special case, since they are curved after
+ // our default correction. To keep them straight, we remove the mid point
+ // and correct the first and the last point to be the center of the
+ // source and destination node respectively.
+ if (edgeInfo.points.length === 3 && isStraightLine(edgeInfo.points)) {
+ if (sourceNode != null) {
+ let cxSource = sourceNode.expanded ?
+ sourceNode.x : computeCXPositionOfNodeShape(sourceNode);
+ edgeInfo.points[0].x = cxSource;
+ }
+ if (destNode != null) {
+ let cxDest = destNode.expanded ?
+ destNode.x : computeCXPositionOfNodeShape(destNode);
+ edgeInfo.points[2].x = cxDest;
+ }
+ // Remove the middle point so the edge doesn't curve.
+ edgeInfo.points = [edgeInfo.points[0], edgeInfo.points[1]];
+ }
+ // Correct the destination endpoint of the edge.
+ let nextToLastPoint = edgeInfo.points[edgeInfo.points.length - 2];
+ // The destination node might be null if this is a bridge edge.
+ if (destNode != null) {
+ edgeInfo.points[edgeInfo.points.length - 1] =
+ intersectPointAndNode(nextToLastPoint, destNode);
+ }
+ // Correct the source endpoint of the edge.
+ let secondPoint = edgeInfo.points[1];
+ // The source might be null if this is a bridge edge.
+ if (sourceNode != null) {
+ edgeInfo.points[0] = intersectPointAndNode(secondPoint, sourceNode);
+ }
+
+ _.each(edgeInfo.points, (point: render.Point) => {
minX = point.x < minX ? point.x : minX;
maxX = point.x > maxX ? point.x : maxX;
minY = point.y < minY ? point.y : minY;
@@ -365,8 +436,7 @@ function dagreLayout(graph: graphlib.Graph<any, any>, params)
nodeInfo.y -= minY;
});
_.each(graph.edges(), edgeObj => {
- _.each(graph.edge(edgeObj).points,
- (point: { x: number, y: number }) => {
+ _.each(graph.edge(edgeObj).points, (point: render.Point) => {
point.x -= minX;
point.y -= minY;
});
@@ -374,16 +444,15 @@ function dagreLayout(graph: graphlib.Graph<any, any>, params)
return {
width: maxX - minX,
- height: maxY - minY,
+ height: maxY - minY
};
}
-/** Layout a metanode. */
-function layoutMetanode(renderNodeInfo): void {
+/** Layout a metanode. Only called for an expanded node. */
+function layoutMetanode(renderNodeInfo: render.RenderGroupNodeInfo): void {
// First, copy params specific to meta nodes onto this render info object.
let params = PARAMS.subscene.meta;
- renderNodeInfo = _.extend(renderNodeInfo, params);
-
+ _.extend(renderNodeInfo, params);
// Invoke dagre.layout() on the core graph and record the bounding box
// dimensions.
_.extend(renderNodeInfo.coreBox,
@@ -392,70 +461,70 @@ function layoutMetanode(renderNodeInfo): void {
// Calculate the position of nodes in isolatedInExtract relative to the
// top-left corner of inExtractBox (the bounding box for all inExtract nodes)
// and calculate the size of the inExtractBox.
- let hasInExtract = renderNodeInfo.isolatedInExtract.length > 0;
-
- renderNodeInfo.inExtractBox.width = hasInExtract ?
- _(renderNodeInfo.isolatedInExtract).pluck("outerWidth").max() : 0;
+ let maxInExtractWidth = _.max(renderNodeInfo.isolatedInExtract,
+ renderNode => renderNode.width).width;
+ renderNodeInfo.inExtractBox.width = maxInExtractWidth != null ?
+ maxInExtractWidth : 0;
renderNodeInfo.inExtractBox.height =
- _.reduce(renderNodeInfo.isolatedInExtract, (height, child: any, i) => {
+ _.reduce(renderNodeInfo.isolatedInExtract, (height, child, i) => {
let yOffset = i > 0 ? params.extractYOffset : 0;
- // use outerWidth/Height here to avoid overlaps between extracts
- child.x = renderNodeInfo.inExtractBox.width / 2;
- child.y = height + yOffset + child.outerHeight / 2;
- return height + yOffset + child.outerHeight;
+ // use width/height here to avoid overlaps between extracts
+ child.x = 0;
+ child.y = height + yOffset + child.height / 2;
+ return height + yOffset + child.height;
}, 0);
// Calculate the position of nodes in isolatedOutExtract relative to the
// top-left corner of outExtractBox (the bounding box for all outExtract
// nodes) and calculate the size of the outExtractBox.
- let hasOutExtract = renderNodeInfo.isolatedOutExtract.length > 0;
- renderNodeInfo.outExtractBox.width = hasOutExtract ?
- _(renderNodeInfo.isolatedOutExtract).pluck("outerWidth").max() : 0;
+ let maxOutExtractWidth = _.max(renderNodeInfo.isolatedOutExtract,
+ renderNode => renderNode.width).width;
+ renderNodeInfo.outExtractBox.width = maxOutExtractWidth != null ?
+ maxOutExtractWidth : 0;
renderNodeInfo.outExtractBox.height =
- _.reduce(renderNodeInfo.isolatedOutExtract, (height, child: any, i) => {
+ _.reduce(renderNodeInfo.isolatedOutExtract, (height, child, i) => {
let yOffset = i > 0 ? params.extractYOffset : 0;
- // use outerWidth/Height here to avoid overlaps between extracts
- child.x = renderNodeInfo.outExtractBox.width / 2;
- child.y = height + yOffset + child.outerHeight / 2;
- return height + yOffset + child.outerHeight;
+ // use width/height here to avoid overlaps between extracts
+ child.x = 0;
+ child.y = height + yOffset + child.height / 2;
+ return height + yOffset + child.height;
}, 0);
+ // Add the in-extract and out-extract width to the core box width.
+ renderNodeInfo.coreBox.width += renderNodeInfo.inExtractBox.width +
+ renderNodeInfo.outExtractBox.width;
+ renderNodeInfo.coreBox.height =
+ params.labelHeight +
+ Math.max(
+ renderNodeInfo.inExtractBox.height,
+ renderNodeInfo.coreBox.height,
+ renderNodeInfo.outExtractBox.height
+ );
// Determine the whole metanode's width (from left to right).
- renderNodeInfo.width =
- params.paddingLeft + renderNodeInfo.coreBox.width + params.paddingRight +
- (hasInExtract ?
- renderNodeInfo.inExtractBox.width + params.extractXOffset : 0) +
- (hasOutExtract ?
- params.extractXOffset + renderNodeInfo.outExtractBox.width : 0);
-
- // TODO(jimbo): Remove labelHeight and instead incorporate into box sizes.
+ renderNodeInfo.width = renderNodeInfo.coreBox.width +
+ params.paddingLeft + params.paddingRight;
+
// Determine the whole metanode's height (from top to bottom).
renderNodeInfo.height =
- renderNodeInfo.labelHeight +
- params.paddingTop +
- Math.max(
- renderNodeInfo.inExtractBox.height,
- renderNodeInfo.coreBox.height,
- renderNodeInfo.outExtractBox.height
- ) +
- params.paddingBottom;
+ renderNodeInfo.paddingTop +
+ renderNodeInfo.coreBox.height +
+ renderNodeInfo.paddingBottom;
}
/**
* Calculate layout for series node's core graph. Only called for an expanded
* series.
*/
-function layoutSeriesNode(node: render.RenderGroupNodeInformation): void {
+function layoutSeriesNode(node: render.RenderGroupNodeInfo): void {
let graph = node.coreGraph;
let params = PARAMS.subscene.series;
_.extend(node, params);
// Layout the core.
- _.extend(node.coreBox,
- dagreLayout(node.coreGraph, PARAMS.graph.series));
+ _.extend(node.coreBox, dagreLayout(node.coreGraph, PARAMS.graph.series));
_.each(graph.nodes(), nodeName => {
graph.node(nodeName).excluded = false;
@@ -468,24 +537,16 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void {
/**
* Calculate layout for annotations of a given node.
- * This will modify positions of the the given node and its annotations.
+ * This will modify positions of the given node and its annotations.
*
* @see tf.graph.render.Node and tf.graph.render.Annotation
* for description of each property of each render node.
*
*/
- function layoutAnnotation(renderNodeInfo: render.RenderNodeInformation): void {
+function layoutAnnotation(renderNodeInfo: render.RenderNodeInfo): void {
// If the render node is an expanded metanode, then its annotations will not
// be visible and we should skip the annotation calculations.
if (renderNodeInfo.expanded) {
- _.extend(renderNodeInfo, {
- inboxWidth: 0,
- inboxHeight: 0,
- outboxWidth: 0,
- outboxHeight: 0,
- outerWidth: renderNodeInfo.width,
- outerHeight: renderNodeInfo.height
- });
return;
}
@@ -499,31 +560,20 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void {
_.each(outAnnotations, a => sizeAnnotation(a));
let params = PARAMS.annotations;
- renderNodeInfo.inboxWidth =
- inAnnotations.length > 0 ?
- (<any>_(inAnnotations).pluck("width").max()) +
- params.xOffset + params.labelWidth + params.labelOffset :
- 0;
-
- renderNodeInfo.outboxWidth =
- outAnnotations.length > 0 ?
- (<any>_(outAnnotations).pluck("width").max()) +
- params.xOffset + params.labelWidth + params.labelOffset :
- 0;
// Calculate annotation node position (a.dx, a.dy)
// and total height for in-annotations
// After this chunk of code:
// inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset)
let inboxHeight = _.reduce(inAnnotations,
- (height, a: any, i) => {
+ (height, a, i) => {
let yOffset = i > 0 ? params.yOffset : 0;
- a.dx = -(renderNodeInfo.width + a.width) / 2 - params.xOffset;
+ a.dx = -(renderNodeInfo.coreBox.width + a.width) / 2 - params.xOffset;
a.dy = height + yOffset + a.height / 2;
return height + yOffset + a.height;
}, 0);
- _.each(inAnnotations, (a: any) => {
+ _.each(inAnnotations, a => {
a.dy -= inboxHeight / 2;
a.labelOffset = params.labelOffset;
@@ -535,14 +585,14 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void {
// outboxHeight = sum of annotation heights +
// (annotation.length - 1 * yOffset)
let outboxHeight = _.reduce(outAnnotations,
- (height, a: any, i) => {
+ (height, a, i) => {
let yOffset = i > 0 ? params.yOffset : 0;
- a.dx = (renderNodeInfo.width + a.width) / 2 + params.xOffset;
+ a.dx = (renderNodeInfo.coreBox.width + a.width) / 2 + params.xOffset;
a.dy = height + yOffset + a.height / 2;
return height + yOffset + a.height;
}, 0);
- _.each(outAnnotations, (a: any) => {
+ _.each(outAnnotations, a => {
// adjust by (half of ) the total height
// so dy is relative to the host node's center.
a.dy -= outboxHeight / 2;
@@ -563,7 +613,7 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void {
.range([-inTouchHeight, inTouchHeight]);
// Calculate annotation edge position
- _.each(inAnnotations, (a: any, i) => {
+ _.each(inAnnotations, (a, i) => {
a.points = [
// The annotation node end
{
@@ -573,7 +623,7 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void {
// The host node end
{
- dx: - renderNodeInfo.width / 2,
+ dx: - renderNodeInfo.coreBox.width / 2,
// only use scale if there are more than one,
// otherwise center it vertically
dy: inAnnotations.length > 1 ? inY(i) : 0
@@ -591,12 +641,12 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void {
.domain([0, outAnnotations.length - 1])
.range([-outTouchHeight, outTouchHeight]);
- _.each(outAnnotations, (a: any, i) => {
+ _.each(outAnnotations, (a, i) => {
// Add point from the border of the annotation node
a.points = [
// The host node end
{
- dx: renderNodeInfo.width / 2,
+ dx: renderNodeInfo.coreBox.width / 2,
// only use scale if there are more than one,
// otherwise center it vertically
dy: outAnnotations.length > 1 ? outY(i) : 0
@@ -609,9 +659,7 @@ function layoutSeriesNode(node: render.RenderGroupNodeInformation): void {
];
});
- renderNodeInfo.outerWidth = renderNodeInfo.width + renderNodeInfo.inboxWidth +
- renderNodeInfo.outboxWidth;
- renderNodeInfo.outerHeight =
+ renderNodeInfo.height =
Math.max(renderNodeInfo.height, inboxHeight, outboxHeight);
}
@@ -640,4 +688,75 @@ function sizeAnnotation(a: render.Annotation): void {
}
}
+/**
+ * Determines the center position of the node's shape. The position depends
+ * on if the node has in and out-annotations.
+ */
+export function computeCXPositionOfNodeShape(renderInfo: render.RenderNodeInfo):
+ number {
+ if (renderInfo.expanded) {
+ return renderInfo.x;
+ }
+ let dx = renderInfo.inAnnotations.list.length ? renderInfo.inboxWidth : 0;
+ return renderInfo.x - renderInfo.width / 2 + dx +
+ renderInfo.coreBox.width / 2;
+}
+
+/** Returns the angle (in degrees) between two points. */
+function angleBetweenTwoPoints(a: render.Point, b: render.Point): number {
+ let dx = b.x - a.x;
+ let dy = b.y - a.y;
+ return 180 * Math.atan(dy / dx) / Math.PI;
+}
+
+/**
+ * Returns if a line going through the specified points is a straight line.
+ */
+function isStraightLine(points: render.Point[]) {
+ let angle = angleBetweenTwoPoints(points[0], points[1]);
+ for (let i = 1; i < points.length - 1; i++) {
+ let newAngle = angleBetweenTwoPoints(points[i], points[i + 1]);
+ // Have a tolerance of 1 degree.
+ if (Math.abs(newAngle - angle) > 1) {
+ return false;
+ }
+ angle = newAngle;
+ }
+ return true;
+}
+
+/**
+ * Returns the intersection of a line between the provided point
+ * and the provided rectangle.
+ */
+function intersectPointAndNode(point: render.Point, node: render.RenderNodeInfo):
+ render.Point {
+ // cx and cy are the center of the rectangle.
+ let cx = node.expanded ?
+ node.x : computeCXPositionOfNodeShape(node);
+ let cy = node.y;
+ // Calculate the slope
+ let dx = point.x - cx;
+ let dy = point.y - cy;
+ let w = node.expanded ? node.width : node.coreBox.width;
+ let h = node.expanded ? node.height : node.coreBox.height;
+ let deltaX, deltaY;
+ if (Math.abs(dy) * w / 2 > Math.abs(dx) * h / 2) {
+ // The intersection is above or below the rectangle.
+ if (dy < 0) {
+ h = -h;
+ }
+ deltaX = dy === 0 ? 0 : h / 2 * dx / dy;
+ deltaY = h / 2;
+ } else {
+ // The intersection is left or right of the rectangle.
+ if (dx < 0) {
+ w = -w;
+ }
+ deltaX = w / 2;
+ deltaY = dx === 0 ? 0 : w / 2 * dy / dx;
+ }
+ return {x: cx + deltaX, y: cy + deltaY};
+}
+
} // close module
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
index c99c61a849..956b39d986 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts
@@ -22,6 +22,8 @@ limitations under the License.
module tf.graph.render {
+export type Point = {x: number, y: number};
+
/**
* Color parameters for op nodes.
*/
@@ -149,18 +151,19 @@ export interface RenderGraphParams {
* Stores the rendering information, such as x and y coordinates,
* for each node in the graph.
*/
-export class RenderGraphInformation {
+export class RenderGraphInfo {
private hierarchy: hierarchy.Hierarchy;
- private index: {[nodeName: string]: RenderNodeInformation};
+ private index: {[nodeName: string]: RenderNodeInfo};
private params: RenderGraphParams;
private deviceColorMap: d3.scale.Ordinal<string, string>;
private memoryUsageScale: d3.scale.Linear<string, string>;
private computeTimeScale: d3.scale.Linear<string, string>;
// Since the rendering information for each node is constructed lazily,
- // upon node's expansion by the user, we keep a map between the node's name and
- // whether the rendering information was already constructed for that node.
+ // upon node's expansion by the user, we keep a map between the node's name
+ // and whether the rendering information was already constructed for that
+ // node.
private hasSubhierarchy: {[nodeName: string]: boolean};
- root: RenderGroupNodeInformation;
+ root: RenderGroupNodeInfo;
constructor(hierarchy: hierarchy.Hierarchy, params: RenderGraphParams) {
this.hierarchy = hierarchy;
@@ -185,7 +188,8 @@ export class RenderGraphInformation {
.range(params.minMaxColors);
// Find also the minimum and maximum compute time.
- let computeTimeExtent = d3.extent(topLevelGraph.nodes(), (nodeName, index) => {
+ let computeTimeExtent = d3.extent(topLevelGraph.nodes(),
+ (nodeName, index) => {
let node = topLevelGraph.node(nodeName);
// Some ops don't have stats at all.
if (node.stats != null) {
@@ -196,27 +200,28 @@ export class RenderGraphInformation {
.domain(computeTimeExtent)
.range(params.minMaxColors);
- // Maps node name to whether the rendering hierarchy was already constructed.
+ // Maps node name to whether the rendering hierarchy was already
+ // constructed.
this.hasSubhierarchy = {};
this.params = params;
- this.root = new RenderGroupNodeInformation(hierarchy.root);
+ this.root = new RenderGroupNodeInfo(hierarchy.root);
this.index[hierarchy.root.name] = this.root;
this.buildSubhierarchy(hierarchy.root.name);
this.root.expanded = true;
}
/**
- * Get a previously created RenderNodeInformation by its node name.
+ * Get a previously created RenderNodeInfo by its node name.
*/
- getRenderNodeByName(nodeName: string): RenderNodeInformation {
+ getRenderNodeByName(nodeName: string): RenderNodeInfo {
return this.index[nodeName];
}
/**
- * Get a previously created RenderNodeInformation for the specified node name,
+ * Get a previously created RenderNodeInfo for the specified node name,
* or create one if it hasn't been created yet.
*/
- getOrCreateRenderNodeByName(nodeName: string): RenderNodeInformation {
+ getOrCreateRenderNodeByName(nodeName: string): RenderNodeInfo {
// Polymer may invoke this with null.
if (!nodeName) {
return null;
@@ -228,8 +233,8 @@ export class RenderGraphInformation {
let node = this.hierarchy.node(nodeName);
let renderInfo = node.isGroupNode ?
- new RenderGroupNodeInformation(<GroupNode>node) :
- new RenderNodeInformation(node);
+ new RenderGroupNodeInfo(<GroupNode>node) :
+ new RenderNodeInfo(node);
this.index[nodeName] = renderInfo;
if (node.stats) {
@@ -291,8 +296,8 @@ export class RenderGraphInformation {
/**
* Returns true if the renderNode is an isolated node within its parent node.
*/
- isNodeAuxilliary(renderNode: RenderNodeInformation): boolean {
- let parentNode = <RenderGroupNodeInformation>this.getRenderNodeByName(
+ isNodeAuxilliary(renderNode: RenderNodeInfo): boolean {
+ let parentNode = <RenderGroupNodeInfo>this.getRenderNodeByName(
renderNode.node.parentNode.name);
let found = _.find(parentNode.isolatedInExtract, node => {
return node.node.name === renderNode.node.name;
@@ -322,7 +327,7 @@ export class RenderGraphInformation {
}
// At this point we know the rendering information is about a group node.
- let renderGroupNodeInfo = <RenderGroupNodeInformation> renderNodeInfo;
+ let renderGroupNodeInfo = <RenderGroupNodeInfo> renderNodeInfo;
let metagraph = renderGroupNodeInfo.node.metagraph;
let coreGraph = renderGroupNodeInfo.coreGraph;
@@ -339,16 +344,16 @@ export class RenderGraphInformation {
if (!childNode.isGroupNode) {
_.each((<OpNode>childNode).inEmbeddings, embedding => {
- let renderMetaedgeInfo = new RenderMetaedgeInformation(null);
+ let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
AnnotationType.CONSTANT, this.params);
- this.index[embedding.name] = new RenderNodeInformation(embedding);
+ this.index[embedding.name] = new RenderNodeInfo(embedding);
});
_.each((<OpNode>childNode).outEmbeddings, embedding => {
- let renderMetaedgeInfo = new RenderMetaedgeInformation(null);
+ let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
AnnotationType.SUMMARY, this.params);
- this.index[embedding.name] = new RenderNodeInformation(embedding);
+ this.index[embedding.name] = new RenderNodeInfo(embedding);
});
}
@@ -357,7 +362,7 @@ export class RenderGraphInformation {
// Add render metaedge info for edges in the metagraph.
_.each(metagraph.edges(), edgeObj => {
let metaedge = metagraph.edge(edgeObj);
- let renderMetaedgeInfo = new RenderMetaedgeInformation(metaedge);
+ let renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge);
coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo);
});
@@ -376,7 +381,7 @@ export class RenderGraphInformation {
return;
}
let parentNodeInfo =
- <RenderGroupNodeInformation> this.index[parentNode.name];
+ <RenderGroupNodeInfo> this.index[parentNode.name];
// Utility function for computing the name of a bridge node.
let getBridgeNodeName = (inbound, ...rest) =>
@@ -439,7 +444,7 @@ export class RenderGraphInformation {
let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges &&
otherCounts.control[otherName] > this.params.maxControlDegree;
- let [annotations, childAnnotations] =
+ let [, childAnnotations] =
inbound ?
[renderNodeInfo.inAnnotations, childRenderInfo.inAnnotations] :
[renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations];
@@ -472,7 +477,7 @@ export class RenderGraphInformation {
inbound ?
{ v: targetName, w: nodeName } :
{ v: nodeName, w: targetName };
- return <RenderMetaedgeInformation>
+ return <RenderMetaedgeInfo>
parentNodeInfo.coreGraph.edge(adjoiningEdgeObj);
};
@@ -547,7 +552,7 @@ export class RenderGraphInformation {
childAnnotations.push(new Annotation(
otherNode,
otherRenderInfo,
- new RenderMetaedgeInformation(bridgeMetaedge),
+ new RenderMetaedgeInfo(bridgeMetaedge),
AnnotationType.SHORTCUT,
inbound), this.params);
return;
@@ -578,7 +583,7 @@ export class RenderGraphInformation {
inbound: inbound,
};
bridgeContainerInfo =
- new RenderNodeInformation(bridgeContainerNode);
+ new RenderNodeInfo(bridgeContainerNode);
this.index[bridgeContainerName] = bridgeContainerInfo;
coreGraph.setNode(bridgeContainerName, bridgeContainerInfo);
}
@@ -596,7 +601,7 @@ export class RenderGraphInformation {
// BridgeNode properties.
inbound: inbound,
};
- bridgeNodeRenderInfo = new RenderNodeInformation(bridgeNode);
+ bridgeNodeRenderInfo = new RenderNodeInfo(bridgeNode);
this.index[bridgeNodeName] = bridgeNodeRenderInfo;
coreGraph.setNode(bridgeNodeName, bridgeNodeRenderInfo);
@@ -607,7 +612,7 @@ export class RenderGraphInformation {
// Create and add a bridge render metaedge.
let bridgeRenderMetaedge =
- new RenderMetaedgeInformation(bridgeMetaedge);
+ new RenderMetaedgeInfo(bridgeMetaedge);
bridgeRenderMetaedge.adjoiningMetaedge = adjoiningMetaedge;
inbound ?
coreGraph.setEdge(bridgeNodeName, childName, bridgeRenderMetaedge) :
@@ -716,7 +721,7 @@ export class RenderGraphInformation {
// BridgeNode properties.
inbound: inbound,
};
- structuralRenderInfo = new RenderNodeInformation(bridgeNode);
+ structuralRenderInfo = new RenderNodeInfo(bridgeNode);
structuralRenderInfo.structural = true;
this.index[structuralNodeName] = structuralRenderInfo;
coreGraph.setNode(structuralNodeName, structuralRenderInfo);
@@ -725,7 +730,7 @@ export class RenderGraphInformation {
}
// Create the structural Metaedge and insert it.
- let structuralMetaedgeInfo = new RenderMetaedgeInformation(null);
+ let structuralMetaedgeInfo = new RenderMetaedgeInfo(null);
structuralMetaedgeInfo.structural = true;
structuralMetaedgeInfo.weight--; // Reduce weight for dagre layout.
inbound ?
@@ -748,8 +753,8 @@ export class RenderGraphInformation {
*/
export class Annotation {
node: Node;
- renderNodeInfo: RenderNodeInformation;
- renderMetaedgeInfo: RenderMetaedgeInformation;
+ renderNodeInfo: RenderNodeInfo;
+ renderMetaedgeInfo: RenderMetaedgeInfo;
annotationType: AnnotationType;
/**
* Center position of annotation relative to the host
@@ -791,8 +796,8 @@ export class Annotation {
* @param isIn True if it is an in-annotation. False if it is an
* out-annotation.
*/
- constructor(node: Node, renderNodeInfo: RenderNodeInformation,
- renderMetaedgeInfo: RenderMetaedgeInformation, type: AnnotationType,
+ constructor(node: Node, renderNodeInfo: RenderNodeInfo,
+ renderMetaedgeInfo: RenderMetaedgeInfo, type: AnnotationType,
isIn: boolean) {
this.node = node;
this.renderNodeInfo = renderNodeInfo;
@@ -813,7 +818,7 @@ export enum AnnotationType {SHORTCUT, CONSTANT, SUMMARY, ELLIPSIS};
/**
* Manages a list of annotations. Two will be used for each
- * RenderNodeInformation, one for in annotations and one for out annotations.
+ * RenderNodeInfo, one for in annotations and one for out annotations.
*/
export class AnnotationList {
/**
@@ -857,7 +862,7 @@ export class AnnotationList {
let ellipsisNode = new tf.graph.EllipsisNodeImpl(1);
this.list.push(new Annotation(ellipsisNode,
- new RenderNodeInformation(ellipsisNode), null,
+ new RenderNodeInfo(ellipsisNode), null,
AnnotationType.ELLIPSIS, annotation.isIn));
}
}
@@ -865,7 +870,7 @@ export class AnnotationList {
/**
* Contains rendering information about a node in the hierarchical graph.
*/
-export class RenderNodeInformation {
+export class RenderNodeInfo {
/** Reference to the original underlying Node from the hierarchical graph. */
node: Node;
/** Whether the node is expanded or not. */
@@ -875,7 +880,9 @@ export class RenderNodeInformation {
* shortcuts to high-degree nodes.
*/
inAnnotations: AnnotationList;
- /** List of rendering information about out-annotations (e.g. summary nodes) */
+ /**
+ * List of rendering information about out-annotations (e.g. summary nodes)
+ */
outAnnotations: AnnotationList;
// --- Params specified by layout --- //
@@ -884,10 +891,25 @@ export class RenderNodeInformation {
x: number;
/** Center y position */
y: number;
- /** Width of the node's shape */
+ /**
+ * Total width of the node's shape, including in- and out-annotations. This
+ * property is used by dagre to layout the graph.
+ */
width: number;
- /** Height of the node's shape. */
+ /**
+ * Total height of the node's shape, including in- and out-annotations. This
+ * property is used by dagre to layout the graph.
+ */
height: number;
+ /**
+ * Size of the main box of the node, excluding in- and out-annotations. This
+ * property is used to draw the rectangle/ellipse shape denoting the node.
+ */
+ coreBox: {
+ width: number,
+ height: number,
+ };
+
/** Width of the bounding box for all in-annotations. */
inboxWidth: number;
/** Width of the bounding box for all out-annotations. */
@@ -930,13 +952,9 @@ export class RenderNodeInformation {
paddingRight: number;
paddingBottom: number;
- /** Width of the whole node including its shape and its annotations */
- outerWidth: number;
- /** Height of the whole node including its shape and its annotations */
- outerHeight: number;
/**
- * Whether a node is extracted as source-like (having high out-degree or matching
- * predefined in-extract pattern.)
+ * Whether a node is extracted as source-like (having high out-degree or
+ * matching predefined in-extract pattern.)
*/
isInExtract: boolean;
/**
@@ -991,11 +1009,9 @@ export class RenderNodeInformation {
this.paddingLeft = 0;
this.paddingRight = 0;
this.paddingBottom = 0;
-
- this.outerWidth = 0;
- this.outerHeight = 0;
this.isInExtract = false;
this.isOutExtract = false;
+ this.coreBox = {width: 0, height: 0};
}
isInCore(): boolean {
@@ -1007,7 +1023,7 @@ export class RenderNodeInformation {
* Contains rendering information about a Metaedge from the underlying
* hierarchical graph. It may be from either a metagraph or a bridgegraph.
*/
-export class RenderMetaedgeInformation {
+export class RenderMetaedgeInfo {
/**
* Reference to the original underlying Metaedge from the hierarchical graph,
* if any. This will be null for the edges which connect OpNodes to their
@@ -1016,15 +1032,15 @@ export class RenderMetaedgeInformation {
metaedge: Metaedge;
/**
- * Reference to the adjoining RenderMeteaedgeInformation from the parent's
+ * Reference to the adjoining RenderMeteaedgeInfo from the parent's
* coreGraph. This is used during layout to determine the point at which this
* edge should touch the node's bounding box. This property will be null for
* edges which terminate at a node on both ends (all non-bridge edges).
*/
- adjoiningMetaedge: RenderMetaedgeInformation;
+ adjoiningMetaedge: RenderMetaedgeInfo;
/**
- * Most of the time, a RenderMetaedgeInformation object represents a real
+ * Most of the time, a RenderMetaedgeInfo object represents a real
* edge between nodes in the underlying graph structure. But sometimes, an
* edge only exsts for layout purposes. These structural edges are added
* during buildSubhierarchy() to force dagre.layout() to put bridge nodes
@@ -1044,12 +1060,12 @@ export class RenderMetaedgeInformation {
* X and Y coordinate pairs of the points in the path of the edge.
* @see tf.graph.node.subsceneAdjustPaths
*/
- points: any[];
+ points: Point[];
/**
* D3 selection of the group containing the path that displays this edge.
*/
- edgeGroup: d3.Selection<RenderMetaedgeInformation>;
+ edgeGroup: d3.Selection<RenderMetaedgeInfo>;
constructor(metaedge: Metaedge) {
this.metaedge = metaedge;
@@ -1059,23 +1075,24 @@ export class RenderMetaedgeInformation {
}
}
-function addInAnnotation(node: RenderNodeInformation, predecessor: Node,
- predecessorRenderInfo: RenderNodeInformation, edge: any,
- type: AnnotationType, params: RenderGraphParams): void {
+function addInAnnotation(node: RenderNodeInfo, predecessor: Node,
+ predecessorRenderInfo: RenderNodeInfo,
+ edge: RenderMetaedgeInfo, type: AnnotationType,
+ params: RenderGraphParams): void {
let annotation = new Annotation(predecessor, predecessorRenderInfo, edge,
type, true);
node.inAnnotations.push(annotation, params);
}
-function addOutAnnotation(node: RenderNodeInformation, successor: Node,
- successorRenderInfo: RenderNodeInformation, edge: any,
+function addOutAnnotation(node: RenderNodeInfo, successor: Node,
+ successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo,
type: AnnotationType, params: RenderGraphParams): void {
let annotation = new Annotation(successor, successorRenderInfo, edge,
type, false);
node.outAnnotations.push(annotation, params);
}
-function setGraphDepth(graph: graphlib.Graph<RenderNodeInformation, any>,
+function setGraphDepth(graph: graphlib.Graph<RenderNodeInfo, any>,
depth: number) {
_.each(graph.nodes(), nodeName => {
let child = graph.node(nodeName);
@@ -1084,7 +1101,7 @@ function setGraphDepth(graph: graphlib.Graph<RenderNodeInformation, any>,
switch (child.node.type) {
case NodeType.META:
case NodeType.SERIES:
- setGroupNodeDepth(<RenderGroupNodeInformation>child, depth - 1);
+ setGroupNodeDepth(<RenderGroupNodeInfo>child, depth - 1);
break;
// Do nothing for leaf
}
@@ -1092,35 +1109,31 @@ function setGraphDepth(graph: graphlib.Graph<RenderNodeInformation, any>,
});
};
-export class RenderGroupNodeInformation extends RenderNodeInformation {
+export class RenderGroupNodeInfo extends RenderNodeInfo {
node: GroupNode;
/**
* The core graph is derived from the underlying node's metagraph, minus
* the extracted source-like and sink-like nodes.
*/
- coreGraph: graphlib.Graph<RenderNodeInformation, RenderMetaedgeInformation>;
- /** Size of the bounding box for a metanode's core graph. */
- coreBox: {
- width: number,
- height: number,
- };
+ coreGraph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>;
/** Size of the bounding box for a metanode's isolated in-extract children. */
inExtractBox: {width: number, height: number};
- /** Size of the bounding box for a metanode's isolated out-extract children. */
+ /**
+ * Size of the bounding box for a metanode's isolated out-extract children.
+ */
outExtractBox: {width: number, height: number};
/** Array of isolated in-extract nodes. */
- isolatedInExtract: RenderNodeInformation[];
+ isolatedInExtract: RenderNodeInfo[];
/** Array of isolated out-extract nodes. */
- isolatedOutExtract: RenderNodeInformation[];
+ isolatedOutExtract: RenderNodeInfo[];
constructor(groupNode: GroupNode) {
super(groupNode);
let metagraph = groupNode.metagraph;
let gl = metagraph.graph();
this.coreGraph =
- createGraph<RenderNodeInformation, RenderMetaedgeInformation>(
+ createGraph<RenderNodeInfo, RenderMetaedgeInfo>(
gl.name, GraphType.CORE, { compound: true });
- this.coreBox = {width: 0, height: 0};
this.inExtractBox = {width: 0, height: 0};
this.outExtractBox = {width: 0, height: 0};
this.isolatedInExtract = [];
@@ -1128,7 +1141,7 @@ export class RenderGroupNodeInformation extends RenderNodeInformation {
}
}
-function setGroupNodeDepth(renderInfo: RenderGroupNodeInformation,
+function setGroupNodeDepth(renderInfo: RenderGroupNodeInfo,
depth: number): void {
if (renderInfo.coreGraph) {
setGraphDepth(renderInfo.coreGraph, depth);
@@ -1142,8 +1155,9 @@ function setGroupNodeDepth(renderInfo: RenderGroupNodeInformation,
* @param v Source name.
* @param w Sink name.
*/
-function createShortcut(graph: graphlib.Graph<RenderNodeInformation, {}>,
- v: string, w: string, params: RenderGraphParams) {
+function createShortcut(
+ graph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>,
+ v: string, w: string, params: RenderGraphParams) {
let src = graph.node(v);
let sink = graph.node(w);
let edge = graph.edge(v, w);
@@ -1173,7 +1187,7 @@ function createShortcut(graph: graphlib.Graph<RenderNodeInformation, {}>,
* If detachAllEdgesForHighDegree or forceDetach is true, extract all of its
* edges. Otherwise, only extract all in-edges.
*/
-function makeOutExtract(renderNode: RenderGroupNodeInformation, n: string,
+function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string,
params: RenderGraphParams, forceDetach?: boolean) {
let graph = renderNode.coreGraph;
let child = graph.node(n);
@@ -1204,7 +1218,7 @@ function makeOutExtract(renderNode: RenderGroupNodeInformation, n: string,
* If detachAllEdgesForHighDegree or forceDetach is true, extract all of its
* edges. Otherwise, only remove all out-edges.
*/
-export function makeInExtract(renderNode: RenderGroupNodeInformation, n: string,
+export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string,
params: RenderGraphParams, forceDetach?: boolean) {
let graph = renderNode.coreGraph;
let child = graph.node(n);
@@ -1251,7 +1265,7 @@ function hasTypeIn(node: Node, types: string[]): boolean {
}
/** Move nodes that are speficied to be excluded out of the core graph. */
-function extractSpeficiedNodes(renderNode: RenderGroupNodeInformation,
+function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo,
params: RenderGraphParams) {
let graph = renderNode.coreGraph;
_.each(graph.nodes(), n => {
@@ -1268,7 +1282,7 @@ function extractSpeficiedNodes(renderNode: RenderGroupNodeInformation,
}
/** Remove edges from pre-defined out-extract patterns */
-function extractPredefinedSink(renderNode: RenderGroupNodeInformation,
+function extractPredefinedSink(renderNode: RenderGroupNodeInfo,
params: RenderGraphParams) {
let graph = renderNode.coreGraph;
_.each(graph.nodes(), n => {
@@ -1283,7 +1297,7 @@ function extractPredefinedSink(renderNode: RenderGroupNodeInformation,
}
/** Remove edges from pre-defined in-extract patterns */
-function extractPredefinedSource(renderNode: RenderGroupNodeInformation,
+function extractPredefinedSource(renderNode: RenderGroupNodeInfo,
params: RenderGraphParams) {
let graph = renderNode.coreGraph;
@@ -1299,7 +1313,7 @@ function extractPredefinedSource(renderNode: RenderGroupNodeInformation,
}
/** Extract from nodes with in-degree > maxInDegree */
-function extractHighInDegree(renderNode: RenderGroupNodeInformation,
+function extractHighInDegree(renderNode: RenderGroupNodeInfo,
params: RenderGraphParams) {
let graph = renderNode.coreGraph;
let maxInDegree = params.maxInDegree;
@@ -1313,7 +1327,8 @@ function extractHighInDegree(renderNode: RenderGroupNodeInformation,
// no regular edges, in which case use the number of control edges.
// This is done so that control edges don't effect if nodes are extracted
// from the core graph, unless the node is only used for control.
- let numEdgesToCount = _.reduce(graph.predecessors(n), (numEdgesToCount, pred) => {
+ let numEdgesToCount = _.reduce(graph.predecessors(n),
+ (numEdgesToCount, pred) => {
let metaedge = graph.edge(pred, n).metaedge;
return numEdgesToCount + (metaedge.numRegularEdges ? 1 : 0);
}, 0);
@@ -1329,7 +1344,7 @@ function extractHighInDegree(renderNode: RenderGroupNodeInformation,
}
/** Extract nodes with out-degree > maxOutDegree */
-function extractHighOutDegree(renderNode: RenderGroupNodeInformation,
+function extractHighOutDegree(renderNode: RenderGroupNodeInfo,
params: RenderGraphParams) {
let graph = renderNode.coreGraph;
let maxOutDegree = params.maxOutDegree;
@@ -1343,7 +1358,8 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInformation,
// no regular edges, in which case use the number of control edges.
// This is done so that control edges don't effect if nodes are extracted
// from the core graph, unless the node is only used for control.
- let numEdgesToCount = _.reduce(graph.successors(n), (numEdgesToCount, succ) => {
+ let numEdgesToCount = _.reduce(graph.successors(n),
+ (numEdgesToCount, succ) => {
let metaedge = graph.edge(n, succ).metaedge;
return numEdgesToCount + (metaedge.numRegularEdges ? 1 : 0);
}, 0);
@@ -1359,7 +1375,7 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInformation,
}
/** Remove control edges from nodes that have too many control edges */
-function removeControlEdges(renderNode: RenderGroupNodeInformation,
+function removeControlEdges(renderNode: RenderGroupNodeInfo,
params: RenderGraphParams) {
let graph = renderNode.coreGraph;
@@ -1408,10 +1424,10 @@ export function mapIndexToHue(id: number): number {
* @param {Object} params render Graph construction parameters. See
* <tf-graph-params>'s output
*/
-function extractHighDegrees(renderNode: RenderGroupNodeInformation,
+function extractHighDegrees(renderNode: RenderGroupNodeInfo,
params: RenderGraphParams) {
- extractSpeficiedNodes(renderNode, params);
+ extractSpecifiedNodes(renderNode, params);
if (params.outExtractTypes) {
extractPredefinedSink(renderNode, params);
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts
index d973f75fd3..425eea0408 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts
@@ -43,7 +43,7 @@ module tf.graph.scene.annotation {
* @return selection of appended objects
*/
export function buildGroup(container, annotationData: render.AnnotationList,
- d: render.RenderNodeInformation, sceneBehavior) {
+ d: render.RenderNodeInfo, sceneBehavior) {
// Select all children and join with data.
let annotationGroups = container.selectAll(function() {
// using d3's selector function
@@ -151,7 +151,7 @@ function addAnnotationLabel(aGroup, label, a, additionalClassNames,
.append("title").text(titleText);
}
-function addInteraction(selection, d: render.RenderNodeInformation,
+function addInteraction(selection, d: render.RenderNodeInfo,
annotation: tf.graph.render.Annotation, sceneBehavior) {
selection
.on("mouseover", a => {
@@ -190,8 +190,9 @@ function addInteraction(selection, d: render.RenderNodeInformation,
* @param a annotation node data.
* @param scene Polymer scene element.
*/
-function update(aGroup, d: render.RenderNodeInformation, a: render.Annotation,
+function update(aGroup, d: render.RenderNodeInfo, a: render.Annotation,
sceneBehavior) {
+ let cx = layout.computeCXPositionOfNodeShape(d);
// Annotations that point to embedded nodes (constants,summary)
// don't have a render information attached so we don't stylize these.
// Also we don't stylize ellipsis annotations (the string "... and X more").
@@ -208,7 +209,7 @@ function update(aGroup, d: render.RenderNodeInformation, a: render.Annotation,
// label position
aGroup.select("text." + Class.Annotation.LABEL).transition().attr({
- x: d.x + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset),
+ x: cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset),
y: d.y + a.dy
});
@@ -218,23 +219,23 @@ function update(aGroup, d: render.RenderNodeInformation, a: render.Annotation,
// centered with the node and horizontally centered between the arrow and the
// text label.
aGroup.select("use.summary").transition().attr({
- x: d.x + a.dx - 3,
+ x: cx + a.dx - 3,
y: d.y + a.dy - 6
});
// Node position (only one of the shape selection will be non-empty.)
scene.positionEllipse(aGroup.select("." + Class.Annotation.NODE + " ellipse"),
- d.x + a.dx, d.y + a.dy, a.width, a.height);
+ cx + a.dx, d.y + a.dy, a.width, a.height);
scene.positionRect(aGroup.select("." + Class.Annotation.NODE + " rect"),
- d.x + a.dx, d.y + a.dy, a.width, a.height);
+ cx + a.dx, d.y + a.dy, a.width, a.height);
scene.positionRect(aGroup.select("." + Class.Annotation.NODE + " use"),
- d.x + a.dx, d.y + a.dy, a.width, a.height);
+ cx + a.dx, d.y + a.dy, a.width, a.height);
// Edge position
aGroup.select("path." + Class.Annotation.EDGE).transition().attr("d", a => {
// map relative position to absolute position
let points = a.points.map(p => {
- return {x: p.dx + d.x, y: p.dy + d.y};
+ return {x: p.dx + cx, y: p.dy + d.y};
});
return edge.interpolate(points);
});
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
index d84bb8e2ca..5ae6244e00 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts
@@ -19,9 +19,9 @@ limitations under the License.
module tf.graph.scene.edge {
-let Scene = tf.graph.scene; // Aliased
+export type EdgeData = {v: string, w: string, label: render.RenderMetaedgeInfo};
-export function getEdgeKey(edgeObj) {
+export function getEdgeKey(edgeObj: EdgeData) {
return edgeObj.v + tf.graph.EDGE_KEY_DELIM + edgeObj.w;
}
@@ -45,8 +45,9 @@ export function getEdgeKey(edgeObj) {
* @return selection of the created nodeGroups
*/
export function buildGroup(sceneGroup,
- graph: graphlib.Graph<tf.graph.render.RenderNodeInformation,
- tf.graph.render.RenderMetaedgeInformation>, sceneBehavior) {
+ graph: graphlib.Graph<tf.graph.render.RenderNodeInfo,
+ tf.graph.render.RenderMetaedgeInfo>, sceneBehavior) {
+ let edges: EdgeData[] = [];
let edgeData = _.reduce(graph.edges(), (edges, edgeObj) => {
let edgeLabel = graph.edge(edgeObj);
edges.push({
@@ -55,11 +56,10 @@ export function buildGroup(sceneGroup,
label: edgeLabel
});
return edges;
- }, []);
+ }, edges);
let container = scene.selectOrCreateChild(sceneGroup, "g",
Class.Edge.CONTAINER);
- let containerNode = container.node();
// Select all children and join with data.
// (Note that all children of g.edges are g.edge)
@@ -76,7 +76,7 @@ export function buildGroup(sceneGroup,
.append("g")
.attr("class", Class.Edge.GROUP)
.attr("data-edge", getEdgeKey)
- .each(function(d) {
+ .each(function(d: EdgeData) {
let edgeGroup = d3.select(this);
d.label.edgeGroup = edgeGroup;
// index node group for quick highlighting
@@ -108,11 +108,11 @@ export function buildGroup(sceneGroup,
* For a given d3 selection and data object, create a path to represent the
* edge described in d.label.
*
- * If d.label is defined, it will be a RenderMetaedgeInformation instance. It
+ * If d.label is defined, it will be a RenderMetaedgeInfo instance. It
* will sometimes be undefined, for example for some Annotation edges for which
* there is no underlying Metaedge in the hierarchical graph.
*/
-export function appendEdge(edgeGroup, d, sceneBehavior, edgeClass?) {
+export function appendEdge(edgeGroup, d: EdgeData, sceneBehavior, edgeClass?) {
edgeClass = edgeClass || Class.Edge.LINE; // set default type
if (d.label && d.label.structural) {
@@ -123,11 +123,16 @@ export function appendEdge(edgeGroup, d, sceneBehavior, edgeClass?) {
.attr("class", edgeClass);
};
+export let interpolate = d3.svg.line<{x: number, y: number}>()
+ .interpolate("basis")
+ .x((d) => { return d.x; })
+ .y((d) => { return d.y; });
+
/**
* Returns a tween interpolator for the endpoint of an edge path.
*/
-function getEdgePathInterpolator(d, i, a) {
- let renderMetaedgeInfo = d.label;
+function getEdgePathInterpolator(d: EdgeData, i: number, a: string) {
+ let renderMetaedgeInfo = <render.RenderMetaedgeInfo> d.label;
let adjoiningMetaedge = renderMetaedgeInfo.adjoiningMetaedge;
if (!adjoiningMetaedge) {
return d3.interpolate(a, interpolate(renderMetaedgeInfo.points));
@@ -162,11 +167,6 @@ function getEdgePathInterpolator(d, i, a) {
};
}
-export let interpolate = d3.svg.line()
- .interpolate("basis")
- .x((d: any) => { return d.x; })
- .y((d: any) => { return d.y; });
-
function position(d) {
d3.select(this).select("path." + Class.Edge.LINE)
.each(function(d) {
@@ -179,10 +179,9 @@ function position(d) {
* For a given d3 selection and data object, mark the edge as a control
* dependency if it contains only control edges.
*
- * d's label property will be a RenderMetaedgeInformation object.
+ * d's label property will be a RenderMetaedgeInfo object.
*/
-function stylize(edgeGroup, d, stylize) {
- let a;
+function stylize(edgeGroup, d: EdgeData, stylize) {
let metaedge = d.label.metaedge;
edgeGroup
.select("path." + Class.Edge.LINE)
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
index c3cf3d684b..32566c99ef 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts
@@ -142,6 +142,10 @@ export class Minimap {
* was updated (e.g. when a node was expanded).
*/
update(): void {
+ // The origin hasn't rendered yet. Ignore making an update.
+ if (this.zoomG.childElementCount === 0) {
+ return;
+ }
let $download = d3.select("#graphdownload");
this.download = <HTMLLinkElement>$download.node();
$download.on("click", d => {
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
index bb1d1fdcdc..c5781a2fd0 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts
@@ -65,7 +65,7 @@ module tf.graph.scene.node {
* @return selection of the created nodeGroups
*/
export function buildGroup(sceneGroup,
- nodeData: render.RenderNodeInformation[], sceneBehavior) {
+ nodeData: render.RenderNodeInfo[], sceneBehavior) {
let container = scene.selectOrCreateChild(sceneGroup, "g",
Class.Node.CONTAINER);
// Select all children and join with data.
@@ -76,7 +76,7 @@ export function buildGroup(sceneGroup,
// (It's not listed in the d3 wiki.)
return this.childNodes; // this here refers to container.node()
})
- .data(nodeData, (d: any) => {
+ .data(nodeData, (d) => {
// make sure that we don't have to swap shape type
return d.node.name + ":" + d.node.type;
});
@@ -124,7 +124,8 @@ export function buildGroup(sceneGroup,
addInteraction(shape, d, sceneBehavior);
// build subscene on the top
- subsceneBuild(nodeGroup, d, sceneBehavior);
+ subsceneBuild(nodeGroup, <render.RenderGroupNodeInfo> d,
+ sceneBehavior);
stylize(nodeGroup, d, sceneBehavior);
position(nodeGroup, d, sceneBehavior);
@@ -168,7 +169,7 @@ export function buildGroup(sceneGroup,
* not have a subscene.
*/
function subsceneBuild(nodeGroup,
- renderNodeInfo: render.RenderGroupNodeInformation, sceneBehavior) {
+ renderNodeInfo: render.RenderGroupNodeInfo, sceneBehavior) {
if (renderNodeInfo.node.isGroupNode) {
if (renderNodeInfo.expanded) {
// Recursively build the subscene.
@@ -184,7 +185,7 @@ function subsceneBuild(nodeGroup,
/**
* Translate the subscene of the given node group
*/
-function subscenePosition(nodeGroup, d: render.RenderNodeInformation) {
+function subscenePosition(nodeGroup, d: render.RenderNodeInfo) {
let x0 = d.x - d.width / 2.0 + d.paddingLeft;
let y0 = d.y - d.height / 2.0 + d.paddingTop;
@@ -199,7 +200,7 @@ function subscenePosition(nodeGroup, d: render.RenderNodeInformation) {
* @param d Info about the node being rendered.
* @param sceneBehavior parent scene module.
*/
-function addButton(selection, d: render.RenderNodeInformation, sceneBehavior) {
+function addButton(selection, d: render.RenderNodeInfo, sceneBehavior) {
let group = scene.selectOrCreateChild(
selection, "g", Class.Node.BUTTON_CONTAINER);
scene.selectOrCreateChild(group, "circle", Class.Node.BUTTON_CIRCLE);
@@ -224,7 +225,7 @@ function addButton(selection, d: render.RenderNodeInformation, sceneBehavior) {
* don't need interaction as their surrounding shape has interaction, and if
* given interaction would cause conflicts with the expand/collapse button.
*/
-function addInteraction(selection, d: render.RenderNodeInformation,
+function addInteraction(selection, d: render.RenderNodeInfo,
sceneBehavior, disableInteraction?: boolean) {
if (disableInteraction) {
selection.attr("pointer-events", "none");
@@ -282,7 +283,7 @@ export function getContextMenu(node: Node, sceneBehavior) {
* @param renderNodeInfo The render node information for the label.
* @param sceneBehavior parent scene module.
*/
-function labelBuild(nodeGroup, renderNodeInfo: render.RenderNodeInformation,
+function labelBuild(nodeGroup, renderNodeInfo: render.RenderNodeInfo,
sceneBehavior) {
let namePath = renderNodeInfo.node.name.split("/");
let text = namePath[namePath.length - 1];
@@ -320,14 +321,15 @@ function getLabelFontScale(sceneBehavior) {
}
return fontScale;
}
+
/**
* Set label position of a given node group
*/
-function labelPosition(nodeGroup, d: render.RenderNodeInformation,
+function labelPosition(nodeGroup, cx: number, cy: number,
yOffset: number) {
scene.selectChild(nodeGroup, "text", Class.Node.LABEL).transition()
- .attr("x", d.x)
- .attr("y", d.y + yOffset);
+ .attr("x", cx)
+ .attr("y", cy + yOffset);
};
/**
@@ -335,7 +337,7 @@ function labelPosition(nodeGroup, d: render.RenderNodeInformation,
* as the shape's data.
*
* @param nodeGroup
- * @param d RenderNodeInformation
+ * @param d Render node information.
* @param nodeClass class for the element.
* @param before Reference DOM node for insertion.
* @return Selection of the shape.
@@ -353,7 +355,7 @@ export function buildShape(nodeGroup, d, nodeClass: string, before?) {
case NodeType.SERIES:
// Choose the correct stamp to use to represent this series.
let stampType = "annotation";
- let groupNodeInfo = <render.RenderGroupNodeInformation>d;
+ let groupNodeInfo = <render.RenderGroupNodeInfo>d;
if (groupNodeInfo.coreGraph) {
stampType = groupNodeInfo.node.hasNonControlEdges
? "vertical" : "horizontal";
@@ -377,7 +379,7 @@ export function buildShape(nodeGroup, d, nodeClass: string, before?) {
return shapeGroup;
};
-export function nodeClass(d: render.RenderNodeInformation) {
+export function nodeClass(d: render.RenderNodeInfo) {
switch (d.node.type) {
case NodeType.OP:
return Class.OPNODE;
@@ -394,43 +396,43 @@ export function nodeClass(d: render.RenderNodeInformation) {
};
/** Modify node and its subscene and its label's positional attributes */
-function position(nodeGroup, d: render.RenderNodeInformation, sceneBehavior) {
+function position(nodeGroup, d: render.RenderNodeInfo, sceneBehavior) {
let shapeGroup = scene.selectChild(nodeGroup, "g", Class.Node.SHAPE);
+ let cx = layout.computeCXPositionOfNodeShape(d);
switch (d.node.type) {
case NodeType.OP: {
// position shape
let shape = scene.selectChild(shapeGroup, "ellipse");
- scene.positionEllipse(shape, d.x, d.y, d.width, d.height);
- labelPosition(nodeGroup, d, d.labelOffset);
+ scene.positionEllipse(shape, cx, d.y, d.coreBox.width, d.coreBox.height);
+ labelPosition(nodeGroup, cx, d.y, d.labelOffset);
break;
}
case NodeType.META: {
// position shape
let shape = scene.selectChild(shapeGroup, "rect");
- scene.positionRect(shape, d.x, d.y, d.width, d.height);
-
if (d.expanded) {
+ scene.positionRect(shape, d.x, d.y, d.width, d.height);
subscenePosition(nodeGroup, d);
-
// put label on top
- labelPosition(nodeGroup, d,
+ labelPosition(nodeGroup, cx, d.y,
- d.height / 2 + d.labelHeight / 2);
} else {
- labelPosition(nodeGroup, d, 0);
+ scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height);
+ labelPosition(nodeGroup, cx, d.y, 0);
}
break;
}
case NodeType.SERIES: {
let shape = scene.selectChild(shapeGroup, "use");
- scene.positionRect(shape, d.x, d.y, d.width, d.height);
if (d.expanded) {
- subscenePosition(nodeGroup, d);
-
+ scene.positionRect(shape, d.x, d.y, d.width, d.height);
+ subscenePosition(nodeGroup, d);
// put label on top
- labelPosition(nodeGroup, d,
+ labelPosition(nodeGroup, cx, d.y,
- d.height / 2 + d.labelHeight / 2);
} else {
- labelPosition(nodeGroup, d, d.labelOffset);
+ scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height);
+ labelPosition(nodeGroup, cx, d.y, d.labelOffset);
}
}
case NodeType.BRIDGE: {
@@ -455,7 +457,7 @@ export enum ColorBy { STRUCTURE, DEVICE, COMPUTE_TIME, MEMORY };
* option.
*/
export function getFillForNode(templateIndex, colorBy,
- renderInfo: render.RenderNodeInformation, isExpanded: boolean): string {
+ renderInfo: render.RenderNodeInfo, isExpanded: boolean): string {
let colorParams = tf.graph.render.MetanodeColors;
switch (colorBy) {
case ColorBy.STRUCTURE:
@@ -493,7 +495,7 @@ export function getFillForNode(templateIndex, colorBy,
linearGradient.selectAll("*").remove();
let cumulativeProportion = 0;
// For each device, create a stop using the proportion of that device.
- _.each(renderInfo.deviceColors, (d: any) => {
+ _.each(renderInfo.deviceColors, d => {
let color = d.color;
linearGradient.append("stop")
.attr("offset", cumulativeProportion)
@@ -522,7 +524,7 @@ export function getFillForNode(templateIndex, colorBy,
* Modify node style by toggling class and assign attributes (only for things
* that can't be done in css).
*/
-export function stylize(nodeGroup, renderInfo: render.RenderNodeInformation,
+export function stylize(nodeGroup, renderInfo: render.RenderNodeInfo,
sceneBehavior, nodeClass?) {
nodeClass = nodeClass || Class.Node.SHAPE;
let isHighlighted = sceneBehavior.isNodeHighlighted(renderInfo.node.name);
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
index 6e97e904da..24c16e31ee 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts
@@ -118,8 +118,8 @@ export function fit(svg, zoomG, d3zoom, callback) {
* provided node.
*/
export function panToNode(nodeName: String, svg, zoomG, d3zoom): boolean {
- let node: any = d3.selectAll("[data-name='" + nodeName + "']."
- + Class.Node.GROUP)[0][0];
+ let node = <SVGAElement> d3.select("[data-name='" + nodeName + "']."
+ + Class.Node.GROUP).node();
if (!node) {
return false;
}
@@ -247,7 +247,7 @@ export function selectChild(container, tagName: string, className?: string) {
* @param sceneClass class attribute of the scene (default="scene").
*/
export function buildGroup(container,
- renderNode: render.RenderGroupNodeInformation,
+ renderNode: render.RenderGroupNodeInfo,
sceneBehavior,
sceneClass: string) {
sceneClass = sceneClass || Class.Scene.GROUP;
@@ -301,8 +301,7 @@ export function buildGroup(container,
// Fade in the scene group if it didn't already exist.
if (isNewSceneGroup) {
- sceneGroup.attr("opacity", 0)
- .transition().attr("opacity", 1);
+ sceneGroup.attr("opacity", 0).transition().attr("opacity", 1);
}
return sceneGroup;
@@ -315,7 +314,7 @@ export function buildGroup(container,
* @param sceneGroup
* @param renderNode render node of a metanode or series node.
*/
-function position(sceneGroup, renderNode: render.RenderGroupNodeInformation) {
+function position(sceneGroup, renderNode: render.RenderGroupNodeInfo) {
// Translate scenes down by the label height so that when showing graphs in
// expanded metanodes, the graphs are below the labels. Do not shift them
// down for series nodes as series nodes don't have labels inside of their
@@ -324,14 +323,13 @@ function position(sceneGroup, renderNode: render.RenderGroupNodeInformation) {
0 : layout.PARAMS.subscene.meta.labelHeight;
// core
- translate(selectChild(sceneGroup, "g", Class.Scene.CORE),
- 0, yTranslate);
+ translate(selectChild(sceneGroup, "g", Class.Scene.CORE), 0, yTranslate);
// in-extract
- let inExtractX = renderNode.coreBox.width === 0 ?
- 0 : renderNode.coreBox.width;
let hasInExtract = renderNode.isolatedInExtract.length > 0;
if (hasInExtract) {
+ let inExtractX = renderNode.coreBox.width -
+ renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width;
translate(selectChild(sceneGroup, "g", Class.Scene.INEXTRACT),
inExtractX, yTranslate);
}
@@ -339,8 +337,8 @@ function position(sceneGroup, renderNode: render.RenderGroupNodeInformation) {
// out-extract
let hasOutExtract = renderNode.isolatedOutExtract.length > 0;
if (hasOutExtract) {
- let outExtractX = inExtractX + renderNode.inExtractBox.width
- + renderNode.extractXOffset;
+ let outExtractX = renderNode.coreBox.width -
+ renderNode.outExtractBox.width / 2;
translate(selectChild(sceneGroup, "g", Class.Scene.OUTEXTRACT),
outExtractX, yTranslate);
}
@@ -355,6 +353,10 @@ export function addGraphClickListener(graphGroup, sceneBehavior) {
/** Helper for adding transform: translate(x0, y0) */
export function translate(selection, x0: number, y0: number) {
+ // If it is already placed on the screen, make it a transition.
+ if (selection.attr("transform") != null) {
+ selection = selection.transition("position");
+ }
selection.attr("transform", "translate(" + x0 + "," + y0 + ")");
};
@@ -382,12 +384,16 @@ export function positionRect(rect, cx: number, cy: number, width: number,
* @param renderNode the render node of the group node to position
* the button on.
*/
-export function positionButton(button,
- renderNode: render.RenderNodeInformation) {
+export function positionButton(button, renderNode: render.RenderNodeInfo) {
+ let cx = layout.computeCXPositionOfNodeShape(renderNode);
// Position the button in the top-right corner of the group node,
// with space given the draw the button inside of the corner.
- let x = renderNode.x + renderNode.width / 2 - 6;
- let y = renderNode.y - renderNode.height / 2 + 6;
+ let width = renderNode.expanded ?
+ renderNode.width : renderNode.coreBox.width;
+ let height = renderNode.expanded ?
+ renderNode.height : renderNode.coreBox.height;
+ let x = cx + width / 2 - 6;
+ let y = renderNode.y - height / 2 + 6;
// For unexpanded series nodes, the button has special placement due
// to the unique visuals of this group node.
if (renderNode.node.type === NodeType.SERIES && !renderNode.expanded) {
diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts
index 41fbbbb9ff..0423e1c863 100644
--- a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts
+++ b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts
@@ -57,13 +57,13 @@ export function detect(h, verifyTemplate): {[templateId: string]: string[]} {
* @return Unique string for a metanode based on depth, |V|, |E| and
* op type histogram.
*/
- function getSignature(metanode) {
+function getSignature(metanode) {
// depth=<number> |V|=<number> |E|=<number>
- let props = _.map({
- "depth": metanode.depth,
- "|V|": metanode.metagraph.nodes().length,
- "|E|": metanode.metagraph.edges().length
- }, function(v, k) { return k + "=" + v; }).join(" ");
+ let props = _.map({
+ "depth": metanode.depth,
+ "|V|": metanode.metagraph.nodes().length,
+ "|E|": metanode.metagraph.edges().length
+ }, function(v, k) { return k + "=" + v; }).join(" ");
// optype1=count1,optype2=count2
let ops = _.map(metanode.opHistogram, function(count, op) {
@@ -84,7 +84,8 @@ export function detect(h, verifyTemplate): {[templateId: string]: string[]} {
*/
function clusterSimilarSubgraphs(h: hierarchy.Hierarchy) {
/** a dict from metanode.signature() => Array of tf.graph.Groups */
- let hashDict = _(h.getNodeMap()).reduce(function(hash, node: OpNode|Metanode, name) {
+ let hashDict = _(h.getNodeMap()).reduce(
+ (hash, node: OpNode|Metanode, name) => {
if (node.type !== NodeType.META) {
return hash;
}
@@ -156,8 +157,8 @@ function groupTemplateAndAssignId(nnGroups, verifyTemplate) {
}, result);
}
-function sortNodes(names: string[], graph: graphlib.Graph<Metanode|OpNode, Metaedge>,
- prefix: string) {
+function sortNodes(names: string[],
+ graph: graphlib.Graph<Metanode|OpNode, Metaedge>, prefix: string) {
return _.sortByAll(names,
function(name) {
let node = graph.node(name);
@@ -181,7 +182,8 @@ function sortNodes(names: string[], graph: graphlib.Graph<Metanode|OpNode, Metae
});
}
-function isSimilarSubgraph(g1: graphlib.Graph<any, any>, g2: graphlib.Graph<any, any>) {
+function isSimilarSubgraph(g1: graphlib.Graph<any, any>,
+ g2: graphlib.Graph<any, any>) {
if (!tf.graph.hasSimilarDegreeSequence(g1, g2)) {
return false;
}
@@ -273,25 +275,27 @@ function isSimilarSubgraph(g1: graphlib.Graph<any, any>, g2: graphlib.Graph<any,
/**
* Returns if two nodes have identical structure.
*/
- function isSimilarNode(n1: OpNode|Metanode|SeriesNode, n2: OpNode|Metanode|SeriesNode): boolean {
+function isSimilarNode(n1: OpNode|Metanode|SeriesNode,
+ n2: OpNode|Metanode|SeriesNode): boolean {
if (n1.type === NodeType.META) {
// compare metanode
let metanode1 = <Metanode> n1;
let metanode2 = <Metanode> n2;
- return metanode1.templateId && metanode2.templateId && metanode1.templateId === metanode2.templateId;
+ return metanode1.templateId && metanode2.templateId &&
+ metanode1.templateId === metanode2.templateId;
} else if (n1.type === NodeType.OP && n2.type === NodeType.OP) {
// compare leaf node
return (<OpNode>n1).op === (<OpNode>n2).op;
} else if (n1.type === NodeType.SERIES && n2.type === NodeType.SERIES) {
// compare series node sizes and operations
// (only need to check one op as all op nodes are identical in series)
- let seriesnode1 = <SeriesNode> n1;
- let seriesnode2 = <SeriesNode> n2;
- let seriesnode1Count = seriesnode1.metagraph.nodeCount();
- return (seriesnode1Count === seriesnode2.metagraph.nodeCount() &&
+ let sn1 = <SeriesNode> n1;
+ let sn2 = <SeriesNode> n2;
+ let seriesnode1Count = sn1.metagraph.nodeCount();
+ return (seriesnode1Count === sn2.metagraph.nodeCount() &&
(seriesnode1Count === 0 ||
- ((<OpNode>seriesnode1.metagraph.node(seriesnode1.metagraph.nodes()[0])).op ===
- (<OpNode>seriesnode2.metagraph.node(seriesnode2.metagraph.nodes()[0])).op)));
+ ((<OpNode>sn1.metagraph.node(sn1.metagraph.nodes()[0])).op ===
+ (<OpNode>sn2.metagraph.node(sn2.metagraph.nodes()[0])).op)));
}
return false;
}
diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html
index 9204b392e3..765803a6a9 100644
--- a/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html
+++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-icon.html
@@ -83,7 +83,7 @@
* Render node information associated with this node. Optional. If
* specified, this is only used when computing the fill of the icon
* element.
- * @type {tf.graph.render.RenderNodeInformation}
+ * @type {tf.graph.render.RenderNodeInfo}
*/
renderInfo: {
type: Object,
diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
index 996735679f..e9e6f6ce02 100644
--- a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
+++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html
@@ -98,10 +98,7 @@ Polymer({
properties: {
renderHierarchy: Object,
name: String,
- colorBy: {
- type: String,
- observer: '_colorByChanged'
- },
+ colorBy: String,
/** @type {d3_zoom} d3 zoom object */
_zoom: Object,
highlightedNode: {
@@ -201,6 +198,7 @@ Polymer({
progress: Object
},
observers: [
+ '_colorByChanged(colorBy, renderHierarchy)',
'_buildAndFit(renderHierarchy)'
],
getNode: function(nodeName) {
@@ -234,7 +232,7 @@ Polymer({
this.templateIndex = renderHierarchy.hierarchy.getTemplateIndex();
tf.time('tf-graph-scene (layout):', function() {
// layout the scene for this meta / series node
- tf.graph.layout.scene(renderHierarchy.root, this);
+ tf.graph.layout.layoutScene(renderHierarchy.root, this);
}.bind(this));
tf.time('tf-graph-scene (build scene):', function() {
diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph.html b/tensorflow/tensorboard/components/tf-graph/tf-graph.html
index ffb737f761..e35664ae7f 100644
--- a/tensorflow/tensorboard/components/tf-graph/tf-graph.html
+++ b/tensorflow/tensorboard/components/tf-graph/tf-graph.html
@@ -112,8 +112,8 @@ Polymer({
// and thus mistakenly pass non-metanode to this module.
return;
}
- var renderGraph = new tf.graph.render.RenderGraphInformation(
- graphHierarchy, params);
+ var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy,
+ params);
// Producing the 'color by' parameters to be consumed
// by the tf-graph-controls panel. It contains information about the
// min and max values and their respective colors, as well as list
diff --git a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html
index 1d211f9b2a..d6748ff167 100644
--- a/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html
+++ b/tensorflow/tensorboard/components/tf-tensorboard/tf-tensorboard.html
@@ -20,11 +20,11 @@ allows the user to toggle between various dashboards.
<paper-toolbar id="toolbar">
<div id="toolbar-content">
<div class="toolbar-title">TensorBoard</div>
- <paper-tabs selected="0" noink class="tabs">
- <paper-tab on-click="chooseEvents">Events</paper-tab>
- <paper-tab on-click="chooseImages">Images</paper-tab>
- <paper-tab on-click="chooseGraphs">Graph</paper-tab>
- <paper-tab on-click="chooseHistograms">Histograms</paper-tab>
+ <paper-tabs selected="0" noink class="tabs" id="tabs">
+ <paper-tab data-mode="events" on-click="changeMode">Events</paper-tab>
+ <paper-tab data-mode="images" on-click="changeMode">Images</paper-tab>
+ <paper-tab data-mode="graphs" on-click="changeMode">Graph</paper-tab>
+ <paper-tab data-mode="histograms" on-click="changeMode">Histograms</paper-tab>
</paper-tabs>
</div>
</paper-toolbar>
@@ -100,17 +100,9 @@ allows the user to toggle between various dashboards.
value: "events",
},
},
- chooseEvents: function() {
- this.mode = "events";
- },
- chooseImages: function() {
- this.mode = "images";
- },
- chooseGraphs: function() {
- this.mode = "graphs";
- },
- chooseHistograms: function() {
- this.mode = "histograms";
+ changeMode: function(ev) {
+ var mode = ev.target.parentElement.getAttribute('data-mode');
+ this._changeMode(mode, true);
},
eventDashboard: function(mode) {
return mode === "events";
@@ -123,7 +115,47 @@ allows the user to toggle between various dashboards.
},
histogramDashboard: function(mode) {
return mode === "histograms";
- }
+ },
+ loadPreviousMode: function() {
+ this._changeMode(this._getModeAndPath().mode, false);
+ },
+ ready: function() {
+ this._changeMode(this._getModeAndPath().mode, true);
+
+ var tb = this;
+ window.addEventListener('popstate', function(){
+ tb.loadPreviousMode();
+ });
+ },
+ _changeMode: function(mode, isNewState) {
+ this.mode = mode;
+
+ // Change the selected tab
+ this.$.tabs.selected = this._tabs().indexOf(mode);
+
+ if (isNewState){
+ var basePath = this._getModeAndPath().path;
+ basePath += basePath[basePath.length - 1] == '/' ? '' : '/';
+ history.pushState(null, null, basePath + mode);
+ }
+ },
+ _getModeAndPath: function() {
+ // Returns a {mode: 'mode', path: 'basePathWithoutMode'}
+ // The mode is assumed to be at the end of the pathname.
+ var tokens = window.location.pathname.split('/');
+ var mode = tokens[tokens.length - 1];
+
+ if (_.contains(this._tabs(), mode)) {
+ return {mode: mode, path: tokens.slice(0, tokens.length-1).join('/')};
+ } else {
+ // Unrecognized modes turn into events
+ return {mode: 'events', path: tokens.join('/')};
+ }
+ },
+ _tabs: function() {
+ var elts = Array.prototype.slice.call(this.querySelectorAll('paper-tab'));
+ return elts.map(function(elt){ return elt.getAttribute('data-mode')});
+ },
});
</script>
</dom-module>
diff --git a/tensorflow/tensorboard/dist/index.html b/tensorflow/tensorboard/dist/index.html
index e75a87a4f3..a72f6a62a6 100644
--- a/tensorflow/tensorboard/dist/index.html
+++ b/tensorflow/tensorboard/dist/index.html
@@ -33,6 +33,7 @@
<link rel="import" href="external/paper-styles/paper-styles.html">
<link rel="import" href="external/paper-toggle-button/paper-toggle-button.html">
<link rel="import" href="external/paper-toolbar/paper-toolbar.html">
+ <link rel="import" href="external/paper-tabs/paper-tabs.html">
<link rel="import" href="dist/tf-tensorboard.html">
diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html
index 2aa1e46ca3..7d01603e1c 100644
--- a/tensorflow/tensorboard/dist/tf-tensorboard.html
+++ b/tensorflow/tensorboard/dist/tf-tensorboard.html
@@ -1,3 +1,4 @@
+// AUTOGENERATED FILE - DO NOT MODIFY
<html><head><meta charset="UTF-8">
@@ -11,6 +12,8 @@
--tb-orange-strong: #f3913e;
--tb-grey-darker: #e2e2e2;
--tb-grey-lighter: #f3f3f3;
+ --tb-ui-dark-accent: #757575;
+ --tb-ui-light-accent: #e0e0e0;
}
</style>
@@ -32,8 +35,21 @@
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
-<script>/// <reference path="../../../typings/tsd.d.ts" />
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../../../typings/tsd.d.ts" />
var tf;
(function (tf) {
/**
@@ -112,7 +128,21 @@ var tf;
tf.escapeQuerySelector = escapeQuerySelector;
})(tf || (tf = {})); // close module tf
</script>
-<script>/// <reference path="../../../typings/tsd.d.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../../../typings/tsd.d.ts" />
/// <reference path="common.ts" />
var tf;
(function (tf) {
@@ -120,7 +150,6 @@ var tf;
(function (graph_1) {
/** Delimiter used in node names to denote namespaces. */
graph_1.NAMESPACE_DELIM = "/";
- var FULL_GRAPH_NAME = "fullGraph";
graph_1.ROOT_NAME = "__root__";
// Separator between the source and the destination name of the edge.
graph_1.EDGE_KEY_DELIM = "--";
@@ -145,6 +174,14 @@ var tf;
})(graph_1.NodeType || (graph_1.NodeType = {}));
var NodeType = graph_1.NodeType;
;
+ /** Indicates if a node is to be included in the main graph when rendered. */
+ (function (InclusionType) {
+ InclusionType[InclusionType["INCLUDE"] = 0] = "INCLUDE";
+ InclusionType[InclusionType["EXCLUDE"] = 1] = "EXCLUDE";
+ InclusionType[InclusionType["UNSPECIFIED"] = 2] = "UNSPECIFIED";
+ })(graph_1.InclusionType || (graph_1.InclusionType = {}));
+ var InclusionType = graph_1.InclusionType;
+ ;
/**
* A SlimGraph is inspired by graphlib.Graph, but having only the functionality
* that we need.
@@ -170,6 +207,7 @@ var tf;
this.parentNode = null;
this.stats = null;
this.setNumMoreNodes(numNodes);
+ this.include = InclusionType.UNSPECIFIED;
}
EllipsisNodeImpl.prototype.setNumMoreNodes = function (numNodes) {
this.numMoreNodes = numNodes;
@@ -190,8 +228,8 @@ var tf;
* @param rawNode The raw node.
* @param normalizedInputs An array of normalized
* inputs that denote the incoming edges to the current node. Each input
- * contains the normalized name of the source node, whether it has a number
- * part and whether it is a control dependency.
+ * contains the normalized name of the source node, whether it has a
+ * number part and whether it is a control dependency.
*/
function OpNodeImpl(rawNode, normalizedInputs) {
this.op = rawNode.op;
@@ -206,6 +244,7 @@ var tf;
this.inEmbeddings = [];
this.outEmbeddings = [];
this.parentNode = null;
+ this.include = InclusionType.UNSPECIFIED;
}
return OpNodeImpl;
})();
@@ -216,8 +255,8 @@ var tf;
}
graph_1.createMetanode = createMetanode;
/**
- * Joins the information from the stats file (memory, compute time) with the graph
- * information.
+ * Joins the information from the stats file (memory, compute time) with the
+ * graph information.
*/
function joinStatsInfoWithGraph(graph, statsJson) {
_.each(statsJson.devStats, function (stats) {
@@ -274,6 +313,7 @@ var tf;
};
return NodeStats;
})();
+ graph_1.NodeStats = NodeStats;
var MetanodeImpl = (function () {
/** A label object for meta-nodes in the graph hierarchy */
function MetanodeImpl(name, opt) {
@@ -304,6 +344,7 @@ var tf;
this.parentNode = null;
this.stats = new NodeStats(0, 0, null);
this.hasNonControlEdges = false;
+ this.include = InclusionType.UNSPECIFIED;
}
MetanodeImpl.prototype.getFirstChild = function () {
return this.metagraph.node(this.metagraph.nodes()[0]);
@@ -404,6 +445,7 @@ var tf;
this.deviceHistogram = {};
this.hasNonControlEdges = false;
this.stats = new NodeStats(0, 0, null);
+ this.include = InclusionType.UNSPECIFIED;
}
return SeriesNodeImpl;
})();
@@ -655,7 +697,8 @@ var tf;
;
/**
* Returns the hierarchical path of the current node, based on the node's name.
- * For example, if the name is 'a/b/c', the returned path is ['a', 'a/b', 'a/b/c'].
+ * For example, if the name is 'a/b/c', the returned path is
+ * ['a', 'a/b', 'a/b/c'].
*/
function getHierarchicalPath(name, seriesNames) {
var path = [];
@@ -679,10 +722,38 @@ var tf;
}
graph_1.getHierarchicalPath = getHierarchicalPath;
;
+ /**
+ * Returns the string for the node inclusion toggle button, dependant
+ * on the provided current InclusionType.
+ */
+ function getIncludeNodeButtonString(include) {
+ if (include === tf.graph.InclusionType.EXCLUDE) {
+ return "Add to main graph";
+ }
+ else {
+ return "Remove from main graph";
+ }
+ }
+ graph_1.getIncludeNodeButtonString = getIncludeNodeButtonString;
+ ;
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // close module tf.graph
</script>
-<script>/// <reference path="../../../typings/tsd.d.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../../../typings/tsd.d.ts" />
/// <reference path="common.ts" />
var tf;
(function (tf) {
@@ -871,7 +942,21 @@ var tf;
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // Close module tf.graph.parser.
</script>
-<script>/// <reference path="graph.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="graph.ts" />
/// <reference path="template.ts" />
/**
* Package for the Graph Hierarchy for TensorFlow graph.
@@ -882,7 +967,6 @@ var tf;
(function (graph_1) {
var hierarchy;
(function (hierarchy_1) {
- var LOG_PREFIX_MSG = "Graph hierarchy: ";
/**
* Class for the Graph Hierarchy for TensorFlow graph.
*/
@@ -1139,6 +1223,17 @@ var tf;
}
return ordering;
};
+ /**
+ * Returns a d3 Ordinal function that can be used to look up the index of
+ * a node based on its template id.
+ */
+ HierarchyImpl.prototype.getTemplateIndex = function () {
+ var templateNames = d3.keys(this.templates);
+ var templateIndex = d3.scale.ordinal()
+ .domain(templateNames)
+ .range(d3.range(0, templateNames.length));
+ return function (templateId) { return templateIndex(templateId); };
+ };
return HierarchyImpl;
})();
/**
@@ -1192,8 +1287,8 @@ var tf;
}, tracker)
.then(function () {
return tf.runAsyncTask("Detect series", 20, function () {
- if (params.groupSeries) {
- groupSeries(h.root, h, seriesNames);
+ if (params.seriesNodeMinSize > 0) {
+ groupSeries(h.root, h, seriesNames, params.seriesNodeMinSize);
}
}, tracker);
})
@@ -1251,7 +1346,8 @@ var tf;
}
parent = child;
}
- // Assuming node name is 'a/b/c', assign the OpNode as a child of the metanode 'a/b'.
+ // Assuming node name is 'a/b/c', assign the OpNode as a child of the
+ // metanode 'a/b'.
h.setNode(node.name, node);
node.parentNode = parent;
parent.metagraph.setNode(node.name, node);
@@ -1333,14 +1429,17 @@ var tf;
*
* @param metanode
* @param hierarchy
- * @return A dictionary from node name to series node name that contains the node
+ * @param threshold If the series has this many nodes or more, then group them
+ * into a series.
+ * @return A dictionary from node name to series node name that contains the
+ * node.
*/
- function groupSeries(metanode, hierarchy, seriesNames) {
+ function groupSeries(metanode, hierarchy, seriesNames, threshold) {
var metagraph = metanode.metagraph;
_.each(metagraph.nodes(), function (n) {
var child = metagraph.node(n);
if (child.type === tf.graph.NodeType.META) {
- groupSeries(child, hierarchy, seriesNames);
+ groupSeries(child, hierarchy, seriesNames, threshold);
}
});
var clusters = clusterNodes(metagraph);
@@ -1349,8 +1448,9 @@ var tf;
// metagraph.
_.each(seriesDict, function (seriesNode, seriesName) {
var nodeMemberNames = seriesNode.metagraph.nodes();
- var firstMember = seriesNode.metagraph.node(nodeMemberNames[0]);
- var seriesType = firstMember.type;
+ if (nodeMemberNames.length < threshold) {
+ return;
+ }
hierarchy.setNode(seriesName, seriesNode); // add to the index
metagraph.setNode(seriesName, seriesNode);
_.each(nodeMemberNames, function (n) {
@@ -1453,7 +1553,8 @@ var tf;
var seriesNodes = [seriesInfoArray[0]];
for (var index = 1; index < seriesInfoArray.length; index++) {
var nextNode = seriesInfoArray[index];
- if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId + 1) {
+ if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId
+ + 1) {
seriesNodes.push(nextNode);
continue;
}
@@ -1489,14 +1590,28 @@ var tf;
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // close module tf.graph.hierarchy
</script>
-<script>/// <reference path="graph.ts" />
-/// <reference path="hierarchy.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
var __extends = (this && this.__extends) || function (d, b) {
for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p];
function __() { this.constructor = d; }
__.prototype = b.prototype;
d.prototype = new __();
};
+/// <reference path="graph.ts" />
+/// <reference path="hierarchy.ts" />
/**
* Package for the Render Hierarchy for TensorFlow graph.
*/
@@ -1507,10 +1622,22 @@ var tf;
var render;
(function (render) {
/**
+ * Color parameters for op nodes.
+ */
+ render.OpNodeColors = {
+ DEFAULT_FILL: "white",
+ DEFAULT_STROKE: "#b2b2b2"
+ };
+ /**
* Color parameters for node encoding.
* @type {Object}
*/
render.MetanodeColors = {
+ /**
+ * Default fill and stroke to use when no other information is available.
+ */
+ DEFAULT_FILL: "#d9d9d9",
+ DEFAULT_STROKE: "#a6a6a6",
SATURATION: 0.6,
LIGHTNESS: 0.85,
/**
@@ -1540,11 +1667,18 @@ var tf;
GRADIENT_OUTLINE: "#888"
};
/**
+ * Color parameters for op nodes.
+ */
+ render.SeriesNodeColors = {
+ DEFAULT_FILL: "white",
+ DEFAULT_STROKE: "#b2b2b2"
+ };
+ /**
* Stores the rendering information, such as x and y coordinates,
* for each node in the graph.
*/
- var RenderGraphInformation = (function () {
- function RenderGraphInformation(hierarchy, params) {
+ var RenderGraphInfo = (function () {
+ function RenderGraphInfo(hierarchy, params) {
this.hierarchy = hierarchy;
this.index = {};
this.deviceColorMap = d3.scale.ordinal()
@@ -1573,15 +1707,67 @@ var tf;
this.computeTimeScale = d3.scale.linear()
.domain(computeTimeExtent)
.range(params.minMaxColors);
- // Maps node name to whether the rendering hierarchy was already constructed.
+ // Maps node name to whether the rendering hierarchy was already
+ // constructed.
this.hasSubhierarchy = {};
this.params = params;
- this.root = new RenderGroupNodeInformation(hierarchy.root);
+ this.root = new RenderGroupNodeInfo(hierarchy.root);
this.index[hierarchy.root.name] = this.root;
this.buildSubhierarchy(hierarchy.root.name);
this.root.expanded = true;
}
- RenderGraphInformation.prototype.getRenderNodeByName = function (nodeName) {
+ /**
+ * Get a previously created RenderNodeInfo by its node name.
+ */
+ RenderGraphInfo.prototype.getRenderNodeByName = function (nodeName) {
+ return this.index[nodeName];
+ };
+ /**
+ * Get a previously created RenderNodeInfo for the specified node name,
+ * or create one if it hasn't been created yet.
+ */
+ RenderGraphInfo.prototype.getOrCreateRenderNodeByName = function (nodeName) {
+ var _this = this;
+ // Polymer may invoke this with null.
+ if (!nodeName) {
+ return null;
+ }
+ if (nodeName in this.index) {
+ return this.index[nodeName];
+ }
+ var node = this.hierarchy.node(nodeName);
+ var renderInfo = node.isGroupNode ?
+ new RenderGroupNodeInfo(node) :
+ new RenderNodeInfo(node);
+ this.index[nodeName] = renderInfo;
+ if (node.stats) {
+ renderInfo.memoryColor = this.memoryUsageScale(node.stats.totalBytes);
+ renderInfo.computeTimeColor =
+ this.computeTimeScale(node.stats.totalMicros);
+ }
+ if (node.isGroupNode) {
+ // Make a list of tuples (device, proportion), where proportion
+ // is the fraction of op nodes that have that device.
+ var pairs = _.pairs(node.deviceHistogram);
+ if (pairs.length > 0) {
+ // Compute the total # of devices.
+ var numDevices = _.sum(pairs, _.last);
+ renderInfo.deviceColors = _.map(pairs, function (pair) { return ({
+ color: _this.deviceColorMap(pair[0]),
+ // Normalize to a proportion of total # of devices.
+ proportion: pair[1] / numDevices
+ }); });
+ }
+ }
+ else {
+ var device = renderInfo.node.device;
+ if (device) {
+ renderInfo.deviceColors = [{
+ color: this.deviceColorMap(device),
+ proportion: 1.0
+ }];
+ }
+ }
return this.index[nodeName];
};
/**
@@ -1590,7 +1776,7 @@ var tf;
* (highlight) a node that isn't drawn yet, by selecting (highlighting)
* its nearest ancestor that has been drawn.
*/
- RenderGraphInformation.prototype.getNearestVisibleAncestor = function (name) {
+ RenderGraphInfo.prototype.getNearestVisibleAncestor = function (name) {
var path = graph_1.getHierarchicalPath(name);
for (var i = 0; i < path.length; i++) {
var nodeName = path[i];
@@ -1603,10 +1789,26 @@ var tf;
return name;
};
// TODO(jimbo): Delete this an any code it touches (all deprecated).
- RenderGraphInformation.prototype.setDepth = function (depth) {
+ RenderGraphInfo.prototype.setDepth = function (depth) {
setGroupNodeDepth(this.root, +depth);
};
- RenderGraphInformation.prototype.buildSubhierarchy = function (nodeName) {
+ /**
+ * Returns true if the renderNode is an isolated node within its parent node.
+ */
+ RenderGraphInfo.prototype.isNodeAuxilliary = function (renderNode) {
+ var parentNode = this.getRenderNodeByName(renderNode.node.parentNode.name);
+ var found = _.find(parentNode.isolatedInExtract, function (node) {
+ return node.node.name === renderNode.node.name;
+ });
+ if (found) {
+ return true;
+ }
+ found = _.find(parentNode.isolatedOutExtract, function (node) {
+ return node.node.name === renderNode.node.name;
+ });
+ return !!found;
+ };
+ RenderGraphInfo.prototype.buildSubhierarchy = function (nodeName) {
var _this = this;
// Terminate if the rendering hierarchy was already constructed
// for this node.
@@ -1628,58 +1830,26 @@ var tf;
// extracted. Also, due to extraction, the coreGraph may contain disjoint
// groups between which there is no visible path (other than annotations).
_.each(metagraph.nodes(), function (childName) {
- var childNode = metagraph.node(childName);
- var childRenderInfo = childNode.isGroupNode ?
- new RenderGroupNodeInformation(childNode) :
- new RenderNodeInformation(childNode);
- _this.index[childName] = childRenderInfo;
+ var childRenderInfo = _this.getOrCreateRenderNodeByName(childName);
+ var childNode = childRenderInfo.node;
coreGraph.setNode(childName, childRenderInfo);
- if (childRenderInfo.node.stats != null) {
- childRenderInfo.memoryColor =
- _this.memoryUsageScale(childRenderInfo.node.stats.totalBytes);
- childRenderInfo.computeTimeColor =
- _this.computeTimeScale(childRenderInfo.node.stats.totalMicros);
- }
if (!childNode.isGroupNode) {
_.each(childNode.inEmbeddings, function (embedding) {
- var renderMetaedgeInfo = new RenderMetaedgeInformation(null);
+ var renderMetaedgeInfo = new RenderMetaedgeInfo(null);
addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, AnnotationType.CONSTANT, _this.params);
- _this.index[embedding.name] = new RenderNodeInformation(embedding);
+ _this.index[embedding.name] = new RenderNodeInfo(embedding);
});
_.each(childNode.outEmbeddings, function (embedding) {
- var renderMetaedgeInfo = new RenderMetaedgeInformation(null);
+ var renderMetaedgeInfo = new RenderMetaedgeInfo(null);
addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, AnnotationType.SUMMARY, _this.params);
- _this.index[embedding.name] = new RenderNodeInformation(embedding);
+ _this.index[embedding.name] = new RenderNodeInfo(embedding);
});
- var device = childRenderInfo.node.device;
- if (device != null) {
- childRenderInfo.deviceColors = [{
- color: _this.deviceColorMap(device),
- proportion: 1.0
- }];
- }
- }
- else {
- // Make a list of tuples (device, proportion), where proportion
- // is the fraction of op nodes that have that device.
- var pairs = _.pairs(childNode.deviceHistogram);
- if (pairs.length > 0) {
- // Compute the total # of devices.
- var numDevices = _.sum(pairs, _.last);
- childRenderInfo.deviceColors = _.map(pairs, function (pair) {
- return {
- color: _this.deviceColorMap(pair[0]),
- // Normalize to a proportion of total # of devices.
- proportion: pair[1] / numDevices
- };
- });
- }
}
});
// Add render metaedge info for edges in the metagraph.
_.each(metagraph.edges(), function (edgeObj) {
var metaedge = metagraph.edge(edgeObj);
- var renderMetaedgeInfo = new RenderMetaedgeInformation(metaedge);
+ var renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge);
coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo);
});
if (this.params.enableExtraction &&
@@ -1755,7 +1925,7 @@ var tf;
otherCounts.control[otherName] > _this.params.maxControlDegree;
var _b = inbound ?
[renderNodeInfo.inAnnotations, childRenderInfo.inAnnotations] :
- [renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations], annotations = _b[0], childAnnotations = _b[1];
+ [renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations], childAnnotations = _b[1];
var isOtherHighDegree = inbound ?
otherCounts.out[otherName] > _this.params.maxOutDegree :
otherCounts.in[otherName] > _this.params.maxInDegree;
@@ -1843,7 +2013,7 @@ var tf;
// If we can't make a bridge path for any reason, then we add an
// annotation instead.
if (!canDrawBridgePath) {
- childAnnotations.push(new Annotation(otherNode, otherRenderInfo, new RenderMetaedgeInformation(bridgeMetaedge), AnnotationType.SHORTCUT, inbound), _this.params);
+ childAnnotations.push(new Annotation(otherNode, otherRenderInfo, new RenderMetaedgeInfo(bridgeMetaedge), AnnotationType.SHORTCUT, inbound), _this.params);
return;
}
// At this point, all conditions have been met for drawing a bridge path.
@@ -1864,11 +2034,12 @@ var tf;
cardinality: 0,
parentNode: null,
stats: null,
+ include: graph_1.InclusionType.UNSPECIFIED,
// BridgeNode properties.
inbound: inbound,
};
bridgeContainerInfo =
- new RenderNodeInformation(bridgeContainerNode);
+ new RenderNodeInfo(bridgeContainerNode);
_this.index[bridgeContainerName] = bridgeContainerInfo;
coreGraph.setNode(bridgeContainerName, bridgeContainerInfo);
}
@@ -1881,10 +2052,11 @@ var tf;
cardinality: 1,
parentNode: null,
stats: null,
+ include: graph_1.InclusionType.UNSPECIFIED,
// BridgeNode properties.
inbound: inbound,
};
- bridgeNodeRenderInfo = new RenderNodeInformation(bridgeNode);
+ bridgeNodeRenderInfo = new RenderNodeInfo(bridgeNode);
_this.index[bridgeNodeName] = bridgeNodeRenderInfo;
coreGraph.setNode(bridgeNodeName, bridgeNodeRenderInfo);
// Set bridgeNode to be a graphlib child of the container node.
@@ -1892,7 +2064,7 @@ var tf;
bridgeContainerInfo.node.cardinality++;
}
// Create and add a bridge render metaedge.
- var bridgeRenderMetaedge = new RenderMetaedgeInformation(bridgeMetaedge);
+ var bridgeRenderMetaedge = new RenderMetaedgeInfo(bridgeMetaedge);
bridgeRenderMetaedge.adjoiningMetaedge = adjoiningMetaedge;
inbound ?
coreGraph.setEdge(bridgeNodeName, childName, bridgeRenderMetaedge) :
@@ -1993,10 +2165,11 @@ var tf;
cardinality: 1,
parentNode: null,
stats: null,
+ include: graph_1.InclusionType.UNSPECIFIED,
// BridgeNode properties.
inbound: inbound,
};
- structuralRenderInfo = new RenderNodeInformation(bridgeNode);
+ structuralRenderInfo = new RenderNodeInfo(bridgeNode);
structuralRenderInfo.structural = true;
_this.index[structuralNodeName] = structuralRenderInfo;
coreGraph.setNode(structuralNodeName, structuralRenderInfo);
@@ -2004,7 +2177,7 @@ var tf;
coreGraph.setParent(structuralNodeName, bridgeContainerName);
}
// Create the structural Metaedge and insert it.
- var structuralMetaedgeInfo = new RenderMetaedgeInformation(null);
+ var structuralMetaedgeInfo = new RenderMetaedgeInfo(null);
structuralMetaedgeInfo.structural = true;
structuralMetaedgeInfo.weight--; // Reduce weight for dagre layout.
inbound ?
@@ -2013,9 +2186,9 @@ var tf;
});
});
};
- return RenderGraphInformation;
+ return RenderGraphInfo;
})();
- render.RenderGraphInformation = RenderGraphInformation;
+ render.RenderGraphInfo = RenderGraphInfo;
/**
* A class for rendering annotation object which contains label
* about the node embedded as annotation, type of annotation and the location
@@ -2066,7 +2239,7 @@ var tf;
;
/**
* Manages a list of annotations. Two will be used for each
- * RenderNodeInformation, one for in annotations and one for out annotations.
+ * RenderNodeInfo, one for in annotations and one for out annotations.
*/
var AnnotationList = (function () {
function AnnotationList() {
@@ -2093,7 +2266,7 @@ var tf;
return;
}
var ellipsisNode = new tf.graph.EllipsisNodeImpl(1);
- this.list.push(new Annotation(ellipsisNode, new RenderNodeInformation(ellipsisNode), null, AnnotationType.ELLIPSIS, annotation.isIn));
+ this.list.push(new Annotation(ellipsisNode, new RenderNodeInfo(ellipsisNode), null, AnnotationType.ELLIPSIS, annotation.isIn));
};
return AnnotationList;
})();
@@ -2101,8 +2274,8 @@ var tf;
/**
* Contains rendering information about a node in the hierarchical graph.
*/
- var RenderNodeInformation = (function () {
- function RenderNodeInformation(node) {
+ var RenderNodeInfo = (function () {
+ function RenderNodeInfo(node) {
this.node = node;
this.expanded = false;
this.inAnnotations = new AnnotationList();
@@ -2127,31 +2300,30 @@ var tf;
this.paddingLeft = 0;
this.paddingRight = 0;
this.paddingBottom = 0;
- this.outerWidth = 0;
- this.outerHeight = 0;
this.isInExtract = false;
this.isOutExtract = false;
+ this.coreBox = { width: 0, height: 0 };
}
- RenderNodeInformation.prototype.isInCore = function () {
+ RenderNodeInfo.prototype.isInCore = function () {
return !this.isInExtract && !this.isOutExtract;
};
- return RenderNodeInformation;
+ return RenderNodeInfo;
})();
- render.RenderNodeInformation = RenderNodeInformation;
+ render.RenderNodeInfo = RenderNodeInfo;
/**
* Contains rendering information about a Metaedge from the underlying
* hierarchical graph. It may be from either a metagraph or a bridgegraph.
*/
- var RenderMetaedgeInformation = (function () {
- function RenderMetaedgeInformation(metaedge) {
+ var RenderMetaedgeInfo = (function () {
+ function RenderMetaedgeInfo(metaedge) {
this.metaedge = metaedge;
this.adjoiningMetaedge = null;
this.structural = false;
this.weight = 1;
}
- return RenderMetaedgeInformation;
+ return RenderMetaedgeInfo;
})();
- render.RenderMetaedgeInformation = RenderMetaedgeInformation;
+ render.RenderMetaedgeInfo = RenderMetaedgeInfo;
function addInAnnotation(node, predecessor, predecessorRenderInfo, edge, type, params) {
var annotation = new Annotation(predecessor, predecessorRenderInfo, edge, type, true);
node.inAnnotations.push(annotation, params);
@@ -2175,23 +2347,22 @@ var tf;
});
}
;
- var RenderGroupNodeInformation = (function (_super) {
- __extends(RenderGroupNodeInformation, _super);
- function RenderGroupNodeInformation(groupNode) {
+ var RenderGroupNodeInfo = (function (_super) {
+ __extends(RenderGroupNodeInfo, _super);
+ function RenderGroupNodeInfo(groupNode) {
_super.call(this, groupNode);
var metagraph = groupNode.metagraph;
var gl = metagraph.graph();
this.coreGraph =
graph_1.createGraph(gl.name, graph_1.GraphType.CORE, { compound: true });
- this.coreBox = { width: 0, height: 0 };
this.inExtractBox = { width: 0, height: 0 };
this.outExtractBox = { width: 0, height: 0 };
this.isolatedInExtract = [];
this.isolatedOutExtract = [];
}
- return RenderGroupNodeInformation;
- })(RenderNodeInformation);
- render.RenderGroupNodeInformation = RenderGroupNodeInformation;
+ return RenderGroupNodeInfo;
+ })(RenderNodeInfo);
+ render.RenderGroupNodeInfo = RenderGroupNodeInfo;
function setGroupNodeDepth(renderInfo, depth) {
if (renderInfo.coreGraph) {
setGraphDepth(renderInfo.coreGraph, depth);
@@ -2208,6 +2379,15 @@ var tf;
var src = graph.node(v);
var sink = graph.node(w);
var edge = graph.edge(v, w);
+ // If either of the nodes is explicitly included in the main graph and
+ // both nodes are in the main graph then do not create the shortcut
+ // and instead keep the real edge.
+ if ((src.node.include === graph_1.InclusionType.INCLUDE ||
+ sink.node.include === graph_1.InclusionType.INCLUDE) &&
+ src.node.include !== graph_1.InclusionType.EXCLUDE &&
+ sink.node.include !== graph_1.InclusionType.EXCLUDE) {
+ return;
+ }
// Add each annotation.
addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT, params);
addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT, params);
@@ -2218,48 +2398,55 @@ var tf;
* Remove edges from a node, and set its isOutExtract property to true,
* and remove the node and move it to isolatedOutExtract.
*
- * If detachAllEdgesForHighDegree is true, extract all of its edges.
- * Otherwise, only extract all in-edges.
+ * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its
+ * edges. Otherwise, only extract all in-edges.
*/
- function makeOutExtract(renderNode, n, params) {
+ function makeOutExtract(renderNode, n, params, forceDetach) {
var graph = renderNode.coreGraph;
- graph.node(n).isOutExtract = true;
+ var child = graph.node(n);
+ child.isOutExtract = true;
_.each(graph.predecessors(n), function (p, index) {
createShortcut(graph, p, n, params);
});
- if (params.detachAllEdgesForHighDegree) {
+ if (params.detachAllEdgesForHighDegree || forceDetach) {
_.each(graph.successors(n), function (s, index) {
createShortcut(graph, n, s, params);
});
}
- if (params.detachAllEdgesForHighDegree || graph.neighbors(n).length === 0) {
- renderNode.isolatedOutExtract.push(graph.node(n));
+ // Remove the node from the core graph if it no longer has neighbors.
+ if (graph.neighbors(n).length === 0) {
+ child.node.include = graph_1.InclusionType.EXCLUDE;
+ renderNode.isolatedOutExtract.push(child);
graph.removeNode(n);
}
}
/**
* Remove edges from a node, set its isInExtract property to true,
* and remove the node and move it to isolatedInExtract.
- * If detachAllEdgesForHighDegree is true, extract all of its edges.
- * Otherwise, only remove all out-edges.
+ *
+ * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its
+ * edges. Otherwise, only remove all out-edges.
*/
- function makeInExtract(renderNode, n, params) {
+ function makeInExtract(renderNode, n, params, forceDetach) {
var graph = renderNode.coreGraph;
- graph.node(n).isInExtract = true;
+ var child = graph.node(n);
+ child.isInExtract = true;
_.each(graph.successors(n), function (s, index) {
createShortcut(graph, n, s, params);
});
- if (params.detachAllEdgesForHighDegree) {
+ if (params.detachAllEdgesForHighDegree || forceDetach) {
_.each(graph.predecessors(n), function (p, index) {
createShortcut(graph, p, n, params);
});
}
- // Remove the node from the core graph if conditions are met.
- if (params.detachAllEdgesForHighDegree || graph.neighbors(n).length === 0) {
- renderNode.isolatedInExtract.push(graph.node(n));
+ // Remove the node from the core graph if it no longer has neighbors.
+ if (graph.neighbors(n).length === 0) {
+ child.node.include = graph_1.InclusionType.EXCLUDE;
+ renderNode.isolatedInExtract.push(child);
graph.removeNode(n);
}
}
+ render.makeInExtract = makeInExtract;
/**
* Check whether the node's type is a member of the given list of types.
*
@@ -2286,11 +2473,30 @@ var tf;
}
return false;
}
+ /** Move nodes that are speficied to be excluded out of the core graph. */
+ function extractSpecifiedNodes(renderNode, params) {
+ var graph = renderNode.coreGraph;
+ _.each(graph.nodes(), function (n) {
+ var renderInfo = graph.node(n);
+ if (renderInfo.node.include === graph_1.InclusionType.EXCLUDE) {
+ if (renderNode.coreGraph.outEdges(n).length >
+ renderNode.coreGraph.inEdges(n).length) {
+ makeOutExtract(renderNode, n, params, true);
+ }
+ else {
+ makeInExtract(renderNode, n, params, true);
+ }
+ }
+ });
+ }
/** Remove edges from pre-defined out-extract patterns */
function extractPredefinedSink(renderNode, params) {
var graph = renderNode.coreGraph;
_.each(graph.nodes(), function (n) {
var renderInfo = graph.node(n);
+ if (renderInfo.node.include !== graph_1.InclusionType.UNSPECIFIED) {
+ return;
+ }
if (hasTypeIn(renderInfo.node, params.outExtractTypes)) {
makeOutExtract(renderNode, n, params);
}
@@ -2301,6 +2507,9 @@ var tf;
var graph = renderNode.coreGraph;
_.each(graph.nodes(), function (n) {
var renderInfo = graph.node(n);
+ if (renderInfo.node.include !== graph_1.InclusionType.UNSPECIFIED) {
+ return;
+ }
if (hasTypeIn(renderInfo.node, params.inExtractTypes)) {
makeInExtract(renderNode, n, params);
}
@@ -2312,6 +2521,9 @@ var tf;
var maxInDegree = params.maxInDegree;
// detect first so degrees don't get affected by other removal
var highInDegreeNames = _.filter(graph.nodes(), function (n) {
+ if (graph.node(n).node.include !== graph_1.InclusionType.UNSPECIFIED) {
+ return false;
+ }
// Count the in-degree based on only regular edges, unless there are
// no regular edges, in which case use the number of control edges.
// This is done so that control edges don't effect if nodes are extracted
@@ -2335,6 +2547,9 @@ var tf;
var maxOutDegree = params.maxOutDegree;
// detect first so degrees don't get affected by other removal
var highOutDegreeNames = _.filter(graph.nodes(), function (n) {
+ if (graph.node(n).node.include !== graph_1.InclusionType.UNSPECIFIED) {
+ return false;
+ }
// Count the out-degree based on only regular edges, unless there are
// no regular edges, in which case use the number of control edges.
// This is done so that control edges don't effect if nodes are extracted
@@ -2400,6 +2615,7 @@ var tf;
* <tf-graph-params>'s output
*/
function extractHighDegrees(renderNode, params) {
+ extractSpecifiedNodes(renderNode, params);
if (params.outExtractTypes) {
extractPredefinedSink(renderNode, params);
}
@@ -2434,6 +2650,9 @@ var tf;
_.each(graph.nodes(), function (n) {
var child = graph.node(n);
var degree = graph.neighbors(n).length;
+ if (child.node.include !== graph_1.InclusionType.UNSPECIFIED) {
+ return;
+ }
if (degree === 0) {
var hasOutAnnotations = child.outAnnotations.list.length > 0;
var hasInAnnotations = child.inAnnotations.list.length > 0;
@@ -2441,23 +2660,27 @@ var tf;
// This case only happens if detachAllEdgesForHighDegree is false.
// (Otherwise all source-like nodes are all isolated already.)
renderNode.isolatedInExtract.push(child);
+ child.node.include = graph_1.InclusionType.EXCLUDE;
graph.removeNode(n);
}
else if (child.isOutExtract) {
// This case only happens if detachAllEdgesForHighDegree is false.
// // (Otherwise all sink-like nodes are all isolated already.)
renderNode.isolatedOutExtract.push(child);
+ child.node.include = graph_1.InclusionType.EXCLUDE;
graph.removeNode(n);
}
else if (params.extractIsolatedNodesWithAnnotationsOnOneSide) {
if (hasOutAnnotations && !hasInAnnotations) {
child.isInExtract = true; // for ones with high out-annotations
renderNode.isolatedInExtract.push(child);
+ child.node.include = graph_1.InclusionType.EXCLUDE;
graph.removeNode(n);
}
else if (hasInAnnotations && !hasOutAnnotations) {
child.isOutExtract = true; // for ones with high in-annotations
renderNode.isolatedOutExtract.push(child);
+ child.node.include = graph_1.InclusionType.EXCLUDE;
graph.removeNode(n);
}
else {
@@ -2470,7 +2693,21 @@ var tf;
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // close module tf.graph.render
</script>
-<script>/// <reference path="graph.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="graph.ts" />
/// <reference path="hierarchy.ts" />
var tf;
(function (tf) {
@@ -2568,6 +2805,7 @@ var tf;
function groupTemplateAndAssignId(nnGroups, verifyTemplate) {
// For each metanode, compare its subgraph (starting from shallower groups)
// and assign template id.
+ var result = {};
return _.reduce(nnGroups, function (templates, nnGroupPair) {
var signature = nnGroupPair[0], nnGroup = nnGroupPair[1].nodes, clusters = [];
nnGroup.forEach(function (metanode) {
@@ -2597,7 +2835,7 @@ var tf;
};
});
return templates;
- }, {});
+ }, result);
}
function sortNodes(names, graph, prefix) {
return _.sortByAll(names, function (name) {
@@ -2697,7 +2935,8 @@ var tf;
// compare metanode
var metanode1 = n1;
var metanode2 = n2;
- return metanode1.templateId && metanode2.templateId && metanode1.templateId === metanode2.templateId;
+ return metanode1.templateId && metanode2.templateId &&
+ metanode1.templateId === metanode2.templateId;
}
else if (n1.type === graph_1.NodeType.OP && n2.type === graph_1.NodeType.OP) {
// compare leaf node
@@ -2706,13 +2945,13 @@ var tf;
else if (n1.type === graph_1.NodeType.SERIES && n2.type === graph_1.NodeType.SERIES) {
// compare series node sizes and operations
// (only need to check one op as all op nodes are identical in series)
- var seriesnode1 = n1;
- var seriesnode2 = n2;
- var seriesnode1Count = seriesnode1.metagraph.nodeCount();
- return (seriesnode1Count === seriesnode2.metagraph.nodeCount() &&
+ var sn1 = n1;
+ var sn2 = n2;
+ var seriesnode1Count = sn1.metagraph.nodeCount();
+ return (seriesnode1Count === sn2.metagraph.nodeCount() &&
(seriesnode1Count === 0 ||
- (seriesnode1.metagraph.node(seriesnode1.metagraph.nodes()[0]).op ===
- seriesnode2.metagraph.node(seriesnode2.metagraph.nodes()[0]).op)));
+ (sn1.metagraph.node(sn1.metagraph.nodes()[0]).op ===
+ sn2.metagraph.node(sn2.metagraph.nodes()[0]).op)));
}
return false;
}
@@ -2720,7 +2959,21 @@ var tf;
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {}));
</script>
-<script>/// <reference path="../graph.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../graph.ts" />
/// <reference path="edge.ts" />
/// <reference path="node.ts" />
/// <reference path="../layout.ts" />
@@ -2824,8 +3077,8 @@ var tf;
* provided node.
*/
function panToNode(nodeName, svg, zoomG, d3zoom) {
- var node = d3.selectAll("[data-name='" + nodeName + "']."
- + scene.Class.Node.GROUP)[0][0];
+ var node = d3.select("[data-name='" + nodeName + "']."
+ + scene.Class.Node.GROUP).node();
if (!node) {
return false;
}
@@ -2993,8 +3246,7 @@ var tf;
position(sceneGroup, renderNode);
// Fade in the scene group if it didn't already exist.
if (isNewSceneGroup) {
- sceneGroup.attr("opacity", 0)
- .transition().attr("opacity", 1);
+ sceneGroup.attr("opacity", 0).transition().attr("opacity", 1);
}
return sceneGroup;
}
@@ -3017,17 +3269,17 @@ var tf;
// core
translate(selectChild(sceneGroup, "g", scene.Class.Scene.CORE), 0, yTranslate);
// in-extract
- var inExtractX = renderNode.coreBox.width === 0 ?
- 0 : renderNode.coreBox.width;
var hasInExtract = renderNode.isolatedInExtract.length > 0;
if (hasInExtract) {
+ var inExtractX = renderNode.coreBox.width -
+ renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width;
translate(selectChild(sceneGroup, "g", scene.Class.Scene.INEXTRACT), inExtractX, yTranslate);
}
// out-extract
var hasOutExtract = renderNode.isolatedOutExtract.length > 0;
if (hasOutExtract) {
- var outExtractX = inExtractX + renderNode.inExtractBox.width
- + renderNode.extractXOffset;
+ var outExtractX = renderNode.coreBox.width -
+ renderNode.outExtractBox.width / 2;
translate(selectChild(sceneGroup, "g", scene.Class.Scene.OUTEXTRACT), outExtractX, yTranslate);
}
}
@@ -3042,6 +3294,10 @@ var tf;
;
/** Helper for adding transform: translate(x0, y0) */
function translate(selection, x0, y0) {
+ // If it is already placed on the screen, make it a transition.
+ if (selection.attr("transform") != null) {
+ selection = selection.transition("position");
+ }
selection.attr("transform", "translate(" + x0 + "," + y0 + ")");
}
scene.translate = translate;
@@ -3071,10 +3327,15 @@ var tf;
* the button on.
*/
function positionButton(button, renderNode) {
+ var cx = graph.layout.computeCXPositionOfNodeShape(renderNode);
// Position the button in the top-right corner of the group node,
// with space given the draw the button inside of the corner.
- var x = renderNode.x + renderNode.width / 2 - 6;
- var y = renderNode.y - renderNode.height / 2 + 6;
+ var width = renderNode.expanded ?
+ renderNode.width : renderNode.coreBox.width;
+ var height = renderNode.expanded ?
+ renderNode.height : renderNode.coreBox.height;
+ var x = cx + width / 2 - 6;
+ var y = renderNode.y - height / 2 + 6;
// For unexpanded series nodes, the button has special placement due
// to the unique visuals of this group node.
if (renderNode.node.type === graph.NodeType.SERIES && !renderNode.expanded) {
@@ -3113,10 +3374,25 @@ var tf;
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // close module
</script>
-<script>/// <reference path="../graph.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../graph.ts" />
/// <reference path="../render.ts" />
/// <reference path="scene.ts" />
/// <reference path="edge.ts" />
+/// <reference path="contextmenu.ts" />
var tf;
(function (tf) {
var graph;
@@ -3124,7 +3400,7 @@ var tf;
var scene;
(function (scene) {
var annotation;
- (function (annotation) {
+ (function (annotation_1) {
/**
* Populate a given annotation container group
*
@@ -3191,7 +3467,7 @@ var tf;
var aGroup = d3.select(this);
update(aGroup, d, a, sceneBehavior);
if (a.annotationType !== tf.graph.render.AnnotationType.ELLIPSIS) {
- addInteraction(aGroup, d, sceneBehavior);
+ addInteraction(aGroup, d, a, sceneBehavior);
}
});
annotationGroups.exit()
@@ -3203,7 +3479,7 @@ var tf;
.remove();
return annotationGroups;
}
- annotation.buildGroup = buildGroup;
+ annotation_1.buildGroup = buildGroup;
;
/**
* Maps an annotation enum to a class name used in css rules.
@@ -3214,11 +3490,10 @@ var tf;
}
function buildShape(aGroup, a, sceneBehavior) {
if (a.annotationType === tf.graph.render.AnnotationType.SUMMARY) {
- var image = scene.selectOrCreateChild(aGroup, "image");
- image.attr({
- "xlink:href": sceneBehavior.resolveUrl("../../lib/svg/summary-icon.svg"),
- "height": "12px",
- "width": "12px",
+ var summary = scene.selectOrCreateChild(aGroup, "use");
+ summary.attr({
+ "class": "summary",
+ "xlink:href": "#summary-icon",
"cursor": "pointer"
});
}
@@ -3247,7 +3522,7 @@ var tf;
.text(label)
.append("title").text(titleText);
}
- function addInteraction(selection, d, sceneBehavior) {
+ function addInteraction(selection, d, annotation, sceneBehavior) {
selection
.on("mouseover", function (a) {
sceneBehavior.fire("annotation-highlight", {
@@ -3270,6 +3545,10 @@ var tf;
hostName: d.node.name
});
});
+ if (annotation.annotationType !== tf.graph.render.AnnotationType.SUMMARY &&
+ annotation.annotationType !== tf.graph.render.AnnotationType.CONSTANT) {
+ selection.on("contextmenu", tf.graph.scene.contextmenu.getMenu(tf.graph.scene.node.getContextMenu(annotation.node, sceneBehavior)));
+ }
}
;
/**
@@ -3281,6 +3560,7 @@ var tf;
* @param scene Polymer scene element.
*/
function update(aGroup, d, a, sceneBehavior) {
+ var cx = graph.layout.computeCXPositionOfNodeShape(d);
// Annotations that point to embedded nodes (constants,summary)
// don't have a render information attached so we don't stylize these.
// Also we don't stylize ellipsis annotations (the string "... and X more").
@@ -3294,7 +3574,7 @@ var tf;
}
// label position
aGroup.select("text." + scene.Class.Annotation.LABEL).transition().attr({
- x: d.x + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset),
+ x: cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset),
y: d.y + a.dy
});
// Some annotations (such as summary) are represented using a 12x12 image tag.
@@ -3302,19 +3582,19 @@ var tf;
// If there is an image, we adjust the location of the image to be vertically
// centered with the node and horizontally centered between the arrow and the
// text label.
- aGroup.select("image").transition().attr({
- x: d.x + a.dx - 3,
+ aGroup.select("use.summary").transition().attr({
+ x: cx + a.dx - 3,
y: d.y + a.dy - 6
});
// Node position (only one of the shape selection will be non-empty.)
- scene.positionEllipse(aGroup.select("." + scene.Class.Annotation.NODE + " ellipse"), d.x + a.dx, d.y + a.dy, a.width, a.height);
- scene.positionRect(aGroup.select("." + scene.Class.Annotation.NODE + " rect"), d.x + a.dx, d.y + a.dy, a.width, a.height);
- scene.positionRect(aGroup.select("." + scene.Class.Annotation.NODE + " use"), d.x + a.dx, d.y + a.dy, a.width, a.height);
+ scene.positionEllipse(aGroup.select("." + scene.Class.Annotation.NODE + " ellipse"), cx + a.dx, d.y + a.dy, a.width, a.height);
+ scene.positionRect(aGroup.select("." + scene.Class.Annotation.NODE + " rect"), cx + a.dx, d.y + a.dy, a.width, a.height);
+ scene.positionRect(aGroup.select("." + scene.Class.Annotation.NODE + " use"), cx + a.dx, d.y + a.dy, a.width, a.height);
// Edge position
aGroup.select("path." + scene.Class.Annotation.EDGE).transition().attr("d", function (a) {
// map relative position to absolute position
var points = a.points.map(function (p) {
- return { x: p.dx + d.x, y: p.dy + d.y };
+ return { x: p.dx + cx, y: p.dy + d.y };
});
return scene.edge.interpolate(points);
});
@@ -3325,7 +3605,21 @@ var tf;
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // close module
</script>
-<script>/// <reference path="../graph.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../graph.ts" />
/// <reference path="../render.ts" />
/// <reference path="scene.ts" />
var tf;
@@ -3336,7 +3630,6 @@ var tf;
(function (scene) {
var edge;
(function (edge) {
- var Scene = tf.graph.scene; // Aliased
function getEdgeKey(edgeObj) {
return edgeObj.v + tf.graph.EDGE_KEY_DELIM + edgeObj.w;
}
@@ -3361,6 +3654,7 @@ var tf;
* @return selection of the created nodeGroups
*/
function buildGroup(sceneGroup, graph, sceneBehavior) {
+ var edges = [];
var edgeData = _.reduce(graph.edges(), function (edges, edgeObj) {
var edgeLabel = graph.edge(edgeObj);
edges.push({
@@ -3369,9 +3663,8 @@ var tf;
label: edgeLabel
});
return edges;
- }, []);
+ }, edges);
var container = scene.selectOrCreateChild(sceneGroup, "g", scene.Class.Edge.CONTAINER);
- var containerNode = container.node();
// Select all children and join with data.
// (Note that all children of g.edges are g.edge)
var edgeGroups = container.selectAll(function () {
@@ -3416,7 +3709,7 @@ var tf;
* For a given d3 selection and data object, create a path to represent the
* edge described in d.label.
*
- * If d.label is defined, it will be a RenderMetaedgeInformation instance. It
+ * If d.label is defined, it will be a RenderMetaedgeInfo instance. It
* will sometimes be undefined, for example for some Annotation edges for which
* there is no underlying Metaedge in the hierarchical graph.
*/
@@ -3430,6 +3723,10 @@ var tf;
}
edge.appendEdge = appendEdge;
;
+ edge.interpolate = d3.svg.line()
+ .interpolate("basis")
+ .x(function (d) { return d.x; })
+ .y(function (d) { return d.y; });
/**
* Returns a tween interpolator for the endpoint of an edge path.
*/
@@ -3462,10 +3759,6 @@ var tf;
return dPath;
};
}
- edge.interpolate = d3.svg.line()
- .interpolate("basis")
- .x(function (d) { return d.x; })
- .y(function (d) { return d.y; });
function position(d) {
d3.select(this).select("path." + scene.Class.Edge.LINE)
.each(function (d) {
@@ -3478,10 +3771,9 @@ var tf;
* For a given d3 selection and data object, mark the edge as a control
* dependency if it contains only control edges.
*
- * d's label property will be a RenderMetaedgeInformation object.
+ * d's label property will be a RenderMetaedgeInfo object.
*/
function stylize(edgeGroup, d, stylize) {
- var a;
var metaedge = d.label.metaedge;
edgeGroup
.select("path." + scene.Class.Edge.LINE)
@@ -3493,9 +3785,24 @@ var tf;
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // close module
</script>
-<script>/// <reference path="../graph.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../graph.ts" />
/// <reference path="scene.ts" />
/// <reference path="annotation.ts" />
+/// <reference path="contextmenu.ts" />
var tf;
(function (tf) {
var graph;
@@ -3693,6 +4000,7 @@ var tf;
selection.attr("pointer-events", "none");
return;
}
+ var contextMenuFunction = tf.graph.scene.contextmenu.getMenu(getContextMenu(d.node, sceneBehavior));
selection.on("dblclick", function (d) {
sceneBehavior.fire("node-toggle-expand", { name: d.node.name });
})
@@ -3717,10 +4025,28 @@ var tf;
// a graph-select.
d3.event.stopPropagation();
sceneBehavior.fire("node-select", { name: d.node.name });
+ })
+ .on("contextmenu", function (d, i) {
+ sceneBehavior.fire("node-select", { name: d.node.name });
+ contextMenuFunction.call(d, i);
});
}
;
/**
+ * Returns the d3 context menu specification for the provided node.
+ */
+ function getContextMenu(node, sceneBehavior) {
+ return [{
+ title: function (d) {
+ return tf.graph.getIncludeNodeButtonString(node.include);
+ },
+ action: function (elm, d, i) {
+ sceneBehavior.fire("node-toggle-extract", { name: node.name });
+ }
+ }];
+ }
+ node_1.getContextMenu = getContextMenu;
+ /**
* Append svg text for label and assign data.
* @param nodeGroup
* @param renderNodeInfo The render node information for the label.
@@ -3764,10 +4090,10 @@ var tf;
/**
* Set label position of a given node group
*/
- function labelPosition(nodeGroup, d, yOffset) {
+ function labelPosition(nodeGroup, cx, cy, yOffset) {
scene.selectChild(nodeGroup, "text", scene.Class.Node.LABEL).transition()
- .attr("x", d.x)
- .attr("y", d.y + yOffset);
+ .attr("x", cx)
+ .attr("y", cy + yOffset);
}
;
/**
@@ -3775,7 +4101,7 @@ var tf;
* as the shape's data.
*
* @param nodeGroup
- * @param d RenderNodeInformation
+ * @param d Render node information.
* @param nodeClass class for the element.
* @param before Reference DOM node for insertion.
* @return Selection of the shape.
@@ -3837,38 +4163,41 @@ var tf;
/** Modify node and its subscene and its label's positional attributes */
function position(nodeGroup, d, sceneBehavior) {
var shapeGroup = scene.selectChild(nodeGroup, "g", scene.Class.Node.SHAPE);
+ var cx = graph.layout.computeCXPositionOfNodeShape(d);
switch (d.node.type) {
case graph.NodeType.OP: {
// position shape
var shape = scene.selectChild(shapeGroup, "ellipse");
- scene.positionEllipse(shape, d.x, d.y, d.width, d.height);
- labelPosition(nodeGroup, d, d.labelOffset);
+ scene.positionEllipse(shape, cx, d.y, d.coreBox.width, d.coreBox.height);
+ labelPosition(nodeGroup, cx, d.y, d.labelOffset);
break;
}
case graph.NodeType.META: {
// position shape
var shape = scene.selectChild(shapeGroup, "rect");
- scene.positionRect(shape, d.x, d.y, d.width, d.height);
if (d.expanded) {
+ scene.positionRect(shape, d.x, d.y, d.width, d.height);
subscenePosition(nodeGroup, d);
// put label on top
- labelPosition(nodeGroup, d, -d.height / 2 + d.labelHeight / 2);
+ labelPosition(nodeGroup, cx, d.y, -d.height / 2 + d.labelHeight / 2);
}
else {
- labelPosition(nodeGroup, d, 0);
+ scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height);
+ labelPosition(nodeGroup, cx, d.y, 0);
}
break;
}
case graph.NodeType.SERIES: {
var shape = scene.selectChild(shapeGroup, "use");
- scene.positionRect(shape, d.x, d.y, d.width, d.height);
if (d.expanded) {
+ scene.positionRect(shape, d.x, d.y, d.width, d.height);
subscenePosition(nodeGroup, d);
// put label on top
- labelPosition(nodeGroup, d, -d.height / 2 + d.labelHeight / 2);
+ labelPosition(nodeGroup, cx, d.y, -d.height / 2 + d.labelHeight / 2);
}
else {
- labelPosition(nodeGroup, d, d.labelOffset);
+ scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height);
+ labelPosition(nodeGroup, cx, d.y, d.labelOffset);
}
}
case graph.NodeType.BRIDGE: {
@@ -3886,29 +4215,33 @@ var tf;
}
;
/** Enum specifying the options to color nodes by */
- var ColorBy = {
- STRUCTURE: 0,
- DEVICE: 1,
- COMPUTE_TIME: 2,
- MEMORY: 3
- };
+ (function (ColorBy) {
+ ColorBy[ColorBy["STRUCTURE"] = 0] = "STRUCTURE";
+ ColorBy[ColorBy["DEVICE"] = 1] = "DEVICE";
+ ColorBy[ColorBy["COMPUTE_TIME"] = 2] = "COMPUTE_TIME";
+ ColorBy[ColorBy["MEMORY"] = 3] = "MEMORY";
+ })(node_1.ColorBy || (node_1.ColorBy = {}));
+ var ColorBy = node_1.ColorBy;
+ ;
/**
* Returns the fill color for the node given its state and the "color by"
* option.
*/
- function getFillForNode(sceneBehavior, colorBy, renderInfo, isExpanded) {
+ function getFillForNode(templateIndex, colorBy, renderInfo, isExpanded) {
var colorParams = tf.graph.render.MetanodeColors;
switch (colorBy) {
case ColorBy.STRUCTURE:
if (renderInfo.node.type === tf.graph.NodeType.META) {
var tid = renderInfo.node.templateId;
- return tid === null ? colorParams.UNKNOWN : colorParams.STRUCTURE_PALETTE(sceneBehavior.templateIndex(tid), renderInfo.expanded);
+ return tid === null ?
+ colorParams.UNKNOWN :
+ colorParams.STRUCTURE_PALETTE(templateIndex(tid), isExpanded);
}
else if (renderInfo.node.type === tf.graph.NodeType.SERIES) {
// If expanded, we're showing the background rect, which we want to
// appear gray. Otherwise we're showing a stack of ellipses which we
// want to show white.
- return renderInfo.expanded ? colorParams.EXPANDED_COLOR : "white";
+ return isExpanded ? colorParams.EXPANDED_COLOR : "white";
}
else if (renderInfo.node.type === graph.NodeType.BRIDGE) {
return renderInfo.structural ? "#f0e" :
@@ -3958,6 +4291,7 @@ var tf;
throw new Error("Unknown case to color nodes by");
}
}
+ node_1.getFillForNode = getFillForNode;
/**
* Modify node style by toggling class and assign attributes (only for things
* that can't be done in css).
@@ -3975,29 +4309,111 @@ var tf;
// Main node always exists here and it will be reached before subscene,
// so d3 selection is fine here.
var node = nodeGroup.select("." + nodeClass + " ." + scene.Class.Node.COLOR_TARGET);
- var fillColor = getFillForNode(sceneBehavior, ColorBy[sceneBehavior.colorBy.toUpperCase()], renderInfo, isExpanded);
+ var fillColor = getFillForNode(sceneBehavior.templateIndex, ColorBy[sceneBehavior.colorBy.toUpperCase()], renderInfo, isExpanded);
node.style("fill", fillColor);
// Choose outline to be darker version of node color if the node is a single
// color and is not selected.
- if (isSelected) {
- node.style("stroke", null);
- }
- else {
- // If node is colored by a gradient, then use a dark gray outline.
- var outlineColor = fillColor.substring(0, 3) === "url" ?
- tf.graph.render.MetanodeColors.GRADIENT_OUTLINE :
- d3.rgb(fillColor).darker().toString();
- node.style("stroke", outlineColor);
- }
+ node.style("stroke", isSelected ? null : getStrokeForFill(fillColor));
}
node_1.stylize = stylize;
;
+ /**
+ * Given a node's fill color/gradient, determine the stroke for the node.
+ */
+ function getStrokeForFill(fill) {
+ // If node is colored by a gradient, then use a dark gray outline.
+ return fill.substring(0, 3) === "url" ?
+ tf.graph.render.MetanodeColors.GRADIENT_OUTLINE :
+ d3.rgb(fill).darker().toString();
+ }
+ node_1.getStrokeForFill = getStrokeForFill;
})(node = scene.node || (scene.node = {}));
})(scene = graph.scene || (graph.scene = {}));
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // close module
</script>
-<script>/// <reference path="graph.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+var tf;
+(function (tf) {
+ var graph;
+ (function (graph) {
+ var scene;
+ (function (scene) {
+ var contextmenu;
+ (function (contextmenu) {
+ /**
+ * Returns the event listener, which can be used as an argument for the d3
+ * selection.on function. Renders the context menu that is to be displayed
+ * in response to the event.
+ */
+ function getMenu(menu) {
+ var menuSelection = d3.select(".context-menu");
+ // Close the menu when anything else is clicked.
+ d3.select("body").on("click.context", function () {
+ menuSelection.style("display", "none");
+ });
+ // Function called to populate the context menu.
+ return function (data, index) {
+ var _this = this;
+ // Position and display the menu.
+ var event = d3.event;
+ menuSelection.style({
+ "display": "block",
+ "left": (event.layerX + 1) + "px",
+ "top": (event.layerY + 1) + "px"
+ });
+ // Stop the event from propagating further.
+ event.preventDefault();
+ event.stopPropagation();
+ // Add provided items to the context menu.
+ menuSelection.html("");
+ var list = menuSelection.append("ul");
+ list.selectAll("li").data(menu).enter()
+ .append("li")
+ .html(function (d) {
+ return d.title(data);
+ })
+ .on("click", function (d, i) {
+ d.action(_this, data, index);
+ menuSelection.style("display", "none");
+ });
+ };
+ }
+ contextmenu.getMenu = getMenu;
+ ;
+ })(contextmenu = scene.contextmenu || (scene.contextmenu = {}));
+ })(scene = graph.scene || (graph.scene = {}));
+ })(graph = tf.graph || (tf.graph = {}));
+})(tf || (tf = {})); // close module
+</script>
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="graph.ts" />
/// <reference path="render.ts" />
var tf;
(function (tf) {
@@ -4020,14 +4436,19 @@ var tf;
*
* See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
*/
- nodeSep: 110,
+ nodeSep: 5,
/**
* Dagre's ranksep param - number of pixels
* between each rank in the layout.
*
* See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
*/
- rankSep: 25
+ rankSep: 25,
+ /**
+ * Dagre's edgesep param - number of pixels that separate
+ * edges horizontally in the layout.
+ */
+ edgeSep: 5,
},
/** Graph parameter for metanode. */
series: {
@@ -4037,7 +4458,7 @@ var tf;
*
* See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
*/
- nodeSep: 90,
+ nodeSep: 5,
/**
* Dagre's ranksep param - number of pixels
* between each rank in the layout.
@@ -4045,6 +4466,11 @@ var tf;
* See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
*/
rankSep: 25,
+ /**
+ * Dagre's edgesep param - number of pixels that separate
+ * edges horizontally in the layout.
+ */
+ edgeSep: 5
},
/**
* Padding is used to correctly position the graph SVG inside of its parent
@@ -4153,6 +4579,10 @@ var tf;
}
},
annotations: {
+ /** Maximum possible width of the bounding box for in annotations */
+ inboxWidth: 50,
+ /** Maximum possible width of the bounding box for out annotations */
+ outboxWidth: 50,
/** X-space between the shape and each annotation-node. */
xOffset: 10,
/** Y-space between each annotation-node. */
@@ -4188,7 +4618,7 @@ var tf;
}
};
/** Calculate layout for a scene of a group node. */
- function scene(renderNodeInfo) {
+ function layoutScene(renderNodeInfo) {
// Update layout, size, and annotations of its children nodes and edges.
if (renderNodeInfo.node.isGroupNode) {
layoutChildren(renderNodeInfo);
@@ -4201,9 +4631,29 @@ var tf;
layoutSeriesNode(renderNodeInfo);
}
}
- layout.scene = scene;
+ layout.layoutScene = layoutScene;
;
/**
+ * Updates the total width of an unexpanded node which includes the size of its
+ * in and out annotations.
+ */
+ function updateTotalWidthOfNode(renderInfo) {
+ renderInfo.inboxWidth = renderInfo.inAnnotations.list.length > 0 ?
+ layout.PARAMS.annotations.inboxWidth : 0;
+ renderInfo.outboxWidth = renderInfo.outAnnotations.list.length > 0 ?
+ layout.PARAMS.annotations.outboxWidth : 0;
+ // Assign the width of the core box (the main shape of the node).
+ renderInfo.coreBox.width = renderInfo.width;
+ renderInfo.coreBox.height = renderInfo.height;
+ // TODO(jimbo): Account for font width rather than using a magic number.
+ var labelLength = renderInfo.node.name.length -
+ renderInfo.node.name.lastIndexOf(graph_1.NAMESPACE_DELIM) - 1;
+ var charWidth = 3; // 3 pixels per character.
+ // Compute the total width of the node.
+ renderInfo.width = Math.max(renderInfo.coreBox.width +
+ renderInfo.inboxWidth + renderInfo.outboxWidth, labelLength * charWidth);
+ }
+ /**
* Update layout, size, and annotations of its children nodes and edges.
*/
function layoutChildren(renderNodeInfo) {
@@ -4221,21 +4671,21 @@ var tf;
break;
case graph_1.NodeType.META:
if (!childNodeInfo.expanded) {
- // set fixed width and scalable height based on cardinality
+ // Set fixed width and scalable height based on cardinality
_.extend(childNodeInfo, layout.PARAMS.nodeSize.meta);
childNodeInfo.height =
layout.PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality);
}
else {
var childGroupNodeInfo = childNodeInfo;
- scene(childGroupNodeInfo); // Recursively layout its subscene.
+ layoutScene(childGroupNodeInfo); // Recursively layout its subscene.
}
break;
case graph_1.NodeType.SERIES:
if (childNodeInfo.expanded) {
_.extend(childNodeInfo, layout.PARAMS.nodeSize.series.expanded);
var childGroupNodeInfo = childNodeInfo;
- scene(childGroupNodeInfo); // Recursively layout its subscene.
+ layoutScene(childGroupNodeInfo); // Recursively layout its subscene.
}
else {
var childGroupNodeInfo = childNodeInfo;
@@ -4248,6 +4698,11 @@ var tf;
default:
throw Error("Unrecognized node type: " + childNodeInfo.node.type);
}
+ // Compute total width of un-expanded nodes. Width of expanded nodes
+ // has already been computed.
+ if (!childNodeInfo.expanded) {
+ updateTotalWidthOfNode(childNodeInfo);
+ }
// Layout each child's annotations
layoutAnnotation(childNodeInfo);
});
@@ -4260,8 +4715,9 @@ var tf;
*/
function dagreLayout(graph, params) {
_.extend(graph.graph(), {
- nodeSep: params.nodeSep,
- rankSep: params.rankSep
+ nodesep: params.nodeSep,
+ ranksep: params.rankSep,
+ edgesep: params.edgeSep
});
var bridgeNodeNames = [];
var nonBridgeNodeNames = [];
@@ -4284,7 +4740,6 @@ var tf;
};
}
dagre.layout(graph);
- var graphLabel = graph.graph();
// Calculate the true bounding box of the graph by iterating over nodes and
// edges rather than accepting dagre's word for it. In particular, we should
// ignore the extra-wide bridge nodes and bridge edges, and allow for
@@ -4296,31 +4751,62 @@ var tf;
_.each(nonBridgeNodeNames, function (nodeName) {
var nodeInfo = graph.node(nodeName);
var w = 0.5 * nodeInfo.width;
- var x1 = nodeInfo.x - w - nodeInfo.inboxWidth;
- var x2 = nodeInfo.x + w + nodeInfo.outboxWidth;
+ var x1 = nodeInfo.x - w;
+ var x2 = nodeInfo.x + w;
minX = x1 < minX ? x1 : minX;
maxX = x2 > maxX ? x2 : maxX;
- var labelLength = nodeName.length - nodeName.lastIndexOf(graph_1.NAMESPACE_DELIM);
- // TODO(jimbo): Account for font width rather than using a magic number.
- var charWidth = 3; // 3 pixels per character.
- var lw = 0.5 * labelLength * charWidth;
- var lx1 = nodeInfo.x - lw;
- var lx2 = nodeInfo.x + lw;
- minX = lx1 < minX ? lx1 : minX;
- maxX = lx2 > maxX ? lx2 : maxX;
// TODO(jimbo): Account for the height of labels above op nodes here.
- var h = 0.5 * nodeInfo.outerHeight;
+ var h = 0.5 * nodeInfo.height;
var y1 = nodeInfo.y - h;
var y2 = nodeInfo.y + h;
minY = y1 < minY ? y1 : minY;
maxY = y2 > maxY ? y2 : maxY;
});
_.each(graph.edges(), function (edgeObj) {
- var renderMetaedgeInfo = graph.edge(edgeObj);
- if (renderMetaedgeInfo.structural) {
+ var edgeInfo = graph.edge(edgeObj);
+ if (edgeInfo.structural) {
return; // Skip structural edges from min/max calculations.
}
- _.each(renderMetaedgeInfo.points, function (point) {
+ // Since the node size passed to dagre includes the in and out
+ // annotations, the endpoints of the edge produced by dagre may not
+ // point to the actual node shape (rectangle, ellipse). We correct the
+ // end-points by finding the intersection of a line between the
+ // next-to-last (next-to-first) point and the destination (source)
+ // rectangle.
+ var sourceNode = graph.node(edgeInfo.metaedge.v);
+ var destNode = graph.node(edgeInfo.metaedge.w);
+ // Straight 3-points edges are special case, since they are curved after
+ // our default correction. To keep them straight, we remove the mid point
+ // and correct the first and the last point to be the center of the
+ // source and destination node respectively.
+ if (edgeInfo.points.length === 3 && isStraightLine(edgeInfo.points)) {
+ if (sourceNode != null) {
+ var cxSource = sourceNode.expanded ?
+ sourceNode.x : computeCXPositionOfNodeShape(sourceNode);
+ edgeInfo.points[0].x = cxSource;
+ }
+ if (destNode != null) {
+ var cxDest = destNode.expanded ?
+ destNode.x : computeCXPositionOfNodeShape(destNode);
+ edgeInfo.points[2].x = cxDest;
+ }
+ // Remove the middle point so the edge doesn't curve.
+ edgeInfo.points = [edgeInfo.points[0], edgeInfo.points[1]];
+ }
+ // Correct the destination endpoint of the edge.
+ var nextToLastPoint = edgeInfo.points[edgeInfo.points.length - 2];
+ // The destination node might be null if this is a bridge edge.
+ if (destNode != null) {
+ edgeInfo.points[edgeInfo.points.length - 1] =
+ intersectPointAndNode(nextToLastPoint, destNode);
+ }
+ // Correct the source endpoint of the edge.
+ var secondPoint = edgeInfo.points[1];
+ // The source might be null if this is a bridge edge.
+ if (sourceNode != null) {
+ edgeInfo.points[0] = intersectPointAndNode(secondPoint, sourceNode);
+ }
+ _.each(edgeInfo.points, function (point) {
minX = point.x < minX ? point.x : minX;
maxX = point.x > maxX ? point.x : maxX;
minY = point.y < minY ? point.y : minY;
@@ -4342,59 +4828,59 @@ var tf;
});
return {
width: maxX - minX,
- height: maxY - minY,
+ height: maxY - minY
};
}
- /** Layout a metanode. */
+ /** Layout a metanode. Only called for an expanded node. */
function layoutMetanode(renderNodeInfo) {
// First, copy params specific to meta nodes onto this render info object.
var params = layout.PARAMS.subscene.meta;
- renderNodeInfo = _.extend(renderNodeInfo, params);
+ _.extend(renderNodeInfo, params);
// Invoke dagre.layout() on the core graph and record the bounding box
// dimensions.
_.extend(renderNodeInfo.coreBox, dagreLayout(renderNodeInfo.coreGraph, layout.PARAMS.graph.meta));
// Calculate the position of nodes in isolatedInExtract relative to the
// top-left corner of inExtractBox (the bounding box for all inExtract nodes)
// and calculate the size of the inExtractBox.
- var hasInExtract = renderNodeInfo.isolatedInExtract.length > 0;
- renderNodeInfo.inExtractBox.width = hasInExtract ?
- _(renderNodeInfo.isolatedInExtract).pluck("outerWidth").max() : 0;
+ var maxInExtractWidth = _.max(renderNodeInfo.isolatedInExtract, function (renderNode) { return renderNode.width; }).width;
+ renderNodeInfo.inExtractBox.width = maxInExtractWidth != null ?
+ maxInExtractWidth : 0;
renderNodeInfo.inExtractBox.height =
_.reduce(renderNodeInfo.isolatedInExtract, function (height, child, i) {
var yOffset = i > 0 ? params.extractYOffset : 0;
- // use outerWidth/Height here to avoid overlaps between extracts
- child.x = renderNodeInfo.inExtractBox.width / 2;
- child.y = height + yOffset + child.outerHeight / 2;
- return height + yOffset + child.outerHeight;
+ // use width/height here to avoid overlaps between extracts
+ child.x = 0;
+ child.y = height + yOffset + child.height / 2;
+ return height + yOffset + child.height;
}, 0);
// Calculate the position of nodes in isolatedOutExtract relative to the
// top-left corner of outExtractBox (the bounding box for all outExtract
// nodes) and calculate the size of the outExtractBox.
- var hasOutExtract = renderNodeInfo.isolatedOutExtract.length > 0;
- renderNodeInfo.outExtractBox.width = hasOutExtract ?
- _(renderNodeInfo.isolatedOutExtract).pluck("outerWidth").max() : 0;
+ var maxOutExtractWidth = _.max(renderNodeInfo.isolatedOutExtract, function (renderNode) { return renderNode.width; }).width;
+ renderNodeInfo.outExtractBox.width = maxOutExtractWidth != null ?
+ maxOutExtractWidth : 0;
renderNodeInfo.outExtractBox.height =
_.reduce(renderNodeInfo.isolatedOutExtract, function (height, child, i) {
var yOffset = i > 0 ? params.extractYOffset : 0;
- // use outerWidth/Height here to avoid overlaps between extracts
- child.x = renderNodeInfo.outExtractBox.width / 2;
- child.y = height + yOffset + child.outerHeight / 2;
- return height + yOffset + child.outerHeight;
+ // use width/height here to avoid overlaps between extracts
+ child.x = 0;
+ child.y = height + yOffset + child.height / 2;
+ return height + yOffset + child.height;
}, 0);
+ // Add the in-extract and out-extract width to the core box width.
+ renderNodeInfo.coreBox.width += renderNodeInfo.inExtractBox.width +
+ renderNodeInfo.outExtractBox.width;
+ renderNodeInfo.coreBox.height =
+ params.labelHeight +
+ Math.max(renderNodeInfo.inExtractBox.height, renderNodeInfo.coreBox.height, renderNodeInfo.outExtractBox.height);
// Determine the whole metanode's width (from left to right).
- renderNodeInfo.width =
- params.paddingLeft + renderNodeInfo.coreBox.width + params.paddingRight +
- (hasInExtract ?
- renderNodeInfo.inExtractBox.width + params.extractXOffset : 0) +
- (hasOutExtract ?
- params.extractXOffset + renderNodeInfo.outExtractBox.width : 0);
- // TODO(jimbo): Remove labelHeight and instead incorporate into box sizes.
+ renderNodeInfo.width = renderNodeInfo.coreBox.width +
+ params.paddingLeft + params.paddingRight;
// Determine the whole metanode's height (from top to bottom).
renderNodeInfo.height =
- renderNodeInfo.labelHeight +
- params.paddingTop +
- Math.max(renderNodeInfo.inExtractBox.height, renderNodeInfo.coreBox.height, renderNodeInfo.outExtractBox.height) +
- params.paddingBottom;
+ renderNodeInfo.paddingTop +
+ renderNodeInfo.coreBox.height +
+ renderNodeInfo.paddingBottom;
}
/**
* Calculate layout for series node's core graph. Only called for an expanded
@@ -4415,7 +4901,7 @@ var tf;
}
/**
* Calculate layout for annotations of a given node.
- * This will modify positions of the the given node and its annotations.
+ * This will modify positions of the given node and its annotations.
*
* @see tf.graph.render.Node and tf.graph.render.Annotation
* for description of each property of each render node.
@@ -4425,14 +4911,6 @@ var tf;
// If the render node is an expanded metanode, then its annotations will not
// be visible and we should skip the annotation calculations.
if (renderNodeInfo.expanded) {
- _.extend(renderNodeInfo, {
- inboxWidth: 0,
- inboxHeight: 0,
- outboxWidth: 0,
- outboxHeight: 0,
- outerWidth: renderNodeInfo.width,
- outerHeight: renderNodeInfo.height
- });
return;
}
var inAnnotations = renderNodeInfo.inAnnotations.list;
@@ -4442,23 +4920,13 @@ var tf;
// Calculate size for out-annotations
_.each(outAnnotations, function (a) { return sizeAnnotation(a); });
var params = layout.PARAMS.annotations;
- renderNodeInfo.inboxWidth =
- inAnnotations.length > 0 ?
- _(inAnnotations).pluck("width").max() +
- params.xOffset + params.labelWidth + params.labelOffset :
- 0;
- renderNodeInfo.outboxWidth =
- outAnnotations.length > 0 ?
- _(outAnnotations).pluck("width").max() +
- params.xOffset + params.labelWidth + params.labelOffset :
- 0;
// Calculate annotation node position (a.dx, a.dy)
// and total height for in-annotations
// After this chunk of code:
// inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset)
var inboxHeight = _.reduce(inAnnotations, function (height, a, i) {
var yOffset = i > 0 ? params.yOffset : 0;
- a.dx = -(renderNodeInfo.width + a.width) / 2 - params.xOffset;
+ a.dx = -(renderNodeInfo.coreBox.width + a.width) / 2 - params.xOffset;
a.dy = height + yOffset + a.height / 2;
return height + yOffset + a.height;
}, 0);
@@ -4473,7 +4941,7 @@ var tf;
// (annotation.length - 1 * yOffset)
var outboxHeight = _.reduce(outAnnotations, function (height, a, i) {
var yOffset = i > 0 ? params.yOffset : 0;
- a.dx = (renderNodeInfo.width + a.width) / 2 + params.xOffset;
+ a.dx = (renderNodeInfo.coreBox.width + a.width) / 2 + params.xOffset;
a.dy = height + yOffset + a.height / 2;
return height + yOffset + a.height;
}, 0);
@@ -4500,7 +4968,7 @@ var tf;
},
// The host node end
{
- dx: -renderNodeInfo.width / 2,
+ dx: -renderNodeInfo.coreBox.width / 2,
// only use scale if there are more than one,
// otherwise center it vertically
dy: inAnnotations.length > 1 ? inY(i) : 0
@@ -4519,7 +4987,7 @@ var tf;
a.points = [
// The host node end
{
- dx: renderNodeInfo.width / 2,
+ dx: renderNodeInfo.coreBox.width / 2,
// only use scale if there are more than one,
// otherwise center it vertically
dy: outAnnotations.length > 1 ? outY(i) : 0
@@ -4531,9 +4999,7 @@ var tf;
}
];
});
- renderNodeInfo.outerWidth = renderNodeInfo.width + renderNodeInfo.inboxWidth +
- renderNodeInfo.outboxWidth;
- renderNodeInfo.outerHeight =
+ renderNodeInfo.height =
Math.max(renderNodeInfo.height, inboxHeight, outboxHeight);
}
/**
@@ -4563,11 +5029,92 @@ var tf;
break;
}
}
+ /**
+ * Determines the center position of the node's shape. The position depends
+ * on if the node has in and out-annotations.
+ */
+ function computeCXPositionOfNodeShape(renderInfo) {
+ if (renderInfo.expanded) {
+ return renderInfo.x;
+ }
+ var dx = renderInfo.inAnnotations.list.length ? renderInfo.inboxWidth : 0;
+ return renderInfo.x - renderInfo.width / 2 + dx +
+ renderInfo.coreBox.width / 2;
+ }
+ layout.computeCXPositionOfNodeShape = computeCXPositionOfNodeShape;
+ /** Returns the angle (in degrees) between two points. */
+ function angleBetweenTwoPoints(a, b) {
+ var dx = b.x - a.x;
+ var dy = b.y - a.y;
+ return 180 * Math.atan(dy / dx) / Math.PI;
+ }
+ /**
+ * Returns if a line going through the specified points is a straight line.
+ */
+ function isStraightLine(points) {
+ var angle = angleBetweenTwoPoints(points[0], points[1]);
+ for (var i = 1; i < points.length - 1; i++) {
+ var newAngle = angleBetweenTwoPoints(points[i], points[i + 1]);
+ // Have a tolerance of 1 degree.
+ if (Math.abs(newAngle - angle) > 1) {
+ return false;
+ }
+ angle = newAngle;
+ }
+ return true;
+ }
+ /**
+ * Returns the intersection of a line between the provided point
+ * and the provided rectangle.
+ */
+ function intersectPointAndNode(point, node) {
+ // cx and cy are the center of the rectangle.
+ var cx = node.expanded ?
+ node.x : computeCXPositionOfNodeShape(node);
+ var cy = node.y;
+ // Calculate the slope
+ var dx = point.x - cx;
+ var dy = point.y - cy;
+ var w = node.expanded ? node.width : node.coreBox.width;
+ var h = node.expanded ? node.height : node.coreBox.height;
+ var deltaX, deltaY;
+ if (Math.abs(dy) * w / 2 > Math.abs(dx) * h / 2) {
+ // The intersection is above or below the rectangle.
+ if (dy < 0) {
+ h = -h;
+ }
+ deltaX = dy === 0 ? 0 : h / 2 * dx / dy;
+ deltaY = h / 2;
+ }
+ else {
+ // The intersection is left or right of the rectangle.
+ if (dx < 0) {
+ w = -w;
+ }
+ deltaX = w / 2;
+ deltaY = dx === 0 ? 0 : w / 2 * dy / dx;
+ }
+ return { x: cx + deltaX, y: cy + deltaY };
+ }
})(layout = graph_1.layout || (graph_1.layout = {}));
})(graph = tf.graph || (tf.graph = {}));
})(tf || (tf = {})); // close module
</script>
-<script>var tf;
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+var tf;
(function (tf) {
/**
* Mapping from color palette name to color pallette, which contains
@@ -4699,7 +5246,21 @@ var tf;
}, {});
})(tf || (tf = {}));
</script>
-<script>/// <reference path="../../../../typings/tsd.d.ts" />
+<script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../../../../typings/tsd.d.ts" />
/// <reference path="../common.ts" />
var tf;
(function (tf) {
@@ -4759,6 +5320,9 @@ var tf;
this.canvas = $minimap.select("canvas.first").node();
this.canvasBuffer =
$minimap.select("canvas.second").node();
+ this.downloadCanvas =
+ $minimap.select("canvas.download").node();
+ d3.select(this.downloadCanvas).style("display", "none");
}
/**
* Updates the position and the size of the viewpoint rectangle.
@@ -4782,6 +5346,11 @@ var tf;
*/
Minimap.prototype.update = function () {
var _this = this;
+ var $download = d3.select("#graphdownload");
+ this.download = $download.node();
+ $download.on("click", function (d) {
+ _this.download.href = _this.downloadCanvas.toDataURL("image/png");
+ });
var $svg = d3.select(this.svg);
// Read all the style rules in the document and embed them into the svg.
// The svg needs to be self contained, i.e. all the style rules need to be
@@ -4815,7 +5384,8 @@ var tf;
// Get the size of the entire scene.
var sceneSize = this.zoomG.getBBox();
// Since we add padding, account for that here.
- sceneSize.height += this.labelPadding;
+ sceneSize.height += this.labelPadding * 2;
+ sceneSize.width += this.labelPadding * 2;
// Temporarily assign an explicit width/height to the main svg, since
// it doesn't have one (uses flex-box), but we need it for the canvas
// to work.
@@ -4837,6 +5407,10 @@ var tf;
// viewpoint rect.
d3.select(this.minimapSvg).attr(this.minimapSize);
d3.select(this.canvasBuffer).attr(this.minimapSize);
+ // Download canvas width and height are multiples of the style width and
+ // height in order to increase pixel density of the PNG for clarity.
+ d3.select(this.downloadCanvas).style({ width: sceneSize.width, height: sceneSize.height });
+ d3.select(this.downloadCanvas).attr({ width: sceneSize.width * 3, height: sceneSize.height * 3 });
if (this.translate != null && this.zoom != null) {
// Update the viewpoint rectangle shape since the aspect ratio of the
// map has changed.
@@ -4868,6 +5442,9 @@ var tf;
_a = [_this.canvasBuffer, _this.canvas], _this.canvas = _a[0], _this.canvasBuffer = _a[1];
var _a;
});
+ var downloadContext = _this.downloadCanvas.getContext("2d");
+ downloadContext.clearRect(0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height);
+ downloadContext.drawImage(image, 0, 0, _this.downloadCanvas.width, _this.downloadCanvas.height);
};
image.src = "data:image/svg+xml;base64," + btoa(svgXml);
};
@@ -4922,576 +5499,6 @@ var tf;
})(scene = tf.scene || (tf.scene = {}));
})(tf || (tf = {})); // close module tf.scene
</script>
-<script>/// <reference path="graph.ts" />
-/// <reference path="render.ts" />
-var tf;
-(function (tf) {
- var graph;
- (function (graph_1) {
- var layout;
- (function (layout) {
- /** Set of parameters that define the look and feel of the graph. */
- layout.PARAMS = {
- animation: {
- /** Default duration for graph animations in ms. */
- duration: 250
- },
- graph: {
- /** Graph parameter for metanode. */
- meta: {
- /**
- * Dagre's nodesep param - number of pixels that
- * separate nodes horizontally in the layout.
- *
- * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
- */
- nodeSep: 110,
- /**
- * Dagre's ranksep param - number of pixels
- * between each rank in the layout.
- *
- * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
- */
- rankSep: 25
- },
- /** Graph parameter for metanode. */
- series: {
- /**
- * Dagre's nodesep param - number of pixels that
- * separate nodes horizontally in the layout.
- *
- * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
- */
- nodeSep: 90,
- /**
- * Dagre's ranksep param - number of pixels
- * between each rank in the layout.
- *
- * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout
- */
- rankSep: 25,
- },
- /**
- * Padding is used to correctly position the graph SVG inside of its parent
- * element. The padding amounts are applied using an SVG transform of X and
- * Y coordinates.
- */
- padding: {
- paddingTop: 40,
- paddingLeft: 20
- }
- },
- subscene: {
- meta: {
- paddingTop: 10,
- paddingBottom: 10,
- paddingLeft: 10,
- paddingRight: 10,
- /**
- * Used to leave room for the label on top of the highest node in
- * the core graph.
- */
- labelHeight: 20,
- /** X-space between each extracted node and the core graph. */
- extractXOffset: 50,
- /** Y-space between each extracted node. */
- extractYOffset: 20
- },
- series: {
- paddingTop: 10,
- paddingBottom: 10,
- paddingLeft: 10,
- paddingRight: 10,
- labelHeight: 10
- }
- },
- nodeSize: {
- /** Size of meta nodes. */
- meta: {
- radius: 5,
- width: 60,
- /** A scale for the node's height based on number of nodes inside */
- height: d3.scale.linear().domain([1, 200]).range([15, 60]).clamp(true),
- /** The radius of the circle denoting the expand button. */
- expandButtonRadius: 3
- },
- /** Size of op nodes. */
- op: {
- width: 15,
- height: 6,
- radius: 3,
- labelOffset: -8
- },
- /** Size of series nodes. */
- series: {
- expanded: {
- // For expanded series nodes, width and height will be
- // computed to account for the subscene.
- radius: 10,
- labelOffset: 0,
- },
- vertical: {
- // When unexpanded, series whose underlying metagraphs contain
- // one or more non-control edges will show as a vertical stack
- // of ellipses.
- width: 16,
- height: 13,
- labelOffset: -13,
- },
- horizontal: {
- // When unexpanded, series whose underlying metagraphs contain
- // no non-control edges will show as a horizontal stack of
- // ellipses.
- width: 24,
- height: 8,
- radius: 10,
- labelOffset: -10,
- },
- },
- /** Size of bridge nodes. */
- bridge: {
- // NOTE: bridge nodes will normally be invisible, but they must
- // take up some space so that the layout step leaves room for
- // their edges.
- width: 20,
- height: 20,
- radius: 2,
- labelOffset: 0
- }
- },
- shortcutSize: {
- /** Size of shortcuts for op nodes */
- op: {
- width: 10,
- height: 4
- },
- /** Size of shortcuts for meta nodes */
- meta: {
- width: 12,
- height: 4,
- radius: 1
- },
- /** Size of shortcuts for series nodes */
- series: {
- width: 14,
- height: 4,
- }
- },
- annotations: {
- /** X-space between the shape and each annotation-node. */
- xOffset: 10,
- /** Y-space between each annotation-node. */
- yOffset: 3,
- /** X-space between each annotation-node and its label. */
- labelOffset: 2,
- /** Estimate max width for annotation label */
- labelWidth: 35
- },
- constant: {
- size: {
- width: 4,
- height: 4
- }
- },
- series: {
- /** Maximum number of repeated item for unexpanded series node. */
- maxStackCount: 3,
- /**
- * Positioning offset ratio for collapsed stack
- * of parallel series (series without edges between its members).
- */
- parallelStackOffsetRatio: 0.2,
- /**
- * Positioning offset ratio for collapsed stack
- * of tower series (series with edges between its members).
- */
- towerStackOffsetRatio: 0.5
- },
- minimap: {
- /** The maximum width/height the minimap can have. */
- size: 150
- }
- };
- /** Calculate layout for a scene of a group node. */
- function scene(renderNodeInfo) {
- // Update layout, size, and annotations of its children nodes and edges.
- if (renderNodeInfo.node.isGroupNode) {
- layoutChildren(renderNodeInfo);
- }
- // Update position of its children nodes and edges
- if (renderNodeInfo.node.type === graph_1.NodeType.META) {
- layoutMetanode(renderNodeInfo);
- }
- else if (renderNodeInfo.node.type === graph_1.NodeType.SERIES) {
- layoutSeriesNode(renderNodeInfo);
- }
- }
- layout.scene = scene;
- ;
- /**
- * Update layout, size, and annotations of its children nodes and edges.
- */
- function layoutChildren(renderNodeInfo) {
- var children = renderNodeInfo.coreGraph.nodes().map(function (n) {
- return renderNodeInfo.coreGraph.node(n);
- }).concat(renderNodeInfo.isolatedInExtract, renderNodeInfo.isolatedOutExtract);
- _.each(children, function (childNodeInfo) {
- // Set size of each child
- switch (childNodeInfo.node.type) {
- case graph_1.NodeType.OP:
- _.extend(childNodeInfo, layout.PARAMS.nodeSize.op);
- break;
- case graph_1.NodeType.BRIDGE:
- _.extend(childNodeInfo, layout.PARAMS.nodeSize.bridge);
- break;
- case graph_1.NodeType.META:
- if (!childNodeInfo.expanded) {
- // set fixed width and scalable height based on cardinality
- _.extend(childNodeInfo, layout.PARAMS.nodeSize.meta);
- childNodeInfo.height =
- layout.PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality);
- }
- else {
- var childGroupNodeInfo = childNodeInfo;
- scene(childGroupNodeInfo); // Recursively layout its subscene.
- }
- break;
- case graph_1.NodeType.SERIES:
- if (childNodeInfo.expanded) {
- _.extend(childNodeInfo, layout.PARAMS.nodeSize.series.expanded);
- var childGroupNodeInfo = childNodeInfo;
- scene(childGroupNodeInfo); // Recursively layout its subscene.
- }
- else {
- var childGroupNodeInfo = childNodeInfo;
- var seriesParams = childGroupNodeInfo.node.hasNonControlEdges ?
- layout.PARAMS.nodeSize.series.vertical :
- layout.PARAMS.nodeSize.series.horizontal;
- _.extend(childNodeInfo, seriesParams);
- }
- break;
- default:
- throw Error("Unrecognized node type: " + childNodeInfo.node.type);
- }
- // Layout each child's annotations
- layoutAnnotation(childNodeInfo);
- });
- }
- /**
- * Calculate layout for a graph using dagre
- * @param graph the graph to be laid out
- * @param params layout parameters
- * @return width and height of the core graph
- */
- function dagreLayout(graph, params) {
- _.extend(graph.graph(), {
- nodeSep: params.nodeSep,
- rankSep: params.rankSep
- });
- var bridgeNodeNames = [];
- var nonBridgeNodeNames = [];
- // Split out nodes into bridge and non-bridge nodes, and calculate the total
- // width we should use for bridge nodes.
- _.each(graph.nodes(), function (nodeName) {
- var nodeInfo = graph.node(nodeName);
- if (nodeInfo.node.type === graph_1.NodeType.BRIDGE) {
- bridgeNodeNames.push(nodeName);
- }
- else {
- nonBridgeNodeNames.push(nodeName);
- }
- });
- // If there are no non-bridge nodes, then the graph has zero size.
- if (!nonBridgeNodeNames.length) {
- return {
- width: 0,
- height: 0,
- };
- }
- dagre.layout(graph);
- var graphLabel = graph.graph();
- // Calculate the true bounding box of the graph by iterating over nodes and
- // edges rather than accepting dagre's word for it. In particular, we should
- // ignore the extra-wide bridge nodes and bridge edges, and allow for
- // annotation boxes and labels.
- var minX = Infinity;
- var minY = Infinity;
- var maxX = -Infinity;
- var maxY = -Infinity;
- _.each(nonBridgeNodeNames, function (nodeName) {
- var nodeInfo = graph.node(nodeName);
- var w = 0.5 * nodeInfo.width;
- var x1 = nodeInfo.x - w - nodeInfo.inboxWidth;
- var x2 = nodeInfo.x + w + nodeInfo.outboxWidth;
- minX = x1 < minX ? x1 : minX;
- maxX = x2 > maxX ? x2 : maxX;
- var labelLength = nodeName.length - nodeName.lastIndexOf(graph_1.NAMESPACE_DELIM);
- // TODO(jimbo): Account for font width rather than using a magic number.
- var charWidth = 3; // 3 pixels per character.
- var lw = 0.5 * labelLength * charWidth;
- var lx1 = nodeInfo.x - lw;
- var lx2 = nodeInfo.x + lw;
- minX = lx1 < minX ? lx1 : minX;
- maxX = lx2 > maxX ? lx2 : maxX;
- // TODO(jimbo): Account for the height of labels above op nodes here.
- var h = 0.5 * nodeInfo.outerHeight;
- var y1 = nodeInfo.y - h;
- var y2 = nodeInfo.y + h;
- minY = y1 < minY ? y1 : minY;
- maxY = y2 > maxY ? y2 : maxY;
- });
- _.each(graph.edges(), function (edgeObj) {
- var renderMetaedgeInfo = graph.edge(edgeObj);
- if (renderMetaedgeInfo.structural) {
- return; // Skip structural edges from min/max calculations.
- }
- _.each(renderMetaedgeInfo.points, function (point) {
- minX = point.x < minX ? point.x : minX;
- maxX = point.x > maxX ? point.x : maxX;
- minY = point.y < minY ? point.y : minY;
- maxY = point.y > maxY ? point.y : maxY;
- });
- });
- // Shift all nodes and edge points to account for the left-padding amount,
- // and the invisble bridge nodes.
- _.each(graph.nodes(), function (nodeName) {
- var nodeInfo = graph.node(nodeName);
- nodeInfo.x -= minX;
- nodeInfo.y -= minY;
- });
- _.each(graph.edges(), function (edgeObj) {
- _.each(graph.edge(edgeObj).points, function (point) {
- point.x -= minX;
- point.y -= minY;
- });
- });
- return {
- width: maxX - minX,
- height: maxY - minY,
- };
- }
- /** Layout a metanode. */
- function layoutMetanode(renderNodeInfo) {
- // First, copy params specific to meta nodes onto this render info object.
- var params = layout.PARAMS.subscene.meta;
- renderNodeInfo = _.extend(renderNodeInfo, params);
- // Invoke dagre.layout() on the core graph and record the bounding box
- // dimensions.
- _.extend(renderNodeInfo.coreBox, dagreLayout(renderNodeInfo.coreGraph, layout.PARAMS.graph.meta));
- // Calculate the position of nodes in isolatedInExtract relative to the
- // top-left corner of inExtractBox (the bounding box for all inExtract nodes)
- // and calculate the size of the inExtractBox.
- var hasInExtract = renderNodeInfo.isolatedInExtract.length > 0;
- renderNodeInfo.inExtractBox.width = hasInExtract ?
- _(renderNodeInfo.isolatedInExtract).pluck("outerWidth").max() : 0;
- renderNodeInfo.inExtractBox.height =
- _.reduce(renderNodeInfo.isolatedInExtract, function (height, child, i) {
- var yOffset = i > 0 ? params.extractYOffset : 0;
- // use outerWidth/Height here to avoid overlaps between extracts
- child.x = renderNodeInfo.inExtractBox.width / 2;
- child.y = height + yOffset + child.outerHeight / 2;
- return height + yOffset + child.outerHeight;
- }, 0);
- // Calculate the position of nodes in isolatedOutExtract relative to the
- // top-left corner of outExtractBox (the bounding box for all outExtract
- // nodes) and calculate the size of the outExtractBox.
- var hasOutExtract = renderNodeInfo.isolatedOutExtract.length > 0;
- renderNodeInfo.outExtractBox.width = hasOutExtract ?
- _(renderNodeInfo.isolatedOutExtract).pluck("outerWidth").max() : 0;
- renderNodeInfo.outExtractBox.height =
- _.reduce(renderNodeInfo.isolatedOutExtract, function (height, child, i) {
- var yOffset = i > 0 ? params.extractYOffset : 0;
- // use outerWidth/Height here to avoid overlaps between extracts
- child.x = renderNodeInfo.outExtractBox.width / 2;
- child.y = height + yOffset + child.outerHeight / 2;
- return height + yOffset + child.outerHeight;
- }, 0);
- // Determine the whole metanode's width (from left to right).
- renderNodeInfo.width =
- params.paddingLeft + renderNodeInfo.coreBox.width + params.paddingRight +
- (hasInExtract ?
- renderNodeInfo.inExtractBox.width + params.extractXOffset : 0) +
- (hasOutExtract ?
- params.extractXOffset + renderNodeInfo.outExtractBox.width : 0);
- // TODO(jimbo): Remove labelHeight and instead incorporate into box sizes.
- // Determine the whole metanode's height (from top to bottom).
- renderNodeInfo.height =
- renderNodeInfo.labelHeight +
- params.paddingTop +
- Math.max(renderNodeInfo.inExtractBox.height, renderNodeInfo.coreBox.height, renderNodeInfo.outExtractBox.height) +
- params.paddingBottom;
- }
- /**
- * Calculate layout for series node's core graph. Only called for an expanded
- * series.
- */
- function layoutSeriesNode(node) {
- var graph = node.coreGraph;
- var params = layout.PARAMS.subscene.series;
- _.extend(node, params);
- // Layout the core.
- _.extend(node.coreBox, dagreLayout(node.coreGraph, layout.PARAMS.graph.series));
- _.each(graph.nodes(), function (nodeName) {
- graph.node(nodeName).excluded = false;
- });
- // Series do not have in/outExtractBox so no need to include them here.
- node.width = node.coreBox.width + params.paddingLeft + params.paddingRight;
- node.height = node.coreBox.height + params.paddingTop + params.paddingBottom;
- }
- /**
- * Calculate layout for annotations of a given node.
- * This will modify positions of the the given node and its annotations.
- *
- * @see tf.graph.render.Node and tf.graph.render.Annotation
- * for description of each property of each render node.
- *
- */
- function layoutAnnotation(renderNodeInfo) {
- // If the render node is an expanded metanode, then its annotations will not
- // be visible and we should skip the annotation calculations.
- if (renderNodeInfo.expanded) {
- _.extend(renderNodeInfo, {
- inboxWidth: 0,
- inboxHeight: 0,
- outboxWidth: 0,
- outboxHeight: 0,
- outerWidth: renderNodeInfo.width,
- outerHeight: renderNodeInfo.height
- });
- return;
- }
- var inAnnotations = renderNodeInfo.inAnnotations.list;
- var outAnnotations = renderNodeInfo.outAnnotations.list;
- // Calculate size for in-annotations
- _.each(inAnnotations, function (a) { return sizeAnnotation(a); });
- // Calculate size for out-annotations
- _.each(outAnnotations, function (a) { return sizeAnnotation(a); });
- var params = layout.PARAMS.annotations;
- renderNodeInfo.inboxWidth =
- inAnnotations.length > 0 ?
- _(inAnnotations).pluck("width").max() +
- params.xOffset + params.labelWidth + params.labelOffset :
- 0;
- renderNodeInfo.outboxWidth =
- outAnnotations.length > 0 ?
- _(outAnnotations).pluck("width").max() +
- params.xOffset + params.labelWidth + params.labelOffset :
- 0;
- // Calculate annotation node position (a.dx, a.dy)
- // and total height for in-annotations
- // After this chunk of code:
- // inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset)
- var inboxHeight = _.reduce(inAnnotations, function (height, a, i) {
- var yOffset = i > 0 ? params.yOffset : 0;
- a.dx = -(renderNodeInfo.width + a.width) / 2 - params.xOffset;
- a.dy = height + yOffset + a.height / 2;
- return height + yOffset + a.height;
- }, 0);
- _.each(inAnnotations, function (a) {
- a.dy -= inboxHeight / 2;
- a.labelOffset = params.labelOffset;
- });
- // Calculate annotation node position position (a.dx, a.dy)
- // and total height for out-annotations
- // After this chunk of code:
- // outboxHeight = sum of annotation heights +
- // (annotation.length - 1 * yOffset)
- var outboxHeight = _.reduce(outAnnotations, function (height, a, i) {
- var yOffset = i > 0 ? params.yOffset : 0;
- a.dx = (renderNodeInfo.width + a.width) / 2 + params.xOffset;
- a.dy = height + yOffset + a.height / 2;
- return height + yOffset + a.height;
- }, 0);
- _.each(outAnnotations, function (a) {
- // adjust by (half of ) the total height
- // so dy is relative to the host node's center.
- a.dy -= outboxHeight / 2;
- a.labelOffset = params.labelOffset;
- });
- // Creating scales for touch point between the in-annotation edges
- // and their hosts.
- var inTouchHeight = Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, inboxHeight / 2);
- inTouchHeight = inTouchHeight < 0 ? 0 : inTouchHeight;
- var inY = d3.scale.linear()
- .domain([0, inAnnotations.length - 1])
- .range([-inTouchHeight, inTouchHeight]);
- // Calculate annotation edge position
- _.each(inAnnotations, function (a, i) {
- a.points = [
- // The annotation node end
- {
- dx: a.dx + a.width / 2,
- dy: a.dy
- },
- // The host node end
- {
- dx: -renderNodeInfo.width / 2,
- // only use scale if there are more than one,
- // otherwise center it vertically
- dy: inAnnotations.length > 1 ? inY(i) : 0
- }
- ];
- });
- // Creating scales for touch point between the out-annotation edges
- // and their hosts.
- var outTouchHeight = Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, outboxHeight / 2);
- outTouchHeight = outTouchHeight < 0 ? 0 : outTouchHeight;
- var outY = d3.scale.linear()
- .domain([0, outAnnotations.length - 1])
- .range([-outTouchHeight, outTouchHeight]);
- _.each(outAnnotations, function (a, i) {
- // Add point from the border of the annotation node
- a.points = [
- // The host node end
- {
- dx: renderNodeInfo.width / 2,
- // only use scale if there are more than one,
- // otherwise center it vertically
- dy: outAnnotations.length > 1 ? outY(i) : 0
- },
- // The annotation node end
- {
- dx: a.dx - a.width / 2,
- dy: a.dy
- }
- ];
- });
- renderNodeInfo.outerWidth = renderNodeInfo.width + renderNodeInfo.inboxWidth +
- renderNodeInfo.outboxWidth;
- renderNodeInfo.outerHeight =
- Math.max(renderNodeInfo.height, inboxHeight, outboxHeight);
- }
- /**
- * Set size of an annotation node.
- */
- function sizeAnnotation(a) {
- switch (a.annotationType) {
- case graph_1.render.AnnotationType.CONSTANT:
- _.extend(a, layout.PARAMS.constant.size);
- break;
- case graph_1.render.AnnotationType.SHORTCUT:
- if (a.node.type === graph_1.NodeType.OP) {
- _.extend(a, layout.PARAMS.shortcutSize.op);
- }
- else if (a.node.type === graph_1.NodeType.META) {
- _.extend(a, layout.PARAMS.shortcutSize.meta);
- }
- else if (a.node.type === graph_1.NodeType.SERIES) {
- _.extend(a, layout.PARAMS.shortcutSize.series);
- }
- else {
- throw Error("Invalid node type: " + a.node.type);
- }
- break;
- case graph_1.render.AnnotationType.SUMMARY:
- _.extend(a, layout.PARAMS.constant.size);
- break;
- }
- }
- })(layout = graph_1.layout || (graph_1.layout = {}));
- })(graph = tf.graph || (tf.graph = {}));
-})(tf || (tf = {})); // close module
-</script>
@@ -5499,9 +5506,23 @@ var tf;
-</head><body><div hidden="" by-vulcanize=""><dom-module id="tf-data-coordinator" assetpath="../components/tf-event-dashboard/">
- <script>/// <reference path="../../typings/tsd.d.ts" />
-/// <reference path="../../bower_components/plottable/plottable.d.ts" />
+</head><body><div hidden="" by-vulcanize=""><dom-module id="tf-data-coordinator" assetpath="../tf-event-dashboard/">
+ <script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../../typings/tsd.d.ts" />
+/// <reference path="../plottable/plottable.d.ts" />
var TF;
(function (TF) {
/* The DataCoordinator generates TF.Datasets for each run/tag combination,
@@ -5553,14 +5574,28 @@ var TF;
TF.DataCoordinator = DataCoordinator;
})(TF || (TF = {}));
</script>
- <script>/// <reference path="../../typings/tsd.d.ts" />
-/// <reference path="../../bower_components/plottable/plottable.d.ts" />
+ <script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
var __extends = (this && this.__extends) || function (d, b) {
for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p];
function __() { this.constructor = d; }
__.prototype = b.prototype;
d.prototype = new __();
};
+/// <reference path="../../typings/tsd.d.ts" />
+/// <reference path="../plottable/plottable.d.ts" />
var TF;
(function (TF) {
/* An extension of Plottable.Dataset that knows how to load data from a backend.
@@ -5615,7 +5650,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-tooltip-coordinator" assetpath="../components/tf-event-dashboard/">
+<dom-module id="tf-tooltip-coordinator" assetpath="../tf-event-dashboard/">
<script>
Polymer({
is: "tf-tooltip-coordinator",
@@ -5654,7 +5689,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="scrollbar-style" assetpath="../components/tf-dashboard-common/">
+<dom-module id="scrollbar-style" assetpath="../tf-dashboard-common/">
<template>
<style>
.scrollbar::-webkit-scrollbar-track
@@ -5680,7 +5715,7 @@ var TF;
</style>
</template>
</dom-module>
-<dom-module id="run-color-style" assetpath="../components/tf-dashboard-common/">
+<dom-module id="run-color-style" assetpath="../tf-dashboard-common/">
<template>
<style>
[color-class="light-blue"] paper-checkbox {
@@ -5740,7 +5775,7 @@ var TF;
</style>
</template>
</dom-module>
-<dom-module id="tf-multi-checkbox" assetpath="../components/tf-multi-checkbox/">
+<dom-module id="tf-multi-checkbox" assetpath="../tf-multi-checkbox/">
<style include="scrollbar-style"></style>
<style include="run-color-style"></style>
@@ -5917,7 +5952,7 @@ var TF;
</script>
</dom-module>
-<dom-module id="tf-run-selector" assetpath="../components/tf-event-dashboard/">
+<dom-module id="tf-run-selector" assetpath="../tf-event-dashboard/">
<template>
<div id="top-text">
<template is="dom-if" if="[[xValue]]">
@@ -5927,12 +5962,15 @@ var TF;
</div>
</template>
<template is="dom-if" if="[[!xValue]]">
- <div id="tooltip-help" class="tooltip-container">
- Selected Runs:
- </div>
+ <h3 id="tooltip-help" class="tooltip-container">
+ Runs
+ </h3>
</template>
</div>
<tf-multi-checkbox names="[[runs]]" tooltips="[[tooltips]]" highlights="[[_arrayify(closestRun)]]" out-selected="{{outSelected}}" class-scale="[[classScale]]" hide-missing-tooltips=""></tf-multi-checkbox>
+ <paper-button class="x-button" id="toggle-all" on-tap="_toggleAll">
+ Toggle All Runs
+ </paper-button>
<style>
:host {
display: flex;
@@ -5944,7 +5982,6 @@ var TF;
width: 100%;
flex-grow: 0;
flex-shrink: 0;
- padding-left: 35px;
padding-right: 16px;
padding-bottom: 6px;
box-sizing: border-box;
@@ -5956,6 +5993,12 @@ var TF;
flex-shrink: 1;
height: 0px; /* hackhack So the flex-grow takes over and gives it space */
}
+ .x-button {
+ font-size: 13px;
+ background-color: var(--tb-ui-light-accent);
+ margin-top: 5px;
+ color: var(--tb-ui-dark-accent);
+ }
.x-tooltip {
display: flex;
flex-direction: row;
@@ -5967,6 +6010,16 @@ var TF;
.x-tooltip-value {
align-self: flex-end;
}
+ #tooltip-help {
+ color: var(--paper-grey-800);
+ margin: 0;
+ font-weight: normal;
+ font-size: 14px;
+ margin-bottom: 5px;
+ }
+ paper-button {
+ margin-left: 0;
+ }
</style>
</template>
<script>
@@ -5982,17 +6035,24 @@ var TF;
classScale: Object, // map from run name to color class (css)
closestRun: {type: String, value: null}, // which run has a value closest to mouse coordinate
},
+ _toggleAll: function() {
+ if (this.outSelected.length > 0) {
+ this.outSelected = [];
+ } else {
+ this.outSelected = this.runs.slice();
+ }
+ },
_arrayify: function(item) {
return [item];
},
});
</script>
</dom-module>
-<dom-module id="tf-x-type-selector" assetpath="../components/tf-event-dashboard/">
+<dom-module id="tf-x-type-selector" assetpath="../tf-event-dashboard/">
<template>
<div id="buttons">
- <p>X Type: </p>
- <paper-button class="x-button selected" id="step" on-tap="_select" raised="">
+ <h3>Horizontal Axis</h3>
+ <paper-button class="x-button selected" id="step" on-tap="_select">
step
</paper-button>
<paper-button class="x-button" id="relative" on-tap="_select">
@@ -6004,22 +6064,28 @@ var TF;
</div>
<style>
.x-button {
- width: 29%;
- font-size: 14px;
- background-color: var(--paper-grey-500);
- margin-top: 5px;
- color: white;
+ width: 30%;
+ font-size: 13px;
+ background: none;
+ margin-top: 10px;
+ color: var(--tb-ui-dark-accent);
+ }
+
+ .x-button:first-of-type {
+ margin-left: 0;
}
.x-button.selected {
- font-weight: bold;
- background-color: var(--tb-orange-strong) !important;
+ background-color: var(--tb-ui-dark-accent);
+ color: white!important;
}
- #buttons p {
- text-align: center;
- font-size: 12px;
+ #buttons h3 {
+ color: var(--paper-grey-800);
margin: 0;
+ font-weight: normal;
+ font-size: 14px;
+ margin-bottom: 5px;
}
</style>
</template>
@@ -6032,17 +6098,15 @@ var TF;
_select: function(e) {
var _this = this;
["step", "wall_time", "relative"].forEach(function(id) {
- _this.$[id].raised = false;
_this.$[id].classList.remove("selected");
});
- e.currentTarget.raised = true;
this._setOutXType(e.currentTarget.id);
e.currentTarget.classList.add("selected");
},
});
</script>
</dom-module>
-<dom-module id="tf-run-generator" assetpath="../components/tf-dashboard-common/">
+<dom-module id="tf-run-generator" assetpath="../tf-dashboard-common/">
<template>
<iron-ajax id="ajax" auto="" url="[[url]]" handle-as="json" debounce="300" on-response="_setResponse" verbose="true">
</iron-ajax>
@@ -6119,7 +6183,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-color-scale" assetpath="../components/tf-event-dashboard/">
+<dom-module id="tf-color-scale" assetpath="../tf-event-dashboard/">
<script>
(function() {
// TODO(danmane) - get Plottable team to make an API point for this
@@ -6176,9 +6240,23 @@ var TF;
})();
</script>
</dom-module>
-<dom-module id="tf-url-generator" assetpath="../components/tf-dashboard-common/">
- <script>/// <reference path="../../typings/tsd.d.ts" />
-/// <reference path="../../bower_components/plottable/plottable.d.ts" />
+<dom-module id="tf-url-generator" assetpath="../tf-dashboard-common/">
+ <script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../../typings/tsd.d.ts" />
+/// <reference path="../plottable/plottable.d.ts" />
var TF;
(function (TF) {
var Urls;
@@ -6252,7 +6330,7 @@ var TF;
Polymer(polymerObject);
</script>
</dom-module>
-<dom-module id="tf-dashboard-layout" assetpath="../components/tf-dashboard-common/">
+<dom-module id="tf-dashboard-layout" assetpath="../tf-dashboard-common/">
<template>
<div id="sidebar">
<content select=".sidebar"></content>
@@ -6266,23 +6344,22 @@ var TF;
#sidebar {
width: inherit;
height: 100%;
- background-color: var(--tb-grey-darker);
- background-image: linear-gradient(to right, var(--tb-grey-lighter), var(--tb-grey-lighter));
overflow: ellipsis;
- padding-left: 10px;
- padding-right: 10px;
flex-grow: 0;
flex-shrink: 0;
}
#center {
- margin: 0 10px;
height: 100%;
overflow-y: scroll;
- padding-right: 12px;
flex-grow: 1;
flex-shrink: 1;
}
+
+ .tf-graph-dashboard #center {
+ background: white;
+ }
+
:host {
display: flex;
flex-direction: row;
@@ -6296,7 +6373,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="dashboard-style" assetpath="../components/tf-dashboard-common/">
+<dom-module id="dashboard-style" assetpath="../tf-dashboard-common/">
<template>
<style>
.card {
@@ -6304,10 +6381,8 @@ var TF;
width: 300px;
display: flex;
flex-direction: column;
- margin: 5px 5px;
- padding: 5px;
- border: 1px solid var(--paper-grey-500);
- border-radius: 3px;
+ margin: 5px;
+ padding: 0 30px 30px 0;
-webkit-user-select: none;
-moz-user-select: none;
position: relative;
@@ -6316,9 +6391,8 @@ var TF;
.card .card-title {
flex-grow: 0;
flex-shrink: 0;
- margin-bottom: 2px;
+ margin-bottom: 10px;
font-size: 14px;
- font-weight: bold;
text-overflow: ellipsis;
overflow: hidden;
}
@@ -6347,7 +6421,7 @@ var TF;
.expand-button {
position: absolute;
left: 0px;
- bottom: 0px;
+ bottom: 20px;
color: #2196F3;
display: block;
}
@@ -6360,6 +6434,7 @@ var TF;
display: flex;
flex-direction: column;
height: 100%;
+ margin-right: 20px;
}
#categorizer {
@@ -6376,21 +6451,30 @@ var TF;
flex-grow: 1;
}
- #download-option {
- padding-left: 55px;
- color: var(--paper-grey-700);
- font-size: 14px;
+ .sidebar-section {
+ border-top: solid 1px rgba(0, 0, 0, 0.12);
+ padding: 20px 0px 20px 30px;
}
- #download-option paper-toggle-button {
- --paper-toggle-button-checked-button-color: var(--tb-orange-strong);
- --paper-toggle-button-checked-bar-color: var(--tb-orange-weak);
+ .sidebar-section:first-child {
+ border: none;
+ }
+ .sidebar-section:last-child {
+ flex-grow: 1;
+ display: flex;
}
+
+ paper-checkbox {
+ --paper-checkbox-checked-color: var(--tb-ui-dark-accent);
+ --paper-checkbox-unchecked-color: var(--tb-ui-dark-accent);
+ font-size: 14px;
+ }
+
</style>
</template>
</dom-module>
-<dom-module id="tf-downloader" assetpath="../components/tf-dashboard-common/">
+<dom-module id="tf-downloader" assetpath="../tf-dashboard-common/">
<template>
<paper-dropdown-menu no-label-float="true" label="run to download" selected-item-label="{{_run}}">
<paper-menu class="dropdown-content">
@@ -6460,40 +6544,46 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-regex-group" assetpath="../components/tf-regex-group/">
+<dom-module id="tf-regex-group" assetpath="../tf-regex-group/">
<template>
<div class="regex-list">
<template is="dom-repeat" items="{{rawRegexes}}">
<div class="regex-line">
- <paper-input id="text-input" class="regex-input" label="input new regex" no-label-float="" bind-value="{{item.regex}}" invalid="[[!item.valid]]" on-keyup="moveFocus"></paper-input>
- <paper-toggle-button class="active-button" checked="{{item.active}}" disabled="[[!item.valid]]"></paper-toggle-button>
-
- <paper-icon-button icon="delete" class="delete-button" aria-label="Delete Regex" tabindex="0" on-tap="deleteRegex"></paper-icon-button>
+ <paper-checkbox class="active-button" checked="{{item.active}}" disabled="[[!item.valid]]"></paper-checkbox>
+ <paper-input id="text-input" class="regex-input" label="Regex filter" no-label-float="" bind-value="{{item.regex}}" invalid="[[!item.valid]]" on-keyup="moveFocus"></paper-input>
+ <paper-icon-button icon="close" class="delete-button" aria-label="Delete Regex" tabindex="0" on-tap="deleteRegex"></paper-icon-button>
</div>
<style>
.regex-input {
- width: 210px;
+ width: 230px;
display: inline-block;
- padding-left: 8px;
- padding-right: 5px;
+ margin-left: -3px;
}
- .active-button {
- --paper-toggle-button-checked-button-color: var(--tb-orange-strong);
- --paper-toggle-button-checked-bar-color: var(--tb-orange-weak);
- border: none;
+ paper-checkbox {
+ --paper-checkbox-checked-color: var(--tb-ui-dark-accent);
+ --paper-checkbox-unchecked-color: var(--tb-ui-dark-accent);
}
.delete-button {
- color: var(--paper-pink-900);
- width: 24px;
- height: 24px;
+ color: var(--paper-grey-700);
+ width: 40px;
+ height: 40px;
+ margin-right: -10px;
}
+
.regex-list {
margin-bottom: 10px;
}
+
paper-input {
--paper-input-container-focus-color: var(--tb-orange-strong);
+ --paper-input-container-input: {
+ font-size: 14px;
+ };
+ --paper-input-container-label: {
+ font-size: 14px;
+ };
}
</style>
</template>
@@ -6566,38 +6656,44 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-categorizer" assetpath="../components/tf-categorizer/">
+<dom-module id="tf-categorizer" assetpath="../tf-categorizer/">
<template>
<div class="inputs">
<tf-regex-group id="regex-group" regexes="{{regexes}}"></tf-regex-group>
</div>
<div id="underscore-categorization">
- <span>Split On Underscores:</span>
- <paper-toggle-button checked="{{splitOnUnderscore}}"></paper-toggle-button>
+ <paper-checkbox checked$="{{splitOnUnderscore}}">Split on underscores</paper-checkbox>
</div>
<style>
:host {
display: block;
- padding-bottom: 5px;
- padding-top: 5px;
+ padding-bottom: 15px;
}
-
- .inputs {
- padding-left: 5px;
- }
-
- paper-toggle-button {
- --paper-toggle-button-checked-button-color: var(--tb-orange-strong);
- --paper-toggle-button-checked-bar-color: var(--tb-orange-weak);
+ paper-checkbox {
+ --paper-checkbox-checked-color: var(--paper-grey-600);
+ --paper-checkbox-unchecked-color: var(--paper-grey-600);
+ font-size: 14px;
}
#underscore-categorization {
- padding-left: 94px;
color: var(--paper-grey-700);
- font-size: 14px;
}
</style>
</template>
- <script>/// <reference path="../../typings/tsd.d.ts" />
+ <script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+/// <reference path="../../typings/tsd.d.ts" />
var Categorizer;
(function (Categorizer) {
/* Canonical TensorFlow ops are namespaced using forward slashes.
@@ -6738,7 +6834,7 @@ var Categorizer;
});
</script>
</dom-module>
-<dom-module id="tf-chart" assetpath="../components/tf-event-dashboard/">
+<dom-module id="tf-chart" assetpath="../tf-event-dashboard/">
<template>
<svg id="chartsvg"></svg>
<style>
@@ -6761,7 +6857,21 @@ var Categorizer;
}
</style>
</template>
- <script>var __extends = (this && this.__extends) || function (d, b) {
+ <script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+var __extends = (this && this.__extends) || function (d, b) {
for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p];
function __() { this.constructor = d; }
__.prototype = b.prototype;
@@ -6901,14 +7011,28 @@ var Plottable;
Plottable.DragZoomLayer = DragZoomLayer;
})(Plottable || (Plottable = {}));
</script>
- <script>/// <reference path="../../typings/tsd.d.ts" />
-/// <reference path="../../bower_components/plottable/plottable.d.ts" />
+ <script>/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
var __extends = (this && this.__extends) || function (d, b) {
for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p];
function __() { this.constructor = d; }
__.prototype = b.prototype;
d.prototype = new __();
};
+/// <reference path="../../typings/tsd.d.ts" />
+/// <reference path="../plottable/plottable.d.ts" />
var TF;
(function (TF) {
var Y_TOOLTIP_FORMATTER_PRECISION = 4;
@@ -7238,13 +7362,13 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-collapsable-pane" assetpath="../components/tf-collapsable-pane/">
+<dom-module id="tf-collapsable-pane" assetpath="../tf-collapsable-pane/">
<template>
<button class="heading" on-tap="togglePane" open-button$="[[opened]]">
<span class="name">[[name]]</span>
<span class="hackpadding"></span>
<span class="count">
- (<span>[[count]]</span>)
+ <span>[[count]]</span>
</span>
</button>
<iron-collapse opened="[[opened]]">
@@ -7255,47 +7379,63 @@ var TF;
</div>
</iron-collapse>
<style>
+ :host {
+ display: block;
+ margin: 0 5px 1px 10px;
+ }
+
+ :host:first-of-type {
+ margin-top: 20px;
+ }
+
+ :host:last-of-type {
+ margin-bottom: 20px;
+ }
+
.heading {
- margin-top: 10px;
- padding-left: 15px;
- background-color: #f3f3f3;
- border: 1px solid #dedede;
- border-radius: 5px;
- font-size: 18px;
+ background-color: white;
+ border-radius: 2px;
+ border: none;
cursor: pointer;
-webkit-tap-highlight-color: rgba(0,0,0,0);
width: 100%;
- height: 30px;
box-sizing: border-box;
- font-size: 16px;
+ font-size: 15px;
display: inline-flex;
flex-direction: row;
align-items: center;
justify-content: space-between;
line-height: 1;
- padding-top: 2px;
- padding-bottom: 2px;
+ box-shadow: 0 1px 5px rgba(0,0,0,0.2);
+ padding: 10px 15px;
}
.content {
padding: 15px;
border: 1px solid #dedede;
border-top: none;
- border-bottom-left-radius: 5px;
- border-bottom-right-radius: 5px;
+ border-bottom-left-radius: 2px;
+ border-bottom-right-radius: 2px;
+ background: white;
}
+
[open-button] {
border-bottom-left-radius: 0px !important;
border-bottom-right-radius: 0px !important;
}
+
.name {
flex-grow: 0;
}
+
.count {
flex-grow: 0;
float: right;
+ margin-right: 5px;
font-size: 12px;
+ color: var(--paper-grey-500);
}
+
.hackpadding {
/* An obnoxious hack, but I can't get justify-content: space-between to work */
flex-grow: 1;
@@ -7321,7 +7461,7 @@ var TF;
</script>
</dom-module>
-<dom-module id="warning-style" assetpath="../components/tf-dashboard-common/">
+<dom-module id="warning-style" assetpath="../tf-dashboard-common/">
<template>
<style>
.warning {
@@ -7331,7 +7471,7 @@ var TF;
</style>
</template>
</dom-module>
-<dom-module id="tf-event-dashboard" assetpath="../components/tf-event-dashboard/">
+<dom-module id="tf-event-dashboard" assetpath="../tf-event-dashboard/">
<template>
<div id="plumbing">
<tf-url-generator out-runs-url="{{runsUrl}}" out-scalars-url-generator="{{scalarsUrlGen}}" id="urlGenerator"></tf-url-generator>
@@ -7348,16 +7488,16 @@ var TF;
<tf-dashboard-layout>
<div class="sidebar">
- <tf-categorizer id="categorizer" tags="[[_visibleTags]]" categories="{{categories}}"></tf-categorizer>
- <span id="download-option">
- Show Data Download Links:
- <paper-toggle-button checked="{{_show_download_links}}"></paper-toggle-button>
- </span>
-
- <tf-x-type-selector id="xTypeSelector" out-x-type="{{xType}}"></tf-x-type-selector>
-
- <tf-run-selector id="runSelector" runs="[[_runs]]" class-scale="[[classScale]]" out-selected="{{selectedRuns}}" tooltips="[[tooltipMap]]" closest-run="[[closestRun]]" x-value="[[tooltipXValue]]" x-type="[[xType]]"></tf-run-selector>
-
+ <div class="sidebar-section">
+ <tf-categorizer id="categorizer" tags="[[_visibleTags]]" categories="{{categories}}"></tf-categorizer>
+ <paper-checkbox id="download-option" checked$="{{_show_download_links}}">Data download links</paper-checkbox>
+ </div>
+ <div class="sidebar-section">
+ <tf-x-type-selector id="xTypeSelector" out-x-type="{{xType}}"></tf-x-type-selector>
+ </div>
+ <div class="sidebar-section">
+ <tf-run-selector id="runSelector" runs="[[_runs]]" class-scale="[[classScale]]" out-selected="{{selectedRuns}}" tooltips="[[tooltipMap]]" closest-run="[[closestRun]]" x-value="[[tooltipXValue]]" x-type="[[xType]]"></tf-run-selector>
+ </div>
</div>
<div class="center">
<template is="dom-if" if="[[!categories.length]]">
@@ -7442,7 +7582,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-histogram-dashboard" assetpath="../components/tf-histogram-dashboard/">
+<dom-module id="tf-histogram-dashboard" assetpath="../tf-histogram-dashboard/">
<template>
<div id="plumbing">
<tf-url-generator out-runs-url="{{runsUrl}}" out-compressed-histograms-url-generator="{{compressedHistogramsUrlGen}}" id="urlGenerator"></tf-url-generator>
@@ -7458,13 +7598,15 @@ var TF;
<tf-dashboard-layout>
<div class="sidebar">
-
- <tf-categorizer id="categorizer" tags="[[_visibleTags]]" categories="{{categories}}"></tf-categorizer>
-
- <tf-x-type-selector id="xTypeSelector" out-x-type="{{xType}}"></tf-x-type-selector>
-
- <tf-run-selector id="runSelector" runs="[[_runs]]" class-scale="[[classScale]]" out-selected="{{selectedRuns}}" tooltips="[[tooltipMap]]" closest-run="[[closestRun]]" x-value="[[tooltipXValue]]" x-type="[[xType]]"></tf-run-selector>
-
+ <div class="sidebar-section">
+ <tf-categorizer id="categorizer" tags="[[_visibleTags]]" categories="{{categories}}"></tf-categorizer>
+ </div>
+ <div class="sidebar-section">
+ <tf-x-type-selector id="xTypeSelector" out-x-type="{{xType}}"></tf-x-type-selector>
+ </div>
+ <div class="sidebar-section">
+ <tf-run-selector id="runSelector" runs="[[_runs]]" class-scale="[[classScale]]" out-selected="{{selectedRuns}}" tooltips="[[tooltipMap]]" closest-run="[[closestRun]]" x-value="[[tooltipXValue]]" x-type="[[xType]]"></tf-run-selector>
+ </div>
</div>
<div class="center">
@@ -7563,7 +7705,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-image-loader" assetpath="../components/tf-image-dashboard/">
+<dom-module id="tf-image-loader" assetpath="../tf-image-dashboard/">
<style>
:host {
display: block;
@@ -7610,7 +7752,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-image-grid" assetpath="../components/tf-image-dashboard/">
+<dom-module id="tf-image-grid" assetpath="../tf-image-dashboard/">
<template>
<style include="scrollbar-style"></style>
<div id="fullContainer" class="container scrollbar">
@@ -7725,7 +7867,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-image-dashboard" assetpath="../components/tf-image-dashboard/">
+<dom-module id="tf-image-dashboard" assetpath="../tf-image-dashboard/">
<template>
<div id="plumbing">
<tf-url-generator out-runs-url="{{runsUrl}}" out-images-url-generator="{{imagesUrlGen}}" out-individual-image-url-generator="{{individualImageUrlGen}}" id="urlGenerator"></tf-url-generator>
@@ -7783,7 +7925,7 @@ var TF;
});
</script>
</dom-module>
-<dom-module id="tf-graph-loader" assetpath="../components/tf-graph-loader/">
+<dom-module id="tf-graph-loader" assetpath="../tf-graph-loader/">
</dom-module>
<script>
@@ -7906,7 +8048,9 @@ Polymer({
}
var hierarchyParams = {
verifyTemplate: true,
- groupSeries: true,
+ // If a set of numbered op nodes has at least this number of nodes
+ // then group them into a series node.
+ seriesNodeMinSize: 5,
};
var hierarchyTracker = tf.getSubtaskTracker(tracker, 50,
'Namespace hierarchy');
@@ -7950,7 +8094,7 @@ Polymer({
}
});
</script>
-<dom-module id="tf-graph-style" assetpath="../components/tf-graph/">
+<dom-module id="tf-graph-style" assetpath="../tf-graph/">
<template>
<style>
:host {
@@ -8289,7 +8433,7 @@ Polymer({
</style>
</template>
</dom-module>
-<dom-module id="tf-graph-minimap" assetpath="../components/tf-graph/">
+<dom-module id="tf-graph-minimap" assetpath="../tf-graph/">
<template>
<style>
:host {
@@ -8334,6 +8478,7 @@ svg {
<canvas class="first"></canvas>
<canvas class="second"></canvas>
+<canvas class="download"></canvas>
</template>
<script>
Polymer({
@@ -8356,7 +8501,7 @@ Polymer({
});
</script>
</dom-module>
-<dom-module id="tf-graph-scene" assetpath="../components/tf-graph/">
+<dom-module id="tf-graph-scene" assetpath="../tf-graph/">
<template>
<style include="tf-graph-style">
:host {
@@ -8419,6 +8564,9 @@ Polymer({
<use xlink:href="#op-node-annotation-stamp" x="7" y="2"></use>
<use xlink:href="#op-node-annotation-stamp" x="5" y="2"></use>
</g>
+ <svg id="summary-icon" fill="#848484" height="12" viewBox="0 0 24 24" width="12">
+ <path d="M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z"></path>
+ </svg>
<g id="linearGradients"></g>
</defs>
@@ -8433,7 +8581,7 @@ Polymer({
Polymer({
is: 'tf-graph-scene',
properties: {
- graphHierarchy: Object,
+ renderHierarchy: Object,
name: String,
colorBy: {
type: String,
@@ -8473,9 +8621,10 @@ Polymer({
/**
* @type {d3.scale.ordinal}
* Scale mapping from template name to a number between 0 and N-1
- * where N is the number of different template names.
+ * where N is the number of different template names. Used by
+ * tf.graph.scene.node when computing node color by structure.
*/
- templateIndex: Object,
+ templateIndex: Function,
/**
* @type {tf.scene.Minimap}
* A minimap object to notify for zoom events.
@@ -8537,16 +8686,17 @@ Polymer({
progress: Object
},
observers: [
- '_buildAndFit(graphHierarchy)'
+ '_buildAndFit(renderHierarchy)'
],
getNode: function(nodeName) {
- return this.graphHierarchy.getRenderNodeByName(nodeName);
+ return this.renderHierarchy.getRenderNodeByName(nodeName);
},
isNodeExpanded: function(node) {
return node.expanded;
},
setNodeExpanded: function(renderNode) {
- this._build(this.graphHierarchy);
+ this._build(this.renderHierarchy);
+ this._updateLabels(!this._zoomed);
},
/**
* Resets the state of the component. Called whenever the whole graph
@@ -8565,20 +8715,15 @@ Polymer({
.selectAll('*').remove();
},
/** Main method for building the scene */
- _build: function(graphHierarchy) {
- if (!graphHierarchy) { return; } //handle untruthy input
- var templateNames = d3.keys(graphHierarchy.hierarchy.templates);
-
- this.templateIndex = d3.scale.ordinal()
- .domain(templateNames)
- .range(d3.range(0, templateNames.length));
+ _build: function(renderHierarchy) {
+ this.templateIndex = renderHierarchy.hierarchy.getTemplateIndex();
tf.time('tf-graph-scene (layout):', function() {
// layout the scene for this meta / series node
- tf.graph.layout.scene(graphHierarchy.root, this);
+ tf.graph.layout.layoutScene(renderHierarchy.root, this);
}.bind(this));
tf.time('tf-graph-scene (build scene):', function() {
- tf.graph.scene.buildGroup(d3.select(this.$.root), graphHierarchy.root, this);
+ tf.graph.scene.buildGroup(d3.select(this.$.root), renderHierarchy.root, this);
tf.graph.scene.addGraphClickListener(this.$.svg, this);
}.bind(this));
// Update the minimap again when the graph is done animating.
@@ -8641,21 +8786,24 @@ Polymer({
tf.graph.layout.PARAMS.minimap.size,
tf.graph.layout.PARAMS.subscene.meta.labelHeight);
},
- _buildAndFit: function(graphHierarchy) {
+ _buildAndFit: function(renderHierarchy) {
this._resetState();
- this._build(graphHierarchy);
+ this._build(renderHierarchy);
// Fit to screen after the graph is done animating.
setTimeout(this.fit.bind(this), tf.graph.layout.PARAMS.animation.duration);
},
_updateLabels: function(showLabels) {
var titleStyle = this.getElementsByClassName('title')[0].style;
var auxTitleStyle = this.getElementsByClassName('auxTitle')[0].style;
- var core = this.getElementsByClassName(tf.graph.scene.Class.Scene.CORE)[0];
+ var core = d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." +
+ tf.graph.scene.Class.Scene.CORE)[0][0];
// Only show labels if the graph is fully loaded.
if (showLabels && core && this.progress && this.progress.value === 100) {
var aux =
- this.getElementsByClassName(tf.graph.scene.Class.Scene.INEXTRACT)[0] ||
- this.getElementsByClassName(tf.graph.scene.Class.Scene.OUTEXTRACT)[0];
+ d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." +
+ tf.graph.scene.Class.Scene.INEXTRACT)[0][0] ||
+ d3.select("." + tf.graph.scene.Class.Scene.GROUP + ">." +
+ tf.graph.scene.Class.Scene.OUTEXTRACT)[0][0];
var coreX = core.getCTM().e;
var auxX = aux ? aux.getCTM().e : null;
titleStyle.display = 'inline';
@@ -8760,7 +8908,7 @@ Polymer({
}
// Update the minimap to reflect the highlighted (selected) node.
this.minimap.update();
- var node = this.graphHierarchy.hierarchy.node(selectedNode);
+ var node = this.renderHierarchy.hierarchy.node(selectedNode);
var nodeParents = [];
// Create list of all metanode parents of the selected node.
while (node.parentNode != null
@@ -8771,8 +8919,8 @@ Polymer({
// Ensure each parent metanode is built and expanded.
var topParentNodeToBeExpanded;
_.forEachRight(nodeParents, function(parentName) {
- this.graphHierarchy.buildSubhierarchy(parentName);
- var renderNode = this.graphHierarchy.getRenderNodeByName(parentName);
+ this.renderHierarchy.buildSubhierarchy(parentName);
+ var renderNode = this.renderHierarchy.getRenderNodeByName(parentName);
if (renderNode.node.isGroupNode && !renderNode.expanded) {
renderNode.expanded = true;
if (!topParentNodeToBeExpanded) {
@@ -8812,7 +8960,7 @@ Polymer({
},
});
</script>
-<dom-module id="tf-graph-params" assetpath="../components/tf-graph/">
+<dom-module id="tf-graph-params" assetpath="../tf-graph/">
</dom-module>
<script>
Polymer({
@@ -8878,7 +9026,7 @@ Polymer({
*/
detachAllEdgesForHighDegree: {
type: Boolean,
- value: false
+ value: true
},
/**
@@ -8921,12 +9069,14 @@ Polymer({
}
});
</script>
-<dom-module id="tf-graph" assetpath="../components/tf-graph/">
+<dom-module id="tf-graph" assetpath="../tf-graph/">
<template>
<style>
.container {
width: 100%;
height: 100%;
+ background: white;
+ box-shadow: 0 1px 5px rgba(0,0,0,0.2);
}
.vertical {
@@ -8952,7 +9102,7 @@ paper-button {
<tf-graph-params id="graphParams"></tf-graph-params>
<div class="vertical">
<h2>[[title]]</h2>
- <tf-graph-scene id="scene" class="auto" graph-hierarchy="[[_renderHierarchy]]" highlighted-node="[[_getVisible(highlightedNode)]]" selected-node="[[selectedNode]]" color-by="[[colorBy]]" name="[[graphName]]" progress="[[progress]]"></tf-graph-scene>
+ <tf-graph-scene id="scene" class="auto" render-hierarchy="[[renderHierarchy]]" highlighted-node="[[_getVisible(highlightedNode)]]" selected-node="[[selectedNode]]" color-by="[[colorBy]]" name="[[graphName]]" progress="[[progress]]"></tf-graph-scene>
</div>
</div>
</template>
@@ -8985,6 +9135,11 @@ Polymer({
notify: true,
readOnly: true, // Produces and doesn't consume.
},
+ renderHierarchy: {
+ type: Object,
+ readOnly: true,
+ notify: true,
+ },
// internal properties
_graphParams: {
type: Object,
@@ -8996,27 +9151,24 @@ Polymer({
type: Number,
value: 1
},
- _renderHierarchy: {
- type: Object,
- readOnly: true,
- notify: true,
- computed: '_buildRenderHierarchy(graphHierarchy, _graphParams)'
- },
_allowGraphSelect: {
type: Boolean,
value: true
}
},
+ observers: [
+ '_buildRenderHierarchy(graphHierarchy, _graphParams)'
+ ],
_buildRenderHierarchy: function(graphHierarchy, params) {
- return tf.time('new tf.graph.render.Hierarchy', function() {
+ tf.time('new tf.graph.render.Hierarchy', function() {
if (graphHierarchy.root.type !== tf.graph.NodeType.META) {
// root must be metanode but sometimes Polymer's dom-if has not
// remove tf-graph element yet in <tf-node-info>
// and thus mistakenly pass non-metanode to this module.
return;
}
- var renderGraph = new tf.graph.render.RenderGraphInformation(
- graphHierarchy, params);
+ var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy,
+ params);
// Producing the 'color by' parameters to be consumed
// by the tf-graph-controls panel. It contains information about the
// min and max values and their respective colors, as well as list
@@ -9042,14 +9194,14 @@ Polymer({
};
})
});
- return renderGraph;
+ this._setRenderHierarchy(renderGraph);
}.bind(this));
},
_getVisible: function(name) {
if (!name) {
return name;
}
- return this._renderHierarchy.getNearestVisibleAncestor(name);
+ return this.renderHierarchy.getNearestVisibleAncestor(name);
},
listeners: {
'graph-select': '_graphSelected',
@@ -9060,6 +9212,7 @@ Polymer({
'node-select': '_nodeSelected',
'node-highlight': '_nodeHighlighted',
'node-unhighlight': '_nodeUnhighlighted',
+ 'node-toggle-extract': '_nodeToggleExtract',
// Annotations
@@ -9110,53 +9263,72 @@ Polymer({
},
_nodeToggleExpand: function(event) {
var nodeName = event.detail.name;
- var renderNode = this._renderHierarchy.getRenderNodeByName(nodeName);
+ var renderNode = this.renderHierarchy.getRenderNodeByName(nodeName);
// Op nodes are not expandable.
if (renderNode.node.type === tf.graph.NodeType.OP) {
return;
}
- this._renderHierarchy.buildSubhierarchy(nodeName);
+ this.renderHierarchy.buildSubhierarchy(nodeName);
renderNode.expanded = !renderNode.expanded;
this.querySelector('#scene').setNodeExpanded(renderNode);
// Also select the expanded node.
this._nodeSelected(event);
},
+ _nodeToggleExtract: function(event) {
+ // Toggle the include setting of the specified node appropriately.
+ var nodeName = event.detail.name;
+ var renderNode = this.renderHierarchy.getRenderNodeByName(nodeName);
+ if (renderNode.node.include == tf.graph.InclusionType.INCLUDE) {
+ renderNode.node.include = tf.graph.InclusionType.EXCLUDE;
+ } else if (renderNode.node.include == tf.graph.InclusionType.EXCLUDE) {
+ renderNode.node.include = tf.graph.InclusionType.INCLUDE;
+ } else {
+ renderNode.node.include =
+ this.renderHierarchy.isNodeAuxilliary(renderNode)
+ ? tf.graph.InclusionType.INCLUDE : tf.graph.InclusionType.EXCLUDE;
+ }
+
+ // Rebuild the render hierarchy.
+ this._buildRenderHierarchy(this.graphHierarchy, this._graphParams);
+ },
not: function(x) {
return !x;
}
});
</script>
-<dom-module id="tf-graph-icon" assetpath="../components/tf-graph/">
+<dom-module id="tf-graph-icon" assetpath="../tf-graph/">
<template>
<template is="dom-if" if="[[_isType(node, type, 'OP')]]">
<template is="dom-if" if="[[_isConst(node, const)]]">
<svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 10 10">
- <circle fill="white" stroke="#848484" cx="5" cy="5" r="3"></circle>
+ <circle cx="5" cy="5" r="3" fill$="[[_getFill(_computedFill, 'OP')]]" stroke$="[[_getStroke(_computedFill, 'OP')]]"></circle>
</svg>
</template>
<template is="dom-if" if="[[_isSummary(node, summary)]]">
- <img height$="[[height]]" src="[[resolveUrl('../../lib/svg/summary-icon.svg')]]">
+ <svg width$="[[height]]" height$="[[height]]" viewBox="0 0 12 12">
+ <use x="0" y="0" xlink:href="#summary-icon"></use>
+ </svg>
</template>
<template is="dom-if" if="[[_isRegularOp(node, const, summary)]]">
<svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 16 8">
- <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-node-stamp" fill="white" stroke="#ccc" x="8" y="4"></use>
+ <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-node-stamp" fill$="[[_getFill(_computedFill, 'OP')]]" stroke$="[[_getStroke(_computedFill, 'OP')]]" x="8" y="4"></use>
</svg>
</template>
</template>
<template is="dom-if" if="[[_isType(node, type, 'META')]]">
<svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 37 16">
- <rect x="1" y="1" fill="#d9d9d9" stroke="#ccc" stroke-width="2px" height="14" width="35" rx="5" ry="5"></rect>
+ <rect x="1" y="1" fill$="[[_getFill(_computedFill, 'META')]]" stroke$="[[_getStroke(_computedFill, 'META')]]" stroke-width="2px" height="14" width="35" rx="5" ry="5"></rect>
</svg>
</template>
<template is="dom-if" if="[[_isType(node, type, 'SERIES')]]">
<template is="dom-if" if="[[_isVertical(node, vertical)]]">
<svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 16 15">
- <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-vertical-stamp" fill="white" stroke="#ccc" x="0" y="2"></use>
+ <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-vertical-stamp" fill$="[[_getFill(_computedFill, 'SERIES')]]" stroke$="[[_getStroke(_computedFill, 'SERIES')]]" x="0" y="2"></use>
</svg>
</template>
<template is="dom-if" if="[[!_isVertical(node, vertical)]]">
<svg height$="[[height]]" preserveAspectRatio="xMinYMid meet" viewBox="0 0 24 10">
- <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-horizontal-stamp" fill="white" stroke="#ccc" x="0" y="1"></use>
+ <use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#op-series-horizontal-stamp" fill$="[[_getFill(_computedFill, 'SERIES')]]" stroke$="[[_getStroke(_computedFill, 'SERIES')]]" x="0" y="1"></use>
</svg>
</template>
</template>
@@ -9179,7 +9351,36 @@ Polymer({
value: null
},
- /** Type of node to draw. */
+ /**
+ * Render node information associated with this node. Optional. If
+ * specified, this is only used when computing the fill of the icon
+ * element.
+ * @type {tf.graph.render.RenderNodeInfo}
+ */
+ renderInfo: {
+ type: Object,
+ value: null
+ },
+
+ /**
+ * String indicating the type of coloring to use for this node, used
+ * only for deterimining the fill.
+ */
+ colorBy: {
+ type: Object,
+ value: "structural"
+ },
+
+ /**
+ * Function used by structural coloring algorithim to determine which
+ * color to use based on the template ID of the node. Optional.
+ */
+ templateIndex: {
+ type: Function,
+ value: null
+ },
+
+ /** Type of node to draw (ignored if node is set). */
type: {
type: String,
value: null
@@ -9203,11 +9404,70 @@ Polymer({
value: false
},
+ /**
+ * Fill for the icon, optional. If fill is specified and node is not
+ * specified, then this value will override the default for the
+ * element. However, if node is specified, this value will be ignored.
+ */
+ fill: {
+ type: String,
+ value: null
+ },
+
/** Height of the SVG element in pixels, used for scaling. */
height: {
type: Number,
value: 20
+ },
+
+ /** The computed fill for the node. **/
+ _computedFill: {
+ type: String,
+ computed:
+ "_getComputedFill(node, renderInfo, colorBy, templateIndex, fill)"
+ }
+
+ },
+
+ /**
+ * Get the computed fill value for the element.
+ */
+ _getComputedFill: function(inputNode, inputRenderInfo, inputColorBy,
+ inputTemplateIndex, inputFill) {
+ if (inputNode && inputRenderInfo &&
+ inputColorBy && inputTemplateIndex) {
+ var ns = tf.graph.scene.node;
+ var colorBy = ns.ColorBy[inputColorBy.toUpperCase()];
+ return ns.getFillForNode(inputTemplateIndex, colorBy,
+ inputRenderInfo, false);
}
+ return inputFill;
+ },
+
+ /**
+ * Get the fill value for the element, or if that's not possible, return
+ * the default fill value for the node type.
+ */
+ _getFill: function(inputComputedFill, inputNodeType) {
+ return inputComputedFill || ({
+ OP: tf.graph.render.OpNodeColors.DEFAULT_FILL,
+ META: tf.graph.render.MetanodeColors.DEFAULT_FILL,
+ SERIES: tf.graph.render.SeriesNodeColors.DEFAULT_FILL
+ })[inputNodeType];
+ },
+
+ /**
+ * Get the stroke value for the element, or if that's not possible,
+ * return the default stroke value for the node type.
+ */
+ _getStroke: function(inputComputedFill, inputNodeType) {
+ return inputComputedFill ?
+ tf.graph.scene.node.getStrokeForFill(inputComputedFill) :
+ ({
+ OP: tf.graph.render.OpNodeColors.DEFAULT_STROKE,
+ META: tf.graph.render.MetanodeColors.DEFAULT_STROKE,
+ SERIES: tf.graph.render.SeriesNodeColors.DEFAULT_STROKE
+ })[inputNodeType];
},
/**
@@ -9267,7 +9527,7 @@ Polymer({
})();
</script>
</dom-module>
-<dom-module id="tf-node-list-item" assetpath="../components/tf-graph-info/">
+<dom-module id="tf-node-list-item" assetpath="../tf-graph-info/">
<style>
#list-item {
width: 100%;
@@ -9302,7 +9562,7 @@ Polymer({
</style>
<template>
<div id="list-item" on-mouseover="_nodeListener" on-mouseout="_nodeListener" on-click="_nodeListener">
- <tf-graph-icon class="node-icon" node="[[itemNode]]" height="12"></tf-graph-icon>
+ <tf-graph-icon class="node-icon" height="12" color-by="[[colorBy]]" color-by-params="[[colorByParams]]" node="[[itemNode]]" render-info="[[itemRenderInfo]]" template-index="[[templateIndex]]"></tf-graph-icon>
<span title$="[[name]]">[[name]]</span>
</div>
</template>
@@ -9323,11 +9583,19 @@ Polymer({
* @type {tf.graph.Node}
*/
itemNode: Object,
+ /**
+ * The render node information for the item node. Used by the graph
+ * icon in determining fill color.
+ */
+ itemRenderInfo: Object,
name: String,
itemType: {
type: String,
observer: '_itemTypeChanged'
- }
+ },
+ colorBy: String,
+ colorByParams: Object,
+ templateIndex: Function,
},
_itemTypeChanged: function() {
@@ -9351,7 +9619,7 @@ Polymer({
})();
</script>
</dom-module>
-<dom-module id="tf-node-info" assetpath="../components/tf-graph-info/">
+<dom-module id="tf-node-info" assetpath="../tf-graph-info/">
<style>
.sub-list-group {
padding: 8px 12px 0px;
@@ -9432,6 +9700,27 @@ Polymer({
max-width: 20px;
padding: 0;
}
+
+ .toggle-include-group {
+ padding-top: 4px;
+ }
+
+ .toggle-include {
+ margin: 5px 6px;
+ text-transform: none;
+ padding: 4px 6px;
+ font-size: 10pt;
+ background-color: #fafafa;
+ color: #666;
+ }
+
+ .toggle-include:hover {
+ background-color: var(--google-yellow-100);
+ }
+
+ .non-control-list-item {
+ padding-left: 10px;
+ }
</style>
<template>
<paper-item>
@@ -9442,7 +9731,7 @@ Polymer({
<div class="node-name">[[_getNodeName(nodeName)]]</div>
</div>
<div secondary="">
- <tf-graph-icon class="node-icon" node="[[_node]]"></tf-graph-icon>
+ <tf-graph-icon class="node-icon" node="[[_node]]" render-info="[[_getRenderInfo(nodeName, renderHierarchy)]]" color-by="[[colorBy]]" template-index="[[_templateIndex]]"></tf-graph-icon>
<template is="dom-if" if="{{_node.op}}">
<div class="subtitle">
Operation:
@@ -9486,7 +9775,7 @@ Polymer({
(<span>[[_totalPredecessors]]</span>)
<iron-list class="sub-list" id="inputsList" items="[[_predecessors.regular]]">
<template>
- <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" name="[[item]]" item-type="predecessors">
+ <tf-node-list-item class="non-control-list-item" card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" item-render-info="[[_getRenderInfo(item, renderHierarchy)]]" name="[[item]]" item-type="predecessors" color-by="[[colorBy]]" template-index="[[_templateIndex]]">
</tf-node-list-item>
</template>
</iron-list>
@@ -9501,7 +9790,7 @@ Polymer({
<template is="dom-if" if="{{_openedControlPred}}" restamp="true">
<iron-list class="sub-list" items="[[_predecessors.control]]">
<template>
- <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" name="[[item]]" item-type="predecessors">
+ <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" item-render-info="[[_getRenderInfo(item, renderHierarchy)]]" name="[[item]]" item-type="predecessors" color-by="[[colorBy]]" template-index="[[_templateIndex]]">
</tf-node-list-item>
</template>
</iron-list>
@@ -9516,7 +9805,7 @@ Polymer({
(<span>[[_totalSuccessors]]</span>)
<iron-list class="sub-list" id="outputsList" items="[[_successors.regular]]">
<template>
- <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" name="[[item]]" item-type="successor">
+ <tf-node-list-item class="non-control-list-item" card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" item-render-info="[[_getRenderInfo(item, renderHierarchy)]]" name="[[item]]" item-type="successor" color-by="[[colorBy]]" template-index="[[_templateIndex]]">
</tf-node-list-item>
</template>
</iron-list>
@@ -9531,7 +9820,7 @@ Polymer({
<template is="dom-if" if="{{_openedControlSucc}}" restamp="true">
<iron-list class="sub-list" items="[[_successors.control]]">
<template>
- <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" name="[[item]]" item-type="successors">
+ <tf-node-list-item card-node="[[_node]]" item-node="[[_getNode(item, graphHierarchy)]]" item-render-info="[[_getRenderInfo(item, renderHierarchy)]]" name="[[item]]" item-type="successors" color-by="[[colorBy]]" template-index="[[_templateIndex]]">
</tf-node-list-item>
</template>
</iron-list>
@@ -9540,6 +9829,11 @@ Polymer({
</div>
</template>
</div>
+ <div class="toggle-include-group">
+ <paper-button raised="" class="toggle-include" on-click="_toggleInclude">
+ <span>[[_auxButtonText]]</span>
+ </paper-button>
+ </div>
</div>
</template>
</iron-collapse>
@@ -9553,11 +9847,23 @@ Polymer({
properties: {
nodeName: String,
graphHierarchy: Object,
+ renderHierarchy: Object,
+ /** What to color the nodes by (compute time, memory, device etc.) */
+ colorBy: String,
+ _templateIndex: {
+ type: Function,
+ computed: '_getTemplateIndex(graphHierarchy)'
+ },
_node: {
type: Object,
computed: '_getNode(nodeName, graphHierarchy)',
observer: '_resetState'
},
+ // The enum value of the include property of the selected node.
+ nodeInclude: {
+ type: Number,
+ observer: '_nodeIncludeStateChanged'
+ },
_attributes: {
type: Array,
computed: '_getAttributes(_node)'
@@ -9598,18 +9904,25 @@ Polymer({
type: Boolean,
value: false
},
+ _auxButtonText: String
},
expandNode: function() {
this.fire('_node.expand', this.node);
},
- _getNode: function(n, graphHierarchy) {
- return graphHierarchy.node(n);
+ _getTemplateIndex: function(graphHierarchy) {
+ return graphHierarchy.getTemplateIndex();
+ },
+ _getNode: function(nodeName, graphHierarchy) {
+ return graphHierarchy.node(nodeName);
},
_getNodeName: function(nodeName) {
// Insert a zero-width whitespace character before each slash so that
// long node names wrap cleanly at path boundaries.
return (nodeName || '').replace(/\//g, '\u200B/');
},
+ _getRenderInfo: function(nodeName, renderHierarchy) {
+ return this.renderHierarchy.getOrCreateRenderNodeByName(nodeName);
+ },
_getAttributes: function(node) {
this.async(this._resizeList.bind(this, "#attributesList"));
return node && node.attr ? node.attr.map(function(entry) {
@@ -9658,12 +9971,22 @@ Polymer({
if (list) {
list.fire('iron-resize');
}
+ },
+ _toggleInclude: function() {
+ var graphElem = document.querySelector("#graph");
+ graphElem.fire("node-toggle-extract", { name: this.nodeName });
+ var graphBoardElem = document.querySelector("#graphboard");
+ graphBoardElem.fire("node-toggle-extract");
+ },
+ _nodeIncludeStateChanged: function(include, oldInclude) {
+ this.set("_auxButtonText",
+ tf.graph.getIncludeNodeButtonString(include));
}
});
})();
</script>
</dom-module>
-<dom-module id="tf-graph-info" assetpath="../components/tf-graph-info/">
+<dom-module id="tf-graph-info" assetpath="../tf-graph-info/">
<template>
<style>
:host {
@@ -9681,7 +10004,7 @@ h2 {
</style>
<template is="dom-if" if="{{selectedNode}}">
<paper-material elevation="1" class="card">
- <tf-node-info graph-hierarchy="[[graphHierarchy]]" flat-graph="[[graph]]" node-name="[[selectedNode]]" highlighted-node="{{highlightedNode}}">
+ <tf-node-info graph-hierarchy="[[graphHierarchy]]" render-hierarchy="[[renderHierarchy]]" flat-graph="[[graph]]" node-name="[[selectedNode]]" node-include="[[selectedNodeInclude]]" highlighted-node="{{highlightedNode}}" color-by="[[colorBy]]">
</tf-node-info>
</paper-material>
</template>
@@ -9695,6 +10018,8 @@ h2 {
title: String,
graphHierarchy: Object,
graph: Object,
+ renderHierarchy: Object,
+ colorBy: String,
// Two-ways
selectedNode: {
type: String,
@@ -9703,6 +10028,11 @@ h2 {
highlightedNode: {
type: String,
notify: true
+ },
+ // The enum value of the include property of the selected node.
+ selectedNodeInclude: {
+ type: Number,
+ notify: true
}
},
listeners: {
@@ -9723,7 +10053,7 @@ h2 {
})();
</script>
</dom-module>
-<dom-module id="tf-graph-board" assetpath="../components/tf-graph-board/">
+<dom-module id="tf-graph-board" assetpath="../tf-graph-board/">
<template>
<style>
::host {
@@ -9790,6 +10120,32 @@ paper-progress {
--paper-progress-height: 6px;
--paper-progress-active-color: #f3913e;
}
+
+.context-menu {
+ position: absolute;
+ display: none;
+ background-color: #e2e2e2;
+ border-radius: 2px;
+ font-size: 14px;
+ min-width: 150px;
+ border: 1px solid #d4d4d4;
+}
+
+/deep/ .context-menu ul {
+ list-style-type: none;
+ margin: 0;
+ padding: 0;
+ cursor: default;
+}
+
+/deep/ .context-menu ul li {
+ padding: 4px 16px;
+}
+
+/deep/ .context-menu ul li:hover {
+ background-color: #f3913e;
+ color: white;
+}
</style>
<template is="dom-if" if="[[_isNotComplete(progress)]]">
<div id="progress-bar">
@@ -9799,11 +10155,12 @@ paper-progress {
</template>
<div class$="[[_getContainerClass(progress)]]">
<div id="main">
- <tf-graph id="graph" graph-hierarchy="[[graphHierarchy]]" selected-node="{{_selectedNode}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="{{colorByParams}}" graph-name="[[graphName]]" progress="[[progress]]"></tf-graph>
+ <tf-graph id="graph" graph-hierarchy="[[graphHierarchy]]" render-hierarchy="{{_renderHierarchy}}" selected-node="{{_selectedNode}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="{{colorByParams}}" graph-name="[[graphName]]" progress="[[progress]]"></tf-graph>
</div>
<div id="info">
- <tf-graph-info id="graph-info" title="selected" graph-hierarchy="[[graphHierarchy]]" graph="[[graph]]" selected-node="{{_selectedNode}}" highlighted-node="{{_highlightedNode}}"></tf-graph-info>
+ <tf-graph-info id="graph-info" title="selected" graph-hierarchy="[[graphHierarchy]]" render-hierarchy="[[_renderHierarchy]]" graph="[[graph]]" selected-node="{{_selectedNode}}" selected-node-include="{{_selectedNodeInclude}}" highlighted-node="{{_highlightedNode}}" color-by="[[colorBy]]" color-by-params="[[colorByParams]]"></tf-graph-info>
</div>
+ <div class="context-menu"></div>
</div>
</template>
</dom-module>
@@ -9825,14 +10182,24 @@ Polymer({
* for the progress bar and the displayed message.
*/
progress: Object,
+ colorBy: String,
colorByParams: {
type: Object,
notify: true,
},
// Private API: Data routing between child components.
_selectedNode: String,
+ // The enum value of the include property of the selected node.
+ _selectedNodeInclude: Number,
_highlightedNode: String,
+ _renderHierarchy: Object,
},
+ listeners: {
+ 'node-toggle-extract': '_nodeToggleExtract'
+ },
+ observers: [
+ '_updateNodeInclude(_selectedNode)'
+ ],
/** True if the progress is not complete yet (< 100 %). */
_isNotComplete: function(progress) {
return progress.value < 100;
@@ -9846,10 +10213,18 @@ Polymer({
result += ' loading';
}
return result;
+ },
+ _updateNodeInclude: function(nodeName) {
+ var node = this.graphHierarchy.node(nodeName);
+ this.set("_selectedNodeInclude",
+ node ? node.include : tf.graph.InclusionType.UNSPECIFIED);
+ },
+ _nodeToggleExtract: function() {
+ this._updateNodeInclude(this._selectedNode);
}
});
</script>
-<dom-module id="tf-graph-controls" assetpath="../components/tf-graph/">
+<dom-module id="tf-graph-controls" assetpath="../tf-graph/">
<template>
<style>
:host {
@@ -9899,7 +10274,7 @@ table td {
}
.allcontrols {
- padding: 10px;
+ padding: 30px;
}
.legend-holder {
@@ -9908,10 +10283,6 @@ table td {
padding-bottom: 10px;
}
-#fit {
- color: var(--paper-orange-500);
-}
-
paper-radio-button {
padding: 5px;
}
@@ -9979,7 +10350,7 @@ svg.icon {
padding: 0 0 0 55px;
}
-.fit-button-text {
+.button-text {
text-transform: none;
padding: 8px 18px 0 18px;
font-size: 14px
@@ -9992,10 +10363,11 @@ svg.icon {
margin-top: 4px;
}
-.fit-button {
+.iconbutton {
padding: 2px;
width: 30px;
height: 30px;
+ color: var(--paper-orange-500);
}
.hidden-input {
@@ -10011,12 +10383,20 @@ svg.icon {
</style>
<div class="allcontrols">
<div class="control-holder">
- <paper-icon-button id="fit" icon="aspect-ratio" class="fit-button" on-click="fit" alt="Fit to screen">
+ <paper-icon-button icon="aspect-ratio" class="iconbutton" on-click="fit" alt="Fit to screen">
</paper-icon-button>
- <paper-button class="fit-button-text" on-click="fit">Fit to screen
+ <paper-button class="button-text" on-click="fit">Fit to screen
</paper-button>
</div>
<div class="control-holder">
+ <paper-icon-button icon="file-download" class="iconbutton" on-click="download" alt="Download PNG">
+ </paper-icon-button>
+ <paper-button class="button-text" on-click="download">Download PNG
+ </paper-button>
+ <a href="#" id="graphdownload" class="title" download="graph.png">
+ </a>
+ </div>
+ <div class="control-holder">
<div class="title">Run</div>
<paper-dropdown-menu no-label-float="" no-animations="" noink="" class="run-dropdown">
<paper-menu id="select" class="dropdown-content" selected="{{selectedDataset}}">
@@ -10137,8 +10517,8 @@ svg.icon {
</tr>
<tr>
<td>
- <svg class="image-icon">
- <image id="summary-icon" width="24" height="24" x="0" y="0" class="image-icon"></image>
+ <svg class="image-icon" viewBox="0 0 12 12" width="24" height="24">
+ <use x="0" y="0" class="image-icon" xlink:href="#summary-icon"></use>
</svg>
</td>
<td>Summary</td>
@@ -10183,11 +10563,6 @@ svg.icon {
(function() { // Private scope.
Polymer({
is: 'tf-graph-controls',
- ready: function() {
- // Set the url to download the summary icon.
- d3.select(this.$['summary-icon'])
- .attr('xlink:href', this.resolveUrl('../../lib/svg/summary-icon.svg'));
- },
properties: {
// Public API.
hasStats: {
@@ -10207,6 +10582,7 @@ Polymer({
type: Number,
notify: true,
value: 0,
+ observer: '_selectedDatasetChanged'
},
selectedFile: {
type: Object,
@@ -10258,17 +10634,44 @@ Polymer({
endColor: params.endColor
};
},
+ download: function() {
+ this.$.graphdownload.click();
+ },
_updateFileInput: function(e) {
+ var file = e.target.files[0];
+ if (!file) {
+ return;
+ }
+ this._setDownloadFilename(file.name);
this.set('selectedFile', e);
},
_datasetsChanged: function(newDatasets, oldDatasets) {
if (oldDatasets != null || this.selected == null) {
// Select the first dataset by default.
this.set('selectedDataset', 0);
+ this._setDownloadFilename(this.datasets[this.selectedDataset].path);
+ }
+ },
+ _selectedDatasetChanged: function(newDataset, oldDataset) {
+ if (this.datasets) {
+ this._setDownloadFilename(this.datasets[newDataset].path);
}
},
_getFile: function() {
this.$.file.click();
+ },
+ _setDownloadFilename: function(graphPath) {
+ // Strip off everything before the last "/" and strip off the file
+ // extension in order to get the name of the PNG for the graph.
+ var dotIndex = graphPath.lastIndexOf('.');
+ if (dotIndex) {
+ graphPath = graphPath.substring(0, dotIndex);
+ }
+ var slashIndex = graphPath.lastIndexOf('/');
+ if (slashIndex) {
+ graphPath = graphPath.substring(slashIndex + 1);
+ }
+ this.$.graphdownload.setAttribute('download', graphPath + '.png');
}
});
@@ -10311,7 +10714,7 @@ function convertToHumanReadable(value, units, unitIndex) {
})(); // Closing private scope.
</script>
</dom-module>
-<dom-module id="tf-graph-dashboard" assetpath="../components/tf-graph-dashboard/">
+<dom-module id="tf-graph-dashboard" assetpath="../tf-graph-dashboard/">
<template>
<div id="plumbing">
<tf-url-generator out-runs-url="{{_runsUrl}}" out-graph-url-generator="{{_graphUrlGen}}" id="urlGenerator"></tf-url-generator>
@@ -10348,6 +10751,7 @@ function convertToHumanReadable(value, units, unitIndex) {
}
.center {
+ position: relative;
height: 100%;
}
@@ -10386,15 +10790,13 @@ Polymer({
<paper-header-panel>
<paper-toolbar id="toolbar">
<div id="toolbar-content">
- <div class="toolbar-title">
- TensorBoard
- </div>
- <div class="right-buttons">
- <paper-button class="link-button" on-click="chooseEvents" active$="[[eventDashboard(mode)]]" noink="">Events</paper-button>
- <paper-button class="link-button" on-click="chooseImages" active$="[[imageDashboard(mode)]]" noink="">Images</paper-button>
- <paper-button class="link-button" on-click="chooseGraphs" active$="[[graphDashboard(mode)]]" noink="">Graph</paper-button>
- <paper-button class="link-button" on-click="chooseHistograms" active$="[[histogramDashboard(mode)]]" noink="">Histograms</paper-button>
- </div>
+ <div class="toolbar-title">TensorBoard</div>
+ <paper-tabs selected="0" noink="" class="tabs">
+ <paper-tab on-click="chooseEvents">Events</paper-tab>
+ <paper-tab on-click="chooseImages">Images</paper-tab>
+ <paper-tab on-click="chooseGraphs">Graph</paper-tab>
+ <paper-tab on-click="chooseHistograms">Histograms</paper-tab>
+ </paper-tabs>
</div>
</paper-toolbar>
<div id="content" class="fit">
@@ -10416,33 +10818,48 @@ Polymer({
</div>
</paper-header-panel>
<style>
+ :host {
+ height: 100%;
+ display: block;
+ background-color: var(--paper-grey-100);
+ }
+
#toolbar {
background-color: var(--tb-orange-strong);
- background-image: radial-gradient(ellipse, var(--tb-orange-weak), var(--tb-orange-strong));
+ -webkit-font-smoothing: antialiased;
}
+
#toolbar-content {
width: 100%;
+ height: 100%;
display: flex;
flex-direction: row;
justify-content: space-between;
align-items: center;
}
+
.toolbar-title {
- font-size: 30px;
+ font-size: 20px;
+ margin-left: 10px;
+ text-rendering: optimizeLegibility;
+ letter-spacing: -0.025em;
+ font-weight: 500;
}
+
#content {
height: 100%;
}
- .link-button {
- height: 30px;
- }
- [active] {
- font-weight: bold;
- }
- :host {
+
+ .tabs {
+ width: 400px;
+ text-transform: uppercase;
height: 100%;
- display: block;
}
+
+ paper-tabs {
+ --paper-tabs-selection-bar-color: white;
+ }
+
</style>
</template>
<script>
diff --git a/tensorflow/tensorboard/gulpfile.js b/tensorflow/tensorboard/gulpfile.js
index 61387e730b..867fc2f5ef 100644
--- a/tensorflow/tensorboard/gulpfile.js
+++ b/tensorflow/tensorboard/gulpfile.js
@@ -27,6 +27,7 @@ var gulpFilter = require('gulp-filter');
var vulcanize = require('gulp-vulcanize');
var minimist = require('minimist');
var replace = require('gulp-replace');
+var header = require('gulp-header');
var fs = require('fs');
var path = require('path');
var options = minimist(process.argv.slice(2), {
@@ -162,16 +163,9 @@ gulp.task('vulcanize', ['compile.all', 'tslint-strict'], function() {
// fixes https://github.com/Polymer/vulcanize/issues/273
.pipe(replace(linkRegex, ''))
.pipe(replace(scriptRegex, ''))
- .pipe(gulp.dest('dist'));
+ .pipe(header('// AUTOGENERATED FILE - DO NOT MODIFY \n'))
+ .pipe(gulp.dest('../opensource_only/tensorboard'));
- // Vulcanize TensorBoard with all external libraries inlined.
- gulp.src('components/index.html')
- .pipe(vulcanize({
- inlineScripts: true,
- inlineCss: true,
- stripComments: true,
- }))
- .pipe(gulp.dest('dist'));
gulp.src('app/tf-tensorboard-demo.html')
.pipe(vulcanize({
diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json
index 3b87007522..492fd36053 100644
--- a/tensorflow/tensorboard/package.json
+++ b/tensorflow/tensorboard/package.json
@@ -12,19 +12,21 @@
"license": "Apache-2.0",
"devDependencies": {
"gulp": "~3.9.0",
- "gulp-typescript": "~2.8.0",
- "tsd": "~0.6.3",
- "typescript": "~1.5.3",
- "gulp-cli": "~0.3.0",
- "gulp-util": "~3.0.6",
- "gulp-tslint": "~3.1.1-beta",
- "gulp-server-livereload": "~1.4.0",
+ "gulp-cli": "^1.1.0",
+ "gulp-filter": "~3.0.1",
+ "gulp-replace": "~0.5.4",
+ "gulp-server-livereload": "~1.5.4",
+ "gulp-tslint": "~4.2.2",
+ "gulp-typescript": "~2.10.0",
+ "gulp-util": "~3.0.7",
+ "gulp-vulcanize": "~6.1.0",
"merge2": "~0.3.6",
- "gulp-filter": "~3.0.0",
- "vulcanize": "~1.14.0",
- "gulp-vulcanize": "~6.0.1",
"minimist": "~1.2.0",
- "gulp-replace": "~0.5.4",
- "web-component-tester": "~3.3.30"
+ "tsd": "^0.6.5",
+ "tslint": "^3.2.1",
+ "typescript": "^1.6.2",
+ "vulcanize": "^1.14.0",
+ "web-component-tester": "~3.4.2",
+ "gulp-header": "~1.7.1"
}
}
diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/tensorboard_handler.py
index 8c4a9e7689..ae190fef7b 100644
--- a/tensorflow/tensorboard/tensorboard_handler.py
+++ b/tensorflow/tensorboard/tensorboard_handler.py
@@ -44,6 +44,8 @@ from tensorflow.python.platform import resource_loader
from tensorflow.python.summary import event_accumulator
from tensorflow.tensorboard import float_wrapper
+
+DATA_PREFIX = '/data'
RUNS_ROUTE = '/runs'
SCALARS_ROUTE = '/' + event_accumulator.SCALARS
IMAGES_ROUTE = '/' + event_accumulator.IMAGES
@@ -51,6 +53,7 @@ HISTOGRAMS_ROUTE = '/' + event_accumulator.HISTOGRAMS
COMPRESSED_HISTOGRAMS_ROUTE = '/' + event_accumulator.COMPRESSED_HISTOGRAMS
INDIVIDUAL_IMAGE_ROUTE = '/individualImage'
GRAPH_ROUTE = '/' + event_accumulator.GRAPH
+TAB_ROUTES = ['', '/events', '/images', '/graphs', '/histograms']
_IMGHDR_TO_MIMETYPE = {
'bmp': 'image/bmp',
@@ -373,32 +376,34 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
if clean_path.endswith('/'):
clean_path = clean_path[:-1]
- handlers = {
- SCALARS_ROUTE: self._serve_scalars,
- GRAPH_ROUTE: self._serve_graph,
- HISTOGRAMS_ROUTE: self._serve_histograms,
- COMPRESSED_HISTOGRAMS_ROUTE: self._serve_compressed_histograms,
- IMAGES_ROUTE: self._serve_images,
- INDIVIDUAL_IMAGE_ROUTE: self._serve_image,
- RUNS_ROUTE: self._serve_runs,
- '': self._serve_index,
+ data_handlers = {
+ DATA_PREFIX + SCALARS_ROUTE: self._serve_scalars,
+ DATA_PREFIX + GRAPH_ROUTE: self._serve_graph,
+ DATA_PREFIX + HISTOGRAMS_ROUTE: self._serve_histograms,
+ DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE:
+ self._serve_compressed_histograms,
+ DATA_PREFIX + IMAGES_ROUTE: self._serve_images,
+ DATA_PREFIX + INDIVIDUAL_IMAGE_ROUTE: self._serve_image,
+ DATA_PREFIX + RUNS_ROUTE: self._serve_runs,
'/app.js': self._serve_js
}
- if clean_path in handlers:
- query_params = urlparse.parse_qs(parsed_url.query)
- # parse_qs returns a list of values for each key; we're only interested in
- # the first.
- for key in query_params:
- value_count = len(query_params[key])
- if value_count != 1:
- self.send_error(
- 400,
- 'query parameter %s should have exactly one value, had %d' %
- (key, value_count))
- return
-
- query_params[key] = query_params[key][0]
- handlers[clean_path](query_params)
+ query_params = urlparse.parse_qs(parsed_url.query)
+ # parse_qs returns a list of values for each key; we're only interested in
+ # the first.
+ for key in query_params:
+ value_count = len(query_params[key])
+ if value_count != 1:
+ self.send_error(
+ 400,
+ 'query parameter %s should have exactly one value, had %d' %
+ (key, value_count))
+ return
+ query_params[key] = query_params[key][0]
+
+ if clean_path in data_handlers:
+ data_handlers[clean_path](query_params)
+ elif clean_path in TAB_ROUTES:
+ self._serve_index(query_params)
else:
self._serve_static_file(clean_path)
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index 942240bd82..908764d2f7 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/Eigen/Cholesky"
+#include "external/eigen_archive/eigen-eigen-a0661a2bb165/Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index e9896a5fba..c78b7c95ee 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/Eigen/Core"
+#include "external/eigen_archive/eigen-eigen-a0661a2bb165/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index 5db8b147c6..235e34cd5f 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/Eigen/Eigenvalues"
+#include "external/eigen_archive/eigen-eigen-a0661a2bb165/Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index 25e4ebf4f5..cdf52403a3 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/Eigen/LU"
+#include "external/eigen_archive/eigen-eigen-a0661a2bb165/Eigen/LU"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index 8f4bbd7ee9..72e6fa6663 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1 +1 @@
-#include "external/eigen_archive/eigen-eigen-ce5a455b34c0/unsupported/Eigen/CXX11/Tensor"
+#include "external/eigen_archive/eigen-eigen-a0661a2bb165/unsupported/Eigen/CXX11/Tensor"