diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-28 15:58:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 16:02:53 -0700 |
commit | 3f4423fad57694bc8d7adc427d65e5a18c8592b2 (patch) | |
tree | ec552457d6fb768d87048f96e305d97f42bf9f7b | |
parent | f5086804c758812ec9ed67233c58e18236246299 (diff) |
Internal changes only.
PiperOrigin-RevId: 215009955
-rw-r--r-- | tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc | 42 |
1 files changed, 6 insertions, 36 deletions
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc index 1bd1a31e11..bc1a0c5284 100644 --- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc @@ -104,18 +104,9 @@ Status RegisterPerTableLoadOpsForAlgorithmBody( } } { - auto* table_id_attr = op_def->add_attr(); - table_id_attr->set_name("table_id"); - table_id_attr->set_type("int"); - table_id_attr->set_has_minimum(true); - table_id_attr->set_minimum(-1); - table_id_attr->mutable_default_value()->set_i(-1); - } - { auto* table_name_attr = op_def->add_attr(); table_name_attr->set_name("table_name"); table_name_attr->set_type("string"); - table_name_attr->mutable_default_value()->set_s(""); } { auto* num_shards_attr = op_def->add_attr(); @@ -147,11 +138,9 @@ parameters that are loaded from a checkpoint before a training loop is executed. %s table_name: Name of this table; must match a name in the - EmbeddingLayerConfiguration proto (overrides table_id). + EmbeddingLayerConfiguration proto. num_shards: Number of shards into which the embedding tables are divided. shard_id: Identifier of shard for this operation. -table_id: Index of this table in the EmbeddingLayerConfiguration proto - (deprecated). )doc", parameter_descriptions.c_str())); op_def->set_is_commutative(false); @@ -160,14 +149,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto auto shape_inference_function = [state_variable_specs, is_debug_op](shape_inference::InferenceContext* c) -> Status { - int table_id; - TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); string table_name; TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); - // Exactly one must be non-default. - if ((table_id >= 0) == (!table_name.empty())) { - return errors::InvalidArgument( - "exactly one of table_id or table_name must be non-default"); + if (table_name.empty()) { + return errors::InvalidArgument("table_name attribute must be set"); } int num_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); @@ -241,18 +226,9 @@ Status RegisterPerTableRetrieveOpsForAlgorithmBody( } } { - auto* table_id_attr = op_def->add_attr(); - table_id_attr->set_name("table_id"); - table_id_attr->set_type("int"); - table_id_attr->set_has_minimum(true); - table_id_attr->set_minimum(-1); - table_id_attr->mutable_default_value()->set_i(-1); - } - { auto* table_name_attr = op_def->add_attr(); table_name_attr->set_name("table_name"); table_name_attr->set_type("string"); - table_name_attr->mutable_default_value()->set_s(""); } { auto* num_shards_attr = op_def->add_attr(); @@ -283,11 +259,9 @@ the correct embedding table configuration. For example, this op is used to retrieve updated parameters before saving a checkpoint. %s table_name: Name of this table; must match a name in the - EmbeddingLayerConfiguration proto (overrides table_id). + EmbeddingLayerConfiguration proto. num_shards: Number of shards into which the embedding tables are divided. shard_id: Identifier of shard for this operation. -table_id: Index of this table in the EmbeddingLayerConfiguration proto - (deprecated). )doc", parameter_descriptions.c_str())); op_def->set_is_commutative(false); @@ -296,14 +270,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto auto shape_inference_function = [state_variable_specs, is_debug_op](shape_inference::InferenceContext* c) -> Status { - int table_id; - TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); string table_name; TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); - // Exactly one must be non-default. - if ((table_id >= 0) == (!table_name.empty())) { - return errors::InvalidArgument( - "exactly one of table_id or table_name must be non-default"); + if (table_name.empty()) { + return errors::InvalidArgument("table_name must be non-empty"); } int num_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); |