diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-01 05:50:51 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-01 05:54:29 -0700 |
commit | b73c5f80926de3b724a92a57cf0bc49aa7de37bd (patch) | |
tree | cb61f503a1fd3ef914cd26377bba5a2b8d1180b5 | |
parent | 9a169bf3ba840af8ab3caae7ea1c69c682be3ab7 (diff) |
Automated rollback of commit 3f4423fad57694bc8d7adc427d65e5a18c8592b2
PiperOrigin-RevId: 215200418
-rw-r--r-- | tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc | 42 |
1 files changed, 36 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc index 6b0730b40c..5c27d59f82 100644 --- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc @@ -104,9 +104,18 @@ Status RegisterPerTableLoadOpsForAlgorithmBody( } } { + auto* table_id_attr = op_def->add_attr(); + table_id_attr->set_name("table_id"); + table_id_attr->set_type("int"); + table_id_attr->set_has_minimum(true); + table_id_attr->set_minimum(-1); + table_id_attr->mutable_default_value()->set_i(-1); + } + { auto* table_name_attr = op_def->add_attr(); table_name_attr->set_name("table_name"); table_name_attr->set_type("string"); + table_name_attr->mutable_default_value()->set_s(""); } { auto* num_shards_attr = op_def->add_attr(); @@ -138,9 +147,11 @@ parameters that are loaded from a checkpoint before a training loop is executed. %s table_name: Name of this table; must match a name in the - EmbeddingLayerConfiguration proto. + EmbeddingLayerConfiguration proto (overrides table_id). num_shards: Number of shards into which the embedding tables are divided. shard_id: Identifier of shard for this operation. +table_id: Index of this table in the EmbeddingLayerConfiguration proto + (deprecated). )doc", parameter_descriptions.c_str())); op_def->set_is_commutative(false); @@ -149,10 +160,14 @@ shard_id: Identifier of shard for this operation. auto shape_inference_function = [state_variable_specs, is_debug_op](shape_inference::InferenceContext* c) -> Status { + int table_id; + TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); string table_name; TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); - if (table_name.empty()) { - return errors::InvalidArgument("table_name attribute must be set"); + // Exactly one must be non-default. + if ((table_id >= 0) == (!table_name.empty())) { + return errors::InvalidArgument( + "exactly one of table_id or table_name must be non-default"); } int num_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); @@ -226,9 +241,18 @@ Status RegisterPerTableRetrieveOpsForAlgorithmBody( } } { + auto* table_id_attr = op_def->add_attr(); + table_id_attr->set_name("table_id"); + table_id_attr->set_type("int"); + table_id_attr->set_has_minimum(true); + table_id_attr->set_minimum(-1); + table_id_attr->mutable_default_value()->set_i(-1); + } + { auto* table_name_attr = op_def->add_attr(); table_name_attr->set_name("table_name"); table_name_attr->set_type("string"); + table_name_attr->mutable_default_value()->set_s(""); } { auto* num_shards_attr = op_def->add_attr(); @@ -259,9 +283,11 @@ the correct embedding table configuration. For example, this op is used to retrieve updated parameters before saving a checkpoint. %s table_name: Name of this table; must match a name in the - EmbeddingLayerConfiguration proto. + EmbeddingLayerConfiguration proto (overrides table_id). num_shards: Number of shards into which the embedding tables are divided. shard_id: Identifier of shard for this operation. +table_id: Index of this table in the EmbeddingLayerConfiguration proto + (deprecated). )doc", parameter_descriptions.c_str())); op_def->set_is_commutative(false); @@ -270,10 +296,14 @@ shard_id: Identifier of shard for this operation. auto shape_inference_function = [state_variable_specs, is_debug_op](shape_inference::InferenceContext* c) -> Status { + int table_id; + TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); string table_name; TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); - if (table_name.empty()) { - return errors::InvalidArgument("table_name must be non-empty"); + // Exactly one must be non-default. + if ((table_id >= 0) == (!table_name.empty())) { + return errors::InvalidArgument( + "exactly one of table_id or table_name must be non-default"); } int num_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); |