aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-05-05 18:31:20 -0700
committerGravatar GitHub <noreply@github.com>2017-05-05 18:31:20 -0700
commit6bfbcf31dce9a59acfcad51d905894b082989012 (patch)
tree632f146b3ff887d6c7b0f2ac3aa1d4dc5203d175
parentcae8ed1ca54a9fd4f9cc64d08cadebce31fd4607 (diff)
parent97d550a72de8f0dbd43e289a50231930e3526c91 (diff)
Merge pull request #9706 from vrv/branch_155249446
Branch 155249446
-rw-r--r--.gitignore3
-rw-r--r--tensorflow/BUILD43
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc3
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_ops.py2
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py334
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py354
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc24
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h95
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py65
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py20
-rw-r--r--tensorflow/contrib/kernel_methods/README.md55
-rw-r--r--tensorflow/contrib/kernel_methods/g3doc/acc-vs-trn_time.pngbin0 -> 17992 bytes
-rw-r--r--tensorflow/contrib/kernel_methods/g3doc/acc_vs_outdim.pngbin0 -> 19835 bytes
-rw-r--r--tensorflow/contrib/kernel_methods/g3doc/kernel_mapping.jpgbin0 -> 7330 bytes
-rw-r--r--tensorflow/contrib/kernel_methods/g3doc/tutorial.md273
-rw-r--r--tensorflow/contrib/layers/python/layers/encoders.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py314
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops.py140
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py51
-rw-r--r--tensorflow/contrib/layers/python/layers/initializers.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/kmeans.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py83
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py34
-rw-r--r--tensorflow/contrib/signal/BUILD4
-rw-r--r--tensorflow/contrib/signal/__init__.py2
-rw-r--r--tensorflow/contrib/signal/python/__init__.py1
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py20
-rw-r--r--tensorflow/contrib/signal/python/ops/__init__.py1
-rw-r--r--tensorflow/contrib/signal/python/ops/shape_ops.py46
-rw-r--r--tensorflow/contrib/testing/python/framework/fake_summary_writer.py2
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/BUILD23
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py445
-rw-r--r--tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py164
-rw-r--r--tensorflow/contrib/verbs/rdma.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc4
-rw-r--r--tensorflow/core/grappler/costs/BUILD37
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc128
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.h63
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc5
-rw-r--r--tensorflow/core/grappler/costs/utils.cc4
-rw-r--r--tensorflow/core/grappler/costs/virtual_placer.cc57
-rw-r--r--tensorflow/core/grappler/costs/virtual_placer.h45
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.cc565
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.h8
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_test.cc6
-rw-r--r--tensorflow/core/kernels/cwise_op_atan2.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_atan2.cu.cc2
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.cc6
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt25
-rw-r--r--tensorflow/core/ops/ops.pbtxt27
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.distributions.bijector.md33
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md33
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.distributions.md2
-rw-r--r--tensorflow/docs_src/api_guides/python/index.md2
-rw-r--r--tensorflow/docs_src/performance/benchmarks.md143
-rw-r--r--tensorflow/examples/image_retraining/retrain.py5
-rw-r--r--tensorflow/go/op/wrappers.go157
-rw-r--r--tensorflow/opensource_only/eigen.threadpool1
-rw-r--r--tensorflow/python/estimator/estimator_test.py4
-rw-r--r--tensorflow/python/feature_column/feature_column.py102
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py194
-rw-r--r--tensorflow/python/framework/function.py138
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/tensor_priority_test.py37
-rw-r--r--tensorflow/python/layers/convolutional.py73
-rw-r--r--tensorflow/python/layers/convolutional_test.py17
-rw-r--r--tensorflow/python/ops/ctc_ops.py3
-rw-r--r--tensorflow/python/ops/init_ops.py1
-rw-r--r--tensorflow/python/ops/nn_ops.py12
-rw-r--r--tensorflow/python/saved_model/loader_impl.py87
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py21
-rw-r--r--tensorflow/python/tools/import_pb_to_tensorboard.py22
-rw-r--r--tensorflow/python/tools/saved_model_cli.py9
-rw-r--r--tensorflow/python/tools/saved_model_cli_test.py10
-rw-r--r--tensorflow/python/training/checkpoint_utils.py10
-rw-r--r--tensorflow/python/training/saver.py4
-rw-r--r--tensorflow/tensorboard/backend/application_test.py8
-rw-r--r--tensorflow/tensorboard/components/tf_audio_dashboard/BUILD63
-rw-r--r--tensorflow/tensorboard/components/tf_audio_dashboard/demo/BUILD26
-rw-r--r--tensorflow/tensorboard/components/tf_audio_dashboard/demo/data/BUILD17
-rw-r--r--tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-loader.html2
-rw-r--r--tensorflow/tensorboard/components/tf_backend/BUILD81
-rw-r--r--tensorflow/tensorboard/components/tf_backend_d3v4/BUILD45
-rw-r--r--tensorflow/tensorboard/components/tf_color_scale/BUILD65
-rw-r--r--tensorflow/tensorboard/components/tf_color_scale/demo/BUILD26
-rw-r--r--tensorflow/tensorboard/components/tf_color_scale_d3v4/BUILD72
-rw-r--r--tensorflow/tensorboard/components/tf_dashboard_common/BUILD102
-rw-r--r--tensorflow/tensorboard/components/tf_dashboard_common/demo/BUILD31
-rw-r--r--tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html2
-rw-r--r--tensorflow/tensorboard/components/tf_dashboard_common_d3v4/BUILD114
-rw-r--r--tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html2
-rw-r--r--tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD63
-rw-r--r--tensorflow/tensorboard/components/tf_distribution_dashboard/demo/BUILD26
-rw-r--r--tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data/BUILD17
-rw-r--r--tensorflow/tensorboard/components/tf_globals/BUILD49
-rw-r--r--tensorflow/tensorboard/components/tf_globals_d3v4/BUILD16
-rw-r--r--tensorflow/tensorboard/components/tf_graph_common/BUILD65
-rw-r--r--tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html2
-rw-r--r--tensorflow/tensorboard/components/tf_graph_info/tf-graph-info.html2
-rw-r--r--tensorflow/tensorboard/components/tf_histogram_dashboard/BUILD62
-rw-r--r--tensorflow/tensorboard/components/tf_histogram_dashboard/demo/BUILD26
-rw-r--r--tensorflow/tensorboard/components/tf_histogram_dashboard/demo/data/BUILD17
-rw-r--r--tensorflow/tensorboard/components/tf_image_dashboard/BUILD59
-rw-r--r--tensorflow/tensorboard/components/tf_image_dashboard/demo/BUILD25
-rw-r--r--tensorflow/tensorboard/components/tf_image_dashboard/demo/data/BUILD17
-rw-r--r--tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html2
-rw-r--r--tensorflow/tensorboard/components/tf_imports/BUILD120
-rw-r--r--tensorflow/tensorboard/components/tf_imports_d3v4/BUILD78
-rw-r--r--tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD77
-rw-r--r--tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD26
-rw-r--r--tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/BUILD17
-rw-r--r--tensorflow/tensorboard/components/tf_storage/BUILD70
-rw-r--r--tensorflow/tensorboard/components/tf_storage_d3v4/BUILD35
-rw-r--r--tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html2
-rw-r--r--tensorflow/tensorboard/components/tf_text_dashboard/BUILD60
-rw-r--r--tensorflow/tensorboard/components/tf_text_dashboard/demo/BUILD25
-rw-r--r--tensorflow/tensorboard/components/tf_text_dashboard/demo/data/BUILD17
-rw-r--r--tensorflow/tensorboard/components/vz_data_summary/BUILD91
-rw-r--r--tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE34
-rw-r--r--tensorflow/tensorboard/components/vz_distribution_chart/BUILD69
-rw-r--r--tensorflow/tensorboard/components/vz_distribution_chart/demo/BUILD24
-rw-r--r--tensorflow/tensorboard/components/vz_distribution_chart_d3v4/BUILD66
-rw-r--r--tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD50
-rw-r--r--tensorflow/tensorboard/components/vz_histogram_timeseries/demo/BUILD25
-rw-r--r--tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4/BUILD63
-rw-r--r--tensorflow/tensorboard/components/vz_line_chart/BUILD73
-rw-r--r--tensorflow/tensorboard/components/vz_line_chart/demo/BUILD24
-rw-r--r--tensorflow/tensorboard/components/vz_line_chart_d3v4/BUILD66
-rw-r--r--tensorflow/tensorboard/components/vz_projector/BUILD180
-rw-r--r--tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE19
-rw-r--r--tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html2
-rw-r--r--tensorflow/tensorboard/components/vz_projector_d3v4/BUILD179
-rw-r--r--tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html2
-rw-r--r--tensorflow/tensorboard/components/vz_sorting/BUILD50
-rw-r--r--tensorflow/tensorboard/components/vz_sorting/test/BUILD40
-rw-r--r--tensorflow/tensorboard/components/vz_sorting_d3v4/BUILD27
-rw-r--r--tensorflow/tensorboard/defs.bzl20
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt2
-rw-r--r--tensorflow/tools/quantization/quantize_graph.py3
-rw-r--r--tensorflow/workspace.bzl373
-rw-r--r--third_party/llvm/llvm.BUILD2
-rw-r--r--third_party/pprof.BUILD18
151 files changed, 6598 insertions, 1292 deletions
diff --git a/.gitignore b/.gitignore
index e3d6df93c1..bdcb067fc2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,11 +7,8 @@ node_modules
/bazel_pip
/third_party/eigen3/mkl_include
/third_party/mkl/*
-/third_party/py/numpy/numpy_include
/tools/python_bin_path.sh
/tools/git/gen
-/util/python/python_include
-/util/python/python_lib
/pip_test
/_python_build
*.pyc
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index e03ba669ee..64ab215758 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -263,6 +263,7 @@ filegroup(
"//tensorflow/contrib/seq2seq:all_files",
"//tensorflow/contrib/session_bundle:all_files",
"//tensorflow/contrib/session_bundle/example:all_files",
+ "//tensorflow/contrib/signal:all_files",
"//tensorflow/contrib/slim:all_files",
"//tensorflow/contrib/slim/python/slim/data:all_files",
"//tensorflow/contrib/slim/python/slim/nets:all_files",
@@ -326,6 +327,48 @@ filegroup(
"//tensorflow/tensorboard/backend:all_files",
"//tensorflow/tensorboard/backend/event_processing:all_files",
"//tensorflow/tensorboard/components:all_files",
+ "//tensorflow/tensorboard/components/tf_audio_dashboard:all_files",
+ "//tensorflow/tensorboard/components/tf_audio_dashboard/demo:all_files",
+ "//tensorflow/tensorboard/components/tf_backend:all_files",
+ "//tensorflow/tensorboard/components/tf_backend_d3v4:all_files",
+ "//tensorflow/tensorboard/components/tf_color_scale:all_files",
+ "//tensorflow/tensorboard/components/tf_color_scale/demo:all_files",
+ "//tensorflow/tensorboard/components/tf_color_scale_d3v4:all_files",
+ "//tensorflow/tensorboard/components/tf_dashboard_common:all_files",
+ "//tensorflow/tensorboard/components/tf_dashboard_common/demo:all_files",
+ "//tensorflow/tensorboard/components/tf_dashboard_common_d3v4:all_files",
+ "//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files",
+ "//tensorflow/tensorboard/components/tf_distribution_dashboard/demo:all_files",
+ "//tensorflow/tensorboard/components/tf_globals:all_files",
+ "//tensorflow/tensorboard/components/tf_globals_d3v4:all_files",
+ "//tensorflow/tensorboard/components/tf_graph_common:all_files",
+ "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files",
+ "//tensorflow/tensorboard/components/tf_histogram_dashboard/demo:all_files",
+ "//tensorflow/tensorboard/components/tf_image_dashboard:all_files",
+ "//tensorflow/tensorboard/components/tf_image_dashboard/demo:all_files",
+ "//tensorflow/tensorboard/components/tf_imports:all_files",
+ "//tensorflow/tensorboard/components/tf_imports_d3v4:all_files",
+ "//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files",
+ "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files",
+ "//tensorflow/tensorboard/components/tf_storage:all_files",
+ "//tensorflow/tensorboard/components/tf_storage_d3v4:all_files",
+ "//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
+ "//tensorflow/tensorboard/components/tf_text_dashboard/demo:all_files",
+ "//tensorflow/tensorboard/components/vz_data_summary:all_files",
+ "//tensorflow/tensorboard/components/vz_distribution_chart:all_files",
+ "//tensorflow/tensorboard/components/vz_distribution_chart/demo:all_files",
+ "//tensorflow/tensorboard/components/vz_distribution_chart_d3v4:all_files",
+ "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files",
+ "//tensorflow/tensorboard/components/vz_histogram_timeseries/demo:all_files",
+ "//tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4:all_files",
+ "//tensorflow/tensorboard/components/vz_line_chart:all_files",
+ "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
+ "//tensorflow/tensorboard/components/vz_line_chart_d3v4:all_files",
+ "//tensorflow/tensorboard/components/vz_projector:all_files",
+ "//tensorflow/tensorboard/components/vz_projector_d3v4:all_files",
+ "//tensorflow/tensorboard/components/vz_sorting:all_files",
+ "//tensorflow/tensorboard/components/vz_sorting/test:all_files",
+ "//tensorflow/tensorboard/components/vz_sorting_d3v4:all_files",
"//tensorflow/tensorboard/lib:all_files",
"//tensorflow/tensorboard/plugins:all_files",
"//tensorflow/tensorboard/plugins/projector:all_files",
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 82de1c835b..2a1d728bc8 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -28,11 +28,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-namespace op = xla::testing::opcode_matchers;
-
namespace xla {
namespace {
+namespace op = xla::testing::opcode_matchers;
+
+using ::testing::_;
+
class HloRematerializationTest : public HloTestBase {
protected:
// Creates and returns a computation which can benefit from
@@ -145,11 +147,9 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// Find and save the original broadcast instruction which should be
// rematerialized.
const HloInstruction* slice = computation->root_instruction();
- ASSERT_EQ(HloOpcode::kSlice, slice->opcode());
+ ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
const HloInstruction* concat = slice->operand(0);
- ASSERT_EQ(HloOpcode::kConcatenate, concat->opcode());
const HloInstruction* bcast = concat->operand(0);
- ASSERT_EQ(HloOpcode::kBroadcast, bcast->opcode());
SequentialHloOrdering::HloModuleSequence sequence;
// Computation requires 16KB without rematerialization, but uses only 12KB
@@ -165,8 +165,7 @@ TEST_F(HloRematerializationTest, SingleComputation) {
// The broadcast should have been rematerialized.
const HloInstruction* remat_bcast = concat->operand(0);
- EXPECT_EQ(HloOpcode::kBroadcast, remat_bcast->opcode());
- EXPECT_NE(bcast, remat_bcast);
+ EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
// The rematerialized broadcast should be immediate before the concat in the
// sequence.
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index da07dea123..2b14eca5d1 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -68,9 +68,8 @@ void CleanNodeName(string* name) {
}
Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
- LOG(INFO) << "Adding computation " << computation.name();
+ VLOG(2) << "Adding computation " << computation.name();
for (auto embedded : computation.MakeEmbeddedComputationsList()) {
- LOG(INFO) << "Adding embedded computation " << embedded->name();
for (auto& instruction : embedded->instructions()) {
TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
}
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
index fc5270078c..b092eab316 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
@@ -85,7 +85,7 @@ def _init_clusters_random(data, num_clusters, random_seed):
maxval=math_ops.cast(num_data, dtypes.int64),
seed=random_seed,
dtype=dtypes.int64)
- indices = indices % math_ops.cast(num_data, dtypes.int64)
+ indices %= math_ops.cast(num_data, dtypes.int64)
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
return clusters_init
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index 6455e01894..280271a42d 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -35,8 +35,8 @@ class GridRNNCellTest(test.TestCase):
def testGrid2BasicLSTMCell(self):
with self.test_session(use_gpu=False) as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.2)) as root_scope:
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.2)) as root_scope:
x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
@@ -51,21 +51,22 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
- (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
+ })
self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2))
self.assertEqual(res_s[1].c.shape, (1, 2))
self.assertEqual(res_s[1].h.shape, (1, 2))
- self.assertAllClose(res_g, ([[0.36617181, 0.36617181]], ))
- self.assertAllClose(res_s, (([[0.71053141, 0.71053141]],
- [[0.36617181, 0.36617181]]),
- ([[0.72320831, 0.80555487]],
- [[0.39102408, 0.42150158]])))
+ self.assertAllClose(res_g, ([[0.36617181, 0.36617181]],))
+ self.assertAllClose(
+ res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
+ ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
# emulate a loop through the input sequence,
# where we call cell() multiple times
@@ -78,22 +79,22 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s2[1].h.get_shape(), (1, 2))
res_g2, res_s2 = sess.run([g2, s2],
- {x: np.array([[2., 2., 2.]]), m: res_s})
+ {x: np.array([[2., 2., 2.]]),
+ m: res_s})
self.assertEqual(res_g2[0].shape, (1, 2))
self.assertEqual(res_s2[0].c.shape, (1, 2))
self.assertEqual(res_s2[0].h.shape, (1, 2))
self.assertEqual(res_s2[1].c.shape, (1, 2))
self.assertEqual(res_s2[1].h.shape, (1, 2))
self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]])
- self.assertAllClose(res_s2, (([[1.40469193, 1.40469193]],
- [[0.58847463, 0.58847463]]),
- ([[0.97726452, 1.04626071]],
- [[0.4927212, 0.51137757]])))
+ self.assertAllClose(
+ res_s2, (([[1.40469193, 1.40469193]], [[0.58847463, 0.58847463]]),
+ ([[0.97726452, 1.04626071]], [[0.4927212, 0.51137757]])))
def testGrid2BasicLSTMCellTied(self):
with self.test_session(use_gpu=False) as sess:
with variable_scope.variable_scope(
- 'root', initializer=init_ops.constant_initializer(0.2)):
+ 'root', initializer=init_ops.constant_initializer(0.2)):
x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
@@ -108,10 +109,12 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
- (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
+ })
self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2))
@@ -119,29 +122,27 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s[1].h.shape, (1, 2))
self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]])
- self.assertAllClose(res_s, (([[0.71053141, 0.71053141]],
- [[0.36617181, 0.36617181]]),
- ([[0.72320831, 0.80555487]],
- [[0.39102408, 0.42150158]])))
+ self.assertAllClose(
+ res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
+ ([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
res_g, res_s = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res_s})
self.assertEqual(res_g[0].shape, (1, 2))
self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]])
- self.assertAllClose(res_s, (([[0.71200621, 0.71200621]],
- [[0.36703536, 0.36703536]]),
- ([[0.80941606, 0.87550586]],
- [[0.40108523, 0.42199609]])))
+ self.assertAllClose(
+ res_s, (([[0.71200621, 0.71200621]], [[0.36703536, 0.36703536]]),
+ ([[0.80941606, 0.87550586]], [[0.40108523, 0.42199609]])))
def testGrid2BasicLSTMCellWithRelu(self):
with self.test_session(use_gpu=False) as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.2)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.2)):
x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid2BasicLSTMCell(
2, tied=False, non_recurrent_fn=nn_ops.relu)
- self.assertEqual(cell.state_size, ((2, 2), ))
+ self.assertEqual(cell.state_size, ((2, 2),))
g, s = cell(x, m)
self.assertEqual(g[0].get_shape(), (1, 2))
@@ -149,21 +150,22 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
+ res_g, res_s = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
+ })
self.assertEqual(res_g[0].shape, (1, 2))
self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]])
self.assertAllClose(res_s, (([[0.29530135, 0.37520045]],
- [[0.17044567, 0.21292259]]), ))
+ [[0.17044567, 0.21292259]]),))
"""LSTMCell
"""
def testGrid2LSTMCell(self):
with self.test_session(use_gpu=False) as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
@@ -178,10 +180,12 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
- (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
+ })
self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2))
@@ -189,15 +193,14 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s[1].h.shape, (1, 2))
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
- self.assertAllClose(res_s, (([[2.41515064, 2.41515064]],
- [[0.95686918, 0.95686918]]),
- ([[1.38917875, 1.49043763]],
- [[0.83884692, 0.86036491]])))
+ self.assertAllClose(
+ res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
+ ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellTied(self):
with self.test_session(use_gpu=False) as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
@@ -212,10 +215,12 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
- (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
+ })
self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2))
@@ -223,15 +228,14 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s[1].h.shape, (1, 2))
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
- self.assertAllClose(res_s, (([[2.41515064, 2.41515064]],
- [[0.95686918, 0.95686918]]),
- ([[1.38917875, 1.49043763]],
- [[0.83884692, 0.86036491]])))
+ self.assertAllClose(
+ res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
+ ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellWithRelu(self):
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid2LSTMCell(
@@ -244,21 +248,22 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
+ res_g, res_s = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
+ })
self.assertEqual(res_g[0].shape, (1, 2))
self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]])
self.assertAllClose(res_s, (([[0.92270052, 1.02325559]],
- [[0.66159075, 0.70475441]]), ))
+ [[0.66159075, 0.70475441]]),))
"""RNNCell
"""
def testGrid2BasicRNNCell(self):
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
cell = grid_rnn_cell.Grid2BasicRNNCell(2)
@@ -270,26 +275,26 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].get_shape(), (2, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1.], [2., 2.]]),
- m: (np.array([[0.1, 0.1], [0.2, 0.2]]),
- np.array([[0.1, 0.1], [0.2, 0.2]]))})
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1.], [2., 2.]]),
+ m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
+ [0.2, 0.2]]))
+ })
self.assertEqual(res_g[0].shape, (2, 2))
self.assertEqual(res_s[0].shape, (2, 2))
self.assertEqual(res_s[1].shape, (2, 2))
self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
- [0.99480951, 0.99480951]], ))
- self.assertAllClose(res_s,
- ([[0.94685763, 0.94685763],
- [0.99480951, 0.99480951]],
- [[0.80049908, 0.80049908],
- [0.97574311, 0.97574311]]))
+ [0.99480951, 0.99480951]],))
+ self.assertAllClose(
+ res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
+ [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellTied(self):
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True)
@@ -301,55 +306,55 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[1].get_shape(), (2, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1.], [2., 2.]]),
- m: (np.array([[0.1, 0.1], [0.2, 0.2]]),
- np.array([[0.1, 0.1], [0.2, 0.2]]))})
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1.], [2., 2.]]),
+ m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
+ [0.2, 0.2]]))
+ })
self.assertEqual(res_g[0].shape, (2, 2))
self.assertEqual(res_s[0].shape, (2, 2))
self.assertEqual(res_s[1].shape, (2, 2))
self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
- [0.99480951, 0.99480951]], ))
- self.assertAllClose(res_s,
- ([[0.94685763, 0.94685763],
- [0.99480951, 0.99480951]],
- [[0.80049908, 0.80049908],
- [0.97574311, 0.97574311]]))
+ [0.99480951, 0.99480951]],))
+ self.assertAllClose(
+ res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
+ [[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self):
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
- m = (array_ops.zeros([1, 2]), )
- cell = grid_rnn_cell.Grid2BasicRNNCell(
- 2, non_recurrent_fn=nn_ops.relu)
- self.assertEqual(cell.state_size, (2, ))
+ m = (array_ops.zeros([1, 2]),)
+ cell = grid_rnn_cell.Grid2BasicRNNCell(2, non_recurrent_fn=nn_ops.relu)
+ self.assertEqual(cell.state_size, (2,))
g, s = cell(x, m)
self.assertEqual(g[0].get_shape(), (1, 2))
self.assertEqual(s[0].get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run([g, s], {x: np.array([[1., 1.]]),
- m: np.array([[0.1, 0.1]])})
+ res_g, res_s = sess.run(
+ [g, s], {x: np.array([[1., 1.]]),
+ m: np.array([[0.1, 0.1]])})
self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].shape, (1, 2))
- self.assertAllClose(res_g, ([[1.80049896, 1.80049896]], ))
- self.assertAllClose(res_s, ([[0.80049896, 0.80049896]], ))
+ self.assertAllClose(res_g, ([[1.80049896, 1.80049896]],))
+ self.assertAllClose(res_s, ([[0.80049896, 0.80049896]],))
"""1-LSTM
"""
def testGrid1LSTMCell(self):
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)) as root_scope:
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
x = array_ops.zeros([1, 3])
- m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), )
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True)
- self.assertEqual(cell.state_size, ((2, 2), ))
+ self.assertEqual(cell.state_size, ((2, 2),))
g, s = cell(x, m)
self.assertEqual(g[0].get_shape(), (1, 2))
@@ -357,17 +362,17 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
+ res_g, res_s = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
+ })
self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2))
- self.assertAllClose(res_g, ([[0.91287315, 0.91287315]], ))
- self.assertAllClose(res_s,
- (([[2.26285243, 2.26285243]],
- [[0.91287315, 0.91287315]]), ))
+ self.assertAllClose(res_g, ([[0.91287315, 0.91287315]],))
+ self.assertAllClose(res_s, (([[2.26285243, 2.26285243]],
+ [[0.91287315, 0.91287315]]),))
root_scope.reuse_variables()
@@ -383,10 +388,9 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s2[0].c.shape, (1, 2))
self.assertEqual(res_s2[0].h.shape, (1, 2))
- self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]], ))
- self.assertAllClose(res_s2,
- (([[2.79966092, 2.79966092]],
- [[0.9032144, 0.9032144]]), ))
+ self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]],))
+ self.assertAllClose(res_s2, (([[2.79966092, 2.79966092]],
+ [[0.9032144, 0.9032144]]),))
g3, s3 = cell(x2, m)
self.assertEqual(g3[0].get_shape(), (1, 2))
@@ -398,18 +402,17 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_g3[0].shape, (1, 2))
self.assertEqual(res_s3[0].c.shape, (1, 2))
self.assertEqual(res_s3[0].h.shape, (1, 2))
- self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]], ))
- self.assertAllClose(res_s3,
- (([[3.3529923, 3.3529923]],
- [[0.92727238, 0.92727238]]), ))
+ self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]],))
+ self.assertAllClose(res_s3, (([[3.3529923, 3.3529923]],
+ [[0.92727238, 0.92727238]]),))
"""3-LSTM
"""
def testGrid3LSTMCell(self):
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
@@ -427,11 +430,13 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[2].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
- (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])),
- (np.array([[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))})
+ res_g, res_s = sess.run([g, s], {
+ x:
+ np.array([[1., 1., 1.]]),
+ m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
+ (np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])), (np.array(
+ [[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))
+ })
self.assertEqual(res_g[0].shape, (1, 2))
self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2))
@@ -440,21 +445,19 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(res_s[2].c.shape, (1, 2))
self.assertEqual(res_s[2].h.shape, (1, 2))
- self.assertAllClose(res_g, ([[0.96892911, 0.96892911]], ))
- self.assertAllClose(res_s, (([[2.45227885, 2.45227885]],
- [[0.96892911, 0.96892911]]),
- ([[1.33592629, 1.4373529]],
- [[0.80867189, 0.83247656]]),
- ([[0.7317788, 0.63205892]],
- [[0.56548983, 0.50446129]])))
+ self.assertAllClose(res_g, ([[0.96892911, 0.96892911]],))
+ self.assertAllClose(
+ res_s, (([[2.45227885, 2.45227885]], [[0.96892911, 0.96892911]]),
+ ([[1.33592629, 1.4373529]], [[0.80867189, 0.83247656]]),
+ ([[0.7317788, 0.63205892]], [[0.56548983, 0.50446129]])))
"""Edge cases
"""
def testGridRNNEdgeCasesLikeRelu(self):
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2])
m = ()
@@ -471,18 +474,18 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s, ())
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
+ res_g, res_s = sess.run([g, s],
+ {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
self.assertEqual(res_g[0].shape, (3, 2))
self.assertEqual(res_s, ())
- self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]], ))
+ self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
def testGridRNNEdgeCasesNoOutput(self):
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
- m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), )
+ m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
# This cell produces no output
cell = grid_rnn_cell.GridRNNCell(
@@ -498,9 +501,10 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s[0].h.get_shape(), (1, 2))
sess.run([variables.global_variables_initializer()])
- res_g, res_s = sess.run(
- [g, s], {x: np.array([[1., 1.]]),
- m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])), )})
+ res_g, res_s = sess.run([g, s], {
+ x: np.array([[1., 1.]]),
+ m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])),)
+ })
self.assertEqual(res_g, ())
self.assertEqual(res_s[0].c.shape, (1, 2))
self.assertEqual(res_s[0].h.shape, (1, 2))
@@ -561,8 +565,9 @@ class GridRNNCellTest(test.TestCase):
cell = grid_rnn_cell.Grid2LSTMCell(
num_units=num_units, non_recurrent_fn=nn_ops.relu)
- inputs = max_length * [array_ops.placeholder(
- dtypes.float32, shape=(batch_size, input_size))]
+ inputs = max_length * [
+ array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
+ ]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
@@ -600,8 +605,9 @@ class GridRNNCellTest(test.TestCase):
cell = grid_rnn_cell.Grid3LSTMCell(
num_units=num_units, non_recurrent_fn=nn_ops.relu)
- inputs = max_length * [array_ops.placeholder(
- dtypes.float32, shape=(batch_size, input_size))]
+ inputs = max_length * [
+ array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
+ ]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
@@ -671,19 +677,17 @@ class GridRNNCellTest(test.TestCase):
self.assertTrue(np.all(np.isfinite(v)))
def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self):
- """Test for #4296
- """
+ """Test for #4296."""
input_size = 5
max_length = 6 # unrolled up to this length
num_units = 2
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units)
inputs = max_length * [
- array_ops.placeholder(
- dtypes.float32, shape=(None, input_size))
+ array_ops.placeholder(dtypes.float32, shape=(None, input_size))
]
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
@@ -700,8 +704,7 @@ class GridRNNCellTest(test.TestCase):
sess.run(variables.global_variables_initializer())
input_value = np.ones((3, input_size))
- values = sess.run(outputs + [state],
- feed_dict={inputs[0]: input_value})
+ values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
for tp in values[:-1]:
for v in tp:
self.assertTrue(np.all(np.isfinite(v)))
@@ -710,18 +713,15 @@ class GridRNNCellTest(test.TestCase):
for v in st:
self.assertTrue(np.all(np.isfinite(v)))
-
def testGrid2LSTMCellLegacy(self):
- """Test for legacy case (when state_is_tuple=False)
- """
+ """Test for legacy case (when state_is_tuple=False)."""
with self.test_session() as sess:
- with variable_scope.variable_scope('root',
- initializer=init_ops.constant_initializer(0.5)):
+ with variable_scope.variable_scope(
+ 'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 8])
- cell = grid_rnn_cell.Grid2LSTMCell(2, use_peepholes=True,
- state_is_tuple=False,
- output_is_tuple=False)
+ cell = grid_rnn_cell.Grid2LSTMCell(
+ 2, use_peepholes=True, state_is_tuple=False, output_is_tuple=False)
self.assertEqual(cell.state_size, 8)
g, s = cell(x, m)
@@ -729,15 +729,17 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(s.get_shape(), (1, 8))
sess.run([variables.global_variables_initializer()])
- res = sess.run(
- [g, s], {x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
+ res = sess.run([g, s], {
+ x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
+ })
self.assertEqual(res[0].shape, (1, 2))
self.assertEqual(res[1].shape, (1, 8))
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
- self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
- 0.95686918, 1.38917875, 1.49043763,
- 0.83884692, 0.86036491]])
+ self.assertAllClose(res[1], [[
+ 2.41515064, 2.41515064, 0.95686918, 0.95686918, 1.38917875,
+ 1.49043763, 0.83884692, 0.86036491
+ ]])
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
index 269ad0a384..252788140f 100644
--- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
+++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
@@ -102,16 +102,16 @@ class GridRNNCell(rnn.RNNCell):
output_is_tuple: If True, the output is a tuple of the outputs of the
recurrent dimensions. If False, they are concatenated along the
column axis. The later behavior will soon be deprecated.
-
+
Raises:
TypeError: if cell_fn does not return an RNNCell instance.
"""
if not state_is_tuple:
- logging.warning("%s: Using a concatenated state is slower and will "
- "soon be deprecated. Use state_is_tuple=True.", self)
+ logging.warning('%s: Using a concatenated state is slower and will '
+ 'soon be deprecated. Use state_is_tuple=True.', self)
if not output_is_tuple:
- logging.warning("%s: Using a concatenated output is slower and will"
- "soon be deprecated. Use output_is_tuple=True.", self)
+ logging.warning('%s: Using a concatenated output is slower and will'
+ 'soon be deprecated. Use output_is_tuple=True.', self)
if num_dims < 1:
raise ValueError('dims must be >= 1: {}'.format(num_dims))
@@ -126,9 +126,7 @@ class GridRNNCell(rnn.RNNCell):
if cell_fn is None:
my_cell_fn = functools.partial(
- rnn.LSTMCell,
- num_units=num_units,
- state_is_tuple=state_is_tuple)
+ rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple)
else:
my_cell_fn = lambda: cell_fn(num_units)
if tied:
@@ -136,9 +134,8 @@ class GridRNNCell(rnn.RNNCell):
else:
self._cells = [my_cell_fn() for _ in range(num_dims)]
if not isinstance(self._cells[0], rnn.RNNCell):
- raise TypeError(
- 'cell_fn must return an RNNCell instance, saw: %s'
- % type(self._cells[0]))
+ raise TypeError('cell_fn must return an RNNCell instance, saw: %s' %
+ type(self._cells[0]))
if self._output_is_tuple:
self._output_size = tuple(self._cells[0].output_size
@@ -201,26 +198,36 @@ class GridRNNCell(rnn.RNNCell):
if self._output_is_tuple:
output = tuple(output_tensors)
else:
- if len(output_tensors) == 0:
- output = array_ops.zeros([0, 0], dtype)
- else:
+ if output_tensors:
output = array_ops.concat(output_tensors, 1)
+ else:
+ output = array_ops.zeros([0, 0], dtype)
if self._state_is_tuple:
states = tuple(new_state[i] for i in self._config.recurrents)
else:
# concat each state first, then flatten the whole thing
- state_tensors = [x for i in self._config.recurrents
- for x in new_state[i]]
- if len(state_tensors) == 0:
- states = array_ops.zeros([0, 0], dtype)
- else:
+ state_tensors = [
+ x for i in self._config.recurrents for x in new_state[i]
+ ]
+ if state_tensors:
states = array_ops.concat(state_tensors, 1)
+ else:
+ states = array_ops.zeros([0, 0], dtype)
return output, states
def _extract_states(self, state):
- """Extract the cell and previous output tensors from the given state
+ """Extract the cell and previous output tensors from the given state.
+
+ Args:
+ state: The RNN state.
+
+ Returns:
+ Tuple of the cell value, previous output, and cell_output_size.
+
+ Raises:
+ ValueError: If len(self._config.recurrents) != len(state).
"""
conf = self._config
@@ -238,8 +245,8 @@ class GridRNNCell(rnn.RNNCell):
if self._state_is_tuple:
if len(conf.recurrents) != len(state):
- raise ValueError("Expected state as a tuple of {} "
- "element".format(len(conf.recurrents)))
+ raise ValueError('Expected state as a tuple of {} '
+ 'element'.format(len(conf.recurrents)))
for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
if cell_output_size > 0:
@@ -247,49 +254,62 @@ class GridRNNCell(rnn.RNNCell):
else:
m_prev[recurrent_dim] = recurrent_state
else:
- for recurrent_dim, start_idx in zip(conf.recurrents, range(
- 0, self.state_size, total_cell_state_size)):
+ for recurrent_dim, start_idx in zip(conf.recurrents,
+ range(0, self.state_size,
+ total_cell_state_size)):
if cell_output_size > 0:
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
[-1, conf.num_units])
m_prev[recurrent_dim] = array_ops.slice(
- state, [0, start_idx + conf.num_units], [-1, cell_output_size])
+ state, [0, start_idx + conf.num_units], [-1, cell_output_size])
else:
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
[-1, conf.num_units])
return c_prev, m_prev, cell_output_size
def _project_input(self, inputs, c_prev, m_prev, with_c):
- """Fills in c_prev and m_prev with projected input, for input dimensions
+ """Fills in c_prev and m_prev with projected input, for input dimensions.
+
+ Args:
+ inputs: inputs tensor
+ c_prev: cell value
+ m_prev: previous output
+ with_c: boolean; whether to include project_c.
+
+ Raises:
+ ValueError: if len(self._config.input) != len(inputs)
"""
conf = self._config
- if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0
- and len(conf.inputs) > 0):
+ if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and
+ conf.inputs):
if isinstance(inputs, tuple):
if len(conf.inputs) != len(inputs):
- raise ValueError("Expect inputs as a tuple of {} "
- "tensors".format(len(conf.inputs)))
+ raise ValueError('Expect inputs as a tuple of {} '
+ 'tensors'.format(len(conf.inputs)))
input_splits = inputs
else:
input_splits = array_ops.split(
- value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
+ value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
input_sz = input_splits[0].get_shape().with_rank(2)[1].value
for i, j in enumerate(conf.inputs):
input_project_m = vs.get_variable(
- 'project_m_{}'.format(j), [input_sz, conf.num_units],
- dtype=inputs.dtype)
+ 'project_m_{}'.format(j), [input_sz, conf.num_units],
+ dtype=inputs.dtype)
m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
if with_c:
input_project_c = vs.get_variable(
- 'project_c_{}'.format(j), [input_sz, conf.num_units],
- dtype=inputs.dtype)
+ 'project_c_{}'.format(j), [input_sz, conf.num_units],
+ dtype=inputs.dtype)
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
def _cell_state_size(self):
- """Total size of the state of the inner cell used in this grid
+ """Total size of the state of the inner cell used in this grid.
+
+ Returns:
+ Total size of the state of the inner cell.
"""
state_sizes = self._cells[0].state_size
if isinstance(state_sizes, tuple):
@@ -306,10 +326,15 @@ class Grid1BasicRNNCell(GridRNNCell):
def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
super(Grid1BasicRNNCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
- state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=False,
+ cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2BasicRNNCell(GridRNNCell):
@@ -322,38 +347,56 @@ class Grid2BasicRNNCell(GridRNNCell):
specified.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None,
- state_is_tuple=True, output_is_tuple=True):
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ state_is_tuple=True,
+ output_is_tuple=True):
super(Grid2BasicRNNCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
- non_recurrent_fn=non_recurrent_fn,
- state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid1BasicLSTMCell(GridRNNCell):
- """1D BasicLSTM cell"""
+ """1D BasicLSTM cell."""
- def __init__(self, num_units, forget_bias=1,
- state_is_tuple=True, output_is_tuple=True):
+ def __init__(self,
+ num_units,
+ forget_bias=1,
+ state_is_tuple=True,
+ output_is_tuple=True):
+ def cell_fn(n):
+ return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
super(Grid1BasicLSTMCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0, tied=False,
- cell_fn=lambda n: rnn.BasicLSTMCell(
- num_units=n, forget_bias=forget_bias),
- state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=False,
+ cell_fn=cell_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2BasicLSTMCell(GridRNNCell):
- """2D BasicLSTM cell
+ """2D BasicLSTM cell.
- This creates a 2D cell which receives input and gives output in the first
- dimension.
+ This creates a 2D cell which receives input and gives output in the first
+ dimension.
- The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
- specified.
+ The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
+ specified.
"""
def __init__(self,
@@ -363,36 +406,53 @@ class Grid2BasicLSTMCell(GridRNNCell):
forget_bias=1,
state_is_tuple=True,
output_is_tuple=True):
+ def cell_fn(n):
+ return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
super(Grid2BasicLSTMCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n: rnn.BasicLSTMCell(
- num_units=n, forget_bias=forget_bias),
- non_recurrent_fn=non_recurrent_fn,
- state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid1LSTMCell(GridRNNCell):
- """1D LSTM cell
+ """1D LSTM cell.
- This is different from Grid1BasicLSTMCell because it gives options to
- specify the forget bias and enabling peepholes
+ This is different from Grid1BasicLSTMCell because it gives options to
+ specify the forget bias and enabling peepholes.
"""
- def __init__(self, num_units, use_peepholes=False, forget_bias=1.0,
- state_is_tuple=True, output_is_tuple=True):
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
+ forget_bias=1.0,
+ state_is_tuple=True,
+ output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid1LSTMCell, self).__init__(
- num_units=num_units, num_dims=1,
- input_dims=0, output_dims=0, priority_dims=0,
- cell_fn=lambda n: rnn.LSTMCell(
- num_units=n, use_peepholes=use_peepholes,
- forget_bias=forget_bias),
- state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
+ num_units=num_units,
+ num_dims=1,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ cell_fn=cell_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2LSTMCell(GridRNNCell):
- """2D LSTM cell
+ """2D LSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -408,19 +468,27 @@ class Grid2LSTMCell(GridRNNCell):
forget_bias=1.0,
state_is_tuple=True,
output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid2LSTMCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n: rnn.LSTMCell(
- num_units=n, forget_bias=forget_bias,
- use_peepholes=use_peepholes),
- non_recurrent_fn=non_recurrent_fn,
- state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid3LSTMCell(GridRNNCell):
- """3D BasicLSTM cell
+ """3D BasicLSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -437,19 +505,27 @@ class Grid3LSTMCell(GridRNNCell):
forget_bias=1.0,
state_is_tuple=True,
output_is_tuple=True):
+
+ def cell_fn(n):
+ return rnn.LSTMCell(
+ num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
+
super(Grid3LSTMCell, self).__init__(
- num_units=num_units, num_dims=3,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n: rnn.LSTMCell(
- num_units=n, forget_bias=forget_bias,
- use_peepholes=use_peepholes),
- non_recurrent_fn=non_recurrent_fn,
- state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
+ num_units=num_units,
+ num_dims=3,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=cell_fn,
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
class Grid2GRUCell(GridRNNCell):
- """2D LSTM cell
+ """2D LSTM cell.
This creates a 2D cell which receives input and gives output in the first
dimension.
@@ -457,23 +533,31 @@ class Grid2GRUCell(GridRNNCell):
specified.
"""
- def __init__(self, num_units, tied=False, non_recurrent_fn=None,
- state_is_tuple=True, output_is_tuple=True):
+ def __init__(self,
+ num_units,
+ tied=False,
+ non_recurrent_fn=None,
+ state_is_tuple=True,
+ output_is_tuple=True):
super(Grid2GRUCell, self).__init__(
- num_units=num_units, num_dims=2,
- input_dims=0, output_dims=0, priority_dims=0, tied=tied,
- non_recurrent_dims=None if non_recurrent_fn is None else 0,
- cell_fn=lambda n: rnn.GRUCell(num_units=n),
- non_recurrent_fn=non_recurrent_fn,
- state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
+ num_units=num_units,
+ num_dims=2,
+ input_dims=0,
+ output_dims=0,
+ priority_dims=0,
+ tied=tied,
+ non_recurrent_dims=None if non_recurrent_fn is None else 0,
+ cell_fn=lambda n: rnn.GRUCell(num_units=n),
+ non_recurrent_fn=non_recurrent_fn,
+ state_is_tuple=state_is_tuple,
+ output_is_tuple=output_is_tuple)
-"""Helpers
-"""
+# Helpers
-_GridRNNDimension = namedtuple(
- '_GridRNNDimension',
- ['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
+_GridRNNDimension = namedtuple('_GridRNNDimension', [
+ 'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'
+])
_GridRNNConfig = namedtuple('_GridRNNConfig',
['num_dims', 'dims', 'inputs', 'outputs',
@@ -502,23 +586,23 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
rnn_dims = []
for i in range(num_dims):
rnn_dims.append(
- _GridRNNDimension(
- idx=i,
- is_input=(i in input_dims),
- is_output=(i in output_dims),
- is_priority=(i in priority_dims),
- non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else
- None))
+ _GridRNNDimension(
+ idx=i,
+ is_input=(i in input_dims),
+ is_output=(i in output_dims),
+ is_priority=(i in priority_dims),
+ non_recurrent_fn=non_recurrent_fn
+ if i in non_recurrent_dims else None))
return _GridRNNConfig(
- num_dims=num_dims,
- dims=rnn_dims,
- inputs=input_dims,
- outputs=output_dims,
- recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
- priority=priority_dims,
- non_priority=[x for x in range(num_dims) if x not in priority_dims],
- tied=tied,
- num_units=num_units)
+ num_dims=num_dims,
+ dims=rnn_dims,
+ inputs=input_dims,
+ outputs=output_dims,
+ recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
+ priority=priority_dims,
+ non_priority=[x for x in range(num_dims) if x not in priority_dims],
+ tied=tied,
+ num_units=num_units)
def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
@@ -544,8 +628,8 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
m_prev[0].dtype)
- last_dim_output = (new_output[-1] if new_output[-1] is not None
- else m_prev[-1])
+ last_dim_output = (new_output[-1]
+ if new_output[-1] is not None else m_prev[-1])
for i in dim_indices:
d = conf.dims[i]
@@ -560,12 +644,12 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
vs.get_variable_scope().reuse_variables()
new_output[d.idx] = layers.fully_connected(
- linear_args,
- num_outputs=conf.num_units,
- activation_fn=d.non_recurrent_fn,
- weights_initializer=vs.get_variable_scope().initializer or
- layers.initializers.xavier_initializer,
- weights_regularizer=vs.get_variable_scope().regularizer)
+ linear_args,
+ num_outputs=conf.num_units,
+ activation_fn=d.non_recurrent_fn,
+ weights_initializer=(vs.get_variable_scope().initializer or
+ layers.initializers.xavier_initializer),
+ weights_regularizer=vs.get_variable_scope().regularizer)
else:
if c_prev[i] is not None:
cell_state = (c_prev[i], last_dim_output)
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 8d50541771..8a97f07732 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
typedef Eigen::ThreadPoolDevice CPUDevice;
using functor::FillProjectiveTransform;
+using generator::INTERPOLATION_BILINEAR;
+using generator::INTERPOLATION_NEAREST;
+using generator::Interpolation;
using generator::ProjectiveGenerator;
template <typename Device, typename T>
class ImageProjectiveTransform : public OpKernel {
+ private:
+ Interpolation interpolation_;
+
public:
- explicit ImageProjectiveTransform(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
+ explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string interpolation_str;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
+ if (interpolation_str == "NEAREST") {
+ interpolation_ = INTERPOLATION_NEAREST;
+ } else if (interpolation_str == "BILINEAR") {
+ interpolation_ = INTERPOLATION_BILINEAR;
+ } else {
+ LOG(FATAL) << "Invalid interpolation " << interpolation_str
+ << ". Supported types: NEAREST, BILINEAR";
+ }
+ }
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
@@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
auto output = output_t->tensor<T, 4>();
- const FillProjectiveTransform<Device, T> functor;
- functor(ctx->eigen_device<Device>(), &output, images, transform);
+ (FillProjectiveTransform<Device, T>(interpolation_))(
+ ctx->eigen_device<Device>(), &output, images, transform);
}
};
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index 92b908a1c6..692e33fcf3 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -28,6 +28,8 @@ namespace tensorflow {
namespace generator {
+enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
+
using Eigen::array;
using Eigen::DenseIndex;
@@ -36,20 +38,19 @@ class ProjectiveGenerator {
private:
typename TTypes<T, 4>::ConstTensor input_;
typename TTypes<float>::ConstMatrix transforms_;
+ const Interpolation interpolation_;
public:
static const int kNumParameters = 8;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
- typename TTypes<float>::ConstMatrix transforms)
- : input_(input), transforms_(transforms) {}
+ typename TTypes<float>::ConstMatrix transforms,
+ const Interpolation interpolation)
+ : input_(input), transforms_(transforms), interpolation_(interpolation) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const array<DenseIndex, 4>& coords) const {
- array<DenseIndex, 4> input_coords;
- input_coords[0] = coords[0];
-
const int64 output_y = coords[1];
const int64 output_x = coords[2];
const float* transform =
@@ -57,24 +58,73 @@ class ProjectiveGenerator {
? transforms_.data()
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
- const int64 input_x = std::round(
+ const float input_x =
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
- projection);
- const int64 input_y = std::round(
+ projection;
+ const float input_y =
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
- projection);
-
- if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x &&
- input_x < input_.dimension(2))) {
- // TODO(ringwalt): Add a fill value input.
- return T(0);
+ projection;
+
+ // TODO(ringwalt): Add a fill value input.
+ static const T fill_value = T(0);
+ switch (interpolation_) {
+ case INTERPOLATION_NEAREST:
+ // Switch the order of x and y again for indexing into the image.
+ return nearest_interpolation(coords[0], input_y, input_x, coords[3],
+ fill_value);
+ case INTERPOLATION_BILINEAR:
+ return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
+ fill_value);
}
- input_coords[1] = input_y;
- input_coords[2] = input_x;
+ }
+
+ private:
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+ nearest_interpolation(const DenseIndex batch, const float y, const float x,
+ const DenseIndex channel, const T fill_value) const {
+ return read_with_fill_value(batch, DenseIndex(std::round(y)),
+ DenseIndex(std::round(x)), channel, fill_value);
+ }
- input_coords[3] = coords[3];
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
+ bilinear_interpolation(const DenseIndex batch, const float y, const float x,
+ const DenseIndex channel, const T fill_value) const {
+ const float y_floor = std::floor(y);
+ const float x_floor = std::floor(x);
+ const float y_ceil = y_floor + 1;
+ const float x_ceil = x_floor + 1;
+ // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
+ // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
+ const float value_yfloor =
+ (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
+ DenseIndex(x_floor), channel,
+ fill_value) +
+ (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
+ DenseIndex(x_ceil), channel,
+ fill_value);
+ // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
+ // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
+ const float value_yceil =
+ (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
+ DenseIndex(x_floor), channel,
+ fill_value) +
+ (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
+ DenseIndex(x_ceil), channel,
+ fill_value);
+ // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
+ // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
+ return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
+ }
- return input_(input_coords);
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
+ const DenseIndex batch, const DenseIndex y, const DenseIndex x,
+ const DenseIndex channel, const T fill_value) const {
+ // batch and channel must be correct, because they are passed unchanged from
+ // the input.
+ return (0 <= y && y < input_.dimension(1) && 0 <= x &&
+ x < input_.dimension(2))
+ ? input_(array<DenseIndex, 4>{batch, y, x, channel})
+ : fill_value;
}
};
@@ -85,6 +135,7 @@ class ProjectiveGenerator {
// some Eigen device code.
namespace functor {
+using generator::Interpolation;
using generator::ProjectiveGenerator;
template <typename Device, typename T>
@@ -92,15 +143,17 @@ struct FillProjectiveTransform {
typedef typename TTypes<T, 4>::Tensor OutputType;
typedef typename TTypes<T, 4>::ConstTensor InputType;
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
+ const Interpolation interpolation_;
- FillProjectiveTransform() {}
+ FillProjectiveTransform(Interpolation interpolation)
+ : interpolation_(interpolation) {}
EIGEN_ALWAYS_INLINE
void operator()(const Device& device, OutputType* output,
const InputType& images,
const TransformsType& transform) const {
- ProjectiveGenerator<Device, T> generator(images, transform);
- output->device(device) = images.generate(generator);
+ output->device(device) = images.generate(
+ ProjectiveGenerator<Device, T>(images, transform, interpolation_));
}
};
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 18c16cf1bb..a6d3fa4b64 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -23,13 +23,13 @@ using shape_inference::InferenceContext;
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
-// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc.
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
// implement "same" and "valid" modes in the Python function.
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
.Attr("dtype: {uint8, int32, int64, float32, float64}")
+ .Attr("interpolation: string")
.Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(0));
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 6f72ef6adf..5e78b590df 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -25,8 +25,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
_DTYPES = set(
@@ -111,22 +111,71 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
[0, 1, 0, 1],
[0, 1, 1, 1]])
+ def test_bilinear(self):
+ with self.test_session():
+ image = constant_op.constant(
+ [[0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0],
+ [0, 1, 0, 1, 0],
+ [0, 1, 1, 1, 0],
+ [0, 0, 0, 0, 0]],
+ dtypes.float32)
+ # The following result matches:
+ # >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False)
+ # which uses spline interpolation of order 1, equivalent to bilinear
+ # interpolation.
+ self.assertAllClose(
+ image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
+ [[0.000, 0.000, 0.343, 0.000, 0.000],
+ [0.000, 0.586, 0.914, 0.586, 0.000],
+ [0.343, 0.914, 0.000, 0.914, 0.343],
+ [0.000, 0.586, 0.914, 0.586, 0.000],
+ [0.000, 0.000, 0.343, 0.000, 0.000]],
+ atol=0.001)
+ self.assertAllClose(
+ image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(),
+ [[0, 0, 1, 0, 0],
+ [0, 1, 1, 1, 0],
+ [1, 1, 0, 1, 1],
+ [0, 1, 1, 1, 0],
+ [0, 0, 1, 0, 0]])
+
+ def test_bilinear_uint8(self):
+ with self.test_session():
+ image = constant_op.constant(
+ np.asarray(
+ [[0.0, 0.0, 0.0, 0.0, 0.0],
+ [0.0, 255, 255, 255, 0.0],
+ [0.0, 255, 0.0, 255, 0.0],
+ [0.0, 255, 255, 255, 0.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0]],
+ np.uint8),
+ dtypes.uint8)
+ # == np.rint((expected image above) * 255)
+ self.assertAllEqual(
+ image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
+ [[0.0, 0.0, 87., 0.0, 0.0],
+ [0.0, 149, 233, 149, 0.0],
+ [87., 233, 0.0, 233, 87.],
+ [0.0, 149, 233, 149, 0.0],
+ [0.0, 0.0, 87., 0.0, 0.0]])
def _test_grad(self, shape_to_test):
with self.test_session():
test_image_shape = shape_to_test
test_image = np.random.randn(*test_image_shape)
- test_image_tensor = constant_op.constant(test_image,
- shape=test_image_shape)
- test_transform = image_ops.angles_to_projective_transforms(np.pi / 2,
- 4,
- 4)
- test_transform_shape = test_transform.shape
+ test_image_tensor = constant_op.constant(
+ test_image, shape=test_image_shape)
+ test_transform = image_ops.angles_to_projective_transforms(
+ np.pi / 2, 4, 4)
output_shape = test_image_shape
output = image_ops.transform(test_image_tensor, test_transform)
left_err = gradient_checker.compute_gradient_error(
- test_image_tensor, test_image_shape, output, output_shape,
+ test_image_tensor,
+ test_image_shape,
+ output,
+ output_shape,
x_init_value=test_image)
self.assertLess(left_err, 1e-10)
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index 1a15db9c20..0d51d0dee1 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -24,8 +24,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import resource_loader
_image_ops_so = loader.load_op_library(
@@ -37,7 +37,7 @@ _IMAGE_DTYPES = set(
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
-def rotate(images, angles):
+def rotate(images, angles, interpolation="NEAREST"):
"""Rotate image(s) by the passed angle(s) in radians.
Args:
@@ -46,6 +46,7 @@ def rotate(images, angles):
(num_rows, num_columns) (HW).
angles: A scalar angle to rotate all images by, or (if images has rank 4)
a vector of length num_images, with an angle for each image in the batch.
+ interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns:
Image(s) with the same type and shape as `images`, rotated by the given
@@ -70,7 +71,8 @@ def rotate(images, angles):
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
output = transform(
images,
- angles_to_projective_transforms(angles, image_width, image_height))
+ angles_to_projective_transforms(angles, image_height, image_width),
+ interpolation=interpolation)
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width):
axis=1)
-def transform(images, transforms):
+def transform(images, transforms, interpolation="NEAREST"):
"""Applies the given transform(s) to the image(s).
Args:
@@ -134,6 +136,7 @@ def transform(images, transforms):
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
the transform mapping input points to output points.
+ interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
Returns:
Image(s) with the same type and shape as `images`, with the given
@@ -163,8 +166,8 @@ def transform(images, transforms):
transforms = transform_or_transforms
else:
raise TypeError("Transforms should have rank 1 or 2.")
- # pylint: disable=protected-access
- output = gen_image_ops.image_projective_transform(images, transforms)
+ output = gen_image_ops.image_projective_transform(
+ images, transforms, interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
@@ -217,8 +220,10 @@ def _transform_matrices_to_flat(transform_matrices):
@ops.RegisterGradient("ImageProjectiveTransform")
def _image_projective_transform_grad(op, grad):
+ """Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0]
transforms = op.inputs[1]
+ interpolation = op.get_attr("interpolation")
image_or_images = ops.convert_to_tensor(images, name="images")
transform_or_transforms = ops.convert_to_tensor(
@@ -245,7 +250,8 @@ def _image_projective_transform_grad(op, grad):
transforms = _flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms)
transforms = _transform_matrices_to_flat(inverse)
- output = gen_image_ops.image_projective_transform(grad, transforms)
+ output = gen_image_ops.image_projective_transform(
+ grad, transforms, interpolation=interpolation)
if len(image_or_images.get_shape()) == 2:
return [output[0, :, :, 0], None]
elif len(image_or_images.get_shape()) == 3:
diff --git a/tensorflow/contrib/kernel_methods/README.md b/tensorflow/contrib/kernel_methods/README.md
new file mode 100644
index 0000000000..1913800af0
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/README.md
@@ -0,0 +1,55 @@
+# TensorFlow contrib kernel_methods.
+
+This module contains operations and estimators that enable the use of primal
+(explicit) kernel methods in TensorFlow. See also the [tutorial](https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/g3doc/tutorial.md) on how to use this module to improve the quality of
+classification or regression tasks.
+
+## Kernel Mappers
+Implement explicit kernel mapping Ops over tensors. Kernel mappers add
+Tensor-In-Tensor-Out (TITO) Ops to the TensorFlow graph. They can be used in
+conjunction with other layers or ML models.
+
+Sample usage:
+
+```python
+kernel_mapper = tf.contrib.kernel_methods.SomeKernelMapper(...)
+out_tensor = kernel_mapper.map(in_tensor)
+... # code that consumes out_tensor.
+```
+
+Currently, there is a [RandomFourierFeatureMapper]
+(https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py) implemented that maps dense
+input to dense output.
+
+## Kernel-based Estimators
+tf.contrib.learn Estimators that use kernel mappers internally to discover
+non-linearities in the data. These canned estimators map their input features
+using kernel mapper Ops and then apply linear models to the mapped
+features. Combining kernel mappers with linear models and different loss
+functions leads to a variety of models: linear and non-linear SVMs, linear
+regression (with and without kernels) and (multinomial) logistic regression
+(with and without kernels).
+
+Currently there is a [KernelLinearClassifier]
+(https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/kernel_estimators.py) implemented but more pre-packaged estimators
+are on the way.
+
+Sample usage:
+
+```python
+real_column_a = tf.contrib.layers.real_valued_column(name='real_column_a',...)
+sparse_column_b = tf.contrib.layers.sparse_column_with_hash_bucket(...)
+kernel_mappers = {real_column_a : [tf.contrib.kernel_methods.SomeKernelMapper(...)]}
+optimizer = ...
+
+kernel_classifier = tf.contrib.kernel_methods.KernelLinearClassifier(
+ feature_columns=[real_column_a, sparse_column_b],
+ model_dir=...,
+ optimizer=optimizer,
+ kernel_mappers=kernel_mappers)
+
+# Construct input_fns
+kernel_classifier.fit(...)
+kernel_classifier.evaluate(...)
+```
+
diff --git a/tensorflow/contrib/kernel_methods/g3doc/acc-vs-trn_time.png b/tensorflow/contrib/kernel_methods/g3doc/acc-vs-trn_time.png
new file mode 100644
index 0000000000..1028bb3901
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/g3doc/acc-vs-trn_time.png
Binary files differ
diff --git a/tensorflow/contrib/kernel_methods/g3doc/acc_vs_outdim.png b/tensorflow/contrib/kernel_methods/g3doc/acc_vs_outdim.png
new file mode 100644
index 0000000000..b3384e053b
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/g3doc/acc_vs_outdim.png
Binary files differ
diff --git a/tensorflow/contrib/kernel_methods/g3doc/kernel_mapping.jpg b/tensorflow/contrib/kernel_methods/g3doc/kernel_mapping.jpg
new file mode 100644
index 0000000000..fd69fac76e
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/g3doc/kernel_mapping.jpg
Binary files differ
diff --git a/tensorflow/contrib/kernel_methods/g3doc/tutorial.md b/tensorflow/contrib/kernel_methods/g3doc/tutorial.md
new file mode 100644
index 0000000000..51000080a7
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/g3doc/tutorial.md
@@ -0,0 +1,273 @@
+# Improving classification using explicit kernel methods
+
+In this tutorial, we demonstrate how combining (explicit) kernel methods with
+linear models can drastically increase the latters' quality of predictions
+without significantly increasing training and inference times. Currently,
+explicit kernel mappings are supported for dense features. Support for sparse
+features is in the works.
+
+We will use [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn) (TensorFlow's high-level Machine Learning API) Estimators for our ML models.
+tf.contrib.learn API reduces the boilerplate code one needs to write for
+configuring, training and evaluating models and will let us focus on the core
+ideas. If you are not familiar with this API, [tf.contrib.learn Quickstart](https://www.tensorflow.org/get_started/tflearn) is a good place to start. We
+will use MNIST, a widely-used dataset containing images of handwritten digits
+(between 0 and 9). The tutorial consists of the following steps:
+
+* Load and prepare MNIST data for classification.
+* Construct a simple linear model, train it and evaluate it on the eval data.
+* Replace the linear model with a kernelized linear model, re-train and
+re-evaluate.
+
+## Load and prepare MNIST data for classification
+The first step is to prepare the data to be fed to the ML models. The following
+utility command from tf.contrib.learn loads the MNIST dataset:
+
+```python
+data = tf.contrib.learn.datasets.mnist.load_mnist()
+```
+This loads the entire MNIST dataset (containing 70K samples) and splits it into
+train, validation and test data with 55K, 5K and 10K samples respectively. Each
+split contains one numpy array for images (with shape [sample_size, 784]) and
+one for labels (with shape [sample_size, 1]). In this tutorial, we only use the
+train and validation splits (to train and evaluate our models respectively).
+
+In order to feed data to a tf.contrib.learn Estimator, it is helpful to convert
+it to Tensors. For this, we will use an `input function` which adds Ops to the
+TensorFlow graph that, when executed, create mini-batches of Tensors to be used
+downstream. For more background on input functions, check
+[Building Input Functions with tf.contrib.learn](https://www.tensorflow.org/get_started/input_fn). In this example, we will use the `tf.train.shuffle_batch` Op which,
+besides converting numpy arrays to Tensors, allows us to specify the batch_size
+and whether to randomize the input every time the input_fn Ops are executed
+(randomization typically expedites convergence during training). The full code
+for loading and preparing the data is shown in the snippet below. In this
+example, we use mini-batches of size 256 for training and the entire sample (5K
+entries) for evaluation. Feel free to experiment with different batch sizes.
+
+```python
+import numpy as np
+import tensorflow as tf
+
+def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):
+
+ def _input_fn():
+ images_batch, labels_batch = tf.train.shuffle_batch(
+ tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
+ batch_size=batch_size,
+ capacity=capacity,
+ min_after_dequeue=min_after_dequeue,
+ enqueue_many=True,
+ num_threads=4)
+ features_map = {'images': images_batch}
+ return features_map, labels_batch
+
+ return _input_fn
+
+data = tf.contrib.learn.datasets.mnist.load_mnist()
+
+train_input_fn = get_input_fn(data.train, batch_size=256)
+eval_input_fn = get_input_fn(data.validation, batch_size=5000)
+
+```
+
+## Training a simple linear model
+We can now train a linear model over the MNIST dataset. We will use the [tf.contrib.learn.LinearClassifier](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/linear.py) estimator with 10 classes (representing the 10
+digits). The input features form a 784-dimensional (dense) vector which can be
+specified as follows:
+
+```python
+image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
+```
+
+The full code for constructing, training and evaluating a LinearClassifier
+estimator is shown below.
+
+```python
+import time
+
+# Specify the feature(s) to be used by the estimator.
+image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
+estimator = tf.contrib.learn.LinearClassifier(feature_columns=[image_column], n_classes=10)
+
+# Train.
+start = time.time()
+estimator.fit(input_fn=train_input_fn, steps=2000)
+end = time.time()
+print('Elapsed time: {} seconds'.format(end - start))
+
+# Evaluate and report metrics.
+eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
+print(eval_metrics)
+```
+On eval data, the loss (i.e., the value of the objective function being
+minimized during training) lies between **0.25 and 0.30** (depending on the
+parameters used) while the accuracy of the classifier is approximately **92.5%**
+(training is randomized so the exact loss and accuracy will vary). Also, the
+training time is around 25 seconds (this will also vary based on the machine you
+run the code on).
+
+In addition to experimenting with the (training) batch size and the number of
+training steps, there are a couple other parameters that can be tuned as well.
+For instance, you can change the optimization method used to minimize the loss
+by explicitly selecting another optimizer from the collection of
+[available optimizers]
+(https://www.tensorflow.org/code/tensorflow/python/training).
+As an example, the following code constructs a LinearClassifer estimator that
+uses the Follow-The-Regularized-Leader (FTRL) optimization strategy with a
+specific learning rate and l2-regularization.
+
+
+```python
+optimizer = tf.train.FtrlOptimizer(learning_rate=5.0, l2_regularization_strength=1.0)
+estimator = tf.contrib.learn.LinearClassifier(
+ feature_columns=[image_column], n_classes=10, optimizer=optimizer)
+```
+
+Regardless of the values of the parameters, the max accuracy a linear model can
+achieve on this dataset caps at around **93%**.
+
+## Using explicit kernel mappings with the linear model.
+The relatively high error (~7%) of the linear model over MNIST indicates that
+the input data is not linearly separable. We will use explicit kernel mappings
+to reduce the classification error.
+
+**Intuition:** The high-level idea is to use a non-linear map to transform the
+input space to another feature space (of possibly higher dimension) where the
+(transformed) features are (almost) linearly separable and then apply a linear
+model on the mapped features. This is shown in the following figure:
+
+<div style="text-align:center">
+<img src="./kernel_mapping.png">
+</div>
+
+**Technical details overview:** In this example we will use **Random Fourier
+Features** (introduced in the ["Random Features for Large-Scale Kernel Machines"]
+(https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) paper by
+Rahimi and Recht) to map the input data. Random Fourier Features map a vector
+$$\mathbf{x} \in \mathbb{R}^d$$ to $$\mathbf{x'} \in \mathbb{R}^D$$ via the
+following mapping:
+
+$$
+RFFM(\cdot): \mathbb{R}^d \to \mathbb{R}^D, \quad
+RFFM(\mathbf{x}) = \cos(\mathbf{\Omega} \cdot \mathbf{x}+ \mathbf{b})
+$$
+
+where $$\mathbf{\Omega} \in \mathbb{R}^{D \times d}$$,
+$$\mathbf{x} \in \mathbb{R}^d,$$ $$\mathbf{b} \in \mathbb{R}^D$$ and cosine is
+applied elementwise.
+
+In this example, the entries of $$\mathbf{\Omega}$$ and $$\mathbf{b}$$ are
+sampled from distributions such that the mapping satisfies the following
+property:
+
+$$
+RFFM(\mathbf{x})^T \cdot RFFM(\mathbf{y}) \approx
+e^{-\frac{\|\mathbf{x} - \mathbf{y}\|^2}{2 \sigma^2}}
+$$
+
+The right-hand-side quantity of the expression above is known as the RBF (or
+Gaussian) kernel function. This function is one of the most-widely used kernel
+functions in Machine Learning and measures (implicitly) similarity in a
+different (much higher dimensional) space than the original one. See
+[Radial basis function kernel](https://en.wikipedia.org/wiki/Radial_basis_function_kernel)
+for more details.
+
+**Kernel Classifier:** `tf.contrib.kernel_methods.KernelLinearClassifier` is a
+pre-packaged `tf.contrib.learn` estimator that combines the power of explicit
+kernel mappings with linear models. Its API is very similar to that of the
+LinearClassifier with the additional ability to specify a list of explicit
+kernel mappings to be apply to each feature used by the classifier. The
+following code snippet demonstrates how to replace LinearClassifier with
+KernelLinearClassifier.
+
+
+```python
+# Specify the feature(s) to be used by the estimator. This is identical to the
+# code used for the LinearClassifier.
+image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
+optimizer = tf.train.FtrlOptimizer(
+ learning_rate=50.0, l2_regularization_strength=0.001)
+
+
+kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
+ input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
+kernel_mappers = {image_column: [kernel_mapper]}
+estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
+ n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
+
+# Train.
+start = time.time()
+estimator.fit(input_fn=train_input_fn, steps=2000)
+end = time.time()
+print('Elapsed time: {} seconds'.format(end - start))
+
+# Evaluate and report metrics.
+eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
+print(eval_metrics)
+```
+The only additional parameter passed to `KernelLinearClassifier` is a dictionary
+from feature_columns to a list of kernel mappings to be applied to the
+corresponding feature column. In this example, the lines
+
+```python
+kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
+ input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
+kernel_mappers = {image_column: [kernel_mapper]}
+estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
+ n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
+```
+instruct the classifier to first map the initial 784-dimensional images to
+2000-dimensional vectors using random Fourier features and then learn a linear
+model on the transformed vectors. Note that, besides the output dimension, there
+is one more parameter (stddev) involved. This parameter is the standard
+deviation ($$\sigma$$) of the approximated RBF kernel and controls the
+similarity measure used in classification. This parameter is typically
+determined via hyperparameter tuning.
+
+Running the code above yields a loss of approximately **0.10** while the
+accuracy is increased to approximately **97%** on eval data (an increase of 4%
+over the plain linear model). The training time hovers around 35 seconds. We can
+increase the accuracy even more, by increasing the output dimension of the
+mapping and tuning the standard deviation even more.
+
+**On the role of stddev:** The classification quality is very sensitive to the
+value of the stddev parameter used to define the similarity measure between the
+pairs of input features. The following table shows the accuracy of the
+classifier on the eval data for different values of stddev (for all experiments
+the output dimension was fixed to 3000). The optimal value is stddev=5.0. Notice
+how too small or too high stddev values can dramatically decrease the accuracy
+of the classification.
+
+stddev | eval accuracy
+:----- | :------------
+1.0 | 0.1362
+2.0 | 0.4764
+4.0 | 0.9654
+5.0 | 0.9766
+8.0 | 0.9714
+16.0 | 0.8878
+
+**On the role of the output dimension:** Intuitively, the larger the output
+dimension of the mapping, the closer the inner product of two mapped vectors
+approximates the kernel which typically translates to better classification
+accuracy. Another way to think about this is that the output dimension equals
+the number of weights of the linear model (the larger this dimension, the larger
+the "degrees of freedom" of the model). However, after a certain threshold,
+higher output dimensions increase the accuracy by very little (while still
+increasing the training time). This is shown in the following 2 Figures which
+depict the eval accuracy as a function of the output dimension and the training
+time respectively.
+
+![image](./acc_vs_outdim.png) ![image](./acc-vs-trn_time.png)
+
+
+## Explicit Kernel Mappings: summary and practical tips
+* Explicit kernel mappings combine the predictive power of non-linear models
+with the scalability of linear models.
+* Random Fourier Features can be particularly effective for datasets with dense
+features.
+* The parameters of the kernel mapping are often data-dependent. Model quality
+can be very sensitive to these parameters. Use hyperparameter tuning to find the
+optimal values.
+* If you have multiple numerical features, concatinate them into a single
+multi-dimensional one and apply the kernel mapping to the concatenated vector.
+
diff --git a/tensorflow/contrib/layers/python/layers/encoders.py b/tensorflow/contrib/layers/python/layers/encoders.py
index 3afdbb1827..89c9d37bd0 100644
--- a/tensorflow/contrib/layers/python/layers/encoders.py
+++ b/tensorflow/contrib/layers/python/layers/encoders.py
@@ -121,7 +121,7 @@ def embed_sequence(ids,
`Tensor` of `[batch_size, doc_length, embed_dim]` with embedded sequences.
Raises:
- ValueError: if `embed_dim` or `vocab_size` are not specified when
+ ValueError: if `embed_dim` or `vocab_size` are not specified when
`reuse` is `None` or `False`.
"""
if not (reuse or (vocab_size and embed_dim)):
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 04fe2370d1..95a4b032b0 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -131,21 +131,27 @@ import math
import six
from tensorflow.contrib import lookup
+from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.framework import experimental
+from tensorflow.contrib.framework.python.ops import variables as contrib_variables
+from tensorflow.contrib.layers.python.layers import embedding_ops
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
@@ -291,11 +297,13 @@ class _FeatureColumn(object):
# TODO(b/30410315): Support warm starting in all feature columns.
-class _SparseColumn(_FeatureColumn,
- collections.namedtuple("_SparseColumn",
- ["column_name", "is_integerized",
- "bucket_size", "lookup_config",
- "combiner", "dtype"])):
+class _SparseColumn(
+ _FeatureColumn,
+ fc_core._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple("_SparseColumn", [
+ "column_name", "is_integerized", "bucket_size", "lookup_config",
+ "combiner", "dtype"
+ ])):
"""Represents a sparse feature column also known as categorical features.
Instances of this class are immutable. A sparse column means features are
@@ -426,9 +434,8 @@ class _SparseColumn(_FeatureColumn,
initializer=init_ops.zeros_initializer(),
combiner=self.combiner)
- def _get_input_sparse_tensor(self, columns_to_tensors):
- """Looks up the input tensor for transformation and sparsify it if dense."""
- input_tensor = columns_to_tensors[self.name]
+ def _get_input_sparse_tensor(self, input_tensor):
+ """sparsify input_tensor if dense."""
if not isinstance(input_tensor, sparse_tensor_py.SparseTensor):
# To avoid making any assumptions about which values are to be ignored,
# we set ignore_value to -1 for numeric tensors to avoid excluding valid
@@ -455,18 +462,44 @@ class _SparseColumn(_FeatureColumn,
format(self.name, other_column.name))
return compatible
-
-class _SparseColumnIntegerized(_SparseColumn):
- """See `sparse_column_with_integerized_feature`."""
+ @abc.abstractmethod
+ def _do_transform(self, input_tensor):
+ pass
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
- input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
+ input_tensor = self._get_input_sparse_tensor(columns_to_tensors[self.name])
+ columns_to_tensors[self] = self._do_transform(input_tensor)
+
+ def _transform_feature(self, inputs):
+ input_tensor = self._get_input_sparse_tensor(inputs.get(self.name))
+ return self._do_transform(input_tensor)
+ @property
+ def _parse_example_config(self):
+ return self.config
+
+ @property
+ def _num_buckets(self):
+ return self.length
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return fc_core._CategoricalColumn.IdWeightPair( # pylint: disable=protected-access
+ self.id_tensor(input_tensor), self.weight_tensor(input_tensor))
+
+
+class _SparseColumnIntegerized(_SparseColumn):
+ """See `sparse_column_with_integerized_feature`."""
+
+ def _do_transform(self, input_tensor):
sparse_id_values = math_ops.mod(input_tensor.values, self.bucket_size,
name="mod")
- columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
- input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
+ return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values,
+ input_tensor.dense_shape)
def sparse_column_with_integerized_feature(column_name,
@@ -517,10 +550,7 @@ def sparse_column_with_integerized_feature(column_name,
class _SparseColumnHashed(_SparseColumn):
"""See `sparse_column_with_hash_bucket`."""
- def insert_transformed_feature(self, columns_to_tensors):
- """Handles sparse column to id conversion."""
- input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
-
+ def _do_transform(self, input_tensor):
if self.dtype.is_integer:
sparse_values = string_ops.as_string(input_tensor.values)
else:
@@ -528,8 +558,8 @@ class _SparseColumnHashed(_SparseColumn):
sparse_id_values = string_ops.string_to_hash_bucket_fast(
sparse_values, self.bucket_size, name="lookup")
- columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
- input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
+ return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values,
+ input_tensor.dense_shape)
def sparse_column_with_hash_bucket(column_name,
@@ -572,16 +602,13 @@ def sparse_column_with_hash_bucket(column_name,
class _SparseColumnKeys(_SparseColumn):
"""See `sparse_column_with_keys`."""
- def insert_transformed_feature(self, columns_to_tensors):
- """Handles sparse column to id conversion."""
- input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
-
+ def _do_transform(self, input_tensor):
table = lookup.index_table_from_tensor(
mapping=tuple(self.lookup_config.keys),
default_value=self.lookup_config.default_value,
dtype=self.dtype,
name="lookup")
- columns_to_tensors[self] = table.lookup(input_tensor)
+ return table.lookup(input_tensor)
def sparse_column_with_keys(
@@ -621,9 +648,7 @@ def sparse_column_with_keys(
class _SparseColumnVocabulary(_SparseColumn):
"""See `sparse_column_with_vocabulary_file`."""
- def insert_transformed_feature(self, columns_to_tensors):
- """Handles sparse column to id conversion."""
- st = self._get_input_sparse_tensor(columns_to_tensors)
+ def _do_transform(self, st):
if self.dtype.is_integer:
sparse_string_values = string_ops.as_string(st.values)
sparse_string_tensor = sparse_tensor_py.SparseTensor(st.indices,
@@ -638,7 +663,7 @@ class _SparseColumnVocabulary(_SparseColumn):
vocab_size=self.lookup_config.vocab_size,
default_value=self.lookup_config.default_value,
name=self.name + "_lookup")
- columns_to_tensors[self] = table.lookup(sparse_string_tensor)
+ return table.lookup(sparse_string_tensor)
def sparse_column_with_vocabulary_file(column_name,
@@ -694,9 +719,12 @@ def sparse_column_with_vocabulary_file(column_name,
dtype=dtype)
-class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
- "_WeightedSparseColumn",
- ["sparse_id_column", "weight_column_name", "dtype"])):
+class _WeightedSparseColumn(
+ _FeatureColumn,
+ fc_core._CategoricalColumn, # pylint: disable=protected-access
+ collections.namedtuple("_WeightedSparseColumn",
+ ["sparse_id_column", "weight_column_name",
+ "dtype"])):
"""See `weighted_sparse_column`."""
def __new__(cls, sparse_id_column, weight_column_name, dtype):
@@ -725,22 +753,6 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
"""Returns a string which will be used as a key when we do sorting."""
return "{}".format(self)
- def insert_transformed_feature(self, columns_to_tensors):
- """Inserts a tuple with the id and weight tensors."""
- if self.sparse_id_column not in columns_to_tensors:
- self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
-
- weight_tensor = columns_to_tensors[self.weight_column_name]
- if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
- # The weight tensor can be a regular Tensor. In such case, sparsify it.
- weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
- if not self.dtype.is_floating:
- weight_tensor = math_ops.to_float(weight_tensor)
- columns_to_tensors[self] = tuple([
- columns_to_tensors[self.sparse_id_column],
- weight_tensor
- ])
-
def id_tensor(self, input_tensor):
"""Returns the id tensor from the given transformed input_tensor."""
return input_tensor[0]
@@ -768,6 +780,43 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
initializer=init_ops.zeros_initializer(),
combiner=self.sparse_id_column.combiner)
+ def _do_transform(self, id_tensor, weight_tensor):
+ if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
+ # The weight tensor can be a regular Tensor. In such case, sparsify it.
+ weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
+ if not self.dtype.is_floating:
+ weight_tensor = math_ops.to_float(weight_tensor)
+ return tuple([id_tensor, weight_tensor])
+
+ def insert_transformed_feature(self, columns_to_tensors):
+ """Inserts a tuple with the id and weight tensors."""
+ if self.sparse_id_column not in columns_to_tensors:
+ self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
+
+ weight_tensor = columns_to_tensors[self.weight_column_name]
+ columns_to_tensors[self] = self._do_transform(
+ columns_to_tensors[self.sparse_id_column], weight_tensor)
+
+ def _transform_feature(self, inputs):
+ return self._do_transform(
+ inputs.get(self.sparse_id_column), inputs.get(self.weight_column_name))
+
+ @property
+ def _parse_example_config(self):
+ return self.config
+
+ @property
+ def _num_buckets(self):
+ return self.length
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ del weight_collections
+ del trainable
+ input_tensor = inputs.get(self)
+ return fc_core._CategoricalColumn.IdWeightPair( # pylint: disable=protected-access
+ self.id_tensor(input_tensor), self.weight_tensor(input_tensor))
+
def weighted_sparse_column(sparse_id_column,
weight_column_name,
@@ -815,9 +864,10 @@ def weighted_sparse_column(sparse_id_column,
return _WeightedSparseColumn(sparse_id_column, weight_column_name, dtype)
-class _OneHotColumn(_FeatureColumn,
- collections.namedtuple("_OneHotColumn",
- ["sparse_id_column"])):
+class _OneHotColumn(
+ _FeatureColumn,
+ fc_core._DenseColumn, # pylint: disable=protected-access
+ collections.namedtuple("_OneHotColumn", ["sparse_id_column"])):
"""Represents a one-hot column for use in deep networks.
Args:
@@ -897,12 +947,31 @@ class _OneHotColumn(_FeatureColumn,
return math_ops.reduce_sum(
one_hot_id_tensor, reduction_indices=[output_rank - 1])
+ @property
+ def _variable_shape(self):
+ return tensor_shape.TensorShape((self.length))
+
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ del weight_collections
+ del trainable
+ return inputs.get(self)
+
+ def _transform_feature(self, inputs):
+ return self._to_dnn_input_layer(inputs.get(self.sparse_id_column))
+
+ @property
+ def _parse_example_config(self):
+ return self.config
+
-class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
- "_EmbeddingColumn",
- ["sparse_id_column", "dimension", "combiner", "initializer",
- "ckpt_to_load_from", "tensor_name_in_ckpt", "shared_embedding_name",
- "shared_vocab_size", "max_norm", "trainable"])):
+class _EmbeddingColumn(
+ _FeatureColumn,
+ fc_core._DenseColumn, # pylint: disable=protected-access
+ collections.namedtuple("_EmbeddingColumn", [
+ "sparse_id_column", "dimension", "combiner", "initializer",
+ "ckpt_to_load_from", "tensor_name_in_ckpt", "shared_embedding_name",
+ "shared_vocab_size", "max_norm", "trainable"
+ ])):
"""Represents an embedding column.
Args:
@@ -1027,6 +1096,139 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
raise ValueError("Column {} is not supported in linear models. "
"Please use sparse_column.".format(self))
+ @property
+ def _variable_shape(self):
+ return tensor_shape.TensorShape((self.dimension))
+
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ return _embeddings_from_arguments(
+ self,
+ self._deep_embedding_lookup_arguments(inputs.get(self)),
+ weight_collections, trainable)
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.sparse_id_column)
+
+ @property
+ def _parse_example_config(self):
+ return self.config
+
+
+def _is_variable(v):
+ """Returns true if `v` is a variable."""
+ return isinstance(v, (variables.Variable,
+ resource_variable_ops.ResourceVariable))
+
+
+def _embeddings_from_arguments(column,
+ args,
+ weight_collections,
+ trainable,
+ output_rank=2):
+ """Returns embeddings for a column based on the computed arguments.
+
+ Args:
+ column: the column name.
+ args: the _DeepEmbeddingLookupArguments for this column.
+ weight_collections: collections to store weights in.
+ trainable: whether these embeddings should be trainable.
+ output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
+ be combined to produce the desired rank.
+
+ Returns:
+ the embeddings.
+
+ Raises:
+ ValueError: if not possible to create.
+ """
+ # pylint: disable=protected-access
+ input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
+ weight_tensor = None
+ if args.weight_tensor is not None:
+ weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
+ # pylint: enable=protected-access
+
+ # This option is only enabled for scattered_embedding_column.
+ if args.hash_key:
+ embeddings = contrib_variables.model_variable(
+ name="weights",
+ shape=[args.vocab_size],
+ dtype=dtypes.float32,
+ initializer=args.initializer,
+ trainable=(trainable and args.trainable),
+ collections=weight_collections)
+
+ return embedding_ops.scattered_embedding_lookup_sparse(
+ embeddings,
+ input_tensor,
+ args.dimension,
+ hash_key=args.hash_key,
+ combiner=args.combiner,
+ name="lookup")
+
+ if args.shared_embedding_name is not None:
+ shared_embedding_collection_name = (
+ "SHARED_EMBEDDING_COLLECTION_" + args.shared_embedding_name.upper())
+ graph = ops.get_default_graph()
+ shared_embedding_collection = (
+ graph.get_collection_ref(shared_embedding_collection_name))
+ shape = [args.vocab_size, args.dimension]
+ if shared_embedding_collection:
+ if len(shared_embedding_collection) > 1:
+ raise ValueError(
+ "Collection %s can only contain one "
+ "(partitioned) variable." % shared_embedding_collection_name)
+ else:
+ embeddings = shared_embedding_collection[0]
+ if embeddings.get_shape() != shape:
+ raise ValueError(
+ "The embedding variable with name {} already "
+ "exists, but its shape does not match required "
+ "embedding shape here. Please make sure to use "
+ "different shared_embedding_name for different "
+ "shared embeddings.".format(args.shared_embedding_name))
+ else:
+ embeddings = contrib_variables.model_variable(
+ name=args.shared_embedding_name,
+ shape=shape,
+ dtype=dtypes.float32,
+ initializer=args.initializer,
+ trainable=(trainable and args.trainable),
+ collections=weight_collections)
+ graph.add_to_collection(shared_embedding_collection_name, embeddings)
+ else:
+ embeddings = contrib_variables.model_variable(
+ name="weights",
+ shape=[args.vocab_size, args.dimension],
+ dtype=dtypes.float32,
+ initializer=args.initializer,
+ trainable=(trainable and args.trainable),
+ collections=weight_collections)
+
+ if _is_variable(embeddings):
+ embeddings = [embeddings]
+ else:
+ embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ _maybe_restore_from_checkpoint(column._checkpoint_path(), embeddings)
+ return embedding_ops.safe_embedding_lookup_sparse(
+ embeddings,
+ input_tensor,
+ sparse_weights=weight_tensor,
+ combiner=args.combiner,
+ name=column.name + "weights",
+ max_norm=args.max_norm)
+
+
+def _maybe_restore_from_checkpoint(checkpoint_path, variable):
+ if checkpoint_path is not None:
+ path, tensor_name = checkpoint_path
+ weights_to_restore = variable
+ if len(variable) == 1:
+ weights_to_restore = variable[0]
+ checkpoint_utils.init_from_checkpoint(path,
+ {tensor_name: weights_to_restore})
+
def one_hot_column(sparse_id_column):
"""Creates an `_OneHotColumn` for a one-hot or multi-hot repr in a DNN.
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py
index 31aca87002..d010ae6df1 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import functools
-from tensorflow.contrib.framework.python.framework import checkpoint_utils
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import embedding_ops
@@ -34,118 +33,12 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
-def _is_variable(v):
- """Returns true if `v` is a variable."""
- return isinstance(v, (variables.Variable,
- resource_variable_ops.ResourceVariable))
-
-
-def _embeddings_from_arguments(column,
- args,
- weight_collections,
- trainable,
- output_rank=2):
- """Returns embeddings for a column based on the computed arguments.
-
- Args:
- column: the column name.
- args: the _DeepEmbeddingLookupArguments for this column.
- weight_collections: collections to store weights in.
- trainable: whether these embeddings should be trainable.
- output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
- be combined to produce the desired rank.
-
- Returns:
- the embeddings.
-
- Raises:
- ValueError: if not possible to create.
- """
- # pylint: disable=protected-access
- input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
- weight_tensor = None
- if args.weight_tensor is not None:
- weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
- # pylint: enable=protected-access
-
- # This option is only enabled for scattered_embedding_column.
- if args.hash_key:
- embeddings = contrib_variables.model_variable(
- name='weights',
- shape=[args.vocab_size],
- dtype=dtypes.float32,
- initializer=args.initializer,
- trainable=(trainable and args.trainable),
- collections=weight_collections)
-
- return embedding_ops.scattered_embedding_lookup_sparse(
- embeddings, input_tensor, args.dimension,
- hash_key=args.hash_key,
- combiner=args.combiner, name='lookup')
-
- if args.shared_embedding_name is not None:
- shared_embedding_collection_name = (
- 'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper())
- graph = ops.get_default_graph()
- shared_embedding_collection = (
- graph.get_collection_ref(shared_embedding_collection_name))
- shape = [args.vocab_size, args.dimension]
- if shared_embedding_collection:
- if len(shared_embedding_collection) > 1:
- raise ValueError('Collection %s can only contain one '
- '(partitioned) variable.'
- % shared_embedding_collection_name)
- else:
- embeddings = shared_embedding_collection[0]
- if embeddings.get_shape() != shape:
- raise ValueError('The embedding variable with name {} already '
- 'exists, but its shape does not match required '
- 'embedding shape here. Please make sure to use '
- 'different shared_embedding_name for different '
- 'shared embeddings.'.format(
- args.shared_embedding_name))
- else:
- embeddings = contrib_variables.model_variable(
- name=args.shared_embedding_name,
- shape=shape,
- dtype=dtypes.float32,
- initializer=args.initializer,
- trainable=(trainable and args.trainable),
- collections=weight_collections)
- graph.add_to_collection(shared_embedding_collection_name, embeddings)
- else:
- embeddings = contrib_variables.model_variable(
- name='weights',
- shape=[args.vocab_size, args.dimension],
- dtype=dtypes.float32,
- initializer=args.initializer,
- trainable=(trainable and args.trainable),
- collections=weight_collections)
-
- if _is_variable(embeddings):
- embeddings = [embeddings]
- else:
- embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
- # pylint: disable=protected-access
- _maybe_restore_from_checkpoint(
- column._checkpoint_path(), embeddings)
- return embedding_ops.safe_embedding_lookup_sparse(
- embeddings,
- input_tensor,
- sparse_weights=weight_tensor,
- combiner=args.combiner,
- name=column.name + 'weights',
- max_norm=args.max_norm)
-
-
def _maybe_reshape_input_tensor(tensor, column_name, output_rank):
"""Reshape the input tensor by the following rule.
@@ -232,12 +125,13 @@ def _input_from_feature_columns(columns_to_tensors,
# pylint: disable=protected-access
arguments = column._deep_embedding_lookup_arguments(
transformed_tensor)
- output_tensors.append(_embeddings_from_arguments(
- column,
- arguments,
- weight_collections,
- trainable,
- output_rank=output_rank))
+ output_tensors.append(
+ fc._embeddings_from_arguments( # pylint: disable=protected-access
+ column,
+ arguments,
+ weight_collections,
+ trainable,
+ output_rank=output_rank))
except NotImplementedError as ee:
try:
@@ -393,7 +287,7 @@ def _create_embedding_lookup(column,
initializer=embedding_lookup_arguments.initializer,
trainable=trainable,
collections=weight_collections)
- if _is_variable(variable):
+ if fc._is_variable(variable): # pylint: disable=protected-access
variable = [variable]
else:
variable = variable._get_variable_list() # pylint: disable=protected-access
@@ -406,16 +300,6 @@ def _create_embedding_lookup(column,
return variable, predictions
-def _maybe_restore_from_checkpoint(checkpoint_path, variable):
- if checkpoint_path is not None:
- path, tensor_name = checkpoint_path
- weights_to_restore = variable
- if len(variable) == 1:
- weights_to_restore = variable[0]
- checkpoint_utils.init_from_checkpoint(path,
- {tensor_name: weights_to_restore})
-
-
def _create_joint_embedding_lookup(columns_to_tensors,
embedding_lookup_arguments,
num_outputs,
@@ -451,7 +335,7 @@ def _create_joint_embedding_lookup(columns_to_tensors,
initializer=init_ops.zeros_initializer(),
trainable=trainable,
collections=weight_collections)
- if _is_variable(variable):
+ if fc._is_variable(variable): # pylint: disable=protected-access
variable = [variable]
else:
variable = variable._get_variable_list() # pylint: disable=protected-access
@@ -634,7 +518,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
predictions, shape=(-1, num_outputs)))
column_to_variable[column] = variable
_log_variable(variable)
- _maybe_restore_from_checkpoint(column._checkpoint_path(), variable)
+ fc._maybe_restore_from_checkpoint(column._checkpoint_path(), variable) # pylint: disable=protected-access
# pylint: enable=protected-access
predictions_no_bias = math_ops.add_n(output_tensors)
bias = contrib_variables.model_variable(
@@ -827,10 +711,10 @@ def parse_feature_columns_from_sequence_examples(
def _log_variable(variable):
if isinstance(variable, list):
for var in variable:
- if _is_variable(variable):
+ if fc._is_variable(variable): # pylint: disable=protected-access
logging.info('Created variable %s, with device=%s', var.name,
var.device)
- elif _is_variable(variable):
+ elif fc._is_variable(variable): # pylint: disable=protected-access
logging.info('Created variable %s, with device=%s', variable.name,
variable.device)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index a09cc53571..0123921266 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -597,12 +597,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
"income":
constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]),
}
- output = feature_column_ops.input_from_feature_columns(features, [
- one_hot_column, embedding_column, real_valued_column])
+ columns = [one_hot_column, embedding_column, real_valued_column]
+ output = feature_column_ops.input_from_feature_columns(features, columns)
+ output_core = fc_core.make_input_layer(features, columns)
with self.test_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval().shape, output_core.eval().shape)
def testRealValuedColumn(self):
real_valued = feature_column.real_valued_column("price")
@@ -712,11 +715,14 @@ class CreateInputLayersForDNNsTest(test.TestCase):
one_hot_column = feature_column.one_hot_column(weighted_ids_column)
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_column])
+ output_core = fc_core.make_input_layer(features, [one_hot_column])
with self.test_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
output.eval())
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval(), output_core.eval())
def testOneHotColumnFromSparseColumnWithKeysSucceedsForDNN(self):
ids_column = feature_column.sparse_column_with_keys(
@@ -729,12 +735,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"ids": ids_tensor}
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
+ output_core = fc_core.make_input_layer(features, [one_hot_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
output.eval())
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval(), output_core.eval())
def testOneHotColumnFromMultivalentSparseColumnWithKeysSucceedsForDNN(self):
ids_column = feature_column.sparse_column_with_keys(
@@ -747,12 +756,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"ids": ids_tensor}
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
+ output_core = fc_core.make_input_layer(features, [one_hot_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval())
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval(), output_core.eval())
def testOneHotColumnFromSparseColumnWithIntegerizedFeaturePassesForDNN(self):
ids_column = feature_column.sparse_column_with_integerized_feature(
@@ -767,10 +779,13 @@ class CreateInputLayersForDNNsTest(test.TestCase):
}
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
+ output_core = fc_core.make_input_layer(features, [one_hot_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval())
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval(), output_core.eval())
def testOneHotColumnFromSparseColumnWithHashBucketSucceedsForDNN(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("feat", 10)
@@ -782,10 +797,13 @@ class CreateInputLayersForDNNsTest(test.TestCase):
one_hot_sparse = feature_column.one_hot_column(hashed_sparse)
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
+ output_core = fc_core.make_input_layer(features, [one_hot_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape)
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval(), output_core.eval())
def testEmbeddingColumnSucceedsForDNN(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
@@ -797,9 +815,12 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
+ output_core = fc_core.make_input_layer(features, [embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [4, 10])
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval().shape, output_core.eval().shape)
def testScatteredEmbeddingColumnSucceedsForDNN(self):
wire_tensor = sparse_tensor.SparseTensor(
@@ -838,12 +859,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
initializer=init_ops.constant_initializer(init_value))
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
+ output_core = fc_core.make_input_layer(features, [embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
output_eval = output.eval()
self.assertAllEqual(output_eval.shape, [2, 10])
self.assertAllClose(output_eval, np.tile(init_value, [2, 10]))
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval(), output_core.eval())
def testEmbeddingColumnWithMultipleInitializersFails(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
@@ -889,10 +913,14 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
+ output_core = fc_core.make_input_layer(features, [embeded_sparse])
+
with self.test_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval().shape, output_core.eval().shape)
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
"""Same as the previous test, but with integer weights."""
@@ -1534,9 +1562,12 @@ class WeightedSumTest(test.TestCase):
features = {"wire": wire_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
+ logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
with self.test_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(logits.eval(), logits_core.eval())
def testSparseIntColumn(self):
"""Tests a sparse column with int values."""
@@ -1549,9 +1580,12 @@ class WeightedSumTest(test.TestCase):
features = {"wire": wire_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
+ logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
with self.test_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(logits.eval(), logits_core.eval())
def testSparseColumnWithDenseInputTensor(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
@@ -1560,9 +1594,12 @@ class WeightedSumTest(test.TestCase):
features = {"wire": wire_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
+ logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
with self.test_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(logits.eval(), logits_core.eval())
def testWeightedSparseColumn(self):
ids = feature_column.sparse_column_with_keys("ids",
@@ -1579,10 +1616,13 @@ class WeightedSumTest(test.TestCase):
features = {"ids": ids_tensor, "weights": weights_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_ids], num_outputs=5)
+ logits_core = fc_core.make_linear_model(features, [weighted_ids], units=5)
with self.test_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(logits.eval(), logits_core.eval())
def testWeightedSparseColumnWithDenseInputTensor(self):
ids = feature_column.sparse_column_with_keys(
@@ -1594,11 +1634,14 @@ class WeightedSumTest(test.TestCase):
features = {"ids": ids_tensor, "weights": weights_tensor}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_ids], num_outputs=5)
+ logits_core = fc_core.make_linear_model(features, [weighted_ids], units=5)
with self.test_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
+ # Verify cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(logits.eval(), logits_core.eval())
def testCrossedColumn(self):
a = feature_column.sparse_column_with_hash_bucket(
@@ -1649,6 +1692,8 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [movies], num_outputs=1))
+ logits_core = fc_core.make_linear_model(features, [movies])
+
with self.test_session() as sess:
variables_lib.initialize_all_variables().run()
lookup_ops.tables_initializer().run()
@@ -1659,6 +1704,8 @@ class WeightedSumTest(test.TestCase):
# score for first example = 0.3 (matrix) + 0.1 (head-on) = 0.4
# score for second example = 0.5 (winter sleep)
self.assertAllClose(output.eval(), [[0.4], [0.5]])
+ # Cross compatibility: Core builder output should equal to contrib.
+ self.assertAllEqual(output.eval().shape, logits_core.eval().shape)
def testRealValuedColumnWithMultiDimensions(self):
real_valued = feature_column.real_valued_column("price", 2)
diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py
index 1926cbe7b3..271b3c01ff 100644
--- a/tensorflow/contrib/layers/python/layers/initializers.py
+++ b/tensorflow/contrib/layers/python/layers/initializers.py
@@ -36,7 +36,8 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
Xavier Glorot and Yoshua Bengio (2010):
[Understanding the difficulty of training deep feedforward neural
networks. International conference on artificial intelligence and
- statistics.](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.207.2059&rep=rep1&type=pdf)
+ statistics.](
+ http://www.jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)
This initializer is designed to keep the scale of the gradients roughly the
same in all layers. In uniform distribution this ends up being the range:
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 15d2dbf3bf..53c71c6f3e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -102,9 +102,10 @@ def _linear_learning_rate(num_linear_feature_columns):
def _add_hidden_layer_summary(value, tag):
summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
summary.histogram("%s/activation" % tag, value)
+
+
def _add_layer_summary(value, tag):
- summary.scalar("%s/fraction_of_zero_values" % tag,
- nn.zero_fraction(value))
+ summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
summary.histogram("%s/activation" % tag, value)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index cf80dafc37..d86ef8d477 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import layers
-from tensorflow.contrib.framework.python.framework import deprecated
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
index cbbd9671b7..a473cf46d5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.summary import summary
from tensorflow.python.ops.control_flow_ops import with_dependencies
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.session_run_hook import SessionRunArgs
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
index 64a97880c3..02acd70812 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.contrib import layers
from tensorflow.contrib import rnn as rnn_cell
-from tensorflow.contrib.framework.python.framework import deprecated
from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn.estimators import constants
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
index 69469b577d..886aab5b4f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
@@ -455,6 +455,7 @@ class LegacyConstructorTest(test.TestCase):
return {'inputs': inputs}, labels
return input_fn
+
# TODO(jtbates): move all tests below to a benchmark test.
class StateSavingRNNEstimatorLearningTest(test.TestCase):
"""Learning tests for state saving RNN Estimators."""
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index 9ecfc73299..17feeb2736 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -22,6 +22,7 @@ import os
import tempfile
import time
+from tensorflow.contrib.learn.python.learn import estimator as estimator_lib
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn import run_config
@@ -38,6 +39,7 @@ from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_inspect
class SheepCounter(object):
@@ -119,6 +121,15 @@ class TestBaseEstimator(object):
compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
+def _check_method_supports_args(method, kwargs):
+ """Checks that the given method supports the given args."""
+ supported_args = tuple(tf_inspect.getargspec(method).args)
+ for kwarg in kwargs:
+ if kwarg not in supported_args:
+ raise ValueError(
+ 'Argument `{}` is not supported in method {}.'.format(kwarg, method))
+
+
class TestEstimator(
TestBaseEstimator, evaluable.Evaluable, trainable.Trainable):
@@ -126,9 +137,12 @@ class TestEstimator(
super(TestEstimator, self).__init__(config, max_evals, eval_dict)
tf_logging.info('Create Estimator')
+ def evaluate(self, **kwargs):
+ _check_method_supports_args(evaluable.Evaluable.evaluate, kwargs)
+ return super(TestEstimator, self).evaluate(**kwargs)
+
def fit(self, **kwargs):
- if 'hooks' in kwargs:
- raise ValueError('`hooks` is defined in core Estimator')
+ _check_method_supports_args(trainable.Trainable.fit, kwargs)
if 'monitors' in kwargs:
self.monitors = kwargs['monitors']
return super(TestEstimator, self).train(**kwargs)
@@ -136,6 +150,13 @@ class TestEstimator(
def train(self, **kwargs):
raise ValueError('`train` is not defined in Estimator.')
+ def export_savedmodel(
+ self, export_dir_base, serving_input_fn, **kwargs):
+ _check_method_supports_args(
+ estimator_lib.Estimator.export_savedmodel, kwargs)
+ return super(TestEstimator, self).export_savedmodel(
+ export_dir_base, serving_input_fn, **kwargs)
+
class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
@@ -144,17 +165,22 @@ class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
tf_logging.info('Create Core Estimator')
def evaluate(self, **kwargs):
- if 'eval_metrics' in kwargs:
- raise ValueError('`eval_metrics` is not defined in core Estimator')
+ _check_method_supports_args(core_estimator.Estimator.evaluate, kwargs)
return super(TestCoreEstimator, self).evaluate(**kwargs)
def train(self, **kwargs):
- if 'monitors' in kwargs:
- raise ValueError('`monitors` is not defined in core Estimator')
+ _check_method_supports_args(core_estimator.Estimator.train, kwargs)
if 'hooks' in kwargs:
self.monitors = kwargs['hooks']
return super(TestCoreEstimator, self).train(**kwargs)
+ def export_savedmodel(
+ self, export_dir_base, serving_input_receiver_fn, **kwargs):
+ _check_method_supports_args(
+ core_estimator.Estimator.export_savedmodel, kwargs)
+ return super(TestCoreEstimator, self).export_savedmodel(
+ export_dir_base, serving_input_receiver_fn, **kwargs)
+
class _NoopHook(session_run_hook.SessionRunHook):
pass
@@ -184,6 +210,23 @@ class ExperimentTest(test.TestCase):
eval_input_fn='eval_input',
eval_metrics='eval_metrics')
+ def test_default_output_alternative_key_core_estimator(self):
+ est = TestCoreEstimator()
+ export_strategy = saved_model_export_utils.make_export_strategy(
+ est,
+ default_output_alternative_key='export_key',
+ exports_to_keep=None)
+ ex = experiment.Experiment(
+ est,
+ train_input_fn='train_input',
+ eval_input_fn='eval_input',
+ train_steps=100,
+ eval_steps=100,
+ export_strategies=export_strategy)
+ with self.assertRaisesRegexp(
+ ValueError, 'default_output_alternative_key is not supported'):
+ ex.train_and_evaluate()
+
def test_train(self):
for est in self._estimators_for_tests():
eval_metrics = 'eval_metrics' if not isinstance(
@@ -508,7 +551,9 @@ class ExperimentTest(test.TestCase):
eval_metrics = 'eval_metrics' if not isinstance(
est, core_estimator.Estimator) else None
export_strategy_1 = saved_model_export_utils.make_export_strategy(
- est, 'export_input_1', exports_to_keep=None)
+ est,
+ None if isinstance(est, core_estimator.Estimator) else 'export_1',
+ exports_to_keep=None)
ex = experiment.Experiment(
est,
@@ -531,9 +576,13 @@ class ExperimentTest(test.TestCase):
# After reset with list, the count should increase with the number of
# items.
export_strategy_2 = saved_model_export_utils.make_export_strategy(
- est, 'export_input_2', exports_to_keep=None)
+ est,
+ None if isinstance(est, core_estimator.Estimator) else 'export_2',
+ exports_to_keep=None)
export_strategy_3 = saved_model_export_utils.make_export_strategy(
- est, 'export_input_3', exports_to_keep=None)
+ est,
+ None if isinstance(est, core_estimator.Estimator) else 'export_3',
+ exports_to_keep=None)
old_es = ex.reset_export_strategies(
[export_strategy_2, export_strategy_3])
@@ -547,7 +596,9 @@ class ExperimentTest(test.TestCase):
est, core_estimator.Estimator) else None
noop_hook = _NoopHook()
export_strategy = saved_model_export_utils.make_export_strategy(
- est, 'export_input', exports_to_keep=None)
+ est,
+ None if isinstance(est, core_estimator.Estimator) else 'export_input',
+ exports_to_keep=None)
ex = experiment.Experiment(
est,
train_input_fn='train_input',
@@ -625,7 +676,9 @@ class ExperimentTest(test.TestCase):
est, core_estimator.Estimator) else None
noop_hook = _NoopHook()
export_strategy = saved_model_export_utils.make_export_strategy(
- est, 'export_input', exports_to_keep=None)
+ est,
+ None if isinstance(est, core_estimator.Estimator) else 'export_input',
+ exports_to_keep=None)
ex = experiment.Experiment(
est,
train_input_fn='train_input',
@@ -646,7 +699,9 @@ class ExperimentTest(test.TestCase):
eval_metrics = 'eval_metrics' if not isinstance(
est, core_estimator.Estimator) else None
export_strategy = saved_model_export_utils.make_export_strategy(
- est, 'export_input', exports_to_keep=None)
+ est,
+ None if isinstance(est, core_estimator.Estimator) else 'export_input',
+ exports_to_keep=None)
ex = experiment.Experiment(
est,
train_input_fn='train_input',
@@ -796,7 +851,9 @@ class ExperimentTest(test.TestCase):
def test_test(self):
for est in self._estimators_for_tests():
exp_strategy = saved_model_export_utils.make_export_strategy(
- est, 'export_input', exports_to_keep=None)
+ est,
+ None if isinstance(est, core_estimator.Estimator) else 'export_input',
+ exports_to_keep=None)
ex = experiment.Experiment(
est,
train_input_fn='train_input',
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
index 7ad3779314..fa314e69c7 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -42,6 +42,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
from tensorflow.contrib.learn.python.learn.utils import gc
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
+from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import gfile
@@ -352,7 +353,8 @@ def make_export_strategy(serving_input_fn,
`InputFnOps`.
default_output_alternative_key: the name of the head to serve when an
incoming serving request does not explicitly request a specific head.
- Not needed for single-headed models.
+ Must be `None` if the estimator inherits from ${tf.estimator.Estimator}
+ or for single-headed models.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel. Each key should give the destination
path (including the filename) relative to the assets.extra directory.
@@ -384,14 +386,30 @@ def make_export_strategy(serving_input_fn,
Returns:
The string path to the exported directory.
+
+ Raises:
+ ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
+ and `default_output_alternative_key` was specified.
"""
- export_result = estimator.export_savedmodel(
- export_dir_base,
- serving_input_fn,
- default_output_alternative_key=default_output_alternative_key,
- assets_extra=assets_extra,
- as_text=as_text,
- checkpoint_path=checkpoint_path)
+ if isinstance(estimator, core_estimator.Estimator):
+ if default_output_alternative_key is not None:
+ raise ValueError(
+ 'default_output_alternative_key is not supported in core '
+ 'Estimator. Given: {}'.format(default_output_alternative_key))
+ export_result = estimator.export_savedmodel(
+ export_dir_base,
+ serving_input_fn,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path)
+ else:
+ export_result = estimator.export_savedmodel(
+ export_dir_base,
+ serving_input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path)
garbage_collect_exports(export_dir_base, exports_to_keep)
return export_result
diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD
index ddcd0a03cf..5b65a6ae05 100644
--- a/tensorflow/contrib/signal/BUILD
+++ b/tensorflow/contrib/signal/BUILD
@@ -1,9 +1,9 @@
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-package(default_visibility = ["//tensorflow:__subpackages__"])
-
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
py_library(
diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py
index 33639aa82a..9f906dd28e 100644
--- a/tensorflow/contrib/signal/__init__.py
+++ b/tensorflow/contrib/signal/__init__.py
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""##Signal ops.
-"""
@@frames
"""
diff --git a/tensorflow/contrib/signal/python/__init__.py b/tensorflow/contrib/signal/python/__init__.py
index d2b5550aa6..e672d1146c 100644
--- a/tensorflow/contrib/signal/python/__init__.py
+++ b/tensorflow/contrib/signal/python/__init__.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""Signal ops."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
index ba6c67c770..e07942875f 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py
@@ -33,34 +33,34 @@ class FramesTest(test.TestCase):
with self.test_session():
tensor = constant_op.constant(np.arange(9152), dtypes.int32)
tensor = array_ops.expand_dims(tensor, 0)
-
+
result = shape_ops.frames(tensor, 512, 180)
result = result.eval()
-
+
expected = np.tile(np.arange(512), (49, 1))
expected += np.tile(np.arange(49) * 180, (512, 1)).T
-
+
expected = np.expand_dims(expected, axis=0)
expected = np.array(expected, dtype=np.int32)
-
+
self.assertAllEqual(expected, result)
-
+
def test_mapping_of_indices_with_padding(self):
with self.test_session():
tensor = constant_op.constant(np.arange(10000), dtypes.int32)
tensor = array_ops.expand_dims(tensor, 0)
-
+
result = shape_ops.frames(tensor, 512, 192)
result = result.eval()
-
+
expected = np.tile(np.arange(512), (51, 1))
expected += np.tile(np.arange(51) * 192, (512, 1)).T
-
+
expected[expected >= 10000] = 0
-
+
expected = np.expand_dims(expected, axis=0)
expected = np.array(expected, dtype=np.int32)
-
+
self.assertAllEqual(expected, result)
diff --git a/tensorflow/contrib/signal/python/ops/__init__.py b/tensorflow/contrib/signal/python/ops/__init__.py
index d2b5550aa6..e672d1146c 100644
--- a/tensorflow/contrib/signal/python/ops/__init__.py
+++ b/tensorflow/contrib/signal/python/ops/__init__.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""Signal ops."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/signal/python/ops/shape_ops.py b/tensorflow/contrib/signal/python/ops/shape_ops.py
index e1d0f977c1..4914f19be7 100644
--- a/tensorflow/contrib/signal/python/ops/shape_ops.py
+++ b/tensorflow/contrib/signal/python/ops/shape_ops.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""General shape ops for frames."""
from __future__ import absolute_import
from __future__ import division
@@ -23,59 +24,64 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+
def frames(signal, frame_length, frame_step, name=None):
"""Frame a signal into overlapping frames.
+
May be used in front of spectral functions.
-
+
For example:
-
+
```python
pcm = tf.placeholder(tf.float32, [None, 9152])
frames = tf.contrib.signal.frames(pcm, 512, 180)
magspec = tf.abs(tf.spectral.rfft(frames, [512]))
image = tf.expand_dims(magspec, 3)
```
-
+
Args:
signal: A `Tensor` of shape `[batch_size, signal_length]`.
frame_length: An `int32` or `int64` `Tensor`. The length of each frame.
frame_step: An `int32` or `int64` `Tensor`. The step between frames.
name: A name for the operation (optional).
-
+
Returns:
A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`.
+
+ Raises:
+ ValueError: if signal does not have rank 2.
"""
with ops.name_scope(name, "frames", [signal, frame_length, frame_step]):
signal = ops.convert_to_tensor(signal, name="signal")
frame_length = ops.convert_to_tensor(frame_length, name="frame_length")
frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
-
+
signal_rank = signal.shape.ndims
-
+
if signal_rank != 2:
raise ValueError("expected signal to have rank 2 but was " + signal_rank)
-
+
signal_length = array_ops.shape(signal)[1]
-
+
num_frames = math_ops.ceil((signal_length - frame_length) / frame_step)
num_frames = 1 + math_ops.cast(num_frames, dtypes.int32)
-
+
pad_length = (num_frames - 1) * frame_step + frame_length
- pad_signal = array_ops.pad(
- signal, [[0, 0], [0, pad_length - signal_length]])
-
+ pad_signal = array_ops.pad(signal, [[0, 0], [0,
+ pad_length - signal_length]])
+
indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0)
indices_frames = array_ops.tile(indices_frame, [num_frames, 1])
-
+
indices_step = array_ops.expand_dims(
math_ops.range(num_frames) * frame_step, 1)
indices_steps = array_ops.tile(indices_step, [1, frame_length])
-
+
indices = indices_frames + indices_steps
-
- # TODO(Androbin): remove `transpose` when `gather` gets `axis` support
+
+ # TODO(androbin): remove `transpose` when `gather` gets `axis` support
pad_signal = array_ops.transpose(pad_signal)
- frames = array_ops.gather(pad_signal, indices)
- frames = array_ops.transpose(frames, perm=[2, 0, 1])
-
- return frames
+ signal_frames = array_ops.gather(pad_signal, indices)
+ signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1])
+
+ return signal_frames
diff --git a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
index c6fe518bbb..f2065c6662 100644
--- a/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
+++ b/tensorflow/contrib/testing/python/framework/fake_summary_writer.py
@@ -127,6 +127,6 @@ class FakeSummaryWriter(object):
def reopen(self):
pass
-
+
def close(self):
pass
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
index 9c3b10b22c..818c2d2cbf 100644
--- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD
@@ -97,6 +97,29 @@ py_test(
],
)
+py_library(
+ name = "pprof_profiler",
+ srcs = ["pprof_profiler.py"],
+ srcs_version = "PY2AND3",
+ deps = ["@pprof_profile_proto//:pprof_proto_py"],
+)
+
+py_test(
+ name = "pprof_profiler_test",
+ srcs = ["pprof_profiler_test.py"],
+ main = "pprof_profiler_test.py",
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"], # TODO(annarev): get it working with pip.
+ deps = [
+ ":pprof_profiler",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "@pprof_profile_proto//:pprof_proto_py",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
new file mode 100644
index 0000000000..c3fea915a3
--- /dev/null
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
@@ -0,0 +1,445 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Profiler for TensorFlow models that outputs data in pprof format.
+
+See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof
+profile format.
+The following needs to be set for profiler to work:
+ * trace_level needs to be set to FULL_TRACE
+ * run_metadata object should be passed in to session.run call
+
+Sample usage:
+ options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
+ run_metadata = tf.RunMetadata()
+
+ with tf.Session as sess:
+ ...
+ sess.run(computation, run_metadata=run_metadata, options=options)
+ pprof_profiler.profile(sess.graph, run_metadata, output_dir)
+
+
+ The code above would output a pprof profile to separate output_dir/.*.pb.gz
+ file for each device. These files can be passed to pprof for formatting.
+ For e.g.:
+ pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+from collections import namedtuple
+import gzip
+import os
+import string
+import sys
+import time
+
+from pprof import profile_pb2
+
+
+if sys.version_info < (3,):
+ maketrans = string.maketrans
+else:
+ maketrans = str.maketrans
+
+
+ProfileDatum = namedtuple('ProfileDatum', [
+ 'node_exec_stats', 'op_type', 'traceback'])
+
+
+class StringTable(object):
+ """Keeps track of strings to add to string_table in pprof proto."""
+
+ def __init__(self):
+ # Pprof requires first entry in string_table to be ''.
+ self._string_table = ['']
+ self._string_to_index = {'': 0}
+
+ def index_of(self, value_str):
+ """Get index of value_str in the string table.
+
+ If value_str is not in the string table, we will add it at the end
+ and then return the new index.
+ Args:
+ value_str: (string) Value to lookup/add in/to the string table.
+
+ Returns:
+ Index of value_str in the string table.
+ """
+ if value_str is None:
+ value_str = ''
+ if value_str in self._string_to_index:
+ return self._string_to_index[value_str]
+ index = len(self._string_table)
+ self._string_table.append(value_str)
+ self._string_to_index[value_str] = index
+ return index
+
+ def next_index(self):
+ """Gets index that would be assigned to the next added string.
+
+ Returns:
+ Index of the next string if it was added.
+ """
+ return len(self._string_table)
+
+ def string_table(self):
+ """Returns a list of strings to store in pprof's string_table."""
+ return self._string_table
+
+
+class Functions(object):
+ """Keeps track of `Function` protos for pprof profile."""
+
+ def __init__(self, string_table):
+ """Constructor.
+
+ Args:
+ string_table: A `StringTable` object.
+ """
+ self._string_table = string_table
+ # Maps tuples in the form (file_path, function_name, start_line_number)
+ # to `Function` protos.
+ self._function_key_to_function = {}
+
+ def index_of(self, file_path, function_name, function_start_line):
+ """Returns index of the function, adding the function if needed.
+
+ Args:
+ file_path: (string) Path to file where the function is defined.
+ function_name: (string) Function name.
+ function_start_line: (integer) Start line number of function definition.
+
+ Returns:
+ Function index.
+ """
+ function_key = (file_path, function_name, function_start_line)
+ if function_key in self._function_key_to_function:
+ return self._function_key_to_function[function_key].id
+ else:
+ # Function indexes should start from 1
+ function_index = len(self._function_key_to_function) + 1
+ function = profile_pb2.Function()
+ function.id = function_index
+ function.name = self._string_table.index_of(function_name)
+ function.filename = self._string_table.index_of(file_path)
+ function.start_line = function_start_line
+ self._function_key_to_function[function_key] = function
+ return function_index
+
+ def function_protos(self):
+ """Returns list of `profile_pb2.Function` protos."""
+ return self._function_key_to_function.values()
+
+
+class Locations(object):
+ """Keeps track of `Location` protos for pprof profile.
+
+ `Locations` store information about function call locations.
+ """
+
+ def __init__(self, functions):
+ """Constructor.
+
+ Args:
+ functions: A `Functions` object.
+ """
+ self._functions = functions
+ # Maps tuples in the form (file_path, called_function_name, line_number)
+ # to `Location` protos.
+ self._location_key_to_location = {}
+
+ def index_of(
+ self, file_path, line_number, called_function_name, called_file_path,
+ called_function_start_line):
+ """Returns index of the location, adding the location if needed.
+
+ Args:
+ file_path: (string) Path to file that makes the call.
+ line_number: (integer) Call line number.
+ called_function_name: (string) Function name of the function called at
+ `file_path` and `line_number`.
+ called_file_path: (string) Path to file where the called function is
+ defined.
+ called_function_start_line: (integer) Start line number of called
+ function definition in `called_file_path` file.
+
+ Returns:
+ Index of location.
+ """
+ location_key = (file_path, called_function_name, line_number)
+ if location_key in self._location_key_to_location:
+ location = self._location_key_to_location[location_key]
+ return location.id
+ else:
+ # Location indexes should start from 1
+ location_index = len(self._location_key_to_location) + 1
+ location = profile_pb2.Location()
+ location.id = location_index
+ self._location_key_to_location[location_key] = location
+
+ line = location.line.add()
+ line.function_id = self._functions.index_of(
+ called_file_path, called_function_name, called_function_start_line)
+ line.line = line_number
+ return location_index
+
+ def location_protos(self):
+ """Returns list of `profile_pb2.Location` protos."""
+ return self._location_key_to_location.values()
+
+
+class Samples(object):
+ """Keeps track of `Sample` protos for pprof profile.
+
+ Samples store the following statistics in order:
+ count, all_time, op_time
+ """
+
+ def __init__(self, string_table):
+ """Constructor.
+
+ Args:
+ string_table: A `StringTable` object.
+ """
+ self._string_table = string_table
+ # TODO(annarev): figure out if location is unique for each node name.
+ # If not, also key this dictionary based on location ids.
+ self._node_name_to_sample = {}
+
+ def add(self, datum, location_ids):
+ """Adds a sample data point.
+
+ Args:
+ datum: `ProfileDatum` to add a sample for.
+ location_ids: List of numberic location ids for this
+ sample.
+ """
+ node_name = datum.node_exec_stats.node_name
+ if node_name in self._node_name_to_sample:
+ sample = self._node_name_to_sample[node_name]
+ sample.location_id.extend(location_ids)
+ else:
+ sample = profile_pb2.Sample()
+ # Sample stores 3 values: count, all_time, op_time
+ sample.value.extend([0, 0, 0])
+
+ label = sample.label.add()
+ label.key = self._string_table.index_of('node_name')
+ label.str = self._string_table.index_of(node_name)
+ label = sample.label.add()
+ label.key = self._string_table.index_of('op_type')
+ label.str = self._string_table.index_of(datum.op_type)
+ self._node_name_to_sample[node_name] = sample
+ sample.value[0] += 1
+ sample.value[1] += datum.node_exec_stats.all_end_rel_micros
+ sample.value[2] += (
+ datum.node_exec_stats.op_end_rel_micros -
+ datum.node_exec_stats.op_start_rel_micros)
+
+ def get_sample_protos(self):
+ """Returns list of `Sample` protos for pprof profile."""
+ return self._node_name_to_sample.values()
+
+
+class PprofProfiler(object):
+ """Creates profiles in pprof format."""
+
+ def __init__(self, graph, run_metadata):
+ """Constructor.
+
+ Args:
+ graph: A `Graph` instance.
+ run_metadata: A list of `RunMetadata` objects.
+ """
+ self._graph = graph
+ self._run_metadata = run_metadata
+ self._string_table = StringTable()
+ self._functions = Functions(self._string_table)
+ self._locations = Locations(self._functions)
+
+ def profile(self):
+ """Generates pprof profiles.
+
+ Returns:
+ Dictionary mapping from device name to proto in `profile_pb2.Profile`
+ format.
+ """
+ profiles = {}
+ data_generator_func = self._get_profile_data_generator()
+ for device_index, device_stats in enumerate(
+ self._run_metadata.step_stats.dev_stats):
+ # Create profile
+ pprof_proto = self._get_pprof_proto(data_generator_func(device_stats))
+ if not pprof_proto.sample:
+ print(
+ 'Not enough data to create profile for device %s. Did you pass '
+ 'RunMetadata to session.run call?' % device_stats.device)
+ continue
+ # Add device name comment
+ device_count = len(self._run_metadata.step_stats.dev_stats)
+ device_description = (
+ 'Device %d of %d: %s' %
+ (device_index + 1, device_count, device_stats.device))
+ device_description_str_index = self._string_table.next_index()
+ pprof_proto.string_table.append(device_description)
+ pprof_proto.comment.append(device_description_str_index)
+ profiles[device_stats.device] = pprof_proto
+ return profiles
+
+ def _get_pprof_proto(self, profile_datum_generator):
+ """Returns profile data in pprof proto format.
+
+ Args:
+ profile_datum_generator: Generator outputting `ProfileDatum` objects.
+
+ Returns:
+ A proto in pprof format.
+ """
+ pprof_profile = profile_pb2.Profile()
+ samples = Samples(self._string_table)
+
+ for datum in profile_datum_generator:
+ if not datum.traceback:
+ continue
+
+ stack_frame = datum.traceback[-1]
+ after_apply_op = False
+ location_ids = []
+
+ # We add locations from stack trace in bottom-up order.
+ for stack_frame_index in reversed(range(len(datum.traceback) - 1)):
+ prev_stack_frame = stack_frame
+ stack_frame = datum.traceback[stack_frame_index]
+
+ # Call at current frame calls function at previous frame.
+ prev_file_path = prev_stack_frame[0]
+ prev_function = prev_stack_frame[2]
+ prev_function_start_line = prev_stack_frame[4]
+ curr_file_path = stack_frame[0]
+ curr_line_number = stack_frame[1]
+
+ # Skip all calls up to apply_op since they are the same for all ops.
+ if not after_apply_op:
+ if prev_function == 'apply_op':
+ after_apply_op = True
+ continue
+ location_index = self._locations.index_of(
+ curr_file_path, curr_line_number,
+ prev_function, prev_file_path, prev_function_start_line)
+ location_ids.append(location_index)
+ samples.add(datum, location_ids)
+
+ sample_type_description = 'count'
+ sample_type = pprof_profile.sample_type.add()
+ sample_type.type = self._string_table.index_of(sample_type_description)
+ sample_type.unit = self._string_table.index_of('count')
+ sample_type_description = 'all_time'
+ sample_type = pprof_profile.sample_type.add()
+ sample_type.type = self._string_table.index_of(sample_type_description)
+ sample_type.unit = self._string_table.index_of('nanoseconds')
+ sample_type_description = 'op_time'
+ sample_type = pprof_profile.sample_type.add()
+ sample_type.type = self._string_table.index_of(sample_type_description)
+ sample_type.unit = self._string_table.index_of('nanoseconds')
+
+ pprof_profile.string_table.extend(self._string_table.string_table())
+ pprof_profile.sample.extend(samples.get_sample_protos())
+ pprof_profile.function.extend(self._functions.function_protos())
+ pprof_profile.location.extend(self._locations.location_protos())
+ return pprof_profile
+
+ def _get_profile_data_generator(self):
+ """Get function that generates `ProfileDatum` objects.
+
+ Returns:
+ A function that generates `ProfileDatum` objects.
+ """
+ node_to_traceback = defaultdict(list)
+ node_to_op_type = defaultdict(str)
+ for op in self._graph.get_operations():
+ node_to_traceback[op.name] = op.traceback_with_start_lines
+ node_to_op_type[op.name] = op.type
+
+ def profile_data_generator(device_step_stats):
+ for node_stats in device_step_stats.node_stats:
+ if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK':
+ continue
+ yield ProfileDatum(
+ node_stats,
+ node_to_op_type[node_stats.node_name],
+ node_to_traceback[node_stats.node_name])
+
+ return profile_data_generator
+
+
+def get_profiles(graph, run_metadata):
+ """Generate profiles in pprof format.
+
+ See https://github.com/google/pprof/blob/master/proto/profile.proto
+ for pprof proto format.
+
+ Args:
+ graph: A `Graph` object.
+ run_metadata: A `RunMetadata` proto.
+
+ Returns:
+ A dictionary mapping from device name to pprof proto for that device.
+ """
+ return PprofProfiler(graph, run_metadata).profile()
+
+
+def profile(graph, run_metadata, output_dir=None):
+ """Generate profiles in pprof format.
+
+ See https://github.com/google/pprof/blob/master/proto/profile.proto
+ for pprof proto format.
+
+ Args:
+ graph: A `Graph` object.
+ run_metadata: A `RunMetadata` proto.
+ output_dir: (string) Directory to output pprof profile to.
+ Profile files for each device will be stored in compressed
+ serialized proto format. If output_dir is None, profile protos
+ will be printed to stdout instead.
+
+ Returns:
+ List of output files created by this profile call.
+ (Note: this list will be empty if output_dir is None)
+ """
+ profiles = get_profiles(graph, run_metadata)
+ output_file_template = None
+ if output_dir:
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
+ time_suffix = time.strftime('%Y%m%d%H%M%S')
+ output_file_template = os.path.join(
+ output_dir, '%s_' + time_suffix + '.pb.gz')
+
+ profile_files = []
+ for device, pprof_proto in profiles.items():
+ if output_file_template is None:
+ print('No output directory specified, printing to stdout instead.')
+ print(pprof_proto)
+ else:
+ device_name = str(device).strip('/').translate(
+ maketrans('/:', '__'))
+ profile_file = output_file_template % device_name
+ profile_files.append(profile_file)
+ with gzip.open(profile_file, 'w') as output_file:
+ print('Writing profile to %s...' % profile_file)
+ output_file.write(pprof_proto.SerializeToString())
+ return profile_files
diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py
new file mode 100644
index 0000000000..13d3fb41ac
--- /dev/null
+++ b/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py
@@ -0,0 +1,164 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for pprof_profiler."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+
+from pprof import profile_pb2
+from tensorflow.contrib.tfprof.python.tools.tfprof import pprof_profiler
+from tensorflow.core.framework import step_stats_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class PprofProfilerTest(test.TestCase):
+
+ def testDataEmpty(self):
+ output_dir = test.get_temp_dir()
+ run_metadata = config_pb2.RunMetadata()
+ graph = test.mock.MagicMock()
+ graph.get_operations.return_value = []
+
+ profiles = pprof_profiler.get_profiles(graph, run_metadata)
+ self.assertEquals(0, len(profiles))
+ profile_files = pprof_profiler.profile(
+ graph, run_metadata, output_dir)
+ self.assertEquals(0, len(profile_files))
+
+ def testRunMetadataEmpty(self):
+ output_dir = test.get_temp_dir()
+ run_metadata = config_pb2.RunMetadata()
+ graph = test.mock.MagicMock()
+ op1 = test.mock.MagicMock()
+ op1.name = 'Add/123'
+ op1.traceback = [('a/b/file1', 10, 'some_var')]
+ op1.type = 'add'
+ graph.get_operations.return_value = [op1]
+
+ profiles = pprof_profiler.get_profiles(graph, run_metadata)
+ self.assertEquals(0, len(profiles))
+ profile_files = pprof_profiler.profile(
+ graph, run_metadata, output_dir)
+ self.assertEquals(0, len(profile_files))
+
+ def testValidProfile(self):
+ output_dir = test.get_temp_dir()
+ run_metadata = config_pb2.RunMetadata()
+
+ node1 = step_stats_pb2.NodeExecStats(
+ node_name='Add/123',
+ op_start_rel_micros=3,
+ op_end_rel_micros=5,
+ all_end_rel_micros=4)
+
+ run_metadata = config_pb2.RunMetadata()
+ device1 = run_metadata.step_stats.dev_stats.add()
+ device1.device = 'deviceA'
+ device1.node_stats.extend([node1])
+
+ graph = test.mock.MagicMock()
+ op1 = test.mock.MagicMock()
+ op1.name = 'Add/123'
+ op1.traceback = [
+ ('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')]
+ op1.type = 'add'
+ graph.get_operations.return_value = [op1]
+
+ expected_proto = """sample_type {
+ type: 5
+ unit: 5
+}
+sample_type {
+ type: 6
+ unit: 7
+}
+sample_type {
+ type: 8
+ unit: 7
+}
+sample {
+ value: 1
+ value: 4
+ value: 2
+ label {
+ key: 1
+ str: 2
+ }
+ label {
+ key: 3
+ str: 4
+ }
+}
+string_table: ""
+string_table: "node_name"
+string_table: "Add/123"
+string_table: "op_type"
+string_table: "add"
+string_table: "count"
+string_table: "all_time"
+string_table: "nanoseconds"
+string_table: "op_time"
+string_table: "Device 1 of 1: deviceA"
+comment: 9
+"""
+ # Test with protos
+ profiles = pprof_profiler.get_profiles(graph, run_metadata)
+ self.assertEquals(1, len(profiles))
+ self.assertTrue('deviceA' in profiles)
+ self.assertEquals(expected_proto, str(profiles['deviceA']))
+ # Test with files
+ profile_files = pprof_profiler.profile(
+ graph, run_metadata, output_dir)
+ self.assertEquals(1, len(profile_files))
+ with gzip.open(profile_files[0]) as profile_file:
+ profile_contents = profile_file.read()
+ profile = profile_pb2.Profile()
+ profile.ParseFromString(profile_contents)
+ self.assertEquals(expected_proto, str(profile))
+
+ def testProfileWithWhileLoop(self):
+ options = config_pb2.RunOptions()
+ options.trace_level = config_pb2.RunOptions.FULL_TRACE
+ run_metadata = config_pb2.RunMetadata()
+
+ num_iters = 5
+ with self.test_session() as sess:
+ i = constant_op.constant(0)
+ c = lambda i: math_ops.less(i, num_iters)
+ b = lambda i: math_ops.add(i, 1)
+ r = control_flow_ops.while_loop(c, b, [i])
+ sess.run(r, options=options, run_metadata=run_metadata)
+ profiles = pprof_profiler.get_profiles(sess.graph, run_metadata)
+ self.assertEquals(1, len(profiles))
+ profile = next(iter(profiles.values()))
+ add_samples = [] # Samples for the while/Add node
+ for sample in profile.sample:
+ if profile.string_table[sample.label[0].str] == 'while/Add':
+ add_samples.append(sample)
+ # Values for same nodes are aggregated.
+ self.assertEquals(1, len(add_samples))
+ # Value of "count" should be equal to number of iterations.
+ self.assertEquals(num_iters, add_samples[0].value[0])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index de33e5ed83..05df05de35 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -272,7 +272,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
self_.qpn = qp_->qp_num;
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
union ibv_gid gid;
- CHECK(!ibv_query_gid(adapter_->context_, (uint8_t) 1, 0, &gid)) << "Query gid";
+ CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid))
+ << "Query gid";
self_.snp = gid.global.subnet_prefix;
self_.iid = gid.global.interface_id;
}
@@ -479,7 +480,7 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.dest_qp_num = remoteAddr.qpn;
attr.rq_psn = remoteAddr.psn;
attr.max_dest_rd_atomic = 1;
- attr.min_rnr_timer = 12;
+ attr.min_rnr_timer = 12;
attr.ah_attr.is_global = 1;
attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp;
attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
index 0b0e8ec1de..ba206890ce 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
@@ -248,8 +248,8 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
tdata.size(), do_nothing);
slices[1] = ::grpc::Slice(s1, ::grpc::Slice::STEAL_REF);
- gpr_slice s2 = gpr_slice_new(const_cast<TensorBuffer*>(buf),
- 0, unref_tensorbuffer);
+ gpr_slice s2 =
+ gpr_slice_new(const_cast<TensorBuffer*>(buf), 0, unref_tensorbuffer);
slices[2] = ::grpc::Slice(s2, ::grpc::Slice::STEAL_REF);
num_slices += 2;
}
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 372092f42a..8e7209d0d4 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -136,6 +136,22 @@ cc_library(
)
cc_library(
+ name = "virtual_placer",
+ srcs = ["virtual_placer.cc"],
+ hdrs = ["virtual_placer.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":op_performance_data_cc",
+ ":utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:devices",
+ "//tensorflow/core/grappler/clusters:cluster",
+ ],
+)
+
+cc_library(
name = "virtual_scheduler",
srcs = ["virtual_scheduler.cc"],
hdrs = ["virtual_scheduler.h"],
@@ -194,3 +210,24 @@ cc_test(
"//tensorflow/core:test_main",
],
)
+
+cc_library(
+ name = "analytical_cost_estimator",
+ srcs = ["analytical_cost_estimator.cc"],
+ hdrs = ["analytical_cost_estimator.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cost_estimator",
+ ":graph_properties",
+ ":op_level_cost_estimator",
+ ":op_performance_data_cc",
+ ":utils",
+ ":virtual_placer",
+ ":virtual_scheduler",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
new file mode 100644
index 0000000000..29d55ca591
--- /dev/null
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
+
+#include <limits>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/costs/virtual_placer.h"
+#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace grappler {
+
+AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster,
+ bool use_static_shapes)
+ : cluster_(cluster), use_static_shapes_(use_static_shapes) {}
+
+Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
+ item_ = item;
+ return Status::OK();
+}
+
+Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
+ CostGraphDef* cost_graph,
+ Costs* costs) const {
+ GrapplerItem item = item_;
+ item.graph = optimized_graph;
+ GraphProperties properties(item);
+ Status status;
+ if (use_static_shapes_) {
+ status = properties.InferStatically();
+ } else {
+ status = properties.InferDynamically(cluster_);
+ }
+
+ if (!status.ok()) {
+ costs->execution_time = Costs::Duration::max();
+ return status;
+ }
+
+ std::unordered_map<string, CostGraphDef::Node*> name_to_cost;
+ if (cost_graph) {
+ for (auto& node : *cost_graph->mutable_node()) {
+ name_to_cost[node.name()] = &node;
+ }
+ }
+ std::vector<string> inaccurate_nodes;
+ VirtualScheduler scheduler(optimized_graph, item_.fetch);
+ VirtualPlacer placer(cluster_);
+ Costs node_costs;
+ do {
+ const NodeDef* node = scheduler.GetCurrNode();
+ std::vector<OpInfo::TensorProperties> inputs =
+ properties.GetInputProperties(node->name());
+
+ OpInfo::DeviceProperties device = placer.get_device(*node);
+ OpInfo op_info;
+ op_info.set_op(node->op());
+ *op_info.mutable_attr() = node->attr();
+ for (auto& input : inputs) {
+ op_info.add_inputs()->Swap(&input);
+ }
+ op_info.mutable_device()->Swap(&device);
+
+ node_costs = node_estimator_.PredictCosts(op_info);
+ if (node_costs.inaccurate) {
+ inaccurate_nodes.push_back(node->name());
+ }
+ if (cost_graph) {
+ auto it = name_to_cost.find(node->name());
+ CostGraphDef::Node* cost_node;
+ if (it != name_to_cost.end()) {
+ cost_node = it->second;
+ } else {
+ cost_node = cost_graph->add_node();
+ cost_node->set_name(node->name());
+ }
+ string device_name = properties.GetDeviceName(node->name());
+ cost_node->set_device(device_name);
+ cost_node->set_compute_cost(
+ node_costs.execution_time.asMicroSeconds().count());
+ cost_node->set_compute_time(
+ node_costs.compute_time.asMicroSeconds().count());
+ cost_node->set_memory_time(
+ node_costs.memory_time.asMicroSeconds().count());
+ std::vector<OpInfo::TensorProperties> outputs =
+ properties.GetOutputProperties(node->name());
+ for (const auto& output : outputs) {
+ auto output_info = cost_node->add_output_info();
+ output_info->set_dtype(output.dtype());
+ auto shape = output_info->mutable_shape();
+ *shape = output.shape();
+ }
+ }
+ } while (scheduler.MarkCurrNodeExecuted(node_costs));
+
+ *costs = scheduler.Summary();
+ VLOG(1) << inaccurate_nodes.size() << " out of "
+ << optimized_graph.node_size()
+ << " nodes have inaccurate time estimation";
+ for (const auto& node : inaccurate_nodes) {
+ VLOG(2) << "Node with inaccurate time estimation: " << node;
+ }
+ return Status::OK();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/tensorflow/core/grappler/costs/analytical_cost_estimator.h
new file mode 100644
index 0000000000..03e7faa4ff
--- /dev/null
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.h
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
+
+#include "tensorflow/core/grappler/costs/cost_estimator.h"
+#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+class CostGraphDef;
+class GraphDef;
+} // namespace tensorflow
+
+namespace tensorflow {
+namespace grappler {
+
+class Cluster;
+struct GrapplerItem;
+
+// Estimate the cost of running a Grappler item based on the theoretical
+// performance of the hardware that will run the model.
+class AnalyticalCostEstimator : public CostEstimator {
+ public:
+ // Does not take ownership of cluster.
+ explicit AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
+ ~AnalyticalCostEstimator() override {}
+
+ // Initalizes the estimator for the specified grappler item.
+ // This implementation always returns OK.
+ Status Initialize(const GrapplerItem& item) override;
+
+ // Predict the performance of each node of the optimized graph and annotate
+ // the CostGraphDef with the corresponding estimates. Also returns the
+ // expected latency for the whole graph.
+ Status PredictCosts(const GraphDef& optimized_graph, CostGraphDef* cost_graph,
+ Costs* overall_latency) const override;
+
+ private:
+ Cluster* cluster_; // Not owned.
+ GrapplerItem item_;
+ OpLevelCostEstimator node_estimator_;
+ bool use_static_shapes_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index baed7a8899..8549bfa118 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -80,7 +80,8 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo();
// Check if vector instructions are available, and refine performance
// prediction based on this.
- gflops = local_cpu.num_cores() * local_cpu.frequency();
+ // Frequencies are stored in MHz in the DeviceProperties.
+ gflops = local_cpu.num_cores() * local_cpu.frequency() * 1e-3;
if (bandwidth < 0) {
if (local_cpu.bandwidth() > 0) {
bandwidth = local_cpu.bandwidth() / 1e6;
@@ -105,7 +106,7 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
// Pascal.
cores_per_multiprocessor = 64;
}
- gflops = local_gpu.num_cores() * local_gpu.frequency() *
+ gflops = local_gpu.num_cores() * local_gpu.frequency() * 1e-3 *
cores_per_multiprocessor * kOpsPerMac;
if (bandwidth < 0) {
CHECK(local_gpu.bandwidth() > 0);
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index 0852cb4fd3..9447e56a7a 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -147,7 +147,7 @@ OpInfo::DeviceProperties GetLocalCPUInfo() {
// Combine cpu family and model into the model string.
device.set_model(
strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum()));
- device.set_frequency(port::NominalCPUFrequency() * 1e-9);
+ device.set_frequency(port::NominalCPUFrequency() * 1e-6);
device.set_num_cores(port::NumSchedulableCPUs());
device.set_l1_cache_size(Eigen::l1CacheSize());
device.set_l2_cache_size(Eigen::l2CacheSize());
@@ -175,7 +175,7 @@ OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) {
if (error == cudaSuccess) {
device.set_vendor("NVidia");
device.set_model(properties.name);
- device.set_frequency(properties.clockRate / 1000);
+ device.set_frequency(properties.clockRate * 1e-3);
device.set_num_cores(properties.multiProcessorCount);
device.set_num_registers(properties.regsPerMultiprocessor);
// For compute capability less than 5, l1 cache size is configurable to
diff --git a/tensorflow/core/grappler/costs/virtual_placer.cc b/tensorflow/core/grappler/costs/virtual_placer.cc
new file mode 100644
index 0000000000..adc640aaa4
--- /dev/null
+++ b/tensorflow/core/grappler/costs/virtual_placer.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/virtual_placer.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/devices.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+VirtualPlacer::VirtualPlacer(Cluster* cluster) : has_gpu_(false) {
+ devices_["CPU"] = GetLocalCPUInfo();
+ if (GetNumAvailableGPUs() > 0) {
+ has_gpu_ = true;
+ devices_["GPU"] = GetLocalGPUInfo(0);
+ }
+ unknown_device_.set_type("UNKNOWN");
+}
+
+const OpInfo::DeviceProperties& VirtualPlacer::get_device(
+ const NodeDef& node) const {
+ string device_type;
+ DeviceNameUtils::ParsedName parsed;
+ if (!node.device().empty() &&
+ DeviceNameUtils::ParseFullName(node.device(), &parsed)) {
+ device_type = parsed.type;
+ } else {
+ if (has_gpu_) {
+ device_type = "GPU";
+ } else {
+ device_type = "CPU";
+ }
+ }
+ auto it = devices_.find(device_type);
+ if (it == devices_.end()) {
+ return unknown_device_;
+ }
+ return it->second;
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/virtual_placer.h b/tensorflow/core/grappler/costs/virtual_placer.h
new file mode 100644
index 0000000000..812e94bf59
--- /dev/null
+++ b/tensorflow/core/grappler/costs/virtual_placer.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
+#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
+
+#include <unordered_map>
+#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+class NodeDef;
+
+namespace grappler {
+class Cluster;
+
+// The virtual placer emulates the behavior of the TF placer.
+class VirtualPlacer {
+ public:
+ VirtualPlacer(Cluster* cluster);
+
+ const OpInfo::DeviceProperties& get_device(const NodeDef& node) const;
+
+ private:
+ std::unordered_map<string, OpInfo::DeviceProperties> devices_;
+ bool has_gpu_;
+ OpInfo::DeviceProperties unknown_device_;
+};
+
+} // namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index 746fe63e2a..6a748d3462 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/core/kernels/crop_and_resize_op.h"
+#include <functional>
+#include <string>
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -26,10 +29,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@@ -37,41 +43,67 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+using Callback = std::function<void()>;
+
+namespace {
-static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
- const Tensor& boxes,
- const Tensor& box_ind,
- int* num_boxes) {
- if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
+static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
+ const Tensor& box_index,
+ int* num_boxes) {
+ if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
*num_boxes = 0;
- return;
+ return Status::OK();
}
// The shape of 'boxes' is [num_boxes, 4].
- OP_REQUIRES(context, boxes.dims() == 2,
- errors::InvalidArgument("boxes must be 2-D",
- boxes.shape().DebugString()));
+ if (boxes.dims() != 2) {
+ return errors::InvalidArgument("boxes must be 2-D",
+ boxes.shape().DebugString());
+ }
*num_boxes = boxes.dim_size(0);
- OP_REQUIRES(context, boxes.dim_size(1) == 4,
- errors::InvalidArgument("boxes must have 4 columns"));
-
- // The shape of 'box_ind' is [num_boxes].
- OP_REQUIRES(context, box_ind.dims() == 1,
- errors::InvalidArgument("box_ind must be 1-D",
- box_ind.shape().DebugString()));
- OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes,
- errors::InvalidArgument("box_ind has incompatible shape"));
+ if (boxes.dim_size(1) != 4) {
+ return errors::InvalidArgument("boxes must have 4 columns");
+ }
+ // The shape of 'box_index' is [num_boxes].
+ if (box_index.dims() != 1) {
+ return errors::InvalidArgument("box_index must be 1-D",
+ box_index.shape().DebugString());
+ }
+ if (box_index.dim_size(0) != *num_boxes) {
+ return errors::InvalidArgument("box_index has incompatible shape");
+ }
+ return Status::OK();
}
-// Verifies that all values in box_ind are in [0, batch).
+// Conditionally calls the compute callback if all values in box_index are in
+// [0, batch_size) then calls done.
template <typename Device>
-inline void CheckValidBoxInd(
- OpKernelContext* context,
- typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch);
+inline void RunIfBoxIndexIsValid(
+ OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
+ int batch_size, Callback compute, Callback done);
+
+// Specialization of CheckValidBoxIndex for a CPUDevice.
+template <>
+inline void RunIfBoxIndexIsValid<CPUDevice>(
+ OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
+ int batch_size, Callback compute, Callback done) {
+ const int num_boxes = box_index.dimension(0);
+ for (int b = 0; b < num_boxes; ++b) {
+ OP_REQUIRES_ASYNC(
+ context, FastBoundsCheck(box_index(b), batch_size),
+ errors::OutOfRange("box_index has values outside [0, batch_size)"),
+ done);
+ }
+ compute();
+ done();
+}
+
+} // namespace
template <typename Device, typename T>
-class CropAndResizeOp : public OpKernel {
+class CropAndResizeOp : public AsyncOpKernel {
public:
- explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit CropAndResizeOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
@@ -80,69 +112,77 @@ class CropAndResizeOp : public OpKernel {
&extrapolation_value_));
}
- void Compute(OpKernelContext* context) override {
- // The shape of 'image' is [batch, image_height, image_width, channels].
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ // The shape of 'image' is [batch_size, image_height, image_width,
+ // channels].
const Tensor& image = context->input(0);
- OP_REQUIRES(context, image.dims() == 4,
- errors::InvalidArgument("input image must be 4-D",
- image.shape().DebugString()));
-
- const int batch = image.dim_size(0);
- const int image_height = image.dim_size(1);
- const int image_width = image.dim_size(2);
- const int depth = image.dim_size(3);
- OP_REQUIRES(context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"));
-
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1);
-
- // The shape of 'box_ind' is [num_boxes].
- const Tensor& box_ind = context->input(2);
-
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
-
+ // The shape of 'box_index' is [num_boxes].
+ const Tensor& box_index = context->input(2);
// The shape of 'crop_size' is [2].
const Tensor& crop_size = context->input(3);
- OP_REQUIRES(context, crop_size.dims() == 1,
- errors::InvalidArgument("crop_size must be 1-D",
- crop_size.shape().DebugString()));
- OP_REQUIRES(context, crop_size.dim_size(0) == 2,
- errors::InvalidArgument("crop_size must have two elements",
- crop_size.shape().DebugString()));
-
+ // Validate inputs dimensions.
+ OP_REQUIRES_ASYNC(context, image.dims() == 4,
+ errors::InvalidArgument("input image must be 4-D",
+ image.shape().DebugString()),
+ done);
+ const int batch_size = image.dim_size(0);
+ const int image_height = image.dim_size(1);
+ const int image_width = image.dim_size(2);
+ const int depth = image.dim_size(3);
+ OP_REQUIRES_ASYNC(
+ context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"), done);
+ int num_boxes = 0;
+ OP_REQUIRES_OK_ASYNC(
+ context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
+
+ OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
+ errors::InvalidArgument("crop_size must be 1-D",
+ crop_size.shape().DebugString()),
+ done);
+ OP_REQUIRES_ASYNC(
+ context, crop_size.dim_size(0) == 2,
+ errors::InvalidArgument("crop_size must have two elements",
+ crop_size.shape().DebugString()),
+ done);
+
+ // Copy and validate crop sizes.
auto crop_size_vec = crop_size.vec<int32>();
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
- OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("crop dimensions must be positive"));
+ OP_REQUIRES_ASYNC(
+ context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("crop dimensions must be positive"), done);
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK(
+ OP_REQUIRES_OK_ASYNC(
context,
context->allocate_output(
0, TensorShape({num_boxes, crop_height, crop_width, depth}),
- &output));
-
- typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
- typename TTypes<float, 2>::ConstTensor boxes_data =
- boxes.tensor<float, 2>();
- typename TTypes<int32, 1>::ConstTensor box_ind_data =
- box_ind.tensor<int32, 1>();
- typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>();
-
- CheckValidBoxInd<Device>(context, box_ind_data, batch);
-
- bool status = functor::CropAndResize<Device, T>()(
- context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
- extrapolation_value_, crops_data);
- if (!status) {
- context->SetStatus(
- errors::Internal("Failed launch CropAndResizeKernel."));
- }
+ &output),
+ done);
+
+ auto compute_callback = [this, context, output]() {
+ const Tensor& image = context->input(0);
+ const Tensor& boxes = context->input(1);
+ const Tensor& box_index = context->input(2);
+ const bool status = functor::CropAndResize<Device, T>()(
+ context->eigen_device<Device>(), image.tensor<T, 4>(),
+ boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
+ extrapolation_value_, output->tensor<float, 4>());
+ if (!status) {
+ context->SetStatus(
+ errors::Internal("Failed launch CropAndResizeKernel."));
+ }
+ };
+
+ RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
+ batch_size, std::move(compute_callback),
+ std::move(done));
}
private:
@@ -155,10 +195,10 @@ template <typename T>
struct CropAndResize<CPUDevice, T> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_ind,
+ typename TTypes<int32, 1>::ConstTensor box_index,
float extrapolation_value,
typename TTypes<float, 4>::Tensor crops) {
- const int batch = image.dimension(0);
+ const int batch_size = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -173,8 +213,8 @@ struct CropAndResize<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_ind(b);
- if (b_in < 0 || b_in >= batch) {
+ const int32 b_in = box_index(b);
+ if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@@ -235,89 +275,94 @@ struct CropAndResize<CPUDevice, T> {
return true;
}
};
+
} // namespace functor
template <typename Device, typename T>
-class CropAndResizeGradImageOp : public OpKernel {
+class CropAndResizeGradImageOp : public AsyncOpKernel {
public:
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
- : OpKernel(context) {
+ : AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method));
}
- void Compute(OpKernelContext* context) override {
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
-
- OP_REQUIRES(context, grads.dims() == 4,
- errors::InvalidArgument("grads image must be 4-D",
- grads.shape().DebugString()));
- const int crop_height = grads.dim_size(1);
- const int crop_width = grads.dim_size(2);
- OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("grads dimensions must be positive"));
-
// The shape of 'boxes' is [num_boxes, 4].
const Tensor& boxes = context->input(1);
-
- // The shape of 'box_ind' is [num_boxes].
- const Tensor& box_ind = context->input(2);
-
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
-
- OP_REQUIRES(
- context, grads.dim_size(0) == num_boxes,
- errors::InvalidArgument("boxes and grads have incompatible shape"));
-
+ // The shape of 'box_index' is [num_boxes].
+ const Tensor& box_index = context->input(2);
// The shape of 'image_size' is [4].
const Tensor& image_size = context->input(3);
- OP_REQUIRES(context, image_size.dims() == 1,
- errors::InvalidArgument("image_size must be 1-D",
- image_size.shape().DebugString()));
- OP_REQUIRES(context, image_size.dim_size(0) == 4,
- errors::InvalidArgument("image_size must have 4 elements",
- image_size.shape().DebugString()));
+ // Validate input shapes.
+ OP_REQUIRES_ASYNC(context, grads.dims() == 4,
+ errors::InvalidArgument("grads image must be 4-D",
+ grads.shape().DebugString()),
+ done);
+ const int crop_height = grads.dim_size(1);
+ const int crop_width = grads.dim_size(2);
+ OP_REQUIRES_ASYNC(
+ context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("grads dimensions must be positive"), done);
+ int num_boxes = 0;
+ OP_REQUIRES_OK_ASYNC(
+ context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
+ OP_REQUIRES_ASYNC(
+ context, grads.dim_size(0) == num_boxes,
+ errors::InvalidArgument("boxes and grads have incompatible shape"),
+ done);
+
+ OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
+ errors::InvalidArgument("image_size must be 1-D",
+ image_size.shape().DebugString()),
+ done);
+ OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
+ errors::InvalidArgument("image_size must have 4 elements",
+ image_size.shape().DebugString()),
+ done);
auto image_size_vec = image_size.vec<int32>();
- const int batch = internal::SubtleMustCopy(image_size_vec(0));
+ const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
const int image_height = internal::SubtleMustCopy(image_size_vec(1));
const int image_width = internal::SubtleMustCopy(image_size_vec(2));
const int depth = internal::SubtleMustCopy(image_size_vec(3));
-
- OP_REQUIRES(context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"));
- OP_REQUIRES(
+ OP_REQUIRES_ASYNC(
+ context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"), done);
+ OP_REQUIRES_ASYNC(
context, grads.dim_size(3) == depth,
- errors::InvalidArgument("image_size and grads are incompatible"));
+ errors::InvalidArgument("image_size and grads are incompatible"), done);
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output(
- 0, TensorShape({batch, image_height, image_width, depth}),
- &output));
-
- typename TTypes<float, 4>::ConstTensor grads_data =
- grads.tensor<float, 4>();
- typename TTypes<float, 2>::ConstTensor boxes_data =
- boxes.tensor<float, 2>();
- typename TTypes<int32, 1>::ConstTensor box_ind_data =
- box_ind.tensor<int32, 1>();
- typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
-
- CheckValidBoxInd<Device>(context, box_ind_data, batch);
-
- bool status = functor::CropAndResizeBackpropImage<Device, T>()(
- context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
- output_data);
- if (!status) {
- context->SetStatus(
- errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
- }
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_output(
+ 0, TensorShape({batch_size, image_height, image_width, depth}),
+ &output),
+ done);
+
+ auto compute_callback = [context, output]() {
+ const Tensor& grads = context->input(0);
+ const Tensor& boxes = context->input(1);
+ const Tensor& box_index = context->input(2);
+ const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
+ context->eigen_device<Device>(), grads.tensor<float, 4>(),
+ boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
+ output->tensor<T, 4>());
+ if (!status) {
+ context->SetStatus(errors::Internal(
+ "Failed launch CropAndResizeBackpropImage kernel."));
+ }
+ };
+
+ RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
+ batch_size, std::move(compute_callback),
+ std::move(done));
}
};
@@ -328,9 +373,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
bool operator()(const CPUDevice& d,
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_ind,
+ typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<T, 4>::Tensor grads_image) {
- const int batch = grads_image.dimension(0);
+ const int batch_size = grads_image.dimension(0);
const int image_height = grads_image.dimension(1);
const int image_width = grads_image.dimension(2);
@@ -347,8 +392,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_ind(b);
- if (b_in < 0 || b_in >= batch) {
+ const int32 b_in = box_index(b);
+ if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@@ -399,83 +444,90 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
return true;
}
};
+
} // namespace functor
template <typename Device, typename T>
-class CropAndResizeGradBoxesOp : public OpKernel {
+class CropAndResizeGradBoxesOp : public AsyncOpKernel {
public:
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
- : OpKernel(context) {
+ : AsyncOpKernel(context) {
string method;
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
OP_REQUIRES(context, method == "bilinear",
errors::InvalidArgument("method must be 'bilinear'", method));
}
- void Compute(OpKernelContext* context) override {
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
const Tensor& grads = context->input(0);
+ // The shape of 'boxes' is [num_boxes, 4].
+ const Tensor& boxes = context->input(2);
+ // The shape of 'box_index' is [num_boxes].
+ const Tensor& box_index = context->input(3);
+ // The shape of 'image' is [batch_size, image_height, image_width, depth].
+ const Tensor& image = context->input(1);
- OP_REQUIRES(context, grads.dims() == 4,
- errors::InvalidArgument("grads image must be 4-D",
- grads.shape().DebugString()));
-
+ // Validate input shapes.
+ OP_REQUIRES_ASYNC(context, grads.dims() == 4,
+ errors::InvalidArgument("grads image must be 4-D",
+ grads.shape().DebugString()),
+ done);
const int crop_height = grads.dim_size(1);
const int crop_width = grads.dim_size(2);
const int depth = grads.dim_size(3);
- OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
- errors::InvalidArgument("grads dimensions must be positive"));
-
- // The shape of 'image' is [batch, image_height, image_width, depth].
- const Tensor& image = context->input(1);
- OP_REQUIRES(context, image.dims() == 4,
- errors::InvalidArgument("input image must be 4-D",
- image.shape().DebugString()));
-
- const int batch = image.dim_size(0);
+ OP_REQUIRES_ASYNC(
+ context, crop_height > 0 && crop_width > 0,
+ errors::InvalidArgument("grads dimensions must be positive"), done);
+
+ OP_REQUIRES_ASYNC(context, image.dims() == 4,
+ errors::InvalidArgument("input image must be 4-D",
+ image.shape().DebugString()),
+ done);
+ const int batch_size = image.dim_size(0);
const int image_height = image.dim_size(1);
const int image_width = image.dim_size(2);
- OP_REQUIRES(context, image_height > 0 && image_width > 0,
- errors::InvalidArgument("image dimensions must be positive"));
- OP_REQUIRES(context, image.dim_size(3) == depth,
- errors::InvalidArgument("image, grads depth differ"));
-
- // The shape of 'boxes' is [num_boxes, 4].
- const Tensor& boxes = context->input(2);
-
- // The shape of 'box_ind' is [num_boxes].
- const Tensor& box_ind = context->input(3);
+ OP_REQUIRES_ASYNC(
+ context, image_height > 0 && image_width > 0,
+ errors::InvalidArgument("image dimensions must be positive"), done);
+ OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
+ errors::InvalidArgument("image, grads depth differ"),
+ done);
int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
+ OP_REQUIRES_OK_ASYNC(
+ context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
- OP_REQUIRES(
+ OP_REQUIRES_ASYNC(
context, grads.dim_size(0) == num_boxes,
- errors::InvalidArgument("boxes and grads have incompatible shape"));
+ errors::InvalidArgument("boxes and grads have incompatible shape"),
+ done);
// Allocate output tensor.
Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- 0, TensorShape({num_boxes, 4}), &output));
-
- typename TTypes<float, 4>::ConstTensor grads_data =
- grads.tensor<float, 4>();
- typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
- typename TTypes<float, 2>::ConstTensor boxes_data =
- boxes.tensor<float, 2>();
- typename TTypes<int32, 1>::ConstTensor box_ind_data =
- box_ind.tensor<int32, 1>();
- typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>();
-
- CheckValidBoxInd<Device>(context, box_ind_data, batch);
-
- bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
- context->eigen_device<Device>(), grads_data, image_data, boxes_data,
- box_ind_data, output_data);
- if (!status) {
- context->SetStatus(
- errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
- }
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
+ done);
+
+ auto compute_callback = [context, output]() {
+ const Tensor& grads = context->input(0);
+ const Tensor& image = context->input(1);
+ const Tensor& boxes = context->input(2);
+ const Tensor& box_index = context->input(3);
+ const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
+ context->eigen_device<Device>(), grads.tensor<float, 4>(),
+ image.tensor<T, 4>(), boxes.tensor<float, 2>(),
+ box_index.tensor<int32, 1>(), output->tensor<float, 2>());
+ if (!status) {
+ context->SetStatus(errors::Internal(
+ "Failed launch CropAndResizeBackpropBoxes kernel."));
+ }
+ };
+
+ RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
+ batch_size, std::move(compute_callback),
+ std::move(done));
}
};
@@ -487,9 +539,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor grads,
typename TTypes<T, 4>::ConstTensor image,
typename TTypes<float, 2>::ConstTensor boxes,
- typename TTypes<int32, 1>::ConstTensor box_ind,
+ typename TTypes<int32, 1>::ConstTensor box_index,
typename TTypes<float, 2>::Tensor grads_boxes) {
- const int batch = image.dimension(0);
+ const int batch_size = image.dimension(0);
const int image_height = image.dimension(1);
const int image_width = image.dimension(2);
@@ -506,8 +558,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
const float y2 = boxes(b, 2);
const float x2 = boxes(b, 3);
- const int32 b_in = box_ind(b);
- if (b_in < 0 || b_in >= batch) {
+ const int32 b_in = box_index(b);
+ if (!FastBoundsCheck(b_in, batch_size)) {
continue;
}
@@ -589,30 +641,19 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
return true;
}
};
-} // namespace functor
-// Specialization of CheckValidBoxInd for a CPUDevice.
-template <>
-inline void CheckValidBoxInd<CPUDevice>(
- OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
- int batch) {
- const int num_boxes = box_ind.dimension(0);
- for (int b = 0; b < num_boxes; ++b) {
- OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch,
- errors::OutOfRange("box_ind has values outside [0, batch)"));
- }
-}
+} // namespace functor
-#define REGISTER_KERNEL(T) \
- REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .HostMemory("crop_size"), \
- CropAndResizeOp<CPUDevice, T>); \
- \
- REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T"), \
+#define REGISTER_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("crop_size"), \
+ CropAndResizeOp<CPUDevice, T>); \
+ \
+ REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
CropAndResizeGradBoxesOp<CPUDevice, T>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
@@ -634,50 +675,86 @@ TF_CALL_double(REGISTER_KERNEL);
#if GOOGLE_CUDA
-// Forward declaration of the CheckValidBoxIndHelper specialization for GPU.
+// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
namespace functor {
template <>
-void CheckValidBoxIndHelper<GPUDevice>::operator()(
- const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind,
- int batch, typename TTypes<bool, 0>::Tensor isvalid);
-extern template struct CheckValidBoxIndHelper<GPUDevice>;
+void CheckValidBoxIndexHelper<GPUDevice>::operator()(
+ const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
+ int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
+extern template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor
-// Specialization of CheckValidBoxInd for a GPUDevice.
+namespace {
+
+// Specialization of CheckValidBoxIndex for a GPUDevice.
template <>
-inline void CheckValidBoxInd<GPUDevice>(
- OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
- int batch) {
- const int num_boxes = box_ind.dimension(0);
+inline void RunIfBoxIndexIsValid<GPUDevice>(
+ OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
+ int batch_size, Callback compute, Callback done) {
+ const int num_boxes = box_index.dimension(0);
if (num_boxes == 0) {
+ compute();
+ done();
return;
}
- Tensor isvalid_tensor;
- OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<bool>::value,
- TensorShape({}), &isvalid_tensor));
- typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>();
+ Tensor isvalid_dev_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
+ &isvalid_dev_tensor),
+ done);
+ typename TTypes<bool, 0>::Tensor isvalid_dev =
+ isvalid_dev_tensor.tensor<bool, 0>();
- functor::CheckValidBoxIndHelper<GPUDevice>()(
- context->eigen_device<GPUDevice>(), box_ind, batch, isvalid);
+ // Run the actual box check on the device.
+ functor::CheckValidBoxIndexHelper<GPUDevice>()(
+ context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
+ // Copy the result back to the host.
auto* stream = context->op_device_context()->stream();
- OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
-
- bool isvalid_host = false;
- perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(),
- sizeof(bool));
- stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool));
- stream->BlockHostUntilDone();
-
- OP_REQUIRES(context, stream->ok(),
- errors::Internal("cudaMemcpy from device to host failed"));
-
- OP_REQUIRES(context, isvalid_host,
- errors::OutOfRange("box_ind has values outside [0, batch)"));
+ OP_REQUIRES_ASYNC(context, stream,
+ errors::Internal("No GPU stream available."), done);
+ Tensor isvalid_host_tensor;
+ // Use pinned host memory on the host to avoid unnecessary
+ // synchronization.
+ AllocatorAttributes alloc_attr;
+ alloc_attr.set_on_host(true);
+ alloc_attr.set_gpu_compatible(true);
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
+ &isvalid_host_tensor, alloc_attr),
+ done);
+ perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
+ sizeof(bool));
+ const bool status =
+ stream
+ ->ThenMemcpy(
+ isvalid_host_tensor.scalar<bool>().data() /* destination */,
+ wrapped /* source */, sizeof(bool))
+ .ok();
+ OP_REQUIRES_ASYNC(
+ context, status,
+ errors::Internal("Failed to launch copy of isvalid from device to host."),
+ done);
+
+ auto wrapped_callback = [context, isvalid_host_tensor, compute, done]() {
+ const bool isvalid = isvalid_host_tensor.scalar<bool>()();
+ OP_REQUIRES_ASYNC(
+ context, isvalid,
+ errors::OutOfRange("box_index has values outside [0, batch_size)"),
+ done);
+ compute();
+ done();
+ };
+
+ context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
+ stream, wrapped_callback);
}
+} // namespace
+
#define REGISTER_KERNEL(T) \
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
.Device(DEVICE_GPU) \
diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h
index 22df1bdd56..460dbad22b 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.h
+++ b/tensorflow/core/kernels/crop_and_resize_op.h
@@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes {
};
template <typename Device>
-struct CheckValidBoxIndHelper {
- // Checks if all values in box_ind are in [0, batch).
+struct CheckValidBoxIndexHelper {
+ // Checks if all values in box_index are in [0, batch).
void operator()(const Device& d,
- typename TTypes<int32, 1>::ConstTensor box_ind, int batch,
+ typename TTypes<int32, 1>::ConstTensor box_index, int batch,
typename TTypes<bool, 0>::Tensor isvalid) {
- isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all();
+ isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all();
}
};
diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
index 254475db46..c1235fda89 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
@@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
-template struct CheckValidBoxIndHelper<GPUDevice>;
+template struct CheckValidBoxIndexHelper<GPUDevice>;
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc
index 3a7f180598..d6139dae96 100644
--- a/tensorflow/core/kernels/crop_and_resize_op_test.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc
@@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(
- StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
+ StringPiece(s.ToString()).contains("box_index has incompatible shape"))
<< s;
}
@@ -264,8 +264,10 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
Status s = RunOpKernel();
ASSERT_FALSE(s.ok());
EXPECT_TRUE(StringPiece(s.ToString())
- .contains("box_ind has values outside [0, batch)"))
+ .contains("box_index has values outside [0, batch_size)"))
<< s;
}
+// TODO(zhengxq, rmlarsen): Add a benchmark.
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_atan2.cc b/tensorflow/core/kernels/cwise_op_atan2.cc
index 5232737857..68f67c444e 100644
--- a/tensorflow/core/kernels/cwise_op_atan2.cc
+++ b/tensorflow/core/kernels/cwise_op_atan2.cc
@@ -20,4 +20,4 @@ REGISTER2(BinaryOp, CPU, "Atan2", functor::atan2, float, double);
#if GOOGLE_CUDA
REGISTER2(BinaryOp, GPU, "Atan2", functor::atan2, float, double);
#endif
-} // namespace tensorflow \ No newline at end of file
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_atan2.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_atan2.cu.cc
index 0f327eaf6c..137e14ef84 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_atan2.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_atan2.cu.cc
@@ -23,4 +23,4 @@ DEFINE_BINARY2(atan2, float, double);
} // namespace functor
} // namespace tensorflow
-#endif // GOOGLE_CUDA \ No newline at end of file
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc
index 2a27476b09..155d9d1084 100644
--- a/tensorflow/core/kernels/linalg_ops_common.cc
+++ b/tensorflow/core/kernels/linalg_ops_common.cc
@@ -155,7 +155,8 @@ void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
const int col_dimension = input_rank - 1;
const int64 num_rows = in.dim_size(row_dimension);
const int64 num_cols = in.dim_size(col_dimension);
- input_matrix_shapes->emplace_back(std::initializer_list<int64>({num_rows, num_cols}));
+ input_matrix_shapes->emplace_back(
+ std::initializer_list<int64>({num_rows, num_cols}));
inputs->emplace_back(&in);
}
// Have the derived class validate that the inputs are as expected.
@@ -233,8 +234,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
matrix_inputs.emplace_back(
inputs[i]->flat<Scalar>().data() +
matrix_index * input_matrix_shapes[i].num_elements(),
- input_matrix_shapes[i].dim_size(0),
- input_matrix_shapes[i].dim_size(1));
+ input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
}
MatrixMaps matrix_outputs;
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 8856a0faaf..bc40715acc 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -1717,6 +1717,31 @@ op {
}
}
op {
+ name: "Atan2"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "AudioSpectrogram"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index b452ab7ef5..dd0a339609 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -1905,6 +1905,33 @@ op {
summary: "Computes atan of x element-wise."
}
op {
+ name: "Atan2"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ summary: "Computes arctangent of `y/x` element-wise, respecting signs of the arguments."
+ description: "This is the angle \\( \\theta \\in [-\\pi, \\pi] \\) such that\n\\[ x = r \\cos(\\theta) \\]\nand\n\\[ y = r \\sin(\\theta) \\]\nwhere \\(r = \\sqrt(x^2 + y^2) \\)."
+}
+op {
name: "AudioSpectrogram"
input_arg {
name: "input"
diff --git a/tensorflow/docs_src/api_guides/python/contrib.distributions.bijector.md b/tensorflow/docs_src/api_guides/python/contrib.distributions.bijector.md
deleted file mode 100644
index 16a47bfd8b..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.distributions.bijector.md
+++ /dev/null
@@ -1,33 +0,0 @@
-# Random variable transformations (contrib)
-[TOC]
-
-Bijector Ops.
-
-An API for invertible, differentiable transformations of random variables.
-
-## Background
-
-Differentiable, bijective transformations of continuous random variables alter
-the calculations made in the cumulative/probability distribution functions and
-sample function. This module provides a standard interface for making these
-manipulations.
-
-For more details and examples, see the `Bijector` docstring.
-
-To apply a `Bijector`, use `distributions.TransformedDistribution`.
-
-## Bijectors
-
-* @{tf.contrib.distributions.bijector.Affine}
-* @{tf.contrib.distributions.bijector.AffineLinearOperator}
-* @{tf.contrib.distributions.bijector.Bijector}
-* @{tf.contrib.distributions.bijector.Chain}
-* @{tf.contrib.distributions.bijector.CholeskyOuterProduct}
-* @{tf.contrib.distributions.bijector.Exp}
-* @{tf.contrib.distributions.bijector.Identity}
-* @{tf.contrib.distributions.bijector.Inline}
-* @{tf.contrib.distributions.bijector.Invert}
-* @{tf.contrib.distributions.bijector.PowerTransform}
-* @{tf.contrib.distributions.bijector.SigmoidCentered}
-* @{tf.contrib.distributions.bijector.SoftmaxCentered}
-* @{tf.contrib.distributions.bijector.Softplus}
diff --git a/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md b/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md
new file mode 100644
index 0000000000..0ce187b329
--- /dev/null
+++ b/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md
@@ -0,0 +1,33 @@
+# Random variable transformations (contrib)
+[TOC]
+
+Bijector Ops.
+
+An API for invertible, differentiable transformations of random variables.
+
+## Background
+
+Differentiable, bijective transformations of continuous random variables alter
+the calculations made in the cumulative/probability distribution functions and
+sample function. This module provides a standard interface for making these
+manipulations.
+
+For more details and examples, see the `Bijector` docstring.
+
+To apply a `Bijector`, use `distributions.TransformedDistribution`.
+
+## Bijectors
+
+* @{tf.contrib.distributions.bijectors.Affine}
+* @{tf.contrib.distributions.bijectors.AffineLinearOperator}
+* @{tf.contrib.distributions.bijectors.Bijector}
+* @{tf.contrib.distributions.bijectors.Chain}
+* @{tf.contrib.distributions.bijectors.CholeskyOuterProduct}
+* @{tf.contrib.distributions.bijectors.Exp}
+* @{tf.contrib.distributions.bijectors.Identity}
+* @{tf.contrib.distributions.bijectors.Inline}
+* @{tf.contrib.distributions.bijectors.Invert}
+* @{tf.contrib.distributions.bijectors.PowerTransform}
+* @{tf.contrib.distributions.bijectors.SigmoidCentered}
+* @{tf.contrib.distributions.bijectors.SoftmaxCentered}
+* @{tf.contrib.distributions.bijectors.Softplus}
diff --git a/tensorflow/docs_src/api_guides/python/contrib.distributions.md b/tensorflow/docs_src/api_guides/python/contrib.distributions.md
index 2b43e1281d..7a3d509b75 100644
--- a/tensorflow/docs_src/api_guides/python/contrib.distributions.md
+++ b/tensorflow/docs_src/api_guides/python/contrib.distributions.md
@@ -76,7 +76,7 @@ representing the posterior or posterior predictive.
## Kullback-Leibler Divergence
-* @{tf.contrib.distributions.kl}
+* @{tf.contrib.distributions.kl_divergence}
* @{tf.contrib.distributions.RegisterKL}
## Utilities
diff --git a/tensorflow/docs_src/api_guides/python/index.md b/tensorflow/docs_src/api_guides/python/index.md
index 177f19bc80..19d50926d8 100644
--- a/tensorflow/docs_src/api_guides/python/index.md
+++ b/tensorflow/docs_src/api_guides/python/index.md
@@ -40,7 +40,7 @@
* [Losses (contrib)](contrib.losses.md)
* [Metrics (contrib)](contrib.metrics.md)
* [Optimization (contrib)](contrib.opt.md)
-* [Random variable transformations (contrib)](contrib.distributions.bijector.md)
+* [Random variable transformations (contrib)](contrib.distributions.bijectors.md)
* [RNN and Cells (contrib)](contrib.rnn.md)
* [Seq2seq Library (contrib)](contrib.seq2seq.md)
* [Staging (contrib)](contrib.staging.md)
diff --git a/tensorflow/docs_src/performance/benchmarks.md b/tensorflow/docs_src/performance/benchmarks.md
index 19d37794ab..6bbc98ac0d 100644
--- a/tensorflow/docs_src/performance/benchmarks.md
+++ b/tensorflow/docs_src/performance/benchmarks.md
@@ -80,10 +80,12 @@ section.
* **OS:** Ubuntu 16.04 LTS with tests run via Docker
* **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e
+* **Benchmark GitHub hash:** 9165a70
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package`
* **Disk:** Local SSD
* **DataSet:** ImageNet
+* **Test Date:** May 2017
Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3, ResNet-50,
@@ -120,19 +122,19 @@ VGG16 | replicated (with NCCL) | n/a
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
- 1 | 142 | 238 | 95.6 | 2987 | 154
- 2 | 284 | 479 | 187 | 5658 | 295
- 4 | 569 | 948 | 374 | 10509 | 584
- 8 | 1131 | 1886 | 744 | 17822 | 1081
+1 | 142 | 219 | 91.8 | 2987 | 154
+2 | 284 | 422 | 181 | 5658 | 295
+4 | 569 | 852 | 356 | 10509 | 584
+8 | 1131 | 1734 | 716 | 17822 | 1081
**Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
- 1 | 142 | 239 | 95.5 | 2890 | 154
- 2 | 278 | 468 | 187 | 4448 | 284
- 4 | 551 | 938 | 373 | 7105 | 534
- 8 | 1079 | 1802 | 721 | N/A | 898
+1 | 142 | 218 | 91.4 | 2890 | 154
+2 | 278 | 425 | 179 | 4448 | 284
+4 | 551 | 853 | 359 | 7105 | 534
+8 | 1079 | 1630 | 708 | N/A | 898
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to it maxing out the input pipeline.
@@ -145,19 +147,19 @@ The results below are all with a batch size of 32.
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
---- | ----------- | --------- | ---------- | -----
- 1 | 128 | 210 | 85.3 | 144
- 2 | 259 | 412 | 166 | 281
- 4 | 520 | 827 | 330 | 549
- 8 | 995 | 1623 | 643 | 820
+1 | 128 | 195 | 82.7 | 144
+2 | 259 | 368 | 160 | 281
+4 | 520 | 768 | 317 | 549
+8 | 995 | 1485 | 632 | 820
**Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
---- | ----------- | --------- | ---------- | -----
- 1 | 130 | 208 | 85.0 | 144
- 2 | 257 | 403 | 163 | 253
- 4 | 507 | 814 | 325 | 457
- 8 | 966 | 1525 | 641 | 690
+1 | 130 | 193 | 82.4 | 144
+2 | 257 | 369 | 159 | 253
+4 | 507 | 760 | 317 | 457
+8 | 966 | 1410 | 609 | 690
## Details for Google Compute Engine (NVIDIA® Tesla® K80)
@@ -168,11 +170,12 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
* **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e
+* **Benchmark GitHub hash:** 9165a70
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s)
* **DataSet:** ImageNet
-* **Test Date:** April 2017
+* **Test Date:** May 2017
Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
@@ -198,19 +201,19 @@ The configuration used for each model was `variable_update` equal to
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
- 1 | 30.5 | 56.8 | 20.8 | 656 | 35.4
- 2 | 57.8 | 107 | 39.1 | 1209 | 64.8
- 4 | 116 | 212 | 77.2 | 2328 | 120
- 8 | 227 | 419 | 151 | 4640 | 234
+1 | 30.5 | 51.9 | 20.0 | 656 | 35.4
+2 | 57.8 | 99.0 | 38.2 | 1209 | 64.8
+4 | 116 | 195 | 75.8 | 2328 | 120
+8 | 227 | 387 | 148 | 4640 | 234
**Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
- 1 | 30.6 | 56.7 | 20.7 | 639 | 34.2
- 2 | 58.4 | 107 | 39.0 | 1136 | 62.9
- 4 | 115 | 211 | 77.3 | 2067 | 118
- 8 | 225 | 422 | 151 | 4056 | 230
+1 | 30.6 | 51.2 | 20.0 | 639 | 34.2
+2 | 58.4 | 98.8 | 38.3 | 1136 | 62.9
+4 | 115 | 194 | 75.4 | 2067 | 118
+8 | 225 | 381 | 148 | 4056 | 230
### Other Results
@@ -218,19 +221,19 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
-1 | 29.3 | 53.9
-2 | 55.0 | 101
-4 | 109 | 200
-8 | 216 | 398
+1 | 29.3 | 49.5
+2 | 55.0 | 95.4
+4 | 109 | 183
+8 | 216 | 362
**Training real data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
- 1 | 29.5 | 53.6
- 2 | 55.4 | 102
- 4 | 110 | 201
- 8 | 216 | 387
+1 | 29.5 | 49.3
+2 | 55.4 | 95.3
+4 | 110 | 186
+8 | 216 | 359
## Details for Amazon EC2 (NVIDIA® Tesla® K80)
@@ -241,12 +244,13 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
* **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e
+* **Benchmark GitHub hash:** 9165a70
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50
MiB/sec)
* **DataSet:** ImageNet
-* **Test Date:** April 2017
+* **Test Date:** May 2017
Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
@@ -279,19 +283,19 @@ VGG16 | parameter_server | gpu
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
- 1 | 30.8 | 56.3 | 20.9 | 684 | 36.3
- 2 | 58.7 | 108 | 39.3 | 1244 | 69.4
- 4 | 117 | 217 | 79.1 | 2479 | 141
- 8 | 230 | 419 | 156 | 4853 | 260
+1 | 30.8 | 51.5 | 19.7 | 684 | 36.3
+2 | 58.7 | 98.0 | 37.6 | 1244 | 69.4
+4 | 117 | 195 | 74.9 | 2479 | 141
+8 | 230 | 384 | 149 | 4853 | 260
**Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
- 1 | 30.5 | 56.0 | 20.6 | 674 | 36.3
- 2 | 59.0 | 107 | 39.0 | 1227 | 67.5
- 4 | 118 | 205 | 77.9 | 2201 | 136
- 8 | 228 | 405 | 152 | N/A | 242
+1 | 30.5 | 51.3 | 19.7 | 674 | 36.3
+2 | 59.0 | 94.9 | 38.2 | 1227 | 67.5
+4 | 118 | 188 | 75.2 | 2201 | 136
+8 | 228 | 373 | 149 | N/A | 242
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to our EFS setup not providing enough throughput.
@@ -302,19 +306,19 @@ above due to our EFS setup not providing enough throughput.
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
-1 | 29.9 | 53.5
-2 | 57.5 | 101
-4 | 114 | 202
-8 | 216 | 380
+1 | 29.9 | 49.0
+2 | 57.5 | 94.1
+4 | 114 | 184
+8 | 216 | 355
**Training real data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
-1 | 30.0 | 53.6
-2 | 57.5 | 102
-4 | 113 | 202
-8 | 212 | 379
+1 | 30.0 | 49.1
+2 | 57.5 | 95.1
+4 | 113 | 185
+8 | 212 | 353
## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80)
@@ -325,11 +329,12 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
* **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e
+* **Benchmark GitHub hash:** 9165a70
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec)
* **DataSet:** ImageNet
-* **Test Date:** April 2017
+* **Test Date:** May 2017
The batch size and optimizer used for the tests are listed in the table. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
@@ -343,11 +348,11 @@ Optimizer | sgd | sgd | sgd
Configuration used for each model.
-Model | variable_update | local_parameter_device
------------ | ---------------------- | ----------------------
-InceptionV3 | distributed_replicated | n/a
-ResNet-50 | distributed_replicated | n/a
-ResNet-152 | distributed_replicated | n/a
+Model | variable_update | local_parameter_device | cross_replica_sync
+----------- | ---------------------- | ---------------------- | ------------------
+InceptionV3 | distributed_replicated | n/a | True
+ResNet-50 | distributed_replicated | n/a | True
+ResNet-152 | distributed_replicated | n/a | True
To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also
ran parameter servers. Equal numbers of parameter servers and work servers were
@@ -371,11 +376,11 @@ used with the following exceptions:
GPUs | InceptionV3 | ResNet-50 | ResNet-152
---- | ----------- | --------- | ----------
-1 | 29.7 | 55.0 | 19.8
-8 | 229 | 410 | 150
-16 | 459 | 825 | 300
-32 | 902 | 1468 | 575
-64 | 1783 | 3051 | 1004
+1 | 29.7 | 52.4 | 19.4
+8 | 229 | 378 | 146
+16 | 459 | 751 | 291
+32 | 902 | 1388 | 565
+64 | 1783 | 2744 | 981
### Other Results
@@ -387,23 +392,23 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
-1 | 29.2 | 53.0
-8 | 219 | 363
-16 | 427 | 719
-32 | 820 | 1265
-64 | 1608 | 2623
-
+1 | 29.2 | 48.4
+8 | 219 | 333
+16 | 427 | 667
+32 | 820 | 1180
+64 | 1608 | 2315
## Methodology
-This [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
+This
+[script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
was run on the various platforms to generate the above results.
@{$performance_models$High-Performance Models} details techniques in the script
along with examples of how to execute the script.
In order to create results that are as repeatable as possible, each test was run
5 times and then the times were averaged together. GPUs are run in their default
-state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU
+state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU
Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/).
For each test, 10 warmup steps are done and then the next 100 steps are
averaged.
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 44b29a57be..6c1b40b442 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -370,9 +370,8 @@ def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
tf.logging.fatal('File does not exist %s', image_path)
image_data = gfile.FastGFile(image_path, 'rb').read()
try:
- bottleneck_values = run_bottleneck_on_image(sess, image_data,
- jpeg_data_tensor,
- bottleneck_tensor)
+ bottleneck_values = run_bottleneck_on_image(
+ sess, image_data, jpeg_data_tensor, bottleneck_tensor)
except:
raise RuntimeError('Error during processing file %s' % image_path)
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index eb4789a182..e57c197fa4 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -5583,6 +5583,74 @@ func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
return op.Output(0)
}
+// Store the input tensor in the state of the current session.
+//
+// Arguments:
+// value: The tensor to be stored.
+//
+// Returns The handle for the tensor stored in the session state, represented
+// as a ResourceHandle object.
+func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "GetSessionHandleV2",
+ Input: []tf.Input{
+ value,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Adjust the hue of one or more images.
+//
+// `images` is a tensor of at least 3 dimensions. The last dimension is
+// interpretted as channels, and must be three.
+//
+// The input image is considered in the RGB colorspace. Conceptually, the RGB
+// colors are first mapped into HSV. A delta is then applied all the hue values,
+// and then remapped back to RGB colorspace.
+//
+// Arguments:
+// images: Images to adjust. At least 3-D.
+// delta: A float delta to add to the hue.
+//
+// Returns The hue-adjusted image or images.
+func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AdjustHue",
+ Input: []tf.Input{
+ images, delta,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Restore a Reader to its initial clean state.
+//
+// Arguments:
+// reader_handle: Handle to a Reader.
+//
+// Returns the created operation.
+func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReaderResetV2",
+ Input: []tf.Input{
+ reader_handle,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
// Computes softmax cross entropy cost and gradients to backpropagate.
//
// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
@@ -19039,6 +19107,27 @@ func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
return op.Output(0)
}
+// Computes arctangent of `y/x` element-wise, respecting signs of the arguments.
+//
+// This is the angle \( \theta \in [-\pi, \pi] \) such that
+// \[ x = r \cos(\theta) \]
+// and
+// \[ y = r \sin(\theta) \]
+// where \(r = \sqrt(x^2 + y^2) \).
+func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Atan2",
+ Input: []tf.Input{
+ y, x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
//
// The regularized incomplete beta integral is defined as:
@@ -21627,71 +21716,3 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Store the input tensor in the state of the current session.
-//
-// Arguments:
-// value: The tensor to be stored.
-//
-// Returns The handle for the tensor stored in the session state, represented
-// as a ResourceHandle object.
-func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "GetSessionHandleV2",
- Input: []tf.Input{
- value,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Adjust the hue of one or more images.
-//
-// `images` is a tensor of at least 3 dimensions. The last dimension is
-// interpretted as channels, and must be three.
-//
-// The input image is considered in the RGB colorspace. Conceptually, the RGB
-// colors are first mapped into HSV. A delta is then applied all the hue values,
-// and then remapped back to RGB colorspace.
-//
-// Arguments:
-// images: Images to adjust. At least 3-D.
-// delta: A float delta to add to the hue.
-//
-// Returns The hue-adjusted image or images.
-func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AdjustHue",
- Input: []tf.Input{
- images, delta,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Restore a Reader to its initial clean state.
-//
-// Arguments:
-// reader_handle: Handle to a Reader.
-//
-// Returns the created operation.
-func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReaderResetV2",
- Input: []tf.Input{
- reader_handle,
- },
- }
- return scope.AddOperation(opspec)
-}
diff --git a/tensorflow/opensource_only/eigen.threadpool b/tensorflow/opensource_only/eigen.threadpool
deleted file mode 100644
index d2639af4d9..0000000000
--- a/tensorflow/opensource_only/eigen.threadpool
+++ /dev/null
@@ -1 +0,0 @@
-#include "unsupported/Eigen/CXX11/ThreadPool"
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 3b7e3b1c90..5d9913d734 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -436,13 +436,13 @@ class EstimatorTrainTest(test.TestCase):
model_dir=model_dir1,
model_fn=model_fn_global_step_incrementer)
est1.train(dummy_input_fn, steps=5)
-
+
# We have to clear the cache before we can rename the directory,
# otherwise open file handles will prevent the delete on Windows.
writer_cache.FileWriterCache.clear()
model_dir2 = os.path.join(tmpdir, 'model_dir2')
os.renames(model_dir1, model_dir2)
-
+
est2 = estimator.Estimator(
model_dir=model_dir2,
model_fn=model_fn_global_step_incrementer)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index ffdf8868e2..f8855f259e 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -129,6 +129,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -656,6 +657,44 @@ def categorical_column_with_vocabulary_list(
default_value=default_value)
+def categorical_column_with_identity(key, num_buckets, default_value=None):
+ """A `_CategoricalColumn` that returns identity values.
+
+ Use this when your inputs are integers in the range `[0, num_buckets)`. Values
+ outside this range will result in `default_value` if specified, otherwise it
+ will fail.
+
+ Inputs can be either `Tensor` or `SparseTensor`.
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
+ default_value: If `None`, this column's graph operations will fail for
+ out-of-range inputs. Otherwise, this value must be in the range
+ `[0, num_buckets)`, and will replace inputs in that range.
+
+ Returns:
+ A `_CategoricalColumn` that returns identity values.
+
+ Raises:
+ ValueError: if `num_buckets` is less than one.
+ ValueError: if `default_value` is not in range `[0, num_buckets)`.
+ """
+ if num_buckets < 1:
+ raise ValueError(
+ 'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
+ if (default_value is not None) and (
+ (default_value < 0) or (default_value >= num_buckets)):
+ raise ValueError(
+ 'default_value {} not in range [0, {}), column_name {}'.format(
+ default_value, num_buckets, key))
+ return _IdentityCategoricalColumn(
+ key=key, num_buckets=num_buckets, default_value=default_value)
+
+
class _FeatureColumn(object):
"""Represents a feature column abstraction.
@@ -1384,6 +1423,69 @@ class _VocabularyListCategoricalColumn(
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
+class _IdentityCategoricalColumn(
+ _CategoricalColumn,
+ collections.namedtuple('_IdentityCategoricalColumn', (
+ 'key', 'num_buckets', 'default_value'
+ ))):
+
+ """See `categorical_column_with_identity`."""
+
+ @property
+ def name(self):
+ return self.key
+
+ @property
+ def _parse_example_config(self):
+ return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
+
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input(inputs.get(self.key))
+
+ if not input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Invalid input, not integer. key: {} dtype: {}'.format(
+ self.key, input_tensor.dtype))
+
+ values = math_ops.to_int64(input_tensor.values, name='values')
+ num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets')
+ zero = math_ops.to_int64(0, name='zero')
+ if self.default_value is None:
+ # Fail if values are out-of-range.
+ assert_less = check_ops.assert_less(
+ values, num_buckets, data=(values, num_buckets),
+ name='assert_less_than_num_buckets')
+ assert_greater = check_ops.assert_greater_equal(
+ values, zero, data=(values,),
+ name='assert_greater_or_equal_0')
+ with ops.control_dependencies((assert_less, assert_greater)):
+ values = array_ops.identity(values)
+ else:
+ # Assign default for out-of-range values.
+ values = array_ops.where(
+ math_ops.logical_or(
+ values < zero, values >= num_buckets, name='out_of_range'),
+ array_ops.fill(
+ dims=array_ops.shape(values),
+ value=math_ops.to_int64(self.default_value),
+ name='default_values'),
+ values)
+
+ return sparse_tensor_lib.SparseTensor(
+ indices=input_tensor.indices,
+ values=values,
+ dense_shape=input_tensor.dense_shape)
+
+ @property
+ def _num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.num_buckets
+
+ def _get_sparse_tensors(
+ self, inputs, weight_collections=None, trainable=None):
+ return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
+
# TODO(zakaria): Move this to embedding_ops and make it public.
def _safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 59aa39411f..5201811831 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variable_scope
@@ -1828,5 +1829,198 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertAllClose(((3.,), (1.,)), predictions.eval())
+class IdentityCategoricalColumnTest(test.TestCase):
+
+ def test_constructor(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_deep_copy(self):
+ """Tests deepcopy of categorical_column_with_hash_bucket."""
+ original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_invalid_num_buckets_zero(self):
+ with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
+ fc.categorical_column_with_identity(key='aaa', num_buckets=0)
+
+ def test_invalid_num_buckets_negative(self):
+ with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'):
+ fc.categorical_column_with_identity(key='aaa', num_buckets=-1)
+
+ def test_invalid_default_value_too_small(self):
+ with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'):
+ fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3, default_value=-1)
+
+ def test_invalid_default_value_too_big(self):
+ with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'):
+ fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3, default_value=3)
+
+ def test_invalid_input_dtype(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
+ # pylint: disable=protected-access
+ column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
+ 'aaa': ((0, -1), (1, 0))
+ }))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_inputs_too_small(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, -1, 0),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ with self.assertRaisesRegexp(
+ errors.OpError, 'assert_greater_or_equal_0'):
+ id_weight_pair.id_tensor.eval()
+
+ def test_get_sparse_tensors_with_inputs_too_big(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 99, 0),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ with self.assertRaisesRegexp(
+ errors.OpError, 'assert_less_than_num_buckets'):
+ id_weight_pair.id_tensor.eval()
+
+ def test_get_sparse_tensors_with_default_value(self):
+ column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=4, default_value=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, -1, 99),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((1, 3, 3), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
+ column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=4, default_value=3)
+ input_indices = array_ops.placeholder(dtype=dtypes.int64)
+ input_values = array_ops.placeholder(dtype=dtypes.int32)
+ input_shape = array_ops.placeholder(dtype=dtypes.int64)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=input_indices,
+ values=input_values,
+ dense_shape=input_shape)
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64),
+ values=np.array((1, 3, 3), dtype=np.int64),
+ dense_shape=np.array((2, 2), dtype=np.int64)),
+ id_weight_pair.id_tensor.eval(feed_dict={
+ input_indices: ((0, 0), (1, 0), (1, 1)),
+ input_values: (1, -1, 99),
+ input_shape: (2, 2),
+ }))
+
+ def test_make_linear_model(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.make_linear_model({
+ column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 2a1389b91f..ac8aee2c83 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -113,18 +113,19 @@ def _add_op_node(op, func, input_dict):
node_def = func.node_def[-1]
for i in range(len(node_def.input)):
if not node_def.input[i].startswith("^"):
- assert node_def.input[i] in input_dict, (
- "%s missing from %s" % (node_def.input[i], input_dict.items()))
+ assert node_def.input[i] in input_dict, ("%s missing from %s" %
+ (node_def.input[i],
+ input_dict.items()))
node_def.input[i] = input_dict[node_def.input[i]]
-def _graph_to_function_def(graph, inputs, outputs, out_names=None):
+def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
"""Returns `graph` as a `FunctionDef` protocol buffer.
This method creates a [`FunctionDef`](
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
- protocol buffer that contains all the ops present in the graph. The
- graph effectively becomes the body of the function.
+ protocol buffer that contains all the ops in `operations`. The
+ operations become the body of the function.
The arguments `inputs` and `outputs` will be listed as the inputs
and outputs tensors of the function. They must be lists of
@@ -132,6 +133,8 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
Args:
graph: Graph.
+ operations: the operations to put in the function. Must be a subset of
+ the operations in the graph.
inputs: List of tensors. Inputs to the function.
outputs: List of tensors. Outputs of the function.
out_names: Optional list of string names for the outputs.
@@ -145,12 +148,12 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
func = function_pb2.FunctionDef()
func.signature.name = "_"
used_names = set()
- func.signature.input_arg.extend([_tensor_to_argdef(i, used_names=used_names)
- for i in inputs])
+ func.signature.input_arg.extend(
+ [_tensor_to_argdef(i, used_names=used_names) for i in inputs])
if out_names is None:
used_names = set()
- func.signature.output_arg.extend([
- _tensor_to_argdef(o, used_names=used_names) for o in outputs])
+ func.signature.output_arg.extend(
+ [_tensor_to_argdef(o, used_names=used_names) for o in outputs])
elif len(outputs) != len(out_names):
raise ValueError(
"Length of out_names (%d) does not match number of outputs (%d): %s" %
@@ -159,12 +162,12 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
raise ValueError(
"Must not have duplicates in out_names: %s" % ", ".join(out_names))
else:
- func.signature.output_arg.extend([
- _tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
+ func.signature.output_arg.extend(
+ [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
func_arg_placeholders = set([i.name for i in inputs])
input_dict = _create_input_dict(graph, func_arg_placeholders)
- for op in graph.get_operations():
+ for op in operations:
if _is_in_placeholders(op, func_arg_placeholders):
continue
_add_op_node(op, func, input_dict)
@@ -295,17 +298,18 @@ class _FuncGraph(ops.Graph):
self.extra_args = []
self.extra_vars = []
- def getvar(self,
- getter,
- name,
- shape=None,
- dtype=None,
- initializer=None,
- reuse=None,
- trainable=True,
- collections=None, # pylint: disable=redefined-outer-name
- use_resource=None,
- **kwargs):
+ def getvar(
+ self,
+ getter,
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ reuse=None,
+ trainable=True,
+ collections=None, # pylint: disable=redefined-outer-name
+ use_resource=None,
+ **kwargs):
"""A custom variable getter."""
# Here, we switch the default graph to the outer graph and ask the
# variable scope in which the function is defined to give us the
@@ -538,20 +542,23 @@ class _DefinedFunction(object):
# Build the FunctionDef
self._definition = _graph_to_function_def(
- temp_graph, inputs, outputs, out_names=self._out_names)
+ temp_graph,
+ temp_graph.get_operations(),
+ inputs,
+ outputs,
+ out_names=self._out_names)
# Extra kwargs are treated as attrs on the function def.
sig_pre_func_name = self._func_name or _get_func_name(self._func)
- kwargs_attr = _parse_kwargs_as_attrs(
- sig_pre_func_name, **self._extra_kwargs)
+ kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
+ **self._extra_kwargs)
for k in kwargs_attr:
self._definition.attr[k].CopyFrom(kwargs_attr[k])
# Hash the definition and its dependencies.
self._hash_str = self._create_hash_str(
self._definition.signature.input_arg,
- self._definition.signature.output_arg,
- self._definition.node_def)
+ self._definition.signature.output_arg, self._definition.node_def)
# Finally, we decide the function name to use. If not specified,
# make up something which is almost certainly unique (but deterministic).
@@ -658,8 +665,8 @@ def _from_definition(fdef, grad_func=None):
# have access to such a callable here).
func = None
argnames = [arg.name for arg in fdef.signature.input_arg]
- input_types = tuple(dtypes.as_dtype(arg.type)
- for arg in fdef.signature.input_arg)
+ input_types = tuple(
+ dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
func_name = fdef.signature.name
# Note: FunctionDefs do not include python gradient functions, so if the
# original _DefinedFunction included one it will not be reflected here.
@@ -675,8 +682,7 @@ def _from_definition(fdef, grad_func=None):
result._extra_inputs = []
result._hash_str = result._create_hash_str(
result._definition.signature.input_arg,
- result._definition.signature.output_arg,
- result._definition.node_def)
+ result._definition.signature.output_arg, result._definition.node_def)
# pylint: enable=protected-access
return result
@@ -696,7 +702,8 @@ def _from_library(lib):
Raises:
ValueError: `lib` is invalid
"""
- if not lib.function and not lib.gradient: return []
+ if not lib.function and not lib.gradient:
+ return []
# function name -> FunctionDef proto
funcs = {fdef.signature.name: fdef for fdef in lib.function}
@@ -720,8 +727,9 @@ def _from_library(lib):
grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
# Start with functions without gradients
- ready = [fdef for fdef in lib.function
- if func_to_grad[fdef.signature.name] is None]
+ ready = [
+ fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
+ ]
if not ready:
raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
+ str(lib))
@@ -733,7 +741,8 @@ def _from_library(lib):
name = fdef.signature.name
grad = initialized.get(func_to_grad[name])
- if func_to_grad[name]: assert grad
+ if func_to_grad[name]:
+ assert grad
defined_func = _from_definition(fdef, grad_func=grad)
initialized[name] = defined_func
@@ -835,10 +844,15 @@ class _OverloadedFunction(object):
name = self._func_name
if name is not None:
name = "_".join([name, key])
- defined = _DefinedFunction(self._func, self._argnames, input_types, name,
- None, self._python_grad_func,
- out_names=self._out_names,
- **self._extra_kwargs)
+ defined = _DefinedFunction(
+ self._func,
+ self._argnames,
+ input_types,
+ name,
+ None,
+ self._python_grad_func,
+ out_names=self._out_names,
+ **self._extra_kwargs)
_ = defined.name # Fully instantiate the function definition.
if self._grad_func:
# If _grad_func is given, it is another
@@ -849,8 +863,8 @@ class _OverloadedFunction(object):
for _ in defined.definition.signature.output_arg
]
# pylint: disable=protected-access
- defined._grad_func = self._grad_func.instantiate(input_types +
- output_types)
+ defined._grad_func = self._grad_func.instantiate(
+ input_types + output_types)
# pylint: enable=protected-access
self._overload[key] = defined
return defined
@@ -981,22 +995,36 @@ class Defun(object):
raise ValueError(
"The function has fewer arguments than the number of specified "
"input types.")
- return _DefinedFunction(func, argnames, self._input_types,
- self._func_name, self._grad_func,
- self._python_grad_func,
- out_names=self._out_names, **self._extra_kwargs)
+ return _DefinedFunction(
+ func,
+ argnames,
+ self._input_types,
+ self._func_name,
+ self._grad_func,
+ self._python_grad_func,
+ out_names=self._out_names,
+ **self._extra_kwargs)
# 'func' expects no arguments and input types is an empty list.
if min_args == 0 and max_args == 0:
- return _DefinedFunction(func, [], [], self._func_name, self._grad_func,
- self._python_grad_func,
- out_names=self._out_names, **self._extra_kwargs)
+ return _DefinedFunction(
+ func, [], [],
+ self._func_name,
+ self._grad_func,
+ self._python_grad_func,
+ out_names=self._out_names,
+ **self._extra_kwargs)
# Input types are unknown. It's an overloaded function and hence
# its definition needs to be deferred until it's called.
- return _OverloadedFunction(func, argnames, self._func_name, self._grad_func,
- self._python_grad_func,
- out_names=self._out_names, **self._extra_kwargs)
+ return _OverloadedFunction(
+ func,
+ argnames,
+ self._func_name,
+ self._grad_func,
+ self._python_grad_func,
+ out_names=self._out_names,
+ **self._extra_kwargs)
class Declare(object):
@@ -1039,8 +1067,10 @@ class Declare(object):
names = [n for n, t in args]
if len(names) != len(set(names)):
raise ValueError("Expected names to all be unique: %s" % str(names))
- return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
- for n, t in args]
+ return [
+ op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
+ for n, t in args
+ ]
self._sig.input_arg.extend(_to_argdef_list(inputs))
self._sig.output_arg.extend(_to_argdef_list(outputs))
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 2f2f1da090..0846470abc 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -1106,16 +1106,18 @@ class BinaryOpTest(test.TestCase):
def testAtan2SpecialValues(self):
x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
- (1.2345, float('inf')), (1.2345, -float('inf')),
- (-4.321, float('inf')), (-4.125, -float('inf')),
- (float('inf'), float('inf')), (float('inf'), -float('inf')),
- (-float('inf'), float('inf')), (-float('inf'), -float('inf')))
+ (1.2345, float("inf")), (1.2345, -float("inf")),
+ (-4.321, float("inf")), (-4.125, -float("inf")),
+ (float("inf"), float("inf")), (float("inf"), -float("inf")),
+ (-float("inf"), float("inf")), (-float("inf"),
+ -float("inf")))
for dtype in np.float32, np.float64:
x1 = np.array(x1l).astype(dtype)
x2 = np.array(x2l).astype(dtype)
self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
+
class ComparisonOpTest(test.TestCase):
def _compareScalar(self, func, x, y, dtype):
diff --git a/tensorflow/python/kernel_tests/tensor_priority_test.py b/tensorflow/python/kernel_tests/tensor_priority_test.py
index b6674c3aa5..574538a837 100644
--- a/tensorflow/python/kernel_tests/tensor_priority_test.py
+++ b/tensorflow/python/kernel_tests/tensor_priority_test.py
@@ -19,58 +19,65 @@ from __future__ import print_function
import numpy as np
-from tensorflow import Tensor
-from tensorflow import register_tensor_conversion_function
from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test as test_lib
class TensorPriorityTest(test_lib.TestCase):
def testSupportedRhsWithoutDelegation(self):
+
class NumpyArraySubclass(np.ndarray):
pass
- supported_rhs_without_delegation = (
- 3,
- 3.0,
- [1.0, 2.0],
- np.array([1.0, 2.0]),
- NumpyArraySubclass(shape=(1,2), buffer=np.array([1.0, 2.0])),
- ops.convert_to_tensor([[1.0, 2.0]]))
+
+ supported_rhs_without_delegation = (3, 3.0, [1.0, 2.0], np.array(
+ [1.0, 2.0]), NumpyArraySubclass(
+ shape=(1, 2), buffer=np.array([1.0, 2.0])),
+ ops.convert_to_tensor([[1.0, 2.0]]))
for rhs in supported_rhs_without_delegation:
tensor = ops.convert_to_tensor([[10.0, 20.0]])
res = tensor + rhs
- self.assertIsInstance(res, Tensor)
+ self.assertIsInstance(res, ops.Tensor)
def testUnsupportedRhsWithoutDelegation(self):
+
class WithoutReverseAdd(object):
pass
+
tensor = ops.convert_to_tensor([[10.0, 20.0]])
rhs = WithoutReverseAdd()
with self.assertRaisesWithPredicateMatch(
TypeError, lambda e: "Expected float" in str(e)):
- res = tensor + rhs
+ # pylint: disable=pointless-statement
+ tensor + rhs
def testUnsupportedRhsWithDelegation(self):
+
class WithReverseAdd(object):
+
def __radd__(self, lhs):
return "Works!"
+
tensor = ops.convert_to_tensor([[10.0, 20.0]])
rhs = WithReverseAdd()
res = tensor + rhs
self.assertEqual(res, "Works!")
def testFullDelegationControlUsingRegistry(self):
+
class NumpyArraySubclass(np.ndarray):
+
def __radd__(self, lhs):
return "Works!"
+
def raise_to_delegate(value, dtype=None, name=None, as_ref=False):
+ del value, dtype, name, as_ref # Unused.
raise TypeError
- register_tensor_conversion_function(NumpyArraySubclass, raise_to_delegate,
- priority=0)
+
+ ops.register_tensor_conversion_function(
+ NumpyArraySubclass, raise_to_delegate, priority=0)
tensor = ops.convert_to_tensor([[10.0, 20.0]])
- rhs = NumpyArraySubclass(shape=(1,2), buffer=np.array([1.0, 2.0]))
+ rhs = NumpyArraySubclass(shape=(1, 2), buffer=np.array([1.0, 2.0]))
res = tensor + rhs
self.assertEqual(res, "Works!")
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 1dc07525df..938161f426 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -1109,10 +1109,10 @@ class Conv2DTranspose(Conv2D):
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
- out_shape[h_axis] = utils.get_deconv_dim(
- out_shape[h_axis], stride_h, kernel_h, self.padding)
- out_shape[w_axis] = utils.get_deconv_dim(
- out_shape[w_axis], stride_w, kernel_w, self.padding)
+ out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
+ kernel_h, self.padding)
+ out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
+ kernel_w, self.padding)
outputs.set_shape(out_shape)
if self.bias:
@@ -1240,7 +1240,8 @@ class Conv3DTranspose(Conv3D):
name: A string, the name of the layer.
"""
- def __init__(self, filters,
+ def __init__(self,
+ filters,
kernel_size,
strides=(1, 1, 1),
padding='valid',
@@ -1269,12 +1270,13 @@ class Conv3DTranspose(Conv3D):
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
trainable=trainable,
- name=name, **kwargs)
+ name=name,
+ **kwargs)
def build(self, input_shape):
if len(input_shape) != 5:
- raise ValueError('Inputs should have rank 5, ' +
- 'received input shape:', str(input_shape))
+ raise ValueError('Inputs should have rank 5, received input shape:',
+ str(input_shape))
if self.data_format == 'channels_first':
channel_axis = 1
else:
@@ -1285,22 +1287,23 @@ class Conv3DTranspose(Conv3D):
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (self.filters, input_dim)
- self.kernel = self.add_variable('kernel',
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- trainable=True,
- dtype=self.dtype)
+ self.kernel = self.add_variable(
+ 'kernel',
+ shape=kernel_shape,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ trainable=True,
+ dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable('bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- trainable=True,
- dtype=self.dtype)
+ self.bias = self.add_variable(
+ 'bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ trainable=True,
+ dtype=self.dtype)
else:
self.bias = None
- self.built = True
def call(self, inputs):
inputs_shape = array_ops.shape(inputs)
@@ -1343,26 +1346,26 @@ class Conv3DTranspose(Conv3D):
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
- out_shape[d_axis] = utils.get_deconv_dim(
- out_shape[d_axis], stride_d, kernel_d, self.padding)
- out_shape[h_axis] = utils.get_deconv_dim(
- out_shape[h_axis], stride_h, kernel_h, self.padding)
- out_shape[w_axis] = utils.get_deconv_dim(
- out_shape[w_axis], stride_w, kernel_w, self.padding)
+ out_shape[d_axis] = utils.get_deconv_dim(out_shape[d_axis], stride_d,
+ kernel_d, self.padding)
+ out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
+ kernel_h, self.padding)
+ out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
+ kernel_w, self.padding)
outputs.set_shape(out_shape)
if self.bias:
outputs_shape = outputs.shape.as_list()
if self.data_format == 'channels_first':
- outputs_4d = array_ops.reshape(outputs,
- [outputs_shape[0], outputs_shape[1],
- outputs_shape[2] * outputs_shape[3],
- outputs_shape[4]])
+ outputs_4d = array_ops.reshape(outputs, [
+ outputs_shape[0], outputs_shape[1],
+ outputs_shape[2] * outputs_shape[3], outputs_shape[4]
+ ])
else:
- outputs_4d = array_ops.reshape(outputs,
- [outputs_shape[0],
- outputs_shape[1] * outputs_shape[2],
- outputs_shape[3], outputs_shape[4]])
+ outputs_4d = array_ops.reshape(outputs, [
+ outputs_shape[0], outputs_shape[1] * outputs_shape[2],
+ outputs_shape[3], outputs_shape[4]
+ ])
outputs_4d = nn.bias_add(
outputs_4d,
self.bias,
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index 635cc24714..42a2d77534 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -715,8 +715,8 @@ class Conv3DTransposeTest(test.TestCase):
layer = conv_layers.Conv3DTranspose(
32, volumes.get_shape()[1:4], padding='same')
output = layer.apply(volumes)
- self.assertListEqual(output.get_shape().as_list(), [5, depth, height,
- width, 32])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth, height, width, 32])
def testCreateConv3DTransposeWithStrides(self):
depth, height, width = 4, 6, 8
@@ -729,8 +729,7 @@ class Conv3DTransposeTest(test.TestCase):
[5, depth * 2, height * 2, width * 2, 4])
# Test strides integer.
- layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2,
- padding='same')
+ layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2, padding='same')
output = layer.apply(volumes)
self.assertListEqual(output.get_shape().as_list(),
[5, depth * 2, height * 2, width * 2, 4])
@@ -779,14 +778,14 @@ class Conv3DTransposeTest(test.TestCase):
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2)
- conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
+ conv_layers.conv3d_transpose(
+ volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
self.assertEqual(len(variables.trainable_variables()), 2)
def testFunctionalConv3DTransposeReuseFromScope(self):
with variable_scope.variable_scope('scope'):
depth, height, width = 5, 7, 9
- volumes = random_ops.random_uniform((5, depth, height, width, 32),
- seed=1)
+ volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
self.assertEqual(len(variables.trainable_variables()), 2)
with variable_scope.variable_scope('scope', reuse=True):
@@ -798,8 +797,8 @@ class Conv3DTransposeTest(test.TestCase):
with variable_scope.variable_scope(
'scope', initializer=init_ops.ones_initializer()):
depth, height, width = 5, 7, 9
- volumes = random_ops.random_uniform((5, depth, height, width, 32),
- seed=1)
+ volumes = random_ops.random_uniform(
+ (5, depth, height, width, 32), seed=1)
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
weights = variables.trainable_variables()
# Check the names of weights in order.
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 4cfe991784..69edaa2c40 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -205,7 +205,8 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
`decoded.shape`: Shape vector, size `(2)`.
The shape values are: `[batch_size, max_decoded_length]`
neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
- sequence found, the negative of the sum of the greatest logit at each timeframe.
+ sequence found, the negative of the sum of the greatest logit at each
+ timeframe.
"""
outputs = gen_ctc_ops._ctc_greedy_decoder(
inputs, sequence_length, merge_repeated=merge_repeated)
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 4d922946ff..1e2f999995 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -39,6 +39,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 247d35923e..61fda3a798 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -964,8 +964,12 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`.
"""
- return convolution(input=value, filter=filters, padding=padding,
- dilation_rate=np.broadcast_to(rate, (2, )), name=name)
+ return convolution(
+ input=value,
+ filter=filters,
+ padding=padding,
+ dilation_rate=np.broadcast_to(rate, (2,)),
+ name=name)
def conv2d_transpose(value,
@@ -1231,8 +1235,8 @@ def conv3d_transpose(value,
axis = 1 if data_format == "NCDHW" else 4
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]):
raise ValueError("input channels does not match filter's input channels, "
- "{} != {}".format(value.get_shape()[axis], filter.get_shape(
- )[4]))
+ "{} != {}".format(value.get_shape()[axis],
+ filter.get_shape()[4]))
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)):
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index a9d999dad3..3252652174 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -195,46 +195,47 @@ def load(sess, tags, export_dir, **saver_kwargs):
Raises:
RuntimeError: MetaGraphDef associated with the tags cannot be found.
"""
- # Build the SavedModel protocol buffer and find the requested meta graph def.
- saved_model = _parse_saved_model(export_dir)
- found_match = False
- for meta_graph_def in saved_model.meta_graphs:
- if set(meta_graph_def.meta_info_def.tags) == set(tags):
- meta_graph_def_to_load = meta_graph_def
- found_match = True
- break
-
- if not found_match:
- raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
- "[]") + " could not be found in SavedModel")
-
- # Build a saver by importing the meta graph def to load.
- saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
-
- if saver:
- # Build the checkpoint path where the variables are located.
- variables_path = os.path.join(
- compat.as_bytes(export_dir),
- compat.as_bytes(constants.VARIABLES_DIRECTORY),
- compat.as_bytes(constants.VARIABLES_FILENAME))
-
- # Restore the variables using the built saver in the provided session.
- saver.restore(sess, variables_path)
- else:
- tf_logging.info("The specified SavedModel has no variables; no "
- "checkpoints were restored.")
-
- # Get asset tensors, if any.
- asset_tensors_dictionary = _get_asset_tensors(export_dir,
- meta_graph_def_to_load)
-
- main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
- if main_op_tensor is not None:
- sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
- else:
- legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
- if legacy_init_op_tensor is not None:
- sess.run(fetches=[legacy_init_op_tensor],
- feed_dict=asset_tensors_dictionary)
-
- return meta_graph_def_to_load
+ with sess.graph.as_default():
+ # Build the SavedModel protocol buffer and find requested meta graph def.
+ saved_model = _parse_saved_model(export_dir)
+ found_match = False
+ for meta_graph_def in saved_model.meta_graphs:
+ if set(meta_graph_def.meta_info_def.tags) == set(tags):
+ meta_graph_def_to_load = meta_graph_def
+ found_match = True
+ break
+
+ if not found_match:
+ raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
+ "[]") + " could not be found in SavedModel")
+
+ # Build a saver by importing the meta graph def to load.
+ saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
+
+ if saver:
+ # Build the checkpoint path where the variables are located.
+ variables_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.VARIABLES_DIRECTORY),
+ compat.as_bytes(constants.VARIABLES_FILENAME))
+
+ # Restore the variables using the built saver in the provided session.
+ saver.restore(sess, variables_path)
+ else:
+ tf_logging.info("The specified SavedModel has no variables; no "
+ "checkpoints were restored.")
+
+ # Get asset tensors, if any.
+ asset_tensors_dictionary = _get_asset_tensors(export_dir,
+ meta_graph_def_to_load)
+
+ main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
+ if main_op_tensor is not None:
+ sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
+ else:
+ legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
+ if legacy_init_op_tensor is not None:
+ sess.run(
+ fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary)
+
+ return meta_graph_def_to_load
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index a81f744175..fcd6bc3954 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -151,6 +151,27 @@ class SavedModelTest(test.TestCase):
constants.SAVED_MODEL_FILENAME_PBTXT):
loader.load(sess, ["foo"], export_dir)
+ def testVerifySessionGraphUsage(self):
+ export_dir = os.path.join(test.get_temp_dir(),
+ "test_verify_session_graph_usage")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+ builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ # Build a session and supply it to the load operation.
+ sess = session.Session(graph=ops.Graph())
+ loader.load(sess, [tag_constants.TRAINING], export_dir)
+
+ # Check the variable within the scope of the session and its graph.
+ with sess:
+ self.assertEqual(
+ 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+
def testSequence(self):
export_dir = os.path.join(test.get_temp_dir(), "test_sequence")
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py
index 1c65998756..caeb04a24b 100644
--- a/tensorflow/python/tools/import_pb_to_tensorboard.py
+++ b/tensorflow/python/tools/import_pb_to_tensorboard.py
@@ -12,33 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ================================
+"""Imports a protobuf model as a graph in Tensorboard."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.summary import summary
def import_to_tensorboard(model_dir, log_dir):
"""View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
-
+
Args:
model_dir: The location of the protobuf (`pb`) model to visualize
log_dir: The location for the Tensorboard log to begin visualisation from.
-
+
Usage:
Call this function with your model location and desired log directory.
Launch Tensorboard by pointing it to the log directory.
View your imported `.pb` model as a graph.
"""
- with tf.Session(graph=tf.Graph()) as sess:
- with tf.gfile.FastGFile(model_dir, 'rb') as f:
- graph_def = tf.GraphDef()
+ with session.Session(graph=ops.Graph()) as sess:
+ with gfile.FastGFile(model_dir, "rb") as f:
+ graph_def = graph_pb2.GraphDef()
graph_def.ParseFromString(f.read())
- g_in = tf.import_graph_def(graph_def)
+ importer.import_graph_def(graph_def)
- pb_visual_writer = tf.summary.FileWriter(log_dir)
+ pb_visual_writer = summary.FileWriter(log_dir)
pb_visual_writer.add_graph(sess.graph)
print("Model Imported. Visualize by running: "
"> tensorboard --logdir={}".format(log_dir))
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index ac5be246a0..e1be305505 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -504,7 +504,14 @@ def run(args):
Args:
args: A namespace parsed from command line.
+
+ Raises:
+ AttributeError: An error when neither --inputs nor --input_exprs is passed
+ to run command.
"""
+ if not args.inputs and not args.input_exprs:
+ raise AttributeError(
+ 'At least one of --inputs and --input_exprs must be required')
tensor_key_feed_dict = load_inputs_from_input_arg_string(
args.inputs, args.input_exprs)
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
@@ -629,8 +636,6 @@ def create_parser():
def main():
parser = create_parser()
args = parser.parse_args()
- if not args.inputs and not args.input_exprs:
- args.error('At least one of --inputs and --input_exprs is required')
args.func(args)
diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py
index 1c7a44b3eb..8f79c888eb 100644
--- a/tensorflow/python/tools/saved_model_cli_test.py
+++ b/tensorflow/python/tools/saved_model_cli_test.py
@@ -409,6 +409,16 @@ Method name is: tensorflow/serving/predict"""
with self.assertRaises(RuntimeError):
saved_model_cli.run(args)
+ def testRunCommandInputNotGivenError(self):
+ self.parser = saved_model_cli.create_parser()
+ base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+ args = self.parser.parse_args([
+ 'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
+ 'serving_default'
+ ])
+ with self.assertRaises(AttributeError):
+ saved_model_cli.run(args)
+
def testRunCommandWithDebuggerEnabled(self):
self.parser = saved_model_cli.create_parser()
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index ffa39b7793..d52cf9a436 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -210,9 +210,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
else:
var_name = ",".join([v.name for v in var])
_set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
- logging.info("Initialize variable %s from checkpoint %s with %s" % (
- var_name, ckpt_dir_or_file, tensor_name_in_ckpt
- ))
+ logging.info("Initialize variable %s from checkpoint %s with %s",
+ var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
else:
scopes = ""
# TODO(vihanjain): Support list of 'current_var_or_name' here.
@@ -250,9 +249,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
if var is None:
var = _collect_partitioned_variable(var_name, store_vars)
_set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
- logging.info("Initialize variable %s from checkpoint %s with %s" % (
- var_name, ckpt_dir_or_file, full_tensor_name
- ))
+ logging.info("Initialize variable %s from checkpoint %s with %s",
+ var_name, ckpt_dir_or_file, full_tensor_name)
def _get_checkpoint_filename(ckpt_dir_or_file):
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 3ec6692f2c..a65ab79495 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -935,11 +935,11 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
except errors.OpError as e:
# It's ok if the file cannot be read
- logging.warning("%s: %s" % (type(e).__name__, e))
+ logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
except text_format.ParseError as e:
- logging.warning("%s: %s" % (type(e).__name__, e))
+ logging.warning("%s: %s", type(e).__name__, e)
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
finally:
diff --git a/tensorflow/tensorboard/backend/application_test.py b/tensorflow/tensorboard/backend/application_test.py
index 6a9f7e6079..4ea627def7 100644
--- a/tensorflow/tensorboard/backend/application_test.py
+++ b/tensorflow/tensorboard/backend/application_test.py
@@ -230,13 +230,15 @@ class TensorboardServerTest(test.TestCase):
def testScalars(self):
"""Test the format of /data/scalars."""
data = self._getJson('/data/scalars?run=run1&tag=simple_values')
- self.assertEqual(len(data),self._SCALAR_COUNT)
+ self.assertEqual(len(data), self._SCALAR_COUNT)
def testScalarsCsv(self):
"""Test the csv format of /data/scalars."""
- data = self._get('/data/scalars?run=run1&tag=simple_values&format=csv').read()
+ data = self._get(
+ '/data/scalars?run=run1&tag=simple_values&format=csv').read()
line_count = data.count('\n')
- self.assertEqual(line_count,self._SCALAR_COUNT + 1) # include 1 more line for header
+ self.assertEqual(line_count,
+ self._SCALAR_COUNT + 1) # include 1 more line for header
def testHistograms(self):
"""Test the format of /data/histograms."""
diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD b/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD
new file mode 100644
index 0000000000..004bece78b
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD
@@ -0,0 +1,63 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "tf_audio_dashboard",
+ srcs = [
+ "tf-audio-dashboard.html",
+ "tf-audio-grid.html",
+ "tf-audio-loader.html",
+ ],
+ path = "/tf-audio-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_backend",
+ "//tensorflow/tensorboard/components/tf_dashboard_common",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "@org_polymer",
+ "@org_polymer_paper_icon_button",
+ "@org_polymer_paper_slider",
+ "@org_polymer_paper_spinner",
+ "@org_polymer_paper_styles",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-audio-dashboard.html",
+ "tf-audio-grid.html",
+ "tf-audio-loader.html",
+ ],
+ destdir = "tf-audio-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//tensorflow/tensorboard/components/tf_backend:legacy",
+ "//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
+ "//third_party/javascript/polymer/v1/paper-icon-button:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+# This is needed: components/BUILD seeks a legacy_ts rule in this package.
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = ["//tensorflow/tensorboard/components:common_deps"],
+)
diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/BUILD
new file mode 100644
index 0000000000..383ea8d1b6
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/BUILD
@@ -0,0 +1,26 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/tf_audio_dashboard/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/tf-audio-dashboard/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_audio_dashboard",
+ "//tensorflow/tensorboard/components/tf_audio_dashboard/demo/data",
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/demo/data/BUILD b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/data/BUILD
new file mode 100644
index 0000000000..c3824a923d
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/data/BUILD
@@ -0,0 +1,17 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "data",
+ srcs = glob(["*"]),
+ path = "/tf-audio-dashboard/demo/data",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-loader.html b/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-loader.html
index ed3b5efa07..71539537d0 100644
--- a/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-loader.html
+++ b/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-loader.html
@@ -107,6 +107,8 @@ future for loading older clips.
</template>
</template>
<script>
+ "use strict";
+
Polymer({
is: "tf-audio-loader",
properties: {
diff --git a/tensorflow/tensorboard/components/tf_backend/BUILD b/tensorflow/tensorboard/components/tf_backend/BUILD
new file mode 100644
index 0000000000..66fc429c60
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_backend/BUILD
@@ -0,0 +1,81 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+# TODO(dandelion): Add webfiles support for the test code.
+
+webfiles(
+ name = "tf_backend",
+ srcs = [
+ "tf-backend.html",
+ ":ts",
+ ],
+ path = "/tf-backend",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "//tensorflow/tensorboard/components/vz_sorting",
+ "@org_polymer",
+ ],
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = [
+ "backend.ts",
+ "behavior.ts",
+ "requestManager.ts",
+ "router.ts",
+ "urlPathHelpers.ts",
+ ],
+ typings = [
+ "@org_definitelytyped//:d3.d.ts",
+ "@org_definitelytyped//:lodash.d.ts",
+ "//tensorflow/tensorboard/components/vz_sorting:ts_typings",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-backend.html",
+ ":legacy_ts",
+ ],
+ visibility = ["//visibility:public"],
+ destdir = "tf-backend",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [
+ "backend.ts",
+ "behavior.ts",
+ "requestManager.ts",
+ "router.ts",
+ "urlPathHelpers.ts",
+ ],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = [
+ "//tensorflow/tensorboard/components:common_deps",
+ "//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
+ ],
+)
diff --git a/tensorflow/tensorboard/components/tf_backend_d3v4/BUILD b/tensorflow/tensorboard/components/tf_backend_d3v4/BUILD
new file mode 100644
index 0000000000..e54d0e222c
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_backend_d3v4/BUILD
@@ -0,0 +1,45 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+
+licenses(["notice"]) # Apache 2.0
+
+tensorboard_ts_library(
+ name = "ts",
+ srcs = [
+ "backend.ts",
+ "behavior.ts",
+ "requestManager.ts",
+ "router.ts",
+ "urlPathHelpers.ts",
+ ],
+ deps = [
+ "//tensorflow/tensorboard/components/vz_sorting_d3v4:ts",
+ "//third_party/javascript/node_modules/typescript:es2015.promise",
+ "//third_party/javascript/typings/chai",
+ "//third_party/javascript/typings/d3_v4:bundle",
+ "//third_party/javascript/typings/lodash",
+ "//third_party/javascript/typings/mocha",
+ "//third_party/javascript/typings/polymer:polymer_without_externs",
+ "//third_party/javascript/typings/sinon",
+ ],
+)
+
+# TODO(dandelion): Add runners for these tests
+tensorboard_ts_library(
+ name = "tests",
+ srcs = [
+ "backendTests.ts",
+ "behaviorTests.ts",
+ "requestManagerTests.ts",
+ ],
+ deps = [
+ ":ts",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_color_scale/BUILD b/tensorflow/tensorboard/components/tf_color_scale/BUILD
new file mode 100644
index 0000000000..52102577b4
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_color_scale/BUILD
@@ -0,0 +1,65 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+# TODO(dandelion): Add webfiles support for the test code.
+
+webfiles(
+ name = "tf_color_scale",
+ srcs = [
+ "tf-color-scale.html",
+ ":ts",
+ ],
+ path = "/tf-color-scale",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "@org_polymer",
+ ],
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = [
+ "colorScale.ts",
+ "palettes.ts",
+ ],
+ typings = ["@org_definitelytyped//:d3.d.ts"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-color-scale.html",
+ ":legacy_ts",
+ ],
+ destdir = "tf-color-scale",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [
+ "colorScale.ts",
+ "palettes.ts",
+ ],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = ["//tensorflow/tensorboard/components:common_deps"],
+)
diff --git a/tensorflow/tensorboard/components/tf_color_scale/demo/BUILD b/tensorflow/tensorboard/components/tf_color_scale/demo/BUILD
new file mode 100644
index 0000000000..00b8a033b8
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_color_scale/demo/BUILD
@@ -0,0 +1,26 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/tf_color_scale/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/tf-color-scale/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_color_scale",
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_button",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_color_scale_d3v4/BUILD b/tensorflow/tensorboard/components/tf_color_scale_d3v4/BUILD
new file mode 100644
index 0000000000..fd7d394a36
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_color_scale_d3v4/BUILD
@@ -0,0 +1,72 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load(
+ "//tensorflow/tensorboard:defs.bzl",
+ "tensorboard_ts_development_sources",
+ "tensorboard_ts_devserver",
+ "tensorboard_ts_library",
+ "tensorboard_webcomponent_library",
+)
+
+licenses(["notice"]) # Apache 2.0
+
+# TODO(dandelion): Add runner for the test code.
+
+tensorboard_webcomponent_library(
+ name = "tf_color_scale",
+ srcs = ["tf-color-scale.html"],
+ ts_lib_deps = [":ts"],
+ deps = [
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "ts",
+ srcs = [
+ "colorScale.ts",
+ "palettes.ts",
+ ],
+ deps = ["//tensorflow/tensorboard/components:common_deps_d3v4"],
+)
+
+tensorboard_ts_library(
+ name = "tests",
+ srcs = ["colorScaleTests.ts"],
+ deps = [
+ ":ts",
+ "//tensorflow/tensorboard/components:common_deps_d3v4",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+tensorboard_webcomponent_library(
+ name = "demo",
+ srcs = ["demo.html"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tf_color_scale",
+ "//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
+ "//third_party/javascript/polymer/v1/paper-button:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ ],
+)
+
+tensorboard_ts_devserver(
+ name = "devserver",
+ manifest = ":dev_sources",
+ serving_path = "/demo_out/bundle.js",
+ static_files = [":demo"],
+)
+
+tensorboard_ts_development_sources(
+ name = "dev_sources",
+ deps = [
+ ":ts",
+ ],
+)
diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common/BUILD
new file mode 100644
index 0000000000..880e6bd712
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_dashboard_common/BUILD
@@ -0,0 +1,102 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "tf_dashboard_common",
+ srcs = glob(["*.html"]) + [
+ ":ts",
+ ],
+ path = "/tf-dashboard-common",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "//tensorflow/tensorboard/components/tf_imports:plottable",
+ "//tensorflow/tensorboard/components/tf_storage",
+ "//tensorflow/tensorboard/components/vz_sorting",
+ "@org_polymer",
+ "@org_polymer_iron_ajax",
+ "@org_polymer_iron_collapse",
+ "@org_polymer_iron_icons",
+ "@org_polymer_paper_button",
+ "@org_polymer_paper_checkbox",
+ "@org_polymer_paper_dialog",
+ "@org_polymer_paper_dropdown_menu",
+ "@org_polymer_paper_icon_button",
+ "@org_polymer_paper_input",
+ "@org_polymer_paper_item",
+ "@org_polymer_paper_menu",
+ "@org_polymer_paper_slider",
+ "@org_polymer_paper_spinner",
+ "@org_polymer_paper_styles",
+ "@org_polymer_paper_toggle_button",
+ ],
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = [
+ "categorizer.ts",
+ "dashboard-behavior.ts",
+ "reload-behavior.ts",
+ ],
+ typings = [
+ "@org_definitelytyped//:d3.d.ts",
+ "//tensorflow/tensorboard/components/vz_sorting:ts_typings",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = glob(["*.html"]) + [":legacy_ts"],
+ destdir = "tf-dashboard-common",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//tensorflow/tensorboard/components/tf_storage:legacy",
+ "//tensorflow/tensorboard/components/vz_sorting:legacy",
+ "//third_party/javascript/polymer/v1/iron-ajax:lib",
+ "//third_party/javascript/polymer/v1/iron-collapse:lib",
+ "//third_party/javascript/polymer/v1/iron-icons:lib",
+ "//third_party/javascript/polymer/v1/paper-button:lib",
+ "//third_party/javascript/polymer/v1/paper-checkbox:lib",
+ "//third_party/javascript/polymer/v1/paper-dialog:lib",
+ "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
+ "//third_party/javascript/polymer/v1/paper-icon-button:lib",
+ "//third_party/javascript/polymer/v1/paper-input:lib",
+ "//third_party/javascript/polymer/v1/paper-item:lib",
+ "//third_party/javascript/polymer/v1/paper-menu:lib",
+ "//third_party/javascript/polymer/v1/paper-slider:lib",
+ "//third_party/javascript/polymer/v1/paper-spinner:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ "//third_party/javascript/polymer/v1/paper-toggle-button:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [
+ "categorizer.ts",
+ "dashboard-behavior.ts",
+ "reload-behavior.ts",
+ ],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = [
+ "//tensorflow/tensorboard/components:common_deps",
+ "//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
+ ],
+)
diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/demo/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common/demo/BUILD
new file mode 100644
index 0000000000..05cfe34e72
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_dashboard_common/demo/BUILD
@@ -0,0 +1,31 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/tf_dashboard_common/demo
+webfiles(
+ name = "demo",
+ srcs = [
+ "tf-categorizer-demo.html",
+ "tf-collapsable-pane-demo.html",
+ "tf-multi-checkbox-demo.html",
+ "tf-regex-group-demo.html",
+ ],
+ path = "/tf-dashboard-common/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_color_scale",
+ "//tensorflow/tensorboard/components/tf_dashboard_common",
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "@org_polymer_iron_flex_layout",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html
index e2530d5971..83b141cb98 100644
--- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html
+++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html
@@ -57,6 +57,8 @@ plugin is requred to implement two functions:
</style>
</template>
<script>
+ "use strict";
+
Polymer({
is: "tf-chart-scaffold",
properties: {
diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/BUILD
new file mode 100644
index 0000000000..85f06dabb5
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/BUILD
@@ -0,0 +1,114 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load(
+ "//tensorflow/tensorboard:defs.bzl",
+ "tensorboard_ts_development_sources",
+ "tensorboard_ts_devserver",
+ "tensorboard_ts_library",
+ "tensorboard_webcomponent_library",
+)
+
+licenses(["notice"]) # Apache 2.0
+
+tensorboard_webcomponent_library(
+ name = "tf_dashboard_common",
+ srcs = [
+ "dashboard-style.html",
+ "run-color-style.html",
+ "scrollbar-style.html",
+ "tensorboard-color.html",
+ "tf-categorizer.html",
+ "tf-collapsable-pane.html",
+ "tf-dashboard.html",
+ "tf-dashboard-layout.html",
+ "tf-downloader.html",
+ "tf-multi-checkbox.html",
+ "tf-no-data-warning.html",
+ "tf-option-selector.html",
+ "tf-panes-helper.html",
+ "tf-regex-group.html",
+ "tf-run-selector.html",
+ "tf-sidebar-helper.html",
+ ],
+ ts_lib_deps = [":ts"],
+ deps = [
+ "//third_party/javascript/plottable/v3:lib",
+ "//third_party/javascript/polymer/v1/iron-ajax:lib",
+ "//third_party/javascript/polymer/v1/iron-collapse:lib",
+ "//third_party/javascript/polymer/v1/iron-icons:lib",
+ "//third_party/javascript/polymer/v1/paper-button:lib",
+ "//third_party/javascript/polymer/v1/paper-checkbox:lib",
+ "//third_party/javascript/polymer/v1/paper-dialog:lib",
+ "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
+ "//third_party/javascript/polymer/v1/paper-icon-button:lib",
+ "//third_party/javascript/polymer/v1/paper-input:lib",
+ "//third_party/javascript/polymer/v1/paper-item:lib",
+ "//third_party/javascript/polymer/v1/paper-menu:lib",
+ "//third_party/javascript/polymer/v1/paper-slider:lib",
+ "//third_party/javascript/polymer/v1/paper-spinner:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ "//third_party/javascript/polymer/v1/paper-toggle-button:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "ts",
+ srcs = [
+ "dashboard-behavior.ts",
+ "reload-behavior.ts",
+ "tf-categorizer.ts",
+ "tf-multi-checkbox.ts",
+ "tf-regex-group.ts",
+ ],
+ deps = [
+ "//tensorflow/tensorboard/components:common_deps_d3v4",
+ "//tensorflow/tensorboard/components/tf_storage_d3v4:ts",
+ "//tensorflow/tensorboard/components/vz_sorting_d3v4:ts",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "tests",
+ srcs = ["tf-categorizer-tests.ts"],
+ deps = [
+ ":ts",
+ "//tensorflow/tensorboard/components:common_deps_d3v4",
+ ],
+)
+
+tensorboard_webcomponent_library(
+ name = "demo",
+ srcs = [
+ "tf-categorizer-demo.html",
+ "tf-collapsable-pane-demo.html",
+ "tf-multi-checkbox-demo.html",
+ "tf-regex-group-demo.html",
+ ],
+ deps = [
+ ":tf_dashboard_common",
+ "//tensorflow/tensorboard/components/tf_color_scale_d3v4:tf_color_scale",
+ "//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ ],
+)
+
+tensorboard_ts_devserver(
+ name = "devserver",
+ manifest = ":dev_sources",
+ serving_path = "/demo_out/bundle.js",
+ static_files = [":demo"],
+)
+
+tensorboard_ts_development_sources(
+ name = "dev_sources",
+ deps = [
+ ":ts",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html
index 402f909287..fd5320be39 100644
--- a/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html
+++ b/tensorflow/tensorboard/components/tf_dashboard_common_d3v4/tf-chart-scaffold.html
@@ -55,6 +55,8 @@ plugin is requred to implement two functions:
</style>
</template>
<script>
+ "use strict";
+
Polymer({
is: "tf-chart-scaffold",
properties: {
diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD b/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD
new file mode 100644
index 0000000000..e71bbdea2b
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD
@@ -0,0 +1,63 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "tf_distribution_dashboard",
+ srcs = [
+ "tf-distribution-dashboard.html",
+ ],
+ path = "/tf-distribution-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_backend",
+ "//tensorflow/tensorboard/components/tf_color_scale",
+ "//tensorflow/tensorboard/components/tf_dashboard_common",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "//tensorflow/tensorboard/components/vz_distribution_chart",
+ "@org_polymer",
+ "@org_polymer_iron_collapse",
+ "@org_polymer_paper_icon_button",
+ "@org_polymer_paper_styles",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-distribution-dashboard.html",
+ ":legacy_ts",
+ ],
+ destdir = "tf-distribution-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//tensorflow/tensorboard/components/tf_backend:legacy",
+ "//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
+ "//tensorflow/tensorboard/components/vz_distribution_chart:legacy",
+ "//third_party/javascript/polymer/v1/iron-collapse:lib",
+ "//third_party/javascript/polymer/v1/paper-icon-button:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+# This is needed: components/BUILD seeks a legacy_ts rule in this package.
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = ["//tensorflow/tensorboard/components:common_deps"],
+)
diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_distribution_dashboard/demo/BUILD
new file mode 100644
index 0000000000..238937c0c2
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_distribution_dashboard/demo/BUILD
@@ -0,0 +1,26 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/tf_distribution_dashboard/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/tf-distribution-dashboard/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_distribution_dashboard",
+ "//tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data",
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data/BUILD b/tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data/BUILD
new file mode 100644
index 0000000000..589c1980e4
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data/BUILD
@@ -0,0 +1,17 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "data",
+ srcs = glob(["*"]),
+ path = "/tf-distribution-dashboard/demo/data",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_globals/BUILD b/tensorflow/tensorboard/components/tf_globals/BUILD
new file mode 100644
index 0000000000..2e5e4b0515
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_globals/BUILD
@@ -0,0 +1,49 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+# TODO(dandelion): Add webfiles support for the test code.
+
+webfiles(
+ name = "tf_globals",
+ srcs = [
+ "tf-globals.html",
+ ":ts",
+ ],
+ path = "/tf-globals",
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = ["globals.ts"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-globals.html",
+ ":legacy_ts",
+ ],
+ destdir = "tf-globals",
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = ["globals.ts"],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+)
diff --git a/tensorflow/tensorboard/components/tf_globals_d3v4/BUILD b/tensorflow/tensorboard/components/tf_globals_d3v4/BUILD
new file mode 100644
index 0000000000..202e72aaa6
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_globals_d3v4/BUILD
@@ -0,0 +1,16 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+
+licenses(["notice"]) # Apache 2.0
+
+tensorboard_ts_library(
+ name = "ts",
+ srcs = ["globals.ts"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_graph_common/BUILD b/tensorflow/tensorboard/components/tf_graph_common/BUILD
new file mode 100644
index 0000000000..772abbbd31
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_graph_common/BUILD
@@ -0,0 +1,65 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "tf_graph_common",
+ srcs = [
+ "tf-graph-common.html",
+ ":ts",
+ ],
+ path = "/tf-graph-common",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "//tensorflow/tensorboard/components/tf_imports:dagre",
+ "//tensorflow/tensorboard/components/tf_imports:graphlib",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "@org_polymer",
+ ],
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = glob(["*.ts"]),
+ typings = [
+ "@org_definitelytyped//:d3.d.ts",
+ "@org_definitelytyped//:lodash.d.ts",
+ "@org_definitelytyped//:polymer.d.ts",
+ "@org_definitelytyped//:webcomponents.js.d.ts",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-graph-common.html",
+ ":legacy_ts",
+ ],
+ destdir = "tf-graph-common",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = glob(["*.ts"]),
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = ["//tensorflow/tensorboard/components:common_deps"],
+)
diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html b/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html
index 573c3dfa60..d2b069f320 100644
--- a/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html
+++ b/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html
@@ -103,6 +103,8 @@ out-hierarchy-params="{{_hierarchyParams}}"
</dom-module>
<script>
+"use strict";
+
(function() {
TF.Dashboard.TfGraphDashboard = Polymer({
is: 'tf-graph-dashboard',
diff --git a/tensorflow/tensorboard/components/tf_graph_info/tf-graph-info.html b/tensorflow/tensorboard/components/tf_graph_info/tf-graph-info.html
index 45347fb1de..b33e1e00d0 100644
--- a/tensorflow/tensorboard/components/tf_graph_info/tf-graph-info.html
+++ b/tensorflow/tensorboard/components/tf_graph_info/tf-graph-info.html
@@ -169,6 +169,8 @@ h2 {
</template>
</template>
<script>
+"use strict";
+
(function() {
Polymer({
is: 'tf-graph-info',
diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/BUILD b/tensorflow/tensorboard/components/tf_histogram_dashboard/BUILD
new file mode 100644
index 0000000000..c74eafa658
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_histogram_dashboard/BUILD
@@ -0,0 +1,62 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "tf_histogram_dashboard",
+ srcs = [
+ "tf-histogram-dashboard.html",
+ ],
+ path = "/tf-histogram-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_backend",
+ "//tensorflow/tensorboard/components/tf_color_scale",
+ "//tensorflow/tensorboard/components/tf_dashboard_common",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "//tensorflow/tensorboard/components/vz_histogram_timeseries",
+ "@org_polymer",
+ "@org_polymer_iron_collapse",
+ "@org_polymer_paper_icon_button",
+ "@org_polymer_paper_styles",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-histogram-dashboard.html",
+ ":legacy_ts",
+ ],
+ destdir = "tf-histogram-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//tensorflow/tensorboard/components/tf_backend:legacy",
+ "//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
+ "//tensorflow/tensorboard/components/vz_histogram_timeseries:legacy",
+ "//third_party/javascript/polymer/v1/iron-collapse:lib",
+ "//third_party/javascript/polymer/v1/paper-icon-button:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+# This is needed: components/BUILD seeks a legacy_ts rule in this package.
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+)
diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_histogram_dashboard/demo/BUILD
new file mode 100644
index 0000000000..8350084874
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_histogram_dashboard/demo/BUILD
@@ -0,0 +1,26 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/tf_histogram_dashboard/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/tf-histogram-dashboard/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_histogram_dashboard",
+ "//tensorflow/tensorboard/components/tf_histogram_dashboard/demo/data",
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/demo/data/BUILD b/tensorflow/tensorboard/components/tf_histogram_dashboard/demo/data/BUILD
new file mode 100644
index 0000000000..d396efab73
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_histogram_dashboard/demo/data/BUILD
@@ -0,0 +1,17 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "data",
+ srcs = glob(["*"]),
+ path = "/tf-histogram-dashboard/demo/data",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/BUILD b/tensorflow/tensorboard/components/tf_image_dashboard/BUILD
new file mode 100644
index 0000000000..b69fe6ba22
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_image_dashboard/BUILD
@@ -0,0 +1,59 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "tf_image_dashboard",
+ srcs = [
+ "tf-image-dashboard.html",
+ "tf-image-loader.html",
+ ],
+ path = "/tf-image-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_backend",
+ "//tensorflow/tensorboard/components/tf_color_scale",
+ "//tensorflow/tensorboard/components/tf_dashboard_common",
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "@org_polymer",
+ "@org_polymer_paper_dialog",
+ "@org_polymer_paper_icon_button",
+ "@org_polymer_paper_slider",
+ "@org_polymer_paper_spinner",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-image-dashboard.html",
+ "tf-image-loader.html",
+ ":legacy_ts",
+ ],
+ destdir = "tf-image-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_backend:legacy",
+ "//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
+ ],
+)
+
+# This is needed: components/BUILD seeks a legacy_ts rule in this package.
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+)
diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_image_dashboard/demo/BUILD
new file mode 100644
index 0000000000..3a42342ca0
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_image_dashboard/demo/BUILD
@@ -0,0 +1,25 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/tf_image_dashboard/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/tf-image-dashboard/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_image_dashboard",
+ "//tensorflow/tensorboard/components/tf_image_dashboard/demo/data",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/demo/data/BUILD b/tensorflow/tensorboard/components/tf_image_dashboard/demo/data/BUILD
new file mode 100644
index 0000000000..a613ac66c7
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_image_dashboard/demo/data/BUILD
@@ -0,0 +1,17 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "data",
+ srcs = glob(["*"]),
+ path = "/tf-image-dashboard/demo/data",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html b/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html
index f667520fb5..d9ba013dce 100644
--- a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html
+++ b/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html
@@ -108,6 +108,8 @@ future for loading older images.
</style>
</template>
<script>
+ "use strict";
+
Polymer({
is: "tf-image-loader",
properties: {
diff --git a/tensorflow/tensorboard/components/tf_imports/BUILD b/tensorflow/tensorboard/components/tf_imports/BUILD
new file mode 100644
index 0000000000..b41a6bd446
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_imports/BUILD
@@ -0,0 +1,120 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "d3",
+ srcs = [
+ "d3.html",
+ "@org_d3js",
+ ],
+ path = "/tf-imports",
+)
+
+webfiles(
+ name = "lodash",
+ srcs = [
+ "lodash.html",
+ "@com_lodash",
+ ],
+ path = "/tf-imports",
+)
+
+webfiles(
+ name = "graphlib",
+ srcs = [
+ "graphlib.html",
+ "@io_github_cpettitt_graphlib",
+ ],
+ path = "/tf-imports",
+ deps = [":lodash"],
+)
+
+webfiles(
+ name = "dagre",
+ srcs = [
+ "dagre.html",
+ "@io_github_cpettitt_dagre",
+ ],
+ path = "/tf-imports",
+ deps = [
+ ":graphlib",
+ ":lodash",
+ ],
+)
+
+webfiles(
+ name = "plottable",
+ srcs = [
+ "plottable.html",
+ "@com_palantir_plottable//:plottable.css",
+ "@com_palantir_plottable//:plottable.js",
+ ],
+ path = "/tf-imports",
+ deps = [":d3"],
+)
+
+# Generate single TypeScript typings file for d3.js with no ES6 imports.
+#
+# The DefinitelyTyped definition of d3 v4 was written under the assumption that
+# we want to use d3 in a modularized way. We don't want to do that because its
+# import statements use NodeJS namespaces, and the Web Compiler only supports
+# W3C, ECMA, and IETF standards.
+genrule(
+ name = "d3v4_typings",
+ srcs = [
+ # please maintain a reverse topological order
+ "@org_definitelytyped_types_d3_path//:index.d.ts",
+ "@org_definitelytyped_types_d3_time//:index.d.ts",
+ "@org_definitelytyped_types_d3_dsv//:index.d.ts",
+ "@org_definitelytyped_types_d3_color//:index.d.ts",
+ "@org_definitelytyped_types_d3_selection//:index.d.ts",
+ # d3-transition defines stuff in d3-selection wat?
+ # "@org_definitelytyped_types_d3_transition//:index.d.ts",
+ "@org_definitelytyped_types_d3_shape//:index.d.ts",
+ "@org_definitelytyped_types_d3_scale//:index.d.ts",
+ "@org_definitelytyped_types_d3_request//:index.d.ts",
+ "@org_definitelytyped_types_d3_interpolate//:index.d.ts",
+ "@org_definitelytyped_types_d3_drag//:index.d.ts",
+ "@org_definitelytyped_types_d3_brush//:index.d.ts",
+ "@org_definitelytyped_types_d3_axis//:index.d.ts",
+ "@org_definitelytyped_types_d3_zoom//:index.d.ts",
+ "@org_definitelytyped_types_d3_array//:index.d.ts",
+ "@org_definitelytyped_types_d3_chord//:index.d.ts",
+ "@org_definitelytyped_types_d3_collection//:index.d.ts",
+ "@org_definitelytyped_types_d3_dispatch//:index.d.ts",
+ "@org_definitelytyped_types_d3_ease//:index.d.ts",
+ "@org_definitelytyped_types_d3_force//:index.d.ts",
+ "@org_definitelytyped_types_d3_format//:index.d.ts",
+ "@org_definitelytyped_types_d3_hierarchy//:index.d.ts",
+ "@org_definitelytyped_types_d3_polygon//:index.d.ts",
+ "@org_definitelytyped_types_d3_quadtree//:index.d.ts",
+ "@org_definitelytyped_types_d3_queue//:index.d.ts",
+ "@org_definitelytyped_types_d3_random//:index.d.ts",
+ "@org_definitelytyped_types_d3_timer//:index.d.ts",
+ "@org_definitelytyped_types_d3_voronoi//:index.d.ts",
+ ],
+ outs = ["d3v4.d.ts"],
+ cmd = "\n".join([
+ "(",
+ " echo 'declare namespace d3 {'",
+ " for f in $(SRCS); do",
+ " echo",
+ " echo /////////////////////////////////////////////",
+ " echo // $$f",
+ " echo /////////////////////////////////////////////",
+ " echo",
+ " sed '/^import /d' $$f",
+ " done",
+ " echo '}'",
+ ") >$@",
+ ]),
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_imports_d3v4/BUILD b/tensorflow/tensorboard/components/tf_imports_d3v4/BUILD
new file mode 100644
index 0000000000..23ecee759c
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_imports_d3v4/BUILD
@@ -0,0 +1,78 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "d3",
+ srcs = [
+ "d3.html",
+ "@org_d3js_v4",
+ ],
+ path = "/tf-imports-d3v4",
+)
+
+webfiles(
+ name = "lodash",
+ srcs = [
+ "lodash.html",
+ "@com_lodash",
+ ],
+ path = "/tf-imports-d3v4",
+)
+
+webfiles(
+ name = "graphlib",
+ srcs = [
+ "graphlib.html",
+ "@io_github_cpettitt_graphlib",
+ ],
+ path = "/tf-imports-d3v4",
+ deps = [":lodash"],
+)
+
+webfiles(
+ name = "dagre",
+ srcs = [
+ "dagre.html",
+ "@io_github_cpettitt_dagre",
+ ],
+ path = "/tf-imports-d3v4",
+ deps = [
+ ":graphlib",
+ ":lodash",
+ ],
+)
+
+webfiles(
+ name = "plottable",
+ srcs = [
+ "plottable.html",
+ # TODO(jart): Replace when Plottable v3 is fixed.
+ "plottable.js",
+ "plottable.css",
+ # "//third_party/javascript/plottable/v3:plottable.css",
+ # "//third_party/javascript/plottable/v3:plottable.js",
+ ],
+ path = "/tf-imports-d3v4",
+ deps = [":d3"],
+)
+
+genrule(
+ name = "TEMPORARY_fake_plottable_css",
+ outs = ["plottable.css"],
+ cmd = "echo >$@",
+)
+
+genrule(
+ name = "TEMPORARY_fake_plottable_js",
+ outs = ["plottable.js"],
+ cmd = "echo >$@",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD
new file mode 100644
index 0000000000..436d0c262d
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD
@@ -0,0 +1,77 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "tf_scalar_dashboard",
+ srcs = [
+ "tf-scalar-dashboard.html",
+ "tf-smoothing-input.html",
+ ],
+ path = "/tf-scalar-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_backend",
+ "//tensorflow/tensorboard/components/tf_color_scale",
+ "//tensorflow/tensorboard/components/tf_dashboard_common",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "//tensorflow/tensorboard/components/vz_line_chart",
+ "@org_polymer",
+ "@org_polymer_iron_collapse",
+ "@org_polymer_paper_checkbox",
+ "@org_polymer_paper_dropdown_menu",
+ "@org_polymer_paper_icon_button",
+ "@org_polymer_paper_input",
+ "@org_polymer_paper_item",
+ "@org_polymer_paper_menu",
+ "@org_polymer_paper_slider",
+ "@org_polymer_paper_styles",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-scalar-dashboard.html",
+ "tf-smoothing-input.html",
+ ],
+ destdir = "tf-scalar-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//tensorflow/tensorboard/components/tf_backend:legacy",
+ "//tensorflow/tensorboard/components/tf_color_scale:legacy",
+ "//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
+ "//tensorflow/tensorboard/components/vz_line_chart:legacy",
+ "//third_party/javascript/polymer/v1/iron-collapse:lib",
+ "//third_party/javascript/polymer/v1/paper-checkbox:lib",
+ "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
+ "//third_party/javascript/polymer/v1/paper-icon-button:lib",
+ "//third_party/javascript/polymer/v1/paper-input:lib",
+ "//third_party/javascript/polymer/v1/paper-item:lib",
+ "//third_party/javascript/polymer/v1/paper-menu:lib",
+ "//third_party/javascript/polymer/v1/paper-slider:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+# This is needed: components/BUILD seeks a legacy_ts rule in this package.
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = ["//tensorflow/tensorboard/components:common_deps"],
+)
diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD
new file mode 100644
index 0000000000..218fda3fdb
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD
@@ -0,0 +1,26 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/tf_scalar_dashboard/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/tf-scalar-dashboard/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "//tensorflow/tensorboard/components/tf_scalar_dashboard",
+ "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/BUILD
new file mode 100644
index 0000000000..7f39d27f60
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/BUILD
@@ -0,0 +1,17 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "data",
+ srcs = glob(["*"]),
+ path = "/tf-scalar-dashboard/demo/data",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_storage/BUILD b/tensorflow/tensorboard/components/tf_storage/BUILD
new file mode 100644
index 0000000000..e796340e03
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_storage/BUILD
@@ -0,0 +1,70 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+# TODO(dandelion): Add webfiles support for the test code.
+
+webfiles(
+ name = "tf_storage",
+ srcs = [
+ "tf-storage.html",
+ ":ts",
+ ],
+ path = "/tf-storage",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_globals",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "@org_polymer",
+ ],
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = [
+ "storage.ts",
+ ],
+ typings = [
+ "@org_definitelytyped//:lodash.d.ts",
+ "//tensorflow/tensorboard/components/tf_globals:ts_typings",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-storage.html",
+ ":legacy_ts",
+ ],
+ visibility = ["//visibility:public"],
+ destdir = "tf-storage",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//tensorflow/tensorboard/components/tf_globals:legacy",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = ["storage.ts"],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = [
+ "//tensorflow/tensorboard/components:common_deps",
+ "//tensorflow/tensorboard/components/tf_globals:legacy_ts",
+ ],
+)
diff --git a/tensorflow/tensorboard/components/tf_storage_d3v4/BUILD b/tensorflow/tensorboard/components/tf_storage_d3v4/BUILD
new file mode 100644
index 0000000000..877f026319
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_storage_d3v4/BUILD
@@ -0,0 +1,35 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+
+licenses(["notice"]) # Apache 2.0
+
+# TODO(dandelion): Add test runner.
+
+tensorboard_ts_library(
+ name = "ts",
+ srcs = ["storage.ts"],
+ deps = [
+ "//tensorflow/tensorboard/components/tf_globals_d3v4:ts",
+ "//third_party/javascript/lodash:lodash-module",
+ "//third_party/javascript/typings/lodash",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "tests",
+ srcs = ["storageTests.ts"],
+ deps = [
+ ":ts",
+ "//tensorflow/tensorboard/components/tf_globals_d3v4:ts",
+ "//third_party/javascript/typings/chai",
+ "//third_party/javascript/typings/mocha",
+ "//third_party/javascript/typings/sinon",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html
index b5b2e2d5a8..7440263f88 100644
--- a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html
+++ b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html
@@ -210,6 +210,8 @@ allows the user to toggle between various dashboards.
</style>
</template>
<script>
+ "use strict";
+
Polymer({
is: "tf-tensorboard",
behaviors: [TF.TensorBoard.AutoReloadBehavior],
diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD
new file mode 100644
index 0000000000..71e16db691
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD
@@ -0,0 +1,60 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "tf_text_dashboard",
+ srcs = [
+ "tf-text-dashboard.html",
+ "tf-text-loader.html",
+ ],
+ path = "/tf-text-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_backend",
+ "//tensorflow/tensorboard/components/tf_color_scale",
+ "//tensorflow/tensorboard/components/tf_dashboard_common",
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "@org_polymer",
+ "@org_polymer_paper_dialog",
+ "@org_polymer_paper_icon_button",
+ "@org_polymer_paper_material",
+ "@org_polymer_paper_slider",
+ "@org_polymer_paper_spinner",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "tf-text-dashboard.html",
+ "tf-text-loader.html",
+ ],
+ destdir = "tf-text-dashboard",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_backend:legacy",
+ "//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
+ "//third_party/javascript/polymer/v1/paper-material:lib",
+ ],
+)
+
+# This is needed: components/BUILD seeks a legacy_ts rule in this package.
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+)
diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/demo/BUILD
new file mode 100644
index 0000000000..6cd6702e4b
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/BUILD
@@ -0,0 +1,25 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/tf_text_dashboard/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/tf-text-dashboard/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_text_dashboard",
+ "//tensorflow/tensorboard/components/tf_text_dashboard/demo/data",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/BUILD
new file mode 100644
index 0000000000..8adf661396
--- /dev/null
+++ b/tensorflow/tensorboard/components/tf_text_dashboard/demo/data/BUILD
@@ -0,0 +1,17 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "data",
+ srcs = glob(["*"]),
+ path = "/tf-text-dashboard/demo/data/plugin/text",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/vz_data_summary/BUILD b/tensorflow/tensorboard/components/vz_data_summary/BUILD
new file mode 100644
index 0000000000..a4ba0c089c
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_data_summary/BUILD
@@ -0,0 +1,91 @@
+package(default_visibility = ["//visibility:public"])
+
+load(
+ "//tensorflow/tensorboard:defs.bzl",
+ "tensorboard_ts_config",
+ "tensorboard_ts_declaration",
+ "tensorboard_ts_development_sources",
+ "tensorboard_ts_devserver",
+ "tensorboard_ts_library",
+ "tensorboard_webcomponent_library",
+)
+
+licenses(["notice"]) # Apache 2.0
+
+tensorboard_webcomponent_library(
+ name = "lib",
+ srcs = ["vz-data-summary.html"],
+ ts_lib_deps = [":ts_lib"],
+ destdir = "vz-data-summary",
+ deps = [
+ "//learning/vis/vz_elements:common",
+ "//third_party/javascript/d3/v3:lib",
+ "//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
+ "//third_party/javascript/polymer/v1/iron-resizable-behavior:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "ts_lib",
+ srcs = ["vz-data-summary.ts"],
+ externs_list = [":externs"],
+ deps = [
+ ":typings",
+ "//third_party/javascript/typings/polymer:polymer_without_externs",
+ ],
+)
+
+tensorboard_ts_declaration(
+ name = "typings",
+ srcs = ["typings.d.ts"],
+)
+
+# This build rule is used to run the demo.
+tensorboard_ts_devserver(
+ name = "dev_server",
+ manifest = ":dev_sources",
+ serving_path = "/demo_lib_out/vz-data-summary/vz-data-summary.js",
+ static_files = [":demo_lib"],
+ deps = [":tsconfig"],
+)
+
+tensorboard_webcomponent_library(
+ name = "demo_lib",
+ srcs = ["demo.html"],
+ destdir = "vz-data-summary",
+ deps = [
+ ":lib",
+ "//third_party/javascript/d3/v3:lib",
+ "//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
+ "//third_party/javascript/polymer/v1/iron-resizable-behavior:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "demo_ts_lib",
+ srcs = ["demo.ts"],
+ externs_list = [":externs"],
+ deps = [
+ ":ts_lib",
+ "//third_party/javascript/typings/d3",
+ ],
+)
+
+tensorboard_ts_development_sources(
+ name = "dev_sources",
+ deps = [":demo_ts_lib"],
+)
+
+tensorboard_ts_config(
+ name = "tsconfig",
+ deps = [":ts_lib"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE b/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE
deleted file mode 100644
index 9743d70d94..0000000000
--- a/tensorflow/tensorboard/components/vz_data_summary/BUILD.OPENSOURCE
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the 'License');
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an 'AS IS' BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# =============================================================================
-
-# Description:
-# Package for the data-summary vz-element.
-package(default_visibility = ["//tensorflow:internal"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tensorboard/components/vz_distribution_chart/BUILD b/tensorflow/tensorboard/components/vz_distribution_chart/BUILD
new file mode 100644
index 0000000000..5ce98b1d7c
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_distribution_chart/BUILD
@@ -0,0 +1,69 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "vz_distribution_chart",
+ srcs = [
+ "vz-distribution-chart.html",
+ ":ts",
+ ],
+ path = "/vz-distribution-chart",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "//tensorflow/tensorboard/components/tf_imports:plottable",
+ "//tensorflow/tensorboard/components/vz_line_chart",
+ "@org_polymer",
+ ],
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = ["vz-distribution-chart.ts"],
+ typings = [
+ "@org_definitelytyped//:d3.d.ts",
+ "@com_palantir_plottable//:plottable.d.ts",
+ "@org_definitelytyped//:lodash.d.ts",
+ "//tensorflow/tensorboard/components/vz_line_chart:ts_typings",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "vz-distribution-chart.html",
+ ":legacy_ts",
+ ],
+ visibility = ["//visibility:public"],
+ destdir = "vz-distribution-chart",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//tensorflow/tensorboard/components/vz_sorting:legacy",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = ["vz-distribution-chart.ts"],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = [
+ "//tensorflow/tensorboard/components:common_deps",
+ "//tensorflow/tensorboard/components/vz_line_chart:legacy_ts",
+ ],
+)
diff --git a/tensorflow/tensorboard/components/vz_distribution_chart/demo/BUILD b/tensorflow/tensorboard/components/vz_distribution_chart/demo/BUILD
new file mode 100644
index 0000000000..77f05aa2d4
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_distribution_chart/demo/BUILD
@@ -0,0 +1,24 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/vz_distribution_chart/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/vz-distribution-chart/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/vz_distribution_chart",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/vz_distribution_chart_d3v4/BUILD b/tensorflow/tensorboard/components/vz_distribution_chart_d3v4/BUILD
new file mode 100644
index 0000000000..384fc6f44c
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_distribution_chart_d3v4/BUILD
@@ -0,0 +1,66 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load(
+ "//tensorflow/tensorboard:defs.bzl",
+ "tensorboard_ts_development_sources",
+ "tensorboard_ts_devserver",
+ "tensorboard_ts_library",
+ "tensorboard_webcomponent_library",
+)
+
+licenses(["notice"]) # Apache 2.0
+
+tensorboard_ts_library(
+ name = "ts",
+ srcs = ["vz-distribution-chart.ts"],
+ deps = [
+ "//tensorflow/tensorboard/components:common_deps_d3v4",
+ "//tensorflow/tensorboard/components/vz_line_chart_d3v4:ts",
+ "//third_party/javascript/lodash:lodash-module",
+ "//third_party/javascript/plottable/v3:bundle",
+ ],
+)
+
+tensorboard_webcomponent_library(
+ name = "vz_distribution_chart",
+ srcs = [
+ "demo.html",
+ "vz-distribution-chart.html",
+ ],
+ ts_lib_deps = [":ts"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/tensorboard/components/vz_sorting:legacy",
+ "//third_party/javascript/plottable/v3:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_webcomponent_library(
+ name = "demo",
+ srcs = ["demo.html"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":vz_distribution_chart",
+ "//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ ],
+)
+
+tensorboard_ts_devserver(
+ name = "devserver",
+ manifest = ":dev_sources",
+ serving_path = "/demo_out/bundle.js",
+ static_files = [":demo"],
+)
+
+tensorboard_ts_development_sources(
+ name = "dev_sources",
+ deps = [":ts"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD b/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD
new file mode 100644
index 0000000000..31032a6b23
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD
@@ -0,0 +1,50 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "vz_histogram_timeseries",
+ srcs = [
+ "vz-histogram-timeseries.html",
+ ],
+ path = "/vz-histogram-timeseries",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports:d3",
+ "@org_polymer",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "index.html",
+ "vz-histogram-timeseries.html",
+ ":legacy_ts",
+ ],
+ visibility = ["//visibility:public"],
+ destdir = "vz-histogram-timeseries",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+# This is needed: components/BUILD seeks a legacy_ts rule in this package.
+tensorboard_ts_library(
+ name = "legacy_ts",
+ deps_mgmt = "off",
+ runtime = "nodejs",
+)
diff --git a/tensorflow/tensorboard/components/vz_histogram_timeseries/demo/BUILD b/tensorflow/tensorboard/components/vz_histogram_timeseries/demo/BUILD
new file mode 100644
index 0000000000..894de95be6
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_histogram_timeseries/demo/BUILD
@@ -0,0 +1,25 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/vz_histogram_timeseries/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/vz-histogram-timeseries/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/vz_histogram_timeseries",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_button",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4/BUILD b/tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4/BUILD
new file mode 100644
index 0000000000..02c0640742
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4/BUILD
@@ -0,0 +1,63 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "vz_histogram_timeseries",
+ srcs = [
+ "vz-histogram-timeseries.html",
+ ],
+ path = "/vz-histogram-timeseries",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports_d3v4:d3",
+ "@org_polymer",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+# bazel run //third_party/tensorflow/tensorboard/components/vz_histogram_timeseries/demo
+webfiles(
+ name = "demo",
+ srcs = ["demo.html"],
+ path = "/vz-histogram-timeseries",
+ deps = [
+ ":vz_histogram_timeseries",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_button",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "vz-histogram-timeseries.html",
+ ":legacy_ts",
+ ],
+ visibility = ["//visibility:public"],
+ destdir = "vz-histogram-timeseries",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+# This is needed: components/BUILD seeks a legacy_ts rule in this package.
+tensorboard_ts_library(
+ name = "legacy_ts",
+ deps_mgmt = "off",
+ runtime = "nodejs",
+)
diff --git a/tensorflow/tensorboard/components/vz_line_chart/BUILD b/tensorflow/tensorboard/components/vz_line_chart/BUILD
new file mode 100644
index 0000000000..7b13abd721
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_line_chart/BUILD
@@ -0,0 +1,73 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "vz_line_chart",
+ srcs = [
+ "vz-line-chart.html",
+ ":ts",
+ ],
+ path = "/vz-line-chart",
+ deps = [
+ "//tensorflow/tensorboard/components/tf_imports:lodash",
+ "//tensorflow/tensorboard/components/tf_imports:plottable",
+ "@org_polymer",
+ ],
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = [
+ "dragZoomInteraction.ts",
+ "vz-chart-helpers.ts",
+ "vz-line-chart.ts",
+ ],
+ typings = [
+ "@org_definitelytyped//:d3.d.ts",
+ "@com_palantir_plottable//:plottable.d.ts",
+ "@org_definitelytyped//:lodash.d.ts",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "index.html",
+ "vz-line-chart.html",
+ ":legacy_ts",
+ ],
+ visibility = ["//visibility:public"],
+ destdir = "vz-line-chart",
+ deps = [
+ "//tensorflow/tensorboard/components:tf_imports",
+ "//tensorflow/tensorboard/components/vz_sorting:legacy",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = [
+ "dragZoomInteraction.ts",
+ "vz-chart-helpers.ts",
+ "vz-line-chart.ts",
+ ],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = ["//tensorflow/tensorboard/components:common_deps"],
+)
diff --git a/tensorflow/tensorboard/components/vz_line_chart/demo/BUILD b/tensorflow/tensorboard/components/vz_line_chart/demo/BUILD
new file mode 100644
index 0000000000..84699b67b6
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_line_chart/demo/BUILD
@@ -0,0 +1,24 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+
+licenses(["notice"]) # Apache 2.0
+
+# bazel run //third_party/tensorflow/tensorboard/components/vz_line_chart/demo
+webfiles(
+ name = "demo",
+ srcs = ["index.html"],
+ path = "/vz-line-chart/demo",
+ deps = [
+ "//tensorflow/tensorboard/components/vz_line_chart",
+ "@org_polymer_iron_demo_helpers",
+ "@org_polymer_paper_styles",
+ "@org_polymer_webcomponentsjs",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/vz_line_chart_d3v4/BUILD b/tensorflow/tensorboard/components/vz_line_chart_d3v4/BUILD
new file mode 100644
index 0000000000..fd460b3608
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_line_chart_d3v4/BUILD
@@ -0,0 +1,66 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load(
+ "//tensorflow/tensorboard:defs.bzl",
+ "tensorboard_ts_development_sources",
+ "tensorboard_ts_devserver",
+ "tensorboard_ts_library",
+ "tensorboard_webcomponent_library",
+)
+
+licenses(["notice"]) # Apache 2.0
+
+tensorboard_ts_library(
+ name = "ts",
+ srcs = [
+ "dragZoomInteraction.ts",
+ "vz-chart-helpers.ts",
+ "vz-line-chart.ts",
+ ],
+ deps = [
+ "//tensorflow/tensorboard/components:common_deps_d3v4",
+ "//third_party/javascript/lodash:lodash-module",
+ "//third_party/javascript/plottable/v3:bundle",
+ ],
+)
+
+tensorboard_webcomponent_library(
+ name = "vz_line_chart",
+ srcs = ["vz-line-chart.html"],
+ ts_lib_deps = [":ts"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/tensorboard/components/vz_sorting:legacy",
+ "//third_party/javascript/plottable/v3:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+ ],
+)
+
+tensorboard_webcomponent_library(
+ name = "demo",
+ srcs = ["demo.html"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":vz_line_chart",
+ "//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
+ "//third_party/javascript/polymer/v1/paper-styles:lib",
+ ],
+)
+
+tensorboard_ts_devserver(
+ name = "devserver",
+ manifest = ":dev_sources",
+ serving_path = "/demo_out/bundle.js",
+ static_files = [":demo"],
+)
+
+tensorboard_ts_development_sources(
+ name = "dev_sources",
+ deps = [":ts"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/components/vz_projector/BUILD b/tensorflow/tensorboard/components/vz_projector/BUILD
new file mode 100644
index 0000000000..0ea94b9a22
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector/BUILD
@@ -0,0 +1,180 @@
+package(default_visibility = [":projector_group"])
+
+load(
+ "//tensorflow/tensorboard:defs.bzl",
+ "tensorboard_karma_web_test_suite",
+ "tensorboard_ts_config",
+ "tensorboard_ts_declaration",
+ "tensorboard_ts_development_sources",
+ "tensorboard_ts_library",
+ "tensorboard_webcomponent_library",
+)
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+# Visibility group for all clients of projector.
+package_group(
+ name = "projector_group",
+ packages = [
+ "//apps/labs/towerbridge/...",
+ "//experimental/bigpicture/projector/...",
+ "//java/com/google/apps/labs/towerbridge/...",
+ "//learning/vis/projector/...",
+ "//tensorflow/tensorboard/...",
+ ],
+)
+
+tensorboard_ts_declaration(
+ name = "external",
+ srcs = ["external.d.ts"],
+)
+
+tensorboard_ts_library(
+ name = "ts_lib",
+ srcs = glob(
+ ["*.ts"],
+ exclude = [
+ "*.d.ts",
+ "*_test.ts",
+ "bh_tsne.ts",
+ "sptree.ts",
+ ],
+ ),
+ runtime_deps = [
+ "//third_party/javascript/d3/v3:d3",
+ "//third_party/javascript/numericjs",
+ "//third_party/javascript/threejs/r77:threejs",
+ "//third_party/javascript/threejs/r77/examples/js/controls:orbitcontrols",
+ "//third_party/javascript/weblas",
+ ],
+ deps = [
+ ":external",
+ ":tsne_ts_lib",
+ "//third_party/javascript/node_modules/typescript:es2015.promise",
+ "//third_party/javascript/typings/d3",
+ "//third_party/javascript/typings/polymer:polymer_without_externs",
+ "//third_party/javascript/typings/threejs:three",
+ "//third_party/javascript/typings/webcomponents_js",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "tsne_ts_lib",
+ srcs = [
+ "bh_tsne.ts",
+ "sptree.ts",
+ ],
+)
+
+_PROJECTOR_LIB_DEPS = [
+ "//third_party/javascript/polymer/v1/iron-collapse:lib",
+ "//third_party/javascript/polymer/v1/iron-icons:lib",
+ "//third_party/javascript/polymer/v1/paper-button:lib",
+ "//third_party/javascript/polymer/v1/paper-checkbox:lib",
+ "//third_party/javascript/polymer/v1/paper-dialog:lib",
+ "//third_party/javascript/polymer/v1/paper-dialog-scrollable:lib",
+ "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
+ "//third_party/javascript/polymer/v1/paper-icon-button:lib",
+ "//third_party/javascript/polymer/v1/paper-input:lib",
+ "//third_party/javascript/polymer/v1/paper-item:lib",
+ "//third_party/javascript/polymer/v1/paper-listbox:lib",
+ "//third_party/javascript/polymer/v1/paper-slider:lib",
+ "//third_party/javascript/polymer/v1/paper-spinner:lib",
+ "//third_party/javascript/polymer/v1/paper-toast:lib",
+ "//third_party/javascript/polymer/v1/paper-toggle-button:lib",
+ "//third_party/javascript/polymer/v1/paper-tooltip:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+]
+
+_PROJECTOR_DESTDIR = "vz-projector"
+
+_PROJECTOR_LIB_TS_LIB_DEPS = [
+ ":ts_lib",
+ ":tsne_ts_lib",
+]
+
+# Standalone embedding projector demos should depend on this target. We
+# exclude the HTML file for the dashboard itself. Demos do not need that
+# HTML file. This was introduced because standalone demos as of today
+# have an additional Closure pass that uses a compilation configuration
+# stricter than that of TensorBoard.
+tensorboard_webcomponent_library(
+ name = "lib",
+ srcs = glob(
+ ["*.html"],
+ exclude = ["vz-projector-dashboard.html"],
+ ),
+ ts_lib_deps = _PROJECTOR_LIB_TS_LIB_DEPS,
+ destdir = _PROJECTOR_DESTDIR,
+ deps = _PROJECTOR_LIB_DEPS,
+)
+
+# TensorBoard, however, should depend on this target, which includes
+# the HTML file for the dashboard.
+tensorboard_webcomponent_library(
+ name = "lib_for_tensorboard",
+ srcs = glob(["*.html"]),
+ ts_lib_deps = _PROJECTOR_LIB_TS_LIB_DEPS,
+ destdir = _PROJECTOR_DESTDIR,
+ deps = _PROJECTOR_LIB_DEPS,
+)
+
+### Tests ###
+
+tensorboard_ts_library(
+ name = "ts_test",
+ testonly = 1,
+ srcs = glob(["*_test.ts"]),
+ runtime_deps = [
+ "//third_party/javascript/polymer/v1/polymer:lib_all_js",
+ ],
+ deps = [
+ ":ts_lib",
+ ":tsne_ts_lib",
+ "//third_party/javascript/typings/chai",
+ "//third_party/javascript/typings/jasmine:jasmine_without_externs",
+ "//third_party/javascript/typings/mocha",
+ ],
+)
+
+tensorboard_ts_development_sources(
+ name = "dev_sources_for_test",
+ testonly = 1,
+ runtime_deps = [
+ "//third_party/javascript/chai",
+ "//third_party/javascript/mocha",
+ ],
+ deps = [
+ ":ts_test",
+ ],
+)
+
+# To run locally, run :all_tests_local
+tensorboard_karma_web_test_suite(
+ name = "all_tests",
+ size = "medium",
+ browsers = ["//testing/web/browsers:chrome-linux"],
+ manifest = ":dev_sources_for_test",
+)
+
+# Generate a TypeScript IDE project by running this target.
+tensorboard_ts_config(
+ name = "tsconfig",
+ deps = [
+ ":ts_lib",
+ ":ts_test",
+ ":tsne_ts_lib",
+ ],
+)
diff --git a/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE b/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE
deleted file mode 100644
index 8c222be10e..0000000000
--- a/tensorflow/tensorboard/components/vz_projector/BUILD.OPENSOURCE
+++ /dev/null
@@ -1,19 +0,0 @@
-# Description:
-# Package for the Embedding Projector component.
-package(default_visibility = ["//tensorflow:internal"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html
index 3857113ac0..55c15da5ed 100644
--- a/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html
+++ b/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html
@@ -37,6 +37,8 @@ limitations under the License.
</template>
</template>
<script>
+"use strict";
+
(function() {
TF.Dashboard.VzProjectorDashboard = Polymer({
is: 'vz-projector-dashboard',
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD b/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD
new file mode 100644
index 0000000000..d9703d080d
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD
@@ -0,0 +1,179 @@
+package(default_visibility = [":projector_group"])
+
+load(
+ "//tensorflow/tensorboard:defs.bzl",
+ "tensorboard_karma_web_test_suite",
+ "tensorboard_ts_config",
+ "tensorboard_ts_declaration",
+ "tensorboard_ts_development_sources",
+ "tensorboard_ts_library",
+ "tensorboard_webcomponent_library",
+)
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**"],
+ exclude = [
+ "OWNERS",
+ "tsconfig.json",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+# Visibility group for all clients of projector.
+package_group(
+ name = "projector_group",
+ packages = [
+ "//apps/labs/towerbridge/...",
+ "//experimental/bigpicture/projector/...",
+ "//java/com/google/apps/labs/towerbridge/...",
+ "//learning/vis/projector/...",
+ "//tensorflow/tensorboard/...",
+ ],
+)
+
+tensorboard_ts_declaration(
+ name = "external",
+ srcs = ["external.d.ts"],
+)
+
+tensorboard_ts_library(
+ name = "ts_lib",
+ srcs = glob(
+ ["*.ts"],
+ exclude = [
+ "*.d.ts",
+ "*_test.ts",
+ "bh_tsne.ts",
+ "sptree.ts",
+ ],
+ ),
+ runtime_deps = [
+ "//third_party/javascript/numericjs",
+ "//third_party/javascript/threejs/r77:threejs",
+ "//third_party/javascript/threejs/r77/examples/js/controls:orbitcontrols",
+ "//third_party/javascript/weblas",
+ ],
+ deps = [
+ ":external",
+ ":tsne_ts_lib",
+ "//third_party/javascript/node_modules/typescript:es2015.promise",
+ "//third_party/javascript/typings/d3_v4:bundle",
+ "//third_party/javascript/typings/polymer:polymer_without_externs",
+ "//third_party/javascript/typings/threejs:three",
+ "//third_party/javascript/typings/webcomponents_js",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "tsne_ts_lib",
+ srcs = [
+ "bh_tsne.ts",
+ "sptree.ts",
+ ],
+)
+
+_PROJECTOR_LIB_DEPS = [
+ "//third_party/javascript/polymer/v1/iron-collapse:lib",
+ "//third_party/javascript/polymer/v1/iron-icons:lib",
+ "//third_party/javascript/polymer/v1/paper-button:lib",
+ "//third_party/javascript/polymer/v1/paper-checkbox:lib",
+ "//third_party/javascript/polymer/v1/paper-dialog:lib",
+ "//third_party/javascript/polymer/v1/paper-dialog-scrollable:lib",
+ "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
+ "//third_party/javascript/polymer/v1/paper-icon-button:lib",
+ "//third_party/javascript/polymer/v1/paper-input:lib",
+ "//third_party/javascript/polymer/v1/paper-item:lib",
+ "//third_party/javascript/polymer/v1/paper-listbox:lib",
+ "//third_party/javascript/polymer/v1/paper-slider:lib",
+ "//third_party/javascript/polymer/v1/paper-spinner:lib",
+ "//third_party/javascript/polymer/v1/paper-toast:lib",
+ "//third_party/javascript/polymer/v1/paper-toggle-button:lib",
+ "//third_party/javascript/polymer/v1/paper-tooltip:lib",
+ "//third_party/javascript/polymer/v1/polymer:lib",
+]
+
+_PROJECTOR_DESTDIR = "vz-projector"
+
+_PROJECTOR_LIB_TS_LIB_DEPS = [
+ ":ts_lib",
+ ":tsne_ts_lib",
+]
+
+# Standalone embedding projector demos should depend on this target. We
+# exclude the HTML file for the dashboard itself. Demos do not need that
+# HTML file. This was introduced because standalone demos as of today
+# have an additional Closure pass that uses a compilation configuration
+# stricter than that of TensorBoard.
+tensorboard_webcomponent_library(
+ name = "lib",
+ srcs = glob(
+ ["*.html"],
+ exclude = ["vz-projector-dashboard.html"],
+ ),
+ ts_lib_deps = _PROJECTOR_LIB_TS_LIB_DEPS,
+ destdir = _PROJECTOR_DESTDIR,
+ deps = _PROJECTOR_LIB_DEPS,
+)
+
+# TensorBoard, however, should depend on this target, which includes
+# the HTML file for the dashboard.
+tensorboard_webcomponent_library(
+ name = "lib_for_tensorboard",
+ srcs = glob(["*.html"]),
+ ts_lib_deps = _PROJECTOR_LIB_TS_LIB_DEPS,
+ destdir = _PROJECTOR_DESTDIR,
+ deps = _PROJECTOR_LIB_DEPS,
+)
+
+### Tests ###
+
+tensorboard_ts_library(
+ name = "ts_test",
+ testonly = 1,
+ srcs = glob(["*_test.ts"]),
+ runtime_deps = [
+ "//third_party/javascript/polymer/v1/polymer:lib_all_js",
+ ],
+ deps = [
+ ":ts_lib",
+ ":tsne_ts_lib",
+ "//third_party/javascript/typings/chai",
+ "//third_party/javascript/typings/jasmine:jasmine_without_externs",
+ "//third_party/javascript/typings/mocha",
+ ],
+)
+
+tensorboard_ts_development_sources(
+ name = "dev_sources_for_test",
+ testonly = 1,
+ runtime_deps = [
+ "//third_party/javascript/chai",
+ "//third_party/javascript/mocha",
+ ],
+ deps = [
+ ":ts_test",
+ ],
+)
+
+# To run locally, run :all_tests_local
+tensorboard_karma_web_test_suite(
+ name = "all_tests",
+ size = "medium",
+ browsers = ["//testing/web/browsers:chrome-linux"],
+ manifest = ":dev_sources_for_test",
+)
+
+# Generate a TypeScript IDE project by running this target.
+tensorboard_ts_config(
+ name = "tsconfig",
+ deps = [
+ ":ts_lib",
+ ":ts_test",
+ ":tsne_ts_lib",
+ ],
+)
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html
index 3857113ac0..55c15da5ed 100644
--- a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html
@@ -37,6 +37,8 @@ limitations under the License.
</template>
</template>
<script>
+"use strict";
+
(function() {
TF.Dashboard.VzProjectorDashboard = Polymer({
is: 'vz-projector-dashboard',
diff --git a/tensorflow/tensorboard/components/vz_sorting/BUILD b/tensorflow/tensorboard/components/vz_sorting/BUILD
new file mode 100644
index 0000000000..8efedb3639
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_sorting/BUILD
@@ -0,0 +1,50 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
+
+licenses(["notice"]) # Apache 2.0
+
+webfiles(
+ name = "vz_sorting",
+ srcs = [
+ "vz-sorting.html",
+ ":ts",
+ ],
+ path = "/vz-sorting",
+ visibility = ["//visibility:public"],
+)
+
+tensorboard_typescript_genrule(
+ name = "ts",
+ srcs = ["sorting.ts"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_webcomponent_library(
+ name = "legacy",
+ srcs = [
+ "vz-sorting.html",
+ ":legacy_ts",
+ ],
+ visibility = ["//visibility:public"],
+ destdir = "vz-sorting",
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ srcs = ["sorting.ts"],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = ["//tensorflow/tensorboard/components:common_deps"],
+)
diff --git a/tensorflow/tensorboard/components/vz_sorting/test/BUILD b/tensorflow/tensorboard/components/vz_sorting/test/BUILD
new file mode 100644
index 0000000000..d1cf5a596a
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_sorting/test/BUILD
@@ -0,0 +1,40 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_wct_test_suite")
+
+licenses(["notice"]) # Apache 2.0
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
+
+################################################################################
+# MARKED FOR DELETION
+
+tensorboard_wct_test_suite(
+ name = "legacy_test",
+ size = "medium",
+ srcs = [
+ "index.html",
+ ":legacy_ts",
+ ],
+ deps = [
+ "//tensorflow/tensorboard/components/vz_sorting:legacy",
+ "//third_party/javascript/polymer/v1/webcomponentsjs:lib",
+ ],
+)
+
+tensorboard_ts_library(
+ name = "legacy_ts",
+ testonly = 1,
+ srcs = ["sortingTests.ts"],
+ deps_mgmt = "off",
+ runtime = "nodejs",
+ deps = [
+ "//tensorflow/tensorboard/components:common_deps",
+ "//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
+ ],
+)
diff --git a/tensorflow/tensorboard/components/vz_sorting_d3v4/BUILD b/tensorflow/tensorboard/components/vz_sorting_d3v4/BUILD
new file mode 100644
index 0000000000..6e6ce525b4
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_sorting_d3v4/BUILD
@@ -0,0 +1,27 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
+
+licenses(["notice"]) # Apache 2.0
+
+tensorboard_ts_library(
+ name = "ts",
+ srcs = ["sorting.ts"],
+)
+
+tensorboard_ts_library(
+ name = "tests",
+ srcs = ["sortingTests.ts"],
+ deps = [
+ ":ts",
+ "//third_party/javascript/typings/chai",
+ "//third_party/javascript/typings/mocha",
+ "//third_party/javascript/typings/sinon",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(["**"]),
+ tags = ["notsan"],
+)
diff --git a/tensorflow/tensorboard/defs.bzl b/tensorflow/tensorboard/defs.bzl
index 7bb5f961c9..5d88baa5be 100644
--- a/tensorflow/tensorboard/defs.bzl
+++ b/tensorflow/tensorboard/defs.bzl
@@ -60,6 +60,26 @@ def tensorboard_typescript_genrule(name, srcs, typings=[], **kwargs):
**kwargs
)
+def tensorboard_karma_web_test_suite(**kwargs):
+ """Rules referencing this will be deleted from the codebase soon."""
+ pass
+
+def tensorboard_ts_config(**kwargs):
+ """Rules referencing this will be deleted from the codebase soon."""
+ pass
+
+def tensorboard_ts_declaration(**kwargs):
+ """Rules referencing this will be deleted from the codebase soon."""
+ pass
+
+def tensorboard_ts_development_sources(**kwargs):
+ """Rules referencing this will be deleted from the codebase soon."""
+ pass
+
+def tensorboard_ts_devserver(**kwargs):
+ """Rules referencing this will be deleted from the codebase soon."""
+ pass
+
def tensorboard_ts_library(**kwargs):
"""Rules referencing this will be deleted from the codebase soon."""
pass
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
index 6ca38e259b..78b10c44a2 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt
@@ -33,6 +33,10 @@ tf_module {
argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'channels_last\', \'(1, 1, 1)\', \'None\', \'True\', \'None\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
}
member_method {
+ name: "conv3d_transpose"
+ argspec: "args=[\'inputs\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'channels_last\', \'None\', \'True\', \'None\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+ }
+ member_method {
name: "dense"
argspec: "args=[\'inputs\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'trainable\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'<tensorflow.python.ops.init_ops.Zeros object instance>\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
index 192ceac2dd..3a448798b2 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
@@ -70,7 +70,7 @@ tf_module {
}
member_method {
name: "conv3d_transpose"
- argspec: "args=[\'value\', \'filter\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'None\', \'None\'], "
+ argspec: "args=[\'value\', \'filter\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'SAME\', \'NDHWC\', \'None\'], "
}
member_method {
name: "convolution"
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index f96f9460cc..046d82c2d5 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -638,7 +638,7 @@ tf_module {
}
member_method {
name: "atan2"
- argspec: "args=[\'x1\', \'x2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "batch_to_space"
diff --git a/tensorflow/tools/quantization/quantize_graph.py b/tensorflow/tools/quantization/quantize_graph.py
index 90f1ab4d66..a0cfc352d4 100644
--- a/tensorflow/tools/quantization/quantize_graph.py
+++ b/tensorflow/tools/quantization/quantize_graph.py
@@ -453,7 +453,8 @@ class GraphRewriter(object):
def round_nodes_recursively(self, current_node):
"""The entry point for simple rounding quantization."""
- if (current_node.name in self.already_visited) and self.already_visited[current_node.name]:
+ if (current_node.name in self.already_visited
+ ) and self.already_visited[current_node.name]:
return
self.already_visited[current_node.name] = True
for input_node_name in current_node.input:
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 3831a481ba..5a34862aae 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -83,6 +83,27 @@ temp_workaround_http_archive = repository_rule(
},
)
+def _http_files_with_build_impl(repo_ctx):
+ repo_ctx.template("BUILD", repo_ctx.attr.build_file, {
+ "%prefix%": ".." if _repos_are_siblings() else "external",
+ "%ws%": repo_ctx.attr.repository
+ }, False)
+ for output, urls in repo_ctx.attr.file_urls.items():
+ repo_ctx.download(urls, output,
+ repo_ctx.attr.sha256, executable=False)
+
+# Downloads a set of files and adds a BUILD file.
+http_files_with_build = repository_rule(
+ implementation = _http_files_with_build_impl,
+ attrs = {
+ "build_file": attr.label(),
+ "repository": attr.string(),
+ # Map from output file to URLs to download that file from.
+ "file_urls": attr.string_list_dict(default = {}),
+ "sha256": attr.string(default = ""),
+ },
+)
+
# Executes specified command with arguments and calls 'fail' if it exited with
# non-zero code
@@ -498,11 +519,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
temp_workaround_http_archive(
name = "llvm",
urls = [
- "http://bazel-mirror.storage.googleapis.com/github.com/llvm-mirror/llvm/archive/8a1f075c93565dd665a10ac38490f644b2c02037.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/8a1f075c93565dd665a10ac38490f644b2c02037.tar.gz",
+ "http://bazel-mirror.storage.googleapis.com/github.com/llvm-mirror/llvm/archive/13790c8735a78a029dec92d80f5633418d9ffdd6.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/13790c8735a78a029dec92d80f5633418d9ffdd6.tar.gz",
],
- sha256 = "d9ebd0b49544f3b20ee2a412aac18ed8899b8eef376343a6ba8e179563cbfd86",
- strip_prefix = "llvm-8a1f075c93565dd665a10ac38490f644b2c02037",
+ sha256 = "da4fc7147f1e2706977822934d1b245dcb6248930f8089129362ada14f6119dd",
+ strip_prefix = "llvm-13790c8735a78a029dec92d80f5633418d9ffdd6",
build_file = str(Label("//third_party/llvm:llvm.BUILD")),
repository = tf_repo_name,
)
@@ -636,6 +657,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
repository = tf_repo_name,
)
+ http_files_with_build(
+ name = "pprof_profile_proto",
+ file_urls = {
+ "pprof/profile.proto":
+ ["https://raw.githubusercontent.com/google/pprof/master/proto/profile.proto"],
+ "pprof/LICENSE":
+ ["https://raw.githubusercontent.com/google/pprof/master/LICENSE"]},
+ build_file = str(Label("//third_party:pprof.BUILD")),
+ repository = tf_repo_name,
+ )
+
##############################################################################
# TensorBoard Build Tools
@@ -816,6 +848,27 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
},
)
+ # TODO: Delete previous rule and rename this one org_d3js
+ filegroup_external(
+ name = "org_d3js_v4",
+ # no @license header
+ licenses = ["notice"], # BSD-3-Clause
+ sha256_urls_extract = {
+ "b5fac5b296bc196e6aa7b59f9e33986fc44d23d59a0e211705187be9e35b943d": [
+ "http://bazel-mirror.storage.googleapis.com/github.com/d3/d3/releases/download/v4.8.0/d3.zip",
+ "https://github.com/d3/d3/releases/download/v4.8.0/d3.zip",
+ ],
+ },
+ # TODO(jart): Use srcs=["d3.js"] instead of this once supported.
+ generated_rule_name = "all_files",
+ extra_build_file_content = "\n".join([
+ "filegroup(",
+ " name = \"org_d3js_v4\",",
+ " srcs = [\"d3.js\"],",
+ ")",
+ ]),
+ )
+
filegroup_external(
name = "org_definitelytyped",
licenses = ["notice"], # MIT
@@ -825,8 +878,8 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
"https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/chai/chai.d.ts",
],
"177293828c7a206bf2a7f725753d51396d38668311aa37c96445f91bbf8128a7": [
- "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts",
- "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts",
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts", # v3
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts", # v3
],
"e4cd3d5de0eb3bc7b1063b50d336764a0ac82a658b39b5cf90511f489ffdee60": [
"http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/efd40e67ff323f7147651bdbef03c03ead7b1675/lodash/lodash.d.ts",
@@ -848,6 +901,314 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
)
filegroup_external(
+ name = "org_definitelytyped_types_d3_array",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "61e7abb7b1f01fbcb0cab8cf39003392f422566209edd681fbd070eaa84ca000": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-array/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-array/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_axis",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "95f75c8dcc89850b2e72581d96a7b5f46ea4ac852f828893f141f14a597421f9": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-axis/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-axis/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_brush",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "a2738e693ce8a8640c2d29001e77582c9c361fd23bda44db471629866b60ada7": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-brush/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-brush/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_chord",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "c54d24756eb6d744b31e538ad9bab3a75f6d54e2288b29cc72338d4a057d3e83": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-chord/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-chord/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_collection",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "f987667167b1d2970911247e325eb1c37ca0823646f81ccec837ae59039822f7": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-collection/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-collection/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_color",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "9580c81f38ddcce7be0ac9bd3d0d083adebc34e17441709f90b9e4dcd1c19a56": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-color/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-color/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_dispatch",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "169f80b4cceca8e2e9ed384d81a5db0624cc01a26451dfb5a7e0cec6ea9cfb06": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dispatch/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dispatch/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_drag",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "08d35d139dde58c2722be98d718d01204fd6167d310f09b379e832f3c741489d": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-drag/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-drag/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_dsv",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "62594d00cf9e4bb895339c8e56f64330e202a5eb2a0fa580a1f6e6336f2c93ce": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dsv/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dsv/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_ease",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "d1cf8f99b7bf758c2ba3c0a4ce553e151d4d9b4cf45a6e8bd0edec7ce90f725b": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-ease/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-ease/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_force",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "288421e2008668d2076a4684657dd3d29b992832ef02c552981eb94a91042553": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-force/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-force/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_format",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "b42cb17e580c1fd0b64d478f7bd80ca806efaefda24426a833cf1f30a7275bca": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-format/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-format/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_hierarchy",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "a5683f5835d8716c6b89c075235078438cfab5897023ed720bfa492e244e969e": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-hierarchy/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-hierarchy/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_interpolate",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "590a71b741323ac3139b333ec8b743e24717fdd5b32bcff48ee521162a9dfe1c": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-interpolate/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-interpolate/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_path",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "96f35ba041bcaa265e2b373ee675177410d44d31c980e4f7fbeefd4bcba15b00": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-path/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-path/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_polygon",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "ce453451e8105cac6a4f4a4263ca2142ebb4bf442e342f470a81da691f220fcb": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-polygon/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-polygon/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_quadtree",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "238e278f1be5d6985a19800800cffee80f81199f71d848e3bbc288d1791a6f90": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-quadtree/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-quadtree/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_queue",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "e6ae19aad83495475653578de64fb9d6bf9764eda6c84d70f7935ec84bcc482e": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-queue/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-queue/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_random",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "d31b92ed86c23ec0a4776f99fa81ff033c95b96c8304d8aa9baf3b94af779aa8": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-random/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-random/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_request",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "44bb7b07d977028e6567540a3303b06fc9b33fb0960bc75c520e0733c840d89f": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-request/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-request/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_scale",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "02ce7c644ba34bd1abb84da2e832f248b048b6a23812be4365bd837f186c9f1f": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-scale/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-scale/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_selection",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "699043ddb28dfa5e46d87bc6a24cfc6d604237f298259d3fb3c7066e05e8c86e": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-selection/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-selection/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_shape",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "62668a7aaaf6232762b544f9f89c0f557ca7cfb0cd343a358dda7ecbe26f5739": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-shape/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-shape/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_time",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "0502490ce682fd9265fb1d5d693ce6cd82e3b05e5f5ee3433731266ecb03d5fc": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-time/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-time/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_timer",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "6f191f9aea704aa64b1defa40dfdff1447a6e6bb815feff1660f894500a9c94d": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-timer/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-timer/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_transition",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "a0a7c0c9bfb5c7d6d9d22a8d16b4484b66d13f2ed226954037546cb3da4098ba": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-transition/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-transition/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_voronoi",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "c6bd5f229f915151d0ef678fe50b1aa6a62334ea0a8c6fc0effbac9f7032efc7": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-voronoi/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-voronoi/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
+ name = "org_definitelytyped_types_d3_zoom",
+ licenses = ["notice"], # MIT
+ sha256_urls = {
+ "a25dc17fbd304cf7a0e5e7bbb8339c930d464eb40c4d6e5f839ce9c0191f4110": [
+ "http://bazel-mirror.storage.googleapis.com/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-zoom/index.d.ts",
+ "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-zoom/index.d.ts",
+ ],
+ },
+ )
+
+ filegroup_external(
name = "org_threejs",
# no @license header
licenses = ["notice"], # MIT
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 306ac51767..920162e1d7 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -366,6 +366,7 @@ llvm_target_list = [
("-gen-asm-matcher", "lib/Target/ARM/ARMGenAsmMatcher.inc"),
("-gen-dag-isel", "lib/Target/ARM/ARMGenDAGISel.inc"),
("-gen-fast-isel", "lib/Target/ARM/ARMGenFastISel.inc"),
+ ("-gen-global-isel", "lib/Target/ARM/ARMGenGlobalISel.inc"),
("-gen-callingconv", "lib/Target/ARM/ARMGenCallingConv.inc"),
("-gen-subtarget", "lib/Target/ARM/ARMGenSubtargetInfo.inc"),
("-gen-disassembler", "lib/Target/ARM/ARMGenDisassemblerTables.inc"),
@@ -1116,6 +1117,7 @@ cc_library(
]),
deps = [
":analysis",
+ ":bit_reader",
":bit_writer",
":config",
":core",
diff --git a/third_party/pprof.BUILD b/third_party/pprof.BUILD
new file mode 100644
index 0000000000..6f27bfc819
--- /dev/null
+++ b/third_party/pprof.BUILD
@@ -0,0 +1,18 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # MIT
+
+load("@protobuf//:protobuf.bzl", "py_proto_library")
+
+exports_files(["pprof/LICENSE"])
+
+py_proto_library(
+ name = "pprof_proto_py",
+ srcs = ["pprof/profile.proto"],
+ default_runtime = "@protobuf//:protobuf_python",
+ protoc = "@protobuf//:protoc",
+ srcs_version = "PY2AND3",
+ deps = ["@protobuf//:protobuf_python"],
+)