aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--RELEASE.md8
-rw-r--r--tensorflow/core/framework/op_kernel.h15
-rw-r--r--tensorflow/core/kernels/concat_op.cc12
-rw-r--r--tensorflow/core/kernels/constant_op.cc11
-rw-r--r--tensorflow/core/kernels/logging_ops.cc2
-rw-r--r--tensorflow/core/kernels/pad_op.cc7
-rw-r--r--tensorflow/core/kernels/random_op.cc2
-rw-r--r--tensorflow/core/kernels/reshape_op.h2
-rw-r--r--tensorflow/core/kernels/save_op.cc4
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc4
-rw-r--r--tensorflow/core/kernels/sequence_ops.cc6
-rw-r--r--tensorflow/core/kernels/slice_op.cc9
-rw-r--r--tensorflow/core/kernels/sparse_to_dense_op.cc2
-rw-r--r--tensorflow/core/kernels/summary_image_op.cc4
-rw-r--r--tensorflow/core/kernels/summary_op.cc19
-rw-r--r--tensorflow/core/kernels/tile_ops.cc4
-rw-r--r--tensorflow/core/kernels/training_ops.cc6
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/__init__.py5
-rw-r--r--tensorflow/python/framework/ops.py4
-rw-r--r--tensorflow/python/kernel_tests/learn_test.py225
-rw-r--r--tensorflow/python/ops/learn.py359
-rw-r--r--tensorflow/python/ops/op_def_library.py1
23 files changed, 670 insertions, 42 deletions
diff --git a/RELEASE.md b/RELEASE.md
index 799113f23f..358e2aef1c 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -11,6 +11,14 @@
safety is handled by `saturate_cast`, which makes sure over- and underflows
are handled before casting to data types with smaller ranges.
+## Bug fixes
+
+* The Python API will now properly set the `list` member of `AttrValue` in
+ constructed `GraphDef` messages for empty lists. The serialization of some
+ graphs will change, but the change is both forwards and backwards compatible.
+ It will break tests that compare a generated `GraphDef` to a golden serialized
+ `GraphDef`.
+
# Release 0.6.0
## Major Features and Improvements
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index e61ddd0e2e..a168e0632e 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -121,6 +121,19 @@ class OpKernel {
Status InputRange(const string& input_name, int* start, int* stop) const;
Status OutputRange(const string& output_name, int* start, int* stop) const;
+ // TODO(irving): At the moment, the following three functions forward to
+ // TensorShapeUtils, but they are about to become the only versions once we
+ // become scalar strict.
+ bool allow_legacy_scalars() const { return kAllowLegacyScalars; }
+
+ bool IsLegacyScalar(const TensorShape& shape) const {
+ return TensorShapeUtils::IsLegacyScalar(shape);
+ }
+
+ bool IsLegacyVector(const TensorShape& shape) const {
+ return TensorShapeUtils::IsLegacyVector(shape);
+ }
+
private:
const NodeDef def_;
const DataTypeVector input_types_;
@@ -455,6 +468,8 @@ class OpKernelContext {
Env* env() const { return params_.device->env(); }
+ const OpKernel& op_kernel() const { return *params_.op_kernel; }
+
// Input/output signature.
int num_inputs() const { return params_.inputs->size(); }
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index 4e2ddc2954..db4ae1f18e 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -45,7 +45,7 @@ class ConcatOp : public OpKernel {
const Tensor* concat_dim_tensor;
OP_REQUIRES_OK(c, c->input("concat_dim", &concat_dim_tensor));
OP_REQUIRES(
- c, TensorShapeUtils::IsLegacyScalar(concat_dim_tensor->shape()),
+ c, IsLegacyScalar(concat_dim_tensor->shape()),
errors::InvalidArgument(
"Concat dim tensor should be a scalar integer, but got shape ",
concat_dim_tensor->shape().DebugString()));
@@ -57,7 +57,7 @@ class ConcatOp : public OpKernel {
const TensorShape& input_shape = values[0].shape();
OP_REQUIRES(
c, (0 <= concat_dim && concat_dim < input_dims) ||
- (kAllowLegacyScalars && concat_dim == 0),
+ (allow_legacy_scalars() && concat_dim == 0),
errors::InvalidArgument(
"ConcatOp : Expected concatenating dimensions in the range [", 0,
", ", input_dims, "), but got ", concat_dim));
@@ -74,10 +74,10 @@ class ConcatOp : public OpKernel {
inputs_flat_dim0 *= input_shape.dim_size(d);
}
int output_concat_dim = 0;
- const bool input_is_scalar = TensorShapeUtils::IsLegacyScalar(input_shape);
+ const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
const auto in = values[i];
- const bool in_is_scalar = TensorShapeUtils::IsLegacyScalar(in.shape());
+ const bool in_is_scalar = IsLegacyScalar(in.shape());
OP_REQUIRES(
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
errors::InvalidArgument(
@@ -100,12 +100,12 @@ class ConcatOp : public OpKernel {
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
}
- // TODO(irving): Remove check once !kAllowLegacyScalars
+ // TODO(irving): Remove check once !allow_legacy_scalars().
output_concat_dim += in.dims() > 0 ? in.dim_size(concat_dim) : 1;
}
TensorShape output_shape(input_shape);
- // TODO(irving): Remove rank 0 case once !kAllowLegacyScalars
+ // TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
if (output_shape.dims() == 0) {
output_shape.AddDim(output_concat_dim);
} else {
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 440ff62230..16c20d55f9 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -143,11 +143,14 @@ class FillOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& Tdims = context->input(0);
- OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(Tdims.shape()),
- errors::InvalidArgument("dims must be a vector of int32."));
+ OP_REQUIRES(
+ context, IsLegacyVector(Tdims.shape()),
+ errors::InvalidArgument("dims must be a vector of int32, got shape ",
+ Tdims.shape().ShortDebugString()));
const Tensor& Tvalue = context->input(1);
- OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(Tvalue.shape()),
- errors::InvalidArgument("value must be a scalar."));
+ OP_REQUIRES(context, IsLegacyScalar(Tvalue.shape()),
+ errors::InvalidArgument("value must be a scalar, got shape ",
+ Tvalue.shape().ShortDebugString()));
auto dims = Tdims.flat<int32>();
for (int i = 0; i < dims.size(); i++) {
OP_REQUIRES(context, dims(i) >= 0,
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 302a9531c7..9947666fb5 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -28,7 +28,7 @@ class AssertOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& cond = ctx->input(0);
- OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(cond.shape()),
+ OP_REQUIRES(ctx, IsLegacyScalar(cond.shape()),
errors::InvalidArgument("In[0] should be a scalar: ",
cond.shape().ShortDebugString()));
diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc
index 3ee611bd89..a913ba7ac0 100644
--- a/tensorflow/core/kernels/pad_op.cc
+++ b/tensorflow/core/kernels/pad_op.cc
@@ -59,7 +59,8 @@ class PadOp : public OpKernel {
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
in1.shape().DebugString()));
const int fixed_dims =
- (kAllowLegacyScalars && dims == 0 && in1.dim_size(0) == 1) ? 1 : dims;
+ (allow_legacy_scalars() && dims == 0 && in1.dim_size(0) == 1) ? 1
+ : dims;
OP_REQUIRES(
context, fixed_dims == in1.dim_size(0),
errors::InvalidArgument(
@@ -76,7 +77,7 @@ class PadOp : public OpKernel {
errors::InvalidArgument("Paddings must be non-negative: ",
before_d, " ", after_d));
const int size_d =
- (kAllowLegacyScalars && d == in0.dims()) ? 1 : in0.dim_size(d);
+ (allow_legacy_scalars() && d == in0.dims()) ? 1 : in0.dim_size(d);
output_shape.AddDim(before_d + size_d + after_d);
}
Tensor* output = nullptr;
@@ -89,7 +90,7 @@ class PadOp : public OpKernel {
break;
case 1:
// TODO(irving): Once Pad doesn't need a scalar special case,
- // change flat to tensor. That is, once !kAllowLegacyScalars.
+ // change flat to tensor. That is, once !allow_legacy_scalars().
Operate<1>(context, in0.flat<T>(), paddings, output);
break;
case 2:
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index bb1566b4e8..082c177597 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -180,7 +180,7 @@ namespace {
static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape,
int index, Tensor** output) {
- if (!TensorShapeUtils::IsLegacyVector(shape.shape())) {
+ if (!ctx->op_kernel().IsLegacyVector(shape.shape())) {
return errors::InvalidArgument(
"shape must be a vector of {int32,int64}, got shape ",
shape.shape().ShortDebugString());
diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h
index f1260746bf..07ac7c03b8 100644
--- a/tensorflow/core/kernels/reshape_op.h
+++ b/tensorflow/core/kernels/reshape_op.h
@@ -35,7 +35,7 @@ class ReshapeOp : public OpKernel {
const Tensor& input = context->input(0);
const Tensor& sizes = context->input(1);
// Preliminary validation of sizes.
- OP_REQUIRES(context, TensorShapeUtils::IsLegacyVector(sizes.shape()),
+ OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
errors::InvalidArgument("sizes input must be 1-D, not shape ",
sizes.shape().ShortDebugString()));
const int64 num_dims = sizes.NumElements();
diff --git a/tensorflow/core/kernels/save_op.cc b/tensorflow/core/kernels/save_op.cc
index 4f331f2e2e..51cf86a852 100644
--- a/tensorflow/core/kernels/save_op.cc
+++ b/tensorflow/core/kernels/save_op.cc
@@ -55,7 +55,7 @@ class ShardedFilenameOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
static const char* input_names[3] = {"basename", "shard", "num_shards"};
for (int i = 0; i < ctx->num_inputs(); ++i) {
- OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
+ OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
errors::InvalidArgument(
input_names[i], " must be a scalar, got shape ",
ctx->input(i).shape().ShortDebugString()));
@@ -78,7 +78,7 @@ class ShardedFilespecOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
static const char* input_names[2] = {"basename", "num_shards"};
for (int i = 0; i < ctx->num_inputs(); ++i) {
- OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(ctx->input(i).shape()),
+ OP_REQUIRES(ctx, IsLegacyScalar(ctx->input(i).shape()),
errors::InvalidArgument(
input_names[i], " must be a scalar, got shape ",
ctx->input(i).shape().ShortDebugString()));
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index 7ddb7f474f..96cde639c1 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -184,7 +184,7 @@ class UnsortedSegmentSumOp : public OpKernel {
const Tensor& num_segments = context->input(2);
OP_REQUIRES(
- context, TensorShapeUtils::IsLegacyScalar(num_segments.shape()),
+ context, IsLegacyScalar(num_segments.shape()),
errors::InvalidArgument("num_segments should be a scalar, not shape ",
num_segments.shape().ShortDebugString()));
@@ -406,7 +406,7 @@ class SparseSegmentMeanGradOp : public OpKernel {
errors::InvalidArgument("indices should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
errors::InvalidArgument("segment_ids should be a vector."));
- OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(output_dim0.shape()),
+ OP_REQUIRES(context, IsLegacyScalar(output_dim0.shape()),
errors::InvalidArgument("output_dim0 should be a scalar."));
const int64 N = indices.NumElements();
diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc
index 86b5fe3d74..4a9dc2cc3c 100644
--- a/tensorflow/core/kernels/sequence_ops.cc
+++ b/tensorflow/core/kernels/sequence_ops.cc
@@ -34,13 +34,13 @@ class RangeOp : public OpKernel {
const Tensor& start_in = context->input(0);
const Tensor& limit_in = context->input(1);
const Tensor& delta_in = context->input(2);
- OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(start_in.shape()),
+ OP_REQUIRES(context, IsLegacyScalar(start_in.shape()),
errors::InvalidArgument("start must be a scalar, not shape ",
start_in.shape().ShortDebugString()));
- OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(limit_in.shape()),
+ OP_REQUIRES(context, IsLegacyScalar(limit_in.shape()),
errors::InvalidArgument("limit must be a scalar, not shape ",
limit_in.shape().ShortDebugString()));
- OP_REQUIRES(context, TensorShapeUtils::IsLegacyScalar(delta_in.shape()),
+ OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()),
errors::InvalidArgument("delta must be a scalar, not shape ",
delta_in.shape().ShortDebugString()));
const int32 start = GetValue(start_in.scalar<T>()());
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 015b1f3465..063acd748d 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -69,14 +69,15 @@ static void SharedValidation(OpKernelContext* context,
const Tensor& size_tensor = context->input(2);
OP_REQUIRES(
- context, TensorShapeUtils::IsLegacyVector(begin_tensor.shape()) &&
- TensorShapeUtils::IsLegacyVector(size_tensor.shape()) &&
+ context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
+ context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
begin_tensor.NumElements() == input.dims() &&
size_tensor.NumElements() == input.dims(),
errors::InvalidArgument(
"Expected begin and size arguments to be 1-D tensors of size ",
- input.dims(), ", but got ", begin_tensor.NumElements(), " and ",
- size_tensor.NumElements(), " instead."));
+ input.dims(), ", but got shapes ",
+ begin_tensor.shape().ShortDebugString(), " and ",
+ size_tensor.shape().ShortDebugString(), " instead."));
const int input_dims = input.dims();
*begin = IntTensorToInt64Vec(begin_tensor);
diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc
index 7759dbdc0f..0f57d26225 100644
--- a/tensorflow/core/kernels/sparse_to_dense_op.cc
+++ b/tensorflow/core/kernels/sparse_to_dense_op.cc
@@ -60,7 +60,7 @@ class SparseToDense : public OpKernel {
// output_shape
const Tensor& output_shape = c->input(1);
OP_REQUIRES(
- c, TensorShapeUtils::IsLegacyVector(output_shape.shape()),
+ c, IsLegacyVector(output_shape.shape()),
errors::InvalidArgument("output_shape should be a vector, got shape ",
output_shape.shape().ShortDebugString()));
OP_REQUIRES(c, output_shape.NumElements() == num_dims,
diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc
index ceaa9967e4..e03b1fdf18 100644
--- a/tensorflow/core/kernels/summary_image_op.cc
+++ b/tensorflow/core/kernels/summary_image_op.cc
@@ -48,8 +48,8 @@ class SummaryImageOp : public OpKernel {
void Compute(OpKernelContext* c) override {
const Tensor& tags = c->input(0);
const Tensor& tensor = c->input(1);
- OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
- errors::InvalidArgument("Tags must have be a scalar"));
+ OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
+ errors::InvalidArgument("Tags must be a scalar"));
OP_REQUIRES(c, tensor.dims() == 4 &&
(tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 ||
tensor.dim_size(3) == 4),
diff --git a/tensorflow/core/kernels/summary_op.cc b/tensorflow/core/kernels/summary_op.cc
index 4031e90857..7451ea5673 100644
--- a/tensorflow/core/kernels/summary_op.cc
+++ b/tensorflow/core/kernels/summary_op.cc
@@ -40,12 +40,12 @@ class SummaryScalarOp : public OpKernel {
const Tensor& tags = c->input(0);
const Tensor& values = c->input(1);
- OP_REQUIRES(c, tags.IsSameSize(values) ||
- (TensorShapeUtils::IsLegacyScalar(tags.shape()) &&
- TensorShapeUtils::IsLegacyScalar(values.shape())),
+ OP_REQUIRES(c, tags.IsSameSize(values) || (IsLegacyScalar(tags.shape()) &&
+ IsLegacyScalar(values.shape())),
errors::InvalidArgument("tags and values not the same shape: ",
tags.shape().ShortDebugString(), " != ",
- values.shape().ShortDebugString()));
+ values.shape().ShortDebugString(),
+ SingleTag(tags)));
auto Ttags = tags.flat<string>();
auto Tvalues = values.flat<T>();
Summary s;
@@ -59,6 +59,15 @@ class SummaryScalarOp : public OpKernel {
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
CHECK(s.SerializeToString(&summary_tensor->scalar<string>()()));
}
+
+ // If there's only one tag, include it in the error message
+ static string SingleTag(const Tensor& tags) {
+ if (tags.NumElements() == 1) {
+ return strings::StrCat(" (tag '", tags.flat<string>()(0), "')");
+ } else {
+ return "";
+ }
+ }
};
template <typename T>
@@ -72,7 +81,7 @@ class SummaryHistoOp : public OpKernel {
const Tensor& tags = c->input(0);
const Tensor& values = c->input(1);
const auto flat = values.flat<T>();
- OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
+ OP_REQUIRES(c, IsLegacyScalar(tags.shape()),
errors::InvalidArgument("tags must be scalar"));
// Build histogram of values in "values" tensor
histogram::Histogram histo;
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index 71bb46293e..9fcee33ebd 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -46,7 +46,7 @@ class TileOp : public OpKernel {
const Tensor& multiples = context->input(1);
OP_REQUIRES(
- context, TensorShapeUtils::IsLegacyVector(multiples.shape()),
+ context, IsLegacyVector(multiples.shape()),
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
multiples.shape().ShortDebugString()));
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
@@ -192,7 +192,7 @@ class TileGradientOp : public OpKernel {
const Tensor& input = context->input(0);
const Tensor& multiples = context->input(1);
OP_REQUIRES(
- context, TensorShapeUtils::IsLegacyVector(multiples.shape()),
+ context, IsLegacyVector(multiples.shape()),
errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
multiples.shape().ShortDebugString()));
OP_REQUIRES(context, input.dims() == multiples.NumElements(),
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index c57895b460..cef6231d92 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -153,7 +153,7 @@ class ApplyGradientDescentOp : public OpKernel {
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(0)));
const Tensor& alpha = ctx->input(1);
- OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(alpha.shape()),
+ OP_REQUIRES(ctx, IsLegacyScalar(alpha.shape()),
errors::InvalidArgument("alpha is not a scalar: ",
alpha.shape().DebugString()));
const Tensor& delta = ctx->input(2);
@@ -242,7 +242,7 @@ class ApplyAdagradOp : public OpKernel {
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", def().input(1)));
const Tensor& lr = ctx->input(2);
- OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()),
+ OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
const Tensor& grad = ctx->input(3);
@@ -336,7 +336,7 @@ class SparseApplyAdagradOp : public OpKernel {
errors::InvalidArgument("var must be at least 1 dimensional"));
const Tensor& lr = ctx->input(2);
- OP_REQUIRES(ctx, TensorShapeUtils::IsLegacyScalar(lr.shape()),
+ OP_REQUIRES(ctx, IsLegacyScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
const Tensor& grad = ctx->input(3);
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7fbffb5c02..92d9970043 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -683,6 +683,7 @@ py_library(
"ops/image_ops.py",
"ops/init_ops.py",
"ops/io_ops.py",
+ "ops/learn.py",
"ops/linalg_grad.py",
"ops/linalg_ops.py",
"ops/logging_ops.py",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 093163fc96..f343de453a 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -57,7 +57,8 @@ from tensorflow.python.client.client_lib import *
# Ops
from tensorflow.python.ops.standard_ops import *
-# Bring nn, image_ops, user_ops, compat as a subpackages
+# Bring learn, nn, image_ops, user_ops, compat as a subpackages
+from tensorflow.python.ops import learn
from tensorflow.python.ops import nn
from tensorflow.python.ops import image_ops as image
from tensorflow.python.user_ops import user_ops
@@ -77,7 +78,7 @@ from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
# Don't export modules except for the few we really want
-_whitelist = set([app, compat, errors, flags, image, logging, nn,
+_whitelist = set([app, compat, errors, flags, image, learn, logging, nn,
python_io, resource_loader, test, train, user_ops])
# TODO(b/25561952): tf.tensor_util is DEPRECATED. Please avoid.
_whitelist.update([tensor_util]) # pylint: disable=undefined-variable
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 0f907874b1..b4a8bfdce1 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -3159,6 +3159,8 @@ class GraphKeys(object):
keep moving averages. See
[`tf.moving_average_variables()`](../../api_docs/python/state_ops.md#moving_average_variables)
for more details.
+ * `REGULARIZATION_LOSSES`: regularization losses collected during graph
+ construction.
"""
# Key to collect Variable objects that must be saved and restored
@@ -3178,6 +3180,8 @@ class GraphKeys(object):
ASSET_FILEPATHS = "asset_filepaths"
# Key to collect Variable objects that keep moving averages.
MOVING_AVERAGE_VARIABLES = "moving_average_variables"
+ # Key to collected regularization losses at graph construction.
+ REGULARIZATION_LOSSES = "regularization_losses"
def add_to_collection(name, value):
diff --git a/tensorflow/python/kernel_tests/learn_test.py b/tensorflow/python/kernel_tests/learn_test.py
new file mode 100644
index 0000000000..fc3c75982e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/learn_test.py
@@ -0,0 +1,225 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tf.learn."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+import tensorflow.python.platform # pylint: disable=unused-import,g-bad-import-order
+
+import numpy as np
+import six
+import tensorflow as tf
+
+from tensorflow.python.framework import tensor_util
+
+
+class FullyConnectedTest(tf.test.TestCase):
+
+ def setUp(self):
+ tf.test.TestCase.setUp(self)
+ tf.set_random_seed(1234)
+ self.input = tf.constant([[1., 2., 3.], [-4., 5., -6.]])
+ assert not tf.get_collection(tf.GraphKeys.SUMMARIES)
+
+ def assert_summary_scope(self, regexp):
+ for summary in tf.get_collection(tf.GraphKeys.SUMMARIES):
+ tag = tensor_util.ConstantValue(summary.op.inputs[0])
+ assert tag is not None, 'All summaries have constant tags'
+ tag = str(tag)
+ assert isinstance(tag[0], six.string_types), tag[0]
+ assert re.match(regexp, tag), "tag doesn't match %s: %s" % (regexp, tag)
+
+ def test_basic_use(self):
+ output = tf.learn.fully_connected(self.input, 8, activation_fn=tf.nn.relu)
+
+ with tf.Session() as sess:
+ with self.assertRaises(tf.errors.FailedPreconditionError):
+ sess.run(output)
+
+ tf.initialize_all_variables().run()
+ out_value = sess.run(output)
+
+ self.assertEqual(output.get_shape().as_list(), [2, 8])
+ self.assertTrue(np.all(out_value >= 0),
+ 'Relu should have capped all values.')
+
+ self.assertGreater(tf.get_collection(tf.GraphKeys.SUMMARIES), 0,
+ 'Some summaries should have been added.')
+ self.assertEqual(2,
+ len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
+ self.assertEqual(0,
+ len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
+ self.assert_summary_scope('fully_connected')
+
+ def test_variable_reuse_with_scope(self):
+ with tf.variable_scope('test') as vs:
+ output1 = tf.learn.fully_connected(self.input,
+ 8,
+ activation_fn=tf.nn.relu)
+ output2 = tf.learn.fully_connected(self.input,
+ 8,
+ activation_fn=tf.nn.relu)
+
+ with tf.variable_scope(vs, reuse=True):
+ output3 = tf.learn.fully_connected(self.input,
+ 8,
+ activation_fn=tf.nn.relu)
+
+ with tf.Session() as sess:
+ tf.initialize_all_variables().run()
+ out_value1, out_value2, out_value3 = sess.run([output1, output2, output3])
+
+ self.assertFalse(np.allclose(out_value1, out_value2))
+ self.assertAllClose(out_value1, out_value3)
+
+ def test_variable_reuse_with_template(self):
+ tmpl1 = tf.make_template('test',
+ tf.learn.fully_connected,
+ num_output_nodes=8)
+ output1 = tmpl1(self.input)
+ output2 = tmpl1(self.input)
+
+ with tf.Session() as sess:
+ tf.initialize_all_variables().run()
+ out_value1, out_value2 = sess.run([output1, output2])
+ self.assertAllClose(out_value1, out_value2)
+ self.assert_summary_scope(r'test(_\d)?/fully_connected')
+
+ def test_custom_initializers(self):
+ output = tf.learn.fully_connected(self.input,
+ 2,
+ activation_fn=tf.nn.relu,
+ weight_init=tf.constant_initializer(2.0),
+ bias_init=tf.constant_initializer(1.0))
+
+ with tf.Session() as sess:
+ tf.initialize_all_variables().run()
+ out_value = sess.run(output)
+
+ self.assertAllClose(np.array([[13.0, 13.0], [0.0, 0.0]]), out_value)
+
+ def test_custom_collections(self):
+ tf.learn.fully_connected(self.input,
+ 2,
+ activation_fn=tf.nn.relu,
+ weight_collections=['unbiased'],
+ bias_collections=['biased'])
+
+ self.assertEquals(1, len(tf.get_collection('unbiased')))
+ self.assertEquals(1, len(tf.get_collection('biased')))
+
+ def test_all_custom_collections(self):
+ tf.learn.fully_connected(self.input,
+ 2,
+ activation_fn=tf.nn.relu,
+ weight_collections=['unbiased', 'all'],
+ bias_collections=['biased', 'all'])
+
+ self.assertEquals(1, len(tf.get_collection('unbiased')))
+ self.assertEquals(1, len(tf.get_collection('biased')))
+ self.assertEquals(2, len(tf.get_collection('all')))
+ self.assertEquals(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
+ tf.get_collection('all'))
+
+ def test_no_summaries(self):
+ tf.learn.fully_connected(self.input,
+ 2,
+ activation_fn=tf.nn.relu,
+ create_summaries=False)
+ self.assertEquals([], tf.get_collection(tf.GraphKeys.SUMMARIES))
+
+ def test_regularizer(self):
+ cnt = [0]
+ tensor = tf.constant(5.0)
+ def test_fn(_):
+ cnt[0] += 1
+ return tensor
+
+ tf.learn.fully_connected(self.input, 2, weight_regularizer=test_fn)
+
+ self.assertEqual([tensor],
+ tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
+ self.assertEqual(1, cnt[0])
+
+ def test_shape_enforcement(self):
+ place = tf.placeholder(tf.float32)
+ with self.assertRaises(ValueError):
+ tf.learn.fully_connected(place, 8)
+ tf.learn.fully_connected(place, 8, num_input_nodes=5) # No error
+
+ place.set_shape([None, None])
+ with self.assertRaises(ValueError):
+ tf.learn.fully_connected(place, 8)
+ tf.learn.fully_connected(place, 8, num_input_nodes=5) # No error
+
+ place.set_shape([None, 6])
+ tf.learn.fully_connected(place, 8) # No error
+ with self.assertRaises(ValueError):
+ tf.learn.fully_connected(place, 8, num_input_nodes=5)
+
+ place = tf.placeholder(tf.float32)
+ place.set_shape([2, 6, 5])
+ with self.assertRaises(ValueError):
+ tf.learn.fully_connected(place, 8)
+
+ def test_no_bias(self):
+ tf.learn.fully_connected(self.input, 2, bias_init=None)
+
+ self.assertEqual(1,
+ len(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)))
+
+
+class RegularizerTest(tf.test.TestCase):
+
+ def test_l1(self):
+ with self.assertRaises(ValueError):
+ tf.learn.l1_regularizer(2.)
+ with self.assertRaises(ValueError):
+ tf.learn.l1_regularizer(-1.)
+ with self.assertRaises(ValueError):
+ tf.learn.l1_regularizer(0)
+
+ self.assertIsNone(tf.learn.l1_regularizer(0.)(None))
+
+ values = np.array([1., -1., 4., 2.])
+ weights = tf.constant(values)
+ with tf.Session() as sess:
+ result = sess.run(tf.learn.l1_regularizer(.5)(weights))
+
+ self.assertAllClose(np.abs(values).sum() * .5, result)
+
+ def test_l2(self):
+ with self.assertRaises(ValueError):
+ tf.learn.l2_regularizer(2.)
+ with self.assertRaises(ValueError):
+ tf.learn.l2_regularizer(-1.)
+ with self.assertRaises(ValueError):
+ tf.learn.l2_regularizer(0)
+
+ self.assertIsNone(tf.learn.l2_regularizer(0.)(None))
+
+ values = np.array([1., -1., 4., 2.])
+ weights = tf.constant(values)
+ with tf.Session() as sess:
+ result = sess.run(tf.learn.l2_regularizer(.42)(weights))
+
+ self.assertAllClose(np.power(values, 2).sum() / 2.0 * .42, result)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/ops/learn.py b/tensorflow/python/ops/learn.py
new file mode 100644
index 0000000000..42be9ca205
--- /dev/null
+++ b/tensorflow/python/ops/learn.py
@@ -0,0 +1,359 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# pylint: disable=g-short-docstring-punctuation
+"""## Higher level ops related to regularization and building layers.
+
+This package provides several ops that take care of creating variables that are
+used internally in a consistent way and provide the building blocks for many
+common machine learning algorithms.
+
+@@fully_connected
+
+## Regularizers
+
+Regularization can help prevent overfitting.
+These have the signature `fn(weights)`. The loss is typically added to
+`tf.GraphKeys.REGULARIZATION_LOSS`
+
+@@l1_regularizer
+@@l2_regularizer
+
+## Initializations
+
+This also includes a common initialization for connecting multiple layers.
+
+@@xavier_initializer
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numbers
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import logging
+
+
+__all__ = ['xavier_initializer', 'fully_connected', 'l1_regularizer',
+ 'l2_regularizer']
+
+
+def xavier_initializer(n_inputs, n_outputs, uniform=True):
+ """Set the parameter initialization using the method described in paper.
+
+ Xavier Glorot and Yoshua Bengio (2010):
+ Understanding the difficulty of training deep feedforward neural
+ networks. International conference on artificial intelligence and
+ statistics.
+
+ This method is designed to keep the scale of the gradients roughly the same
+ in all layers. In uniform distribution this ends up being the range:
+ `x = sqrt(6. / (in + out)); [-x, x]` and for normal distribution a standard
+ deviation of `sqrt(3. / (in + out))` is used.
+
+ Args:
+ n_inputs: The number of input nodes into each output.
+ n_outputs: The number of output nodes for each input.
+ uniform: If true use a uniform distribution, otherwise use a truncated
+ normal.
+
+ Returns:
+ An initializer.
+ """
+ if uniform:
+ # 6 was used in the paper.
+ init_range = math.sqrt(6.0 / (n_inputs + n_outputs))
+ return standard_ops.random_uniform_initializer(-init_range, init_range)
+ else:
+ # 3 gives us approximately the same limits as above since this repicks
+ # values greater than 2 standard deviations from the mean.
+ stddev = math.sqrt(3.0 / (n_inputs + n_outputs))
+ return standard_ops.truncated_normal_initializer(stddev=stddev)
+
+
+def _assert_summary_tag_unique(tag):
+ for summary in ops.get_collection(ops.GraphKeys.SUMMARIES):
+ old_tag = tensor_util.ConstantValue(summary.op.inputs[0])
+ if tag == str(old_tag):
+ raise ValueError('Conflict with summary tag: %s exists on summary %s %s' %
+ (tag, summary, old_tag))
+
+
+def _add_scalar_summary(tensor, tag=None):
+ """Add a summary operation for the tensor.
+
+ Args:
+ tensor: The tensor to summarize.
+ tag: The tag to use, if None then use tensor's op's name.
+
+ Returns:
+ The created histogram summary.
+
+ Raises:
+ ValueError: If the tag is already in use or the rank is not 0.
+ """
+ tensor.get_shape().assert_has_rank(0)
+ tag = tag or tensor.op.name
+ _assert_summary_tag_unique(tag)
+ return standard_ops.scalar_summary(tag, tensor, name='%s_summary' % tag)
+
+
+def _add_histogram_summary(tensor, tag=None):
+ """Add a summary operation for the histogram of a tensor.
+
+ Args:
+ tensor: The tensor to summarize.
+ tag: The tag to use, if None then use tensor's op's name.
+
+ Returns:
+ The created histogram summary.
+
+ Raises:
+ ValueError: If the tag is already in use.
+ """
+ # TODO(opensource): A global or scoped mechanism to disable summaries.
+ tag = tag or tensor.op.name
+ _assert_summary_tag_unique(tag)
+ return standard_ops.histogram_summary(tag, tensor, name='%s_summary' % tag)
+
+
+def _apply_activation_with_summaries(x, activation_fn):
+ """Returns activation_fn(x).
+
+ This applies the given activation and adds useful summaries specific to the
+ activation.
+
+ Args:
+ x: The tensor to apply activation to.
+ activation_fn: An activation function.
+ Returns:
+ A tensor with activation applied to x.
+ """
+ if activation_fn is None:
+ return x
+ y = activation_fn(x)
+ if activation_fn in (nn.relu, nn.softplus, nn.relu6):
+ # Using x for comparison to avoid floating point equality and/or epsilons.
+ _add_scalar_summary(
+ standard_ops.reduce_mean(standard_ops.to_float(standard_ops.less(
+ x, 0.0))), '%s/zeros' % y.op.name)
+ if activation_fn is nn.relu6:
+ _add_scalar_summary(
+ standard_ops.reduce_mean(standard_ops.to_float(standard_ops.greater(
+ x, 6.0))), '%s/sixes' % y.op.name)
+ if activation_fn is nn.l2_normalize:
+ _add_scalar_summary(
+ standard_ops.reduce_mean(standard_ops.sqrt(standard_ops.sum(
+ standard_ops.square(x), 1))), '%s/length' % y.op.name)
+ _add_histogram_summary(y, '%s/activations' % y.op.name)
+ return y
+
+
+def _apply_regularization(w, regularizer):
+ loss = regularizer(w)
+ if loss:
+ ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
+
+
+def l1_regularizer(scale):
+ """Returns a function that can be used to apply L1 regularization to weights.
+
+ L1 regularization encourages sparsity.
+
+ Args:
+ scale: A scalar multiplier `Tensor`. 0.0 disables the regularizer.
+
+ Returns:
+ A function with signature `l1(weights, name=None)` that apply L1
+ regularization.
+
+ Raises:
+ ValueError: If scale is outside of the range [0.0, 1.0] or if scale is not a
+ float.
+ """
+ if isinstance(scale, numbers.Integral):
+ raise ValueError('scale cannot be an integer: %s' % scale)
+ if isinstance(scale, numbers.Real):
+ if scale < 0.:
+ raise ValueError('Setting a scale less than 0 on a regularizer: %g' %
+ scale)
+ if scale >= 1.:
+ raise ValueError('Setting a scale greater than 1 on a regularizer: %g' %
+ scale)
+ if scale == 0.:
+ logging.info('Scale of 0 disables regularizer.')
+ return lambda _, name=None: None
+ def l1(weights, name=None):
+ """Applies L1 regularization to weights."""
+ with ops.op_scope([weights], name, 'l1_regularizer') as scope:
+ my_scale = ops.convert_to_tensor(scale,
+ dtype=weights.dtype.base_dtype,
+ name='scale')
+ return standard_ops.mul(
+ my_scale,
+ standard_ops.reduce_sum(standard_ops.abs(weights)),
+ name=scope)
+ return l1
+
+
+def l2_regularizer(scale):
+ """Returns a function that can be used to apply L2 regularization to weights.
+
+ Small values of L2 can help prevent overfitting the training data.
+
+ Args:
+ scale: A scalar multiplier `Tensor`. 0.0 disables the regularizer.
+
+ Returns:
+ A function with signature `l2(weights, name=None)` that applies L2
+ regularization.
+
+ Raises:
+ ValueError: If scale is outside of the range [0.0, 1.0] or if scale is not a
+ float.
+ """
+ if isinstance(scale, numbers.Integral):
+ raise ValueError('scale cannot be an integer: %s' % (scale,))
+ if isinstance(scale, numbers.Real):
+ if scale < 0.:
+ raise ValueError('Setting a scale less than 0 on a regularizer: %g.' %
+ scale)
+ if scale >= 1.:
+ raise ValueError('Setting a scale greater than 1 on a regularizer: %g.' %
+ scale)
+ if scale == 0.:
+ logging.info('Scale of 0 disables regularizer.')
+ return lambda _, name=None: None
+ def l2(weights, name=None):
+ """Applies l2 regularization to weights."""
+ with ops.op_scope([weights], name, 'l2_regularizer') as scope:
+ my_scale = ops.convert_to_tensor(scale,
+ dtype=weights.dtype.base_dtype,
+ name='scale')
+ return standard_ops.mul(my_scale, nn.l2_loss(weights), name=scope)
+ return l2
+
+
+def fully_connected(x,
+ num_output_nodes,
+ activation_fn=None,
+ weight_init=None,
+ bias_init=standard_ops.constant_initializer(0.),
+ num_input_nodes=None,
+ name=None,
+ weight_collections=None,
+ bias_collections=None,
+ weight_regularizer=None,
+ create_summaries=True):
+ """Adds the parameters for a fully connected layer and returns the output.
+
+ A fully connected layer is generally defined as a matrix multiply:
+ \\\\(y = f(w * x + b)\\\\) where **f** is given by `activation_fn`
+
+ This op creates `w` and optionally `b` (disable with `bias_init=None`) and
+ adds various summaries that can be useful for visualizing learning or
+ diagnosing training problems. The variable creation is compatible with
+ `tf.variable_scope` and so can be reused with `tf.variable_scope` or
+ `tf.make_template`.
+
+ In almost all cases, the number of input nodes can be inferred from the shape
+ of `x`, but if it is unspecified or additional size checks are desired, then
+ `num_input_nodes` can be specified.
+
+ Most of the details of variable creation can be controlled by specifying the
+ initializers (`weight_init` and `bias_init`) and which collections to place
+ the created variables in (`weight_collections` and `bias_collections`).
+
+ A per layer regularization can be specified by setting `weight_regularizer`.
+ This is only applied to weights and not the bias.
+
+ Args:
+ x: The input tensor.
+ num_output_nodes: The size of the output.
+ activation_fn: A function that requires a single Tensor that is applied as a
+ non-linearity. If None is used, then this is a linear layer.
+ weight_init: An optional initialization. If not specified, uses Xavier
+ initialization (see `tf.learn.xavier_initializer`).
+ bias_init: An initializer for the bias, defaults to 0.
+ num_input_nodes: The number of input nodes.
+ name: The name for this operation is used to name operations and to find
+ variables. If specified it must be unique for this scope, otherwise a
+ unique name starting with "fully_connected" will be created. See
+ `tf.variable_op_scope` for details.
+ weight_collections: List of graph collections for just weights.
+ bias_collections: List of graph collections for just bias.
+ weight_regularizer: A regularizer like the result of
+ `tf.learn.l1_regularizer` or `tf.learn.l2_regularizer`.
+ create_summaries: Set to false to disable summaries.
+
+ Returns:
+ The result of applying a fully connected layer.
+
+ Raises:
+ ValueError: if `x` is not rank 2; or `x`'s second dimension is not known
+ and `num_input_nodes` is not specified.
+ """
+ with variable_scope.variable_op_scope([x], name, 'fully_connected') as vs:
+ # Check rank and if num_input_nodes is specified, make sure it matches.
+ x.get_shape().assert_is_compatible_with([None, num_input_nodes])
+
+ if not num_input_nodes:
+ if x.get_shape().dims is None or x.get_shape().dims[1].value is None:
+ raise ValueError(
+ 'If x has an unknown first dimension then num_input_nodes '
+ 'must be specified; shape: %s num_input_nodes: %s'
+ % (x.get_shape(), num_input_nodes))
+ else:
+ num_input_nodes = x.get_shape().dims[1].value
+
+ weight_init = weight_init or xavier_initializer(
+ num_input_nodes, num_output_nodes)
+
+ dtype = x.dtype
+ w = variable_scope.get_variable('weights',
+ shape=[num_input_nodes, num_output_nodes],
+ dtype=dtype,
+ initializer=weight_init,
+ collections=weight_collections)
+
+ if not vs.reuse and create_summaries:
+ _add_histogram_summary(w)
+
+ y = standard_ops.matmul(x, w)
+ # Regularization is only applied to the weights and not bias.
+ if weight_regularizer:
+ _apply_regularization(w, weight_regularizer)
+ if bias_init:
+ b = variable_scope.get_variable('bias',
+ shape=[num_output_nodes],
+ dtype=dtype,
+ initializer=bias_init,
+ collections=bias_collections)
+ if not vs.reuse and create_summaries:
+ _add_histogram_summary(b)
+
+ y = nn.bias_add(y, b)
+
+ if create_summaries:
+ return _apply_activation_with_summaries(y, activation_fn)
+ else:
+ return activation_fn(y)
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
index 94d874f067..a409b81bad 100644
--- a/tensorflow/python/ops/op_def_library.py
+++ b/tensorflow/python/ops/op_def_library.py
@@ -559,6 +559,7 @@ class OpDefLibrary(object):
"less than minimum %d." %
(key, op_type_name, len(value),
attr_def.minimum))
+ attr_value.list.SetInParent()
if attr_def.type == "string":
attr_value.s = _MakeStr(value, key)
if attr_def.HasField("allowed_values"):