aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 05:50:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 05:54:29 -0700
commitb73c5f80926de3b724a92a57cf0bc49aa7de37bd (patch)
treecb61f503a1fd3ef914cd26377bba5a2b8d1180b5
parent9a169bf3ba840af8ab3caae7ea1c69c682be3ab7 (diff)
Automated rollback of commit 3f4423fad57694bc8d7adc427d65e5a18c8592b2
PiperOrigin-RevId: 215200418
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc42
1 files changed, 36 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 6b0730b40c..5c27d59f82 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -104,9 +104,18 @@ Status RegisterPerTableLoadOpsForAlgorithmBody(
}
}
{
+ auto* table_id_attr = op_def->add_attr();
+ table_id_attr->set_name("table_id");
+ table_id_attr->set_type("int");
+ table_id_attr->set_has_minimum(true);
+ table_id_attr->set_minimum(-1);
+ table_id_attr->mutable_default_value()->set_i(-1);
+ }
+ {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
+ table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -138,9 +147,11 @@ parameters that are loaded from a checkpoint before a training loop is
executed.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto.
+ EmbeddingLayerConfiguration proto (overrides table_id).
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
+table_id: Index of this table in the EmbeddingLayerConfiguration proto
+ (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -149,10 +160,14 @@ shard_id: Identifier of shard for this operation.
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
+ int table_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- if (table_name.empty()) {
- return errors::InvalidArgument("table_name attribute must be set");
+ // Exactly one must be non-default.
+ if ((table_id >= 0) == (!table_name.empty())) {
+ return errors::InvalidArgument(
+ "exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
@@ -226,9 +241,18 @@ Status RegisterPerTableRetrieveOpsForAlgorithmBody(
}
}
{
+ auto* table_id_attr = op_def->add_attr();
+ table_id_attr->set_name("table_id");
+ table_id_attr->set_type("int");
+ table_id_attr->set_has_minimum(true);
+ table_id_attr->set_minimum(-1);
+ table_id_attr->mutable_default_value()->set_i(-1);
+ }
+ {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
+ table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -259,9 +283,11 @@ the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto.
+ EmbeddingLayerConfiguration proto (overrides table_id).
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
+table_id: Index of this table in the EmbeddingLayerConfiguration proto
+ (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -270,10 +296,14 @@ shard_id: Identifier of shard for this operation.
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
+ int table_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- if (table_name.empty()) {
- return errors::InvalidArgument("table_name must be non-empty");
+ // Exactly one must be non-default.
+ if ((table_id >= 0) == (!table_name.empty())) {
+ return errors::InvalidArgument(
+ "exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));