aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-28 15:58:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 16:02:53 -0700
commit3f4423fad57694bc8d7adc427d65e5a18c8592b2 (patch)
treeec552457d6fb768d87048f96e305d97f42bf9f7b
parentf5086804c758812ec9ed67233c58e18236246299 (diff)
Internal changes only.
PiperOrigin-RevId: 215009955
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc42
1 files changed, 6 insertions, 36 deletions
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 1bd1a31e11..bc1a0c5284 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -104,18 +104,9 @@ Status RegisterPerTableLoadOpsForAlgorithmBody(
}
}
{
- auto* table_id_attr = op_def->add_attr();
- table_id_attr->set_name("table_id");
- table_id_attr->set_type("int");
- table_id_attr->set_has_minimum(true);
- table_id_attr->set_minimum(-1);
- table_id_attr->mutable_default_value()->set_i(-1);
- }
- {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
- table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -147,11 +138,9 @@ parameters that are loaded from a checkpoint before a training loop is
executed.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto (overrides table_id).
+ EmbeddingLayerConfiguration proto.
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
-table_id: Index of this table in the EmbeddingLayerConfiguration proto
- (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -160,14 +149,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- // Exactly one must be non-default.
- if ((table_id >= 0) == (!table_name.empty())) {
- return errors::InvalidArgument(
- "exactly one of table_id or table_name must be non-default");
+ if (table_name.empty()) {
+ return errors::InvalidArgument("table_name attribute must be set");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
@@ -241,18 +226,9 @@ Status RegisterPerTableRetrieveOpsForAlgorithmBody(
}
}
{
- auto* table_id_attr = op_def->add_attr();
- table_id_attr->set_name("table_id");
- table_id_attr->set_type("int");
- table_id_attr->set_has_minimum(true);
- table_id_attr->set_minimum(-1);
- table_id_attr->mutable_default_value()->set_i(-1);
- }
- {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
- table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -283,11 +259,9 @@ the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto (overrides table_id).
+ EmbeddingLayerConfiguration proto.
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
-table_id: Index of this table in the EmbeddingLayerConfiguration proto
- (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -296,14 +270,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- // Exactly one must be non-default.
- if ((table_id >= 0) == (!table_name.empty())) {
- return errors::InvalidArgument(
- "exactly one of table_id or table_name must be non-default");
+ if (table_name.empty()) {
+ return errors::InvalidArgument("table_name must be non-empty");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));