aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Daryl Ng <darylng@google.com>2018-09-20 11:23:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 11:26:52 -0700
commit30756301bee0de2b1c16a74a710bd7bf29be468d (patch)
treeaaf5ee628804ff0257967f5d1fe1a0116447c166 /tensorflow/contrib/tpu
parentae59f459cd1e6bd2f2bdeb3b49cfedf0cdaf51a1 (diff)
Moving tpu_embedding_config.proto to tpu_embedding_configuration.proto, refactoring it, adding several new fields and an EmbeddingOutputLayout message to provide experimental support for controlling the embedding output.
PiperOrigin-RevId: 213849572
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/BUILD4
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc22
-rw-r--r--tensorflow/contrib/tpu/proto/BUILD18
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_config.proto66
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto95
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto75
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py2
7 files changed, 199 insertions, 83 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 298ffc1ded..87d00aca05 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -80,7 +80,7 @@ tf_gen_op_libs(
"tpu_embedding_ops",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
],
@@ -99,7 +99,7 @@ tf_custom_op_library(
"ops/tpu_embedding_ops.cc",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
"//tensorflow/core:lib_proto_parsing",
],
)
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 72d37f774c..18b98939b8 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/tpu/proto/tpu_embedding_config.pb.h"
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -88,12 +88,12 @@ Status GradientDescentShapes(shape_inference::InferenceContext *c) {
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_config_size();
+ int64 num_tables = config.table_descriptor_size();
if (table_id >= num_tables) {
return errors::InvalidArgument("Table id >= num_tables");
}
- int64 width = config.table_config(table_id).width();
- int64 num_rows = config.table_config(table_id).num_rows();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
return Status::OK();
@@ -160,12 +160,12 @@ Status AdagradShapes(shape_inference::InferenceContext *c) {
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_config_size();
+ int64 num_tables = config.table_descriptor_size();
if (table_id >= num_tables) {
return errors::InvalidArgument("Table id >= num_tables");
}
- int64 width = config.table_config(table_id).width();
- int64 num_rows = config.table_config(table_id).num_rows();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
TF_RETURN_IF_ERROR(
@@ -244,11 +244,11 @@ Status ActivationShapes(shape_inference::InferenceContext *c) {
if (!config.ParseFromString(config_string)) {
return errors::InvalidArgument("Malformed tpu_embedding_config.");
}
- int64 batch_size = config.batch_size();
- int64 num_tables = config.table_config_size();
+ int64 batch_size = config.batch_size_per_tensor_core();
+ int64 num_tables = config.table_descriptor_size();
for (int table_id = 0; table_id < num_tables; ++table_id) {
- int64 width = config.table_config(table_id).width();
- int64 num_features = config.table_config(table_id).num_features();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_features = config.table_descriptor(table_id).vocabulary_size();
c->set_output(table_id, c->Matrix(batch_size * num_features, width));
}
return Status::OK();
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index 598b73b438..c20cab844c 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -10,12 +10,15 @@ load(
)
tf_proto_library(
- name = "tpu_embedding_config_proto",
+ name = "tpu_embedding_configuration_proto",
srcs = [
- "tpu_embedding_config.proto",
+ "tpu_embedding_configuration.proto",
],
cc_api_version = 2,
- protodeps = [":optimization_parameters_proto"],
+ protodeps = [
+ ":tpu_embedding_output_layout_proto",
+ ":optimization_parameters_proto",
+ ],
visibility = ["//visibility:public"],
)
@@ -29,6 +32,15 @@ tf_proto_library(
)
tf_proto_library(
+ name = "tpu_embedding_output_layout_proto",
+ srcs = [
+ "tpu_embedding_output_layout.proto",
+ ],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library(
name = "topology_proto",
srcs = [
"topology.proto",
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
deleted file mode 100644
index 3476cc8953..0000000000
--- a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
+++ /dev/null
@@ -1,66 +0,0 @@
-syntax = "proto3";
-
-package tensorflow.tpu;
-
-import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
-
-// The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups
-// and gradient updates separate from the TF Graph.
-message TPUEmbeddingConfiguration {
- // model_mode specifies whether the model is to be run in training or
- // inference. In inference mode, gradient updates to embedding tables are not
- // performed.
- enum ModelMode {
- INVALID = 0;
- TRAINING = 1;
- INFERENCE = 2;
- }
-
- ModelMode model_mode = 1;
-
- // num_hosts is the number of host CPU systems in the training/inference job.
- // Each embedding table must be sharded into num_hosts separate Variables,
- // placed separately on the num_hosts CPU devices in the cluster. Sharding
- // will be performed equivalently to the 'div' sharding_strategy option of
- // embedding_lookup() and embedding_lookup_sparse().
- int32 num_hosts = 2;
-
- // The total number of TensorNodes. This is equal to num_hosts times the
- // number of TensorNodes attached to each host.
- int32 num_tensornodes = 3;
-
- // The number of training examples per TensorNode.
- int32 batch_size = 4;
-
- // Each Embedding
- message TPUEmbeddingTable {
- // Name of the embedding table. This will be used to name Variables in the
- // Tensorflow Graph.
- string name = 1;
-
- // Number of rows of the embedding table. The Variable created to hold the
- // learned embedding table values will have shape (num_rows, width).
- int32 num_rows = 3;
-
- // Width of the embedding table. The Variable created to hold the
- // learned embedding table values will have shape (num_rows, width).
- int32 width = 4;
-
- // Number of distinct embedding activation vectors per training example
- // produced by lookups into this table during model evaluation. For each
- // table, the Graph will receive an activations Tensor of shape
- // (batch_size * table.num_features, table.width).
- // For example, num_features = 1 produces equivalent behavior to a single
- // tf.nn.embedding_lookup() call. In the case of 'multivalent' embeddings,
- // (i.e. tf.nn.embedding_lookup_sparse()) which compute weighted averages of
- // embedding table rows, num_features is the number of vectors produced
- // after averaging. In sequence models num_features is typically equal
- // to the sequence length, since each sequence element must be represented
- // separately to the convolutional or recurrent network.
- int32 num_features = 5;
-
- OptimizationParameters optimization_parameters = 6;
- }
-
- repeated TPUEmbeddingTable table_config = 5;
-}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
new file mode 100644
index 0000000000..da19b135d7
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
@@ -0,0 +1,95 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
+import "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto";
+
+message TPUEmbeddingConfiguration {
+ // Description of the various embedding tables.
+ message TableDescriptor {
+ // Name of the table.
+ string name = 1;
+ // Size of the vocabulary (i.e., number of rows) in the table.
+ int32 vocabulary_size = 2;
+ // The embedding dimension (i.e., the width of the embedding table).
+ int32 dimension = 3;
+ // Number of features mapped to this table.
+ int32 num_features = 4;
+ // Details of the learning algorithm used to update the embedding
+ // parameters.
+ OptimizationParameters optimization_parameters = 5;
+ }
+ repeated TableDescriptor table_descriptor = 1;
+
+ // Mode. Should the embedding layer program be run for inference (just forward
+ // pass), training (both forward and backward pass) or just the backward_pass.
+ enum Mode {
+ UNSPECIFIED = 0;
+ INFERENCE = 1;
+ TRAINING = 2;
+ BACKWARD_PASS_ONLY = 3;
+ }
+ Mode mode = 2;
+
+ // Number of samples in each batch of embedding layer activations sent to
+ // the TensorCore.
+ int32 batch_size_per_tensor_core = 3;
+
+ // Number of TPU hosts used for inference/training.
+ int32 num_hosts = 4;
+
+ // Number of TensorCore used for inference/training.
+ int32 num_tensor_cores = 5;
+
+ // Sharding strategy of the embedding tables among the hosts.
+ // If the sharding_strategy is "mod", each id is assigned to host
+ // "id % num_hosts". For instance, 13 ids are split across 5 hosts as:
+ // [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]].
+ // If the sharding_strategy is "div", ids are assigned to hosts in a
+ // contiguous manner. In this case, 13 ids are split across 5 hosts as:
+ // [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].
+ // In both the strategies, if the id space does not evenly divide the number
+ // of hosts, each of the first "table_descriptor.num_ids % num_hosts" hosts
+ // will be assigned one more id.
+ // This partitioning strategy exactly follows that in the embedding_lookup
+ // TensorFlow function at tensorflow/python/ops/embedding_ops.py.
+ enum ShardingStrategy {
+ DIV_DEFAULT = 0;
+ MOD = 1;
+ }
+ ShardingStrategy sharding_strategy = 6;
+
+ // This parameter determines if the execution of the sparse core will be
+ // pipelined with that of the TensorCore. This parameter only affects results
+ // when mode=TRAINING. If mode=INFERENCE or BACKWARD_PASS_ONLY, this parameter
+ // does not affect execution and hence, is a don't care value.
+ //
+ // false: The execution of the sparse core is not pipelined with that of the
+ // TensorCore. The forward pass of every step on the sparse core is executed
+ // only after the backward pass of the previous step is complete. And the
+ // backward pass on the sparse core is executed only after the embedding
+ // gradients have been computed on the TensorCore on every step. This ensures
+ // that the activations on every step observe the gradient updates from the
+ // previous step on both the sparse core and the TensorCore.
+ //
+ // true: The execution of the sparse core is pipelined with that of the
+ // TensorCore. The forward pass of every step on the sparse core can be
+ // executed after the forward pass of the previous step is complete without
+ // waiting for the backward pass. This improves the utilization of the sparse
+ // core allowing it to process step N+1 while the embedding gradients for step
+ // N are computed on the TensorCore. The backward pass of every step on the
+ // sparse core is executed directly after the forward pass for the next step
+ // is complete. The drawback is that embedding activations for step N+1 do not
+ // observe the embedding gradient updates from step N. This could affect model
+ // quality if step N and N+1 involve the same set of embedding IDs. However,
+ // since the embedding updates are sparse, this is generally not considered a
+ // problem.
+ bool pipeline_execution_with_tensor_core = 7;
+
+ // Extended output layout information; if not provided, a compatibility mode
+ // will use defaults that match the old layout. Providing a value for this
+ // field is EXPERIMENTAL and most ways of filling it will probably break. Do
+ // not set it unless you know what you are doing.
+ TPUEmbeddingOutputLayout output_layout = 8;
+}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
new file mode 100644
index 0000000000..aed30b2f22
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
@@ -0,0 +1,75 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// In the comments here, "layout" refers to the top-level EmbeddingOutputLayout
+// proto contained in the TPUEmbeddingConfiguration.
+
+// The embedding output consists of a list of tensors, each specified by an
+// EmbeddingOutputTensor proto within the EmbeddingOutputLayout (the "output"
+// field). Each table and feature lookup is then placed into some number of
+// particular positions within some output tensor (identified by "tensor_index"
+// within OutputLocation). The tree of table lookups, feature lookups, and
+// output locations is specified by the
+// "table(table_id).feature(feature_id).output_location" repeated fields within
+// EmbeddingOutputLayout.
+
+message TPUEmbeddingOutputLayout {
+ // Location of one copy of the feature's data.
+ message OutputLocation {
+ // Which output tensor this copy of the feature will go into. Must be
+ // between 0 and layout.output_size().
+ int32 tensor_index = 1;
+
+ // Offset in dimension 0 for this feature copy. Must be between 0 and
+ // layout.output(tensor_index).dim0_size_per_sample().
+ int32 dim0_offset = 2;
+
+ // Offset in dimension 1 for this feature copy. Must be between 0 and
+ // layout.output(tensor_index).dim1_size() - table width; repeated or
+ // partially/fully overlapping values are allowed and results in the same
+ // range will be summed (with the gradients replicated in the backward
+ // pass).
+ int32 dim1_offset = 3;
+ }
+
+ // Description of the output placement for one feature.
+ message FeatureDescriptor {
+ // Typically, only one copy of each feature is used, but multiple are
+ // allowed and the same data will be copied to all of them (with the
+ // gradients summed in the backward pass).
+ repeated OutputLocation output_location = 1;
+ }
+
+ // Description of the output placement for features of one table.
+ message TableDescriptor {
+ // Output locations for each feature loaded from this table.
+ repeated FeatureDescriptor feature = 1;
+ }
+ // Output locations for each feature of each table.
+ repeated TableDescriptor table = 1;
+
+ // Data layout and shape computation information for a single output tensor.
+ // Any unused locations in the tensor will be filled with zeros, and
+ // corresponding gradients will be ignored.
+
+ // Size and layout information for 2-D tensors.
+ message TwoDOutputTensor {
+ // Multiplier for output dimension 0 size; used to match legacy format that
+ // stacks features within a sample in dimension 0.
+ int32 dim0_size_per_sample = 2;
+
+ // The size (in dimension 1) of this output tensor.
+ int32 dim1_size = 1;
+ }
+
+ // Format information for a single output tensor.
+ message EmbeddingOutputTensor {
+ oneof output_format {
+ TwoDOutputTensor two_d = 4;
+ }
+ }
+
+ // Shape and layout information for each tensor.
+ repeated EmbeddingOutputTensor output = 2;
+}
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 593f1d909e..7815d81a5b 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -76,7 +76,7 @@ def initialize_system(embedding_config=None, job=None):
"""Initializes a distributed TPU system for use with TensorFlow.
Args:
- embedding_config: If not None, an `EmbeddingLayerConfiguration` proto
+ embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
job: The job (the XXX in TensorFlow device specification /job:XXX) that