From e6bfaf47374b44bb688023904eac98576baf4cd4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Feb 2017 16:43:47 -0800 Subject: Fix cmake build: Exclude graph tranfser libs from cmake Change: 146852544 --- tensorflow/contrib/cmake/tf_core_kernels.cmake | 4 ++-- tensorflow/contrib/cmake/tf_tests.cmake | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 1f31482048..5a92518fc4 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -84,15 +84,15 @@ file(GLOB_RECURSE tf_core_kernels_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/debug_ops.h" # stream_executor dependency "${tensorflow_source_dir}/tensorflow/core/kernels/debug_ops.cc" # stream_executor dependency + "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/*" + "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_execute*.cc" ) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) if(WIN32) file(GLOB_RECURSE tf_core_kernels_windows_exclude_srcs # not working on windows yet - "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/*" "${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*" - "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_execute_op*.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h" "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.cc" ) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index d92dab23ed..a8c1ba0f2b 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -237,6 +237,8 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/cc/framework/gradients_test.cc" "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/call_options_test.cc" "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/tensor_coding_test.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/graph_transferer_test.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc" ) if (NOT tensorflow_ENABLE_GPU) @@ -279,8 +281,6 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/core/kernels/quantization_utils_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/quantize_down_and_shrink_range_op_test.cc" - "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/graph_transferer_test.cc" - "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/debug_ops_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/quantized_activation_ops_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/quantized_bias_add_op_test.cc" -- cgit v1.2.3 From 780bc6b4d98665125c43685b20eeba6ad2804c0c Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 7 Feb 2017 16:45:06 -0800 Subject: Add support for variable major dimension in dense features in example parser c++ op. Full python support (including more comprehensive documentation) coming soon. Change: 146852707 --- tensorflow/core/kernels/example_parsing_ops.cc | 24 +- .../core/kernels/example_parsing_ops_test.cc | 54 ++-- tensorflow/core/ops/parsing_ops.cc | 13 +- tensorflow/core/ops/parsing_ops_test.cc | 42 ++- tensorflow/core/util/example_proto_fast_parsing.cc | 317 +++++++++++++++++---- tensorflow/core/util/example_proto_fast_parsing.h | 1 + tensorflow/core/util/example_proto_helper.h | 26 +- tensorflow/python/kernel_tests/parsing_ops_test.py | 98 +++++++ tensorflow/python/ops/parsing_ops.py | 48 +++- 9 files changed, 525 insertions(+), 98 deletions(-) diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index d03f8fa33a..f4c4460fa4 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -92,7 +92,18 @@ class ExampleParserOp : public OpKernel { for (int d = 0; d < static_cast(attrs_.num_dense); ++d) { const Tensor& def_value = dense_defaults[d]; - if (def_value.NumElements() > 0) { + if (attrs_.variable_length[d]) { + OP_REQUIRES(ctx, def_value.NumElements() == 1, + errors::InvalidArgument( + "dense_shape[", d, "] is a variable length shape: ", + attrs_.dense_shapes[d].DebugString(), + ", therefore " + "def_value[", + d, + "] must contain a single element (" + "the padding element). But its shape is: ", + def_value.shape().DebugString())); + } else if (def_value.NumElements() > 0) { OP_REQUIRES(ctx, attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()), errors::InvalidArgument( @@ -100,12 +111,12 @@ class ExampleParserOp : public OpKernel { "].shape() == ", def_value.shape().DebugString(), " is not compatible with dense_shapes_[", d, "] == ", attrs_.dense_shapes[d].DebugString())); - OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d], - errors::InvalidArgument( - "dense_defaults[", d, "].dtype() == ", - DataTypeString(def_value.dtype()), " != dense_types_[", - d, "] == ", DataTypeString(attrs_.dense_types[d]))); } + OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d], + errors::InvalidArgument( + "dense_defaults[", d, "].dtype() == ", + DataTypeString(def_value.dtype()), " != dense_types_[", d, + "] == ", DataTypeString(attrs_.dense_types[d]))); } example::Result result; @@ -114,6 +125,7 @@ class ExampleParserOp : public OpKernel { for (int d = 0; d < attrs_.num_dense; ++d) { config.dense.push_back({dense_keys_t[d], attrs_.dense_types[d], attrs_.dense_shapes[d], dense_defaults[d], + attrs_.variable_length[d], attrs_.elements_per_stride[d]}); } for (int d = 0; d < attrs_.num_sparse; ++d) { diff --git a/tensorflow/core/kernels/example_parsing_ops_test.cc b/tensorflow/core/kernels/example_parsing_ops_test.cc index 67ac477713..29dbfd3b1b 100644 --- a/tensorflow/core/kernels/example_parsing_ops_test.cc +++ b/tensorflow/core/kernels/example_parsing_ops_test.cc @@ -127,9 +127,11 @@ template <> ExampleTensorMap ExampleStore::serialized_example = ExampleStore::GetSerializedExamples(); -template +enum BenchmarkType { kDense, kSparse, kVarLenDense }; + +template struct BenchmarkOptions { - bool benchmark_dense = BenchmarkDense; + int benchmark_type = b_type; typedef S Store; typename S::Filler filler; }; @@ -145,19 +147,28 @@ static Graph* ParseExample(int batch_size, int num_keys, int feature_size) { std::vector dense_keys; std::vector dense_defaults; std::vector sparse_types; - std::vector dense_shapes; + std::vector dense_shapes; Options opt; for (int i = 0; i < num_keys; ++i) { Tensor key(DT_STRING, TensorShape()); key.scalar()() = strings::Printf("feature_%d", i); - if (opt.benchmark_dense) { - dense_keys.emplace_back(test::graph::Constant(g, key)); - dense_defaults.emplace_back(test::graph::Constant( - g, opt.filler.make_dense_default(feature_size))); - dense_shapes.push_back(TensorShape({feature_size})); - } else { - sparse_keys.emplace_back(test::graph::Constant(g, key)); - sparse_types.push_back(opt.filler.dtype); + switch (opt.benchmark_type) { + case kDense: + dense_keys.emplace_back(test::graph::Constant(g, key)); + dense_defaults.emplace_back(test::graph::Constant( + g, opt.filler.make_dense_default(feature_size))); + dense_shapes.push_back(PartialTensorShape({feature_size})); + break; + case kVarLenDense: + dense_keys.emplace_back(test::graph::Constant(g, key)); + dense_defaults.emplace_back( + test::graph::Constant(g, opt.filler.make_dense_default(1))); + dense_shapes.push_back(PartialTensorShape({-1})); + break; + case kSparse: + sparse_keys.emplace_back(test::graph::Constant(g, key)); + sparse_types.push_back(opt.filler.dtype); + break; } } @@ -176,12 +187,18 @@ static Graph* ParseExample(int batch_size, int num_keys, int feature_size) { } // Benchmark settings (Sparse, Dense) X (Bytes, Int64, Float) -typedef BenchmarkOptions, false> SparseString; -typedef BenchmarkOptions, true> DenseString; -typedef BenchmarkOptions, false> SparseInt64; -typedef BenchmarkOptions, true> DenseInt64; -typedef BenchmarkOptions, false> SparseFloat; -typedef BenchmarkOptions, true> DenseFloat; +typedef BenchmarkOptions, kSparse> SparseString; +typedef BenchmarkOptions, kDense> DenseString; +typedef BenchmarkOptions, kVarLenDense> + VarLenDenseString; +typedef BenchmarkOptions, kSparse> SparseInt64; +typedef BenchmarkOptions, kDense> DenseInt64; +typedef BenchmarkOptions, kVarLenDense> + VarLenDenseInt64; +typedef BenchmarkOptions, kSparse> SparseFloat; +typedef BenchmarkOptions, kDense> DenseFloat; +typedef BenchmarkOptions, kVarLenDense> + VarLenDenseFloat; // B == batch_size, K == num_keys. F == feature_size. // K must be one of 10, 100, 1000 @@ -205,9 +222,12 @@ typedef BenchmarkOptions, true> DenseFloat; BM_AllParseExample(SparseString); BM_AllParseExample(DenseString); +BM_AllParseExample(VarLenDenseString); BM_AllParseExample(SparseInt64); BM_AllParseExample(DenseInt64); +BM_AllParseExample(VarLenDenseInt64); BM_AllParseExample(SparseFloat); BM_AllParseExample(DenseFloat); +BM_AllParseExample(VarLenDenseFloat); } // end namespace tensorflow diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 4ca3f2e07e..b563656f39 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -113,7 +113,11 @@ dense_defaults: A list of Ndense Tensors (some may be empty). when the example's feature_map lacks dense_key[j]. If an empty Tensor is provided for dense_defaults[j], then the Feature dense_keys[j] is required. The input type is inferred from dense_defaults[j], even when it's empty. - If dense_defaults[j] is not empty, its shape must match dense_shapes[j]. + If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, + then the shape of dense_defaults[j] must match that of dense_shapes[j]. + If dense_shapes[j] has an undefined major dimension (variable strides dense + feature), dense_defaults[j] must contain a single element: + the padding element. dense_shapes: A list of Ndense shapes; the shapes of data in each Feature given in dense_keys. The number of elements in the Feature corresponding to dense_key[j] @@ -121,6 +125,13 @@ dense_shapes: A list of Ndense shapes; the shapes of data in each Feature If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): The dense outputs are just the inputs row-stacked by batch. + This works for dense_shapes[j] = (-1, D1, ..., DN). In this case + the shape of the output Tensor dense_values[j] will be + (|serialized|, M, D1, .., DN), where M is the maximum number of blocks + of elements of length D1 * .... * DN, across all minibatch entries + in the input. Any minibatch entry with less than M blocks of elements of + length D1 * ... * DN will be padded with the corresponding default_value + scalar element along the second dimension. sparse_keys: A list of Nsparse string Tensors (scalars). The keys expected in the Examples' features associated with sparse values. sparse_types: A list of Nsparse types; the data types of data in each Feature diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc index 6167c136b1..dc2aa19ee1 100644 --- a/tensorflow/core/ops/parsing_ops_test.cc +++ b/tensorflow/core/ops/parsing_ops_test.cc @@ -61,11 +61,19 @@ TEST(ParsingOpsTest, DecodeCSV_ShapeFn) { } static std::vector MakeDenseShapes(int size, - bool add_extra_shape) { + bool add_extra_shape, + int unknown_outer_dims) { std::vector shapes(size); for (int i = 0; i < size; ++i) { - // Make shapes be the sequence [1]; [1,2], [1,2,3]... - if (i > 0) shapes[i] = shapes[i - 1]; + // Make shapes be the sequence [?,1]; [?,1,2], [?,1,2,3]... + // where the number of prefixed ? depends on unknown_outer_dims. + if (i == 0) { + for (int d = 0; d < unknown_outer_dims; ++d) { + shapes[i].add_dim()->set_size(-1); + } + } else { + shapes[i] = shapes[i - 1]; + } shapes[i].add_dim()->set_size(i + 1); } if (add_extra_shape) { @@ -77,7 +85,8 @@ static std::vector MakeDenseShapes(int size, TEST(ParsingOpsTest, ParseExample_ShapeFn) { ShapeInferenceTestOp op("ParseExample"); auto set_outputs = [&op](int num_sparse, int num_dense, - bool add_extra_shape = false) { + bool add_extra_shape = false, + int unknown_outer_dims = 0) { using NodeOutList = std::vector; using DataTypeList = std::vector; NodeDefBuilder::NodeOut string_in{"a", 0, DT_STRING}; @@ -91,7 +100,8 @@ TEST(ParsingOpsTest, ParseExample_ShapeFn) { .Input(NodeOutList(num_dense, string_in)) .Attr("sparse_types", DataTypeList(num_sparse, DT_FLOAT)) .Attr("dense_types", DataTypeList(num_dense, DT_FLOAT)) - .Attr("dense_shapes", MakeDenseShapes(num_dense, add_extra_shape)) + .Attr("dense_shapes", MakeDenseShapes(num_dense, add_extra_shape, + unknown_outer_dims)) .Finalize(&op.node_def)); }; @@ -115,6 +125,24 @@ TEST(ParsingOpsTest, ParseExample_ShapeFn) { set_outputs(2, 3, true /* add_extra_shape */); INFER_ERROR("len(dense_keys) != len(dense_shapes)", op, "?;?;?;?;?;?;?;?;?;?"); + + // Allow variable strides + set_outputs(2, 3, false /* add_extra_shape */, 1 /* unknown_outer_dims */); + INFER_OK(op, "?;?;?;?;?;?;?;?;?;?", + ("[?,2];[?,2];[?];[?];[2];[2];" // sparse outputs + "[?,?,1];[?,?,1,2];[?,?,1,2,3]")); // dense outputs + INFER_OK(op, "[10];?;?;?;?;?;?;?;?;?", + ("[?,2];[?,2];[?];[?];[2];[2];" // sparse outputs + "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3]")); // dense outputs + + set_outputs(2, 3, true /* add_extra_shape */, 1 /* unknown_outer_dims */); + INFER_ERROR("len(dense_keys) != len(dense_shapes)", op, + "?;?;?;?;?;?;?;?;?;?"); + + // Variable inner dimensions are not supported + set_outputs(2, 3, false /* add_extra_shape */, 2 /* unknown_outer_dims */); + INFER_ERROR("shapes[0] has unknown rank or unknown inner dimensions", op, + "?;?;?;?;?;?;?;?;?;?"); } TEST(ParsingOpsTest, ParseSingleSequenceExample_ShapeFn) { @@ -142,13 +170,13 @@ TEST(ParsingOpsTest, ParseSingleSequenceExample_ShapeFn) { .Attr("context_dense_types", DataTypeList(num_context_dense, DT_FLOAT)) .Attr("context_dense_shapes", - MakeDenseShapes(num_context_dense, add_extra_shape)) + MakeDenseShapes(num_context_dense, add_extra_shape, 0)) .Attr("feature_list_sparse_types", DataTypeList(num_feature_list_sparse, DT_FLOAT)) .Attr("feature_list_dense_types", DataTypeList(num_feature_list_dense, DT_FLOAT)) .Attr("feature_list_dense_shapes", - MakeDenseShapes(num_feature_list_dense, add_extra_shape)) + MakeDenseShapes(num_feature_list_dense, add_extra_shape, 0)) .Finalize(&op.node_def)); }; diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index e14f50551e..facb092dbc 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -424,6 +424,7 @@ Status FastParseSerializedExample( const size_t example_index, const Config& config, const PresizedCuckooMap>& config_index, SeededHasher hasher, std::vector* output_dense, + std::vector* output_varlen_dense, std::vector* output_sparse) { DCHECK(output_dense != nullptr); DCHECK(output_sparse != nullptr); @@ -463,9 +464,9 @@ Status FastParseSerializedExample( } auto example_error = [&](StringPiece suffix) { - return errors::InvalidArgument("Name: ", example_name, ", Key: ", - feature_name, ", Index: ", example_index, - ". ", suffix); + return errors::InvalidArgument("Name: ", example_name, + ", Key: ", feature_name, + ", Index: ", example_index, ". ", suffix); }; auto parse_error = [&] { @@ -494,54 +495,117 @@ Status FastParseSerializedExample( dense_feature_last_example[d] = example_index; if (example_dtype != config.dense[d].dtype) { - return example_error( - strings::StrCat("Data types don't match. Data type: ", - DataTypeString(example_dtype), "Expected type: ", - DataTypeString(config.dense[d].dtype))); + return example_error(strings::StrCat( + "Data types don't match. Data type: ", + DataTypeString(example_dtype), + "Expected type: ", DataTypeString(config.dense[d].dtype))); } - Tensor& out = (*output_dense)[d]; + if (!config.dense[d].variable_length) { + Tensor& out = (*output_dense)[d]; + + const std::size_t num_elements = config.dense[d].elements_per_stride; + const std::size_t offset = example_index * num_elements; + + auto shape_error = [&](size_t size, StringPiece type_str) { + return example_error(strings::StrCat( + "Number of ", type_str, + " values != expected. " + "Values size: ", + size, + " but output shape: ", config.dense[d].shape.DebugString())); + }; + + switch (config.dense[d].dtype) { + case DT_INT64: { + auto out_p = out.flat().data() + offset; + LimitedArraySlice slice(out_p, num_elements); + if (!feature.ParseInt64List(&slice)) return parse_error(); + if (slice.EndDistance() != 0) { + return shape_error(num_elements - slice.EndDistance(), "int64"); + } + break; + } + case DT_FLOAT: { + auto out_p = out.flat().data() + offset; + LimitedArraySlice slice(out_p, num_elements); + if (!feature.ParseFloatList(&slice)) return parse_error(); + if (slice.EndDistance() != 0) { + return shape_error(num_elements - slice.EndDistance(), "float"); + } + break; + } + case DT_STRING: { + auto out_p = out.flat().data() + offset; + LimitedArraySlice slice(out_p, num_elements); + if (!feature.ParseBytesList(&slice)) return parse_error(); + if (slice.EndDistance() != 0) { + return shape_error(num_elements - slice.EndDistance(), "bytes"); + } + break; + } + default: + CHECK(false) << "Should not happen."; + } + } else { // if variable length + SparseBuffer& out = (*output_varlen_dense)[d]; - const std::size_t num_elements = config.dense[d].elements_per_stride; - const std::size_t offset = example_index * num_elements; + const std::size_t num_elements = config.dense[d].elements_per_stride; - auto shape_error = [&](size_t size, StringPiece type_str) { - return example_error(strings::StrCat( - "Number of ", type_str, - " values != expected. " - "Values size: ", - size, " but output shape: ", config.dense[d].shape.DebugString())); - }; + if (example_dtype != DT_INVALID && + example_dtype != config.dense[d].dtype) { + return example_error(strings::StrCat( + "Data types don't match. ", + "Expected type: ", DataTypeString(config.dense[d].dtype))); + } - switch (config.dense[d].dtype) { - case DT_INT64: { - auto out_p = out.flat().data() + offset; - LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseInt64List(&slice)) return parse_error(); - if (slice.EndDistance() != 0) { - return shape_error(num_elements - slice.EndDistance(), "int64"); + auto shape_error = [&](size_t size, StringPiece type_str) { + return example_error(strings::StrCat( + "Number of ", type_str, + " values is not a multiple of stride length. Saw ", size, + " values but output shape is: ", + config.dense[d].shape.DebugString())); + }; + + switch (config.dense[d].dtype) { + case DT_INT64: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseInt64List(&out.int64_list)) { + return parse_error(); + } + if (out.int64_list.size() % num_elements != 0) { + return shape_error(out.int64_list.size(), "int64"); + } + } + out.example_end_indices.push_back(out.int64_list.size()); + break; } - break; - } - case DT_FLOAT: { - auto out_p = out.flat().data() + offset; - LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseFloatList(&slice)) return parse_error(); - if (slice.EndDistance() != 0) { - return shape_error(num_elements - slice.EndDistance(), "float"); + case DT_FLOAT: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseFloatList(&out.float_list)) { + return parse_error(); + } + if (out.float_list.size() % num_elements != 0) { + return shape_error(out.float_list.size(), "float"); + } + } + out.example_end_indices.push_back(out.float_list.size()); + break; } - break; - } - case DT_STRING: { - auto out_p = out.flat().data() + offset; - LimitedArraySlice slice(out_p, num_elements); - if (!feature.ParseBytesList(&slice)) return parse_error(); - if (slice.EndDistance() != 0) { - return shape_error(num_elements - slice.EndDistance(), "bytes"); + case DT_STRING: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseBytesList(&out.bytes_list)) { + return parse_error(); + } + if (out.bytes_list.size() % num_elements != 0) { + return shape_error(out.bytes_list.size(), "bytes"); + } + } + out.example_end_indices.push_back(out.bytes_list.size()); + break; } - break; + default: + CHECK(false) << "Should not happen."; } - default: - CHECK(false) << "Should not happen."; } } else { // If feature was already visited, skip. @@ -563,9 +627,9 @@ Status FastParseSerializedExample( SparseBuffer& out = (*output_sparse)[d]; if (example_dtype != DT_INVALID && example_dtype != config.sparse[d].dtype) { - return example_error( - strings::StrCat("Data types don't match. ", "Expected type: ", - DataTypeString(config.sparse[d].dtype))); + return example_error(strings::StrCat( + "Data types don't match. ", + "Expected type: ", DataTypeString(config.sparse[d].dtype))); } switch (config.sparse[d].dtype) { @@ -602,8 +666,9 @@ Status FastParseSerializedExample( } } - // Handle missing dense features. + // Handle missing dense features for fixed strides. for (size_t d = 0; d < config.dense.size(); ++d) { + if (config.dense[d].variable_length) continue; if (dense_feature_last_example[d] == example_index) continue; if (config.dense[d].default_value.NumElements() == 0) { return errors::InvalidArgument( @@ -637,6 +702,16 @@ Status FastParseSerializedExample( } } + // Handle missing varlen dense features. + for (size_t d = 0; d < config.dense.size(); ++d) { + if (!config.dense[d].variable_length) continue; + if (dense_feature_last_example[d] == example_index) continue; + SparseBuffer& out = (*output_varlen_dense)[d]; + size_t prev_example_end_index = + out.example_end_indices.empty() ? 0 : out.example_end_indices.back(); + out.example_end_indices.push_back(prev_example_end_index); + } + // Handle missing sparse features. for (size_t d = 0; d < config.sparse.size(); ++d) { if (sparse_feature_last_example[d] == example_index) continue; @@ -661,6 +736,65 @@ Status CheckConfigDataType(DataType dtype) { } } +template +const SmallVector& GetListFromBuffer(const SparseBuffer& buffer); + +template <> +const SmallVector& GetListFromBuffer(const SparseBuffer& buffer) { + return buffer.int64_list; +} +template <> +const SmallVector& GetListFromBuffer(const SparseBuffer& buffer) { + return buffer.float_list; +} +template <> +const SmallVector& GetListFromBuffer( + const SparseBuffer& buffer) { + return buffer.bytes_list; +} + +template +void CopyOrMoveBlock(const T* b, const T* e, T* t) { + std::copy(b, e, t); +} +template <> +void CopyOrMoveBlock(const string* b, const string* e, string* t) { + std::move(b, e, t); +} + +template +void FillAndCopyVarLen( + const int d, const size_t num_elements, + const size_t num_elements_per_minibatch, const size_t data_stride_size, + const Config& config, + const std::vector>& varlen_dense_buffers, + Tensor* values) { + const Tensor& default_value = config.dense[d].default_value; + + // Copy-fill the tensors (creating the zero/fill-padding) + std::fill(values->flat().data(), values->flat().data() + num_elements, + default_value.flat()(0)); + + // Iterate over minibatch elements + for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) { + const SparseBuffer& buffer = varlen_dense_buffers[i][d]; + const size_t offset = i * num_elements_per_minibatch; + const size_t stride_size = config.dense[d].elements_per_stride; + + // Copy values over. + auto& list = GetListFromBuffer(buffer); + auto list_ptr = list.begin(); + auto data = values->flat().data() + offset; + DCHECK(list.size() % stride_size == 0); + const size_t num_entries = list.size() / stride_size; + for (size_t j = 0; j < num_entries; ++j) { + CopyOrMoveBlock(list_ptr, list_ptr + stride_size, data); + list_ptr += stride_size; + data += data_stride_size; + } + } +} + } // namespace Status FastParseExample(const Config& config, @@ -701,14 +835,17 @@ Status FastParseExample(const Config& config, "Could not avoid collision. This should not happen."); } - // Allocate dense output (sparse have to be buffered). + // Allocate dense output for fixed length dense values + // (variable-length dense and sparse have to be buffered). + std::vector fixed_dense_values(config.dense.size()); for (size_t d = 0; d < config.dense.size(); ++d) { + if (config.dense[d].variable_length) continue; TensorShape out_shape; out_shape.AddDim(serialized.size()); for (const int64 dim : config.dense[d].shape.dim_sizes()) { out_shape.AddDim(dim); } - result->dense_values.emplace_back(config.dense[d].dtype, out_shape); + fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape); } // This parameter affects performance in a big and data-dependent way. @@ -750,17 +887,19 @@ Status FastParseExample(const Config& config, // Do minibatches in parallel. std::vector> sparse_buffers(num_minibatches); + std::vector> varlen_dense_buffers(num_minibatches); std::vector status_of_minibatch(num_minibatches); auto ProcessMiniBatch = [&](size_t minibatch) { sparse_buffers[minibatch].resize(config.sparse.size()); + varlen_dense_buffers[minibatch].resize(config.dense.size()); size_t start = first_example_of_minibatch(minibatch); size_t end = first_example_of_minibatch(minibatch + 1); for (size_t e = start; e < end; ++e) { status_of_minibatch[minibatch] = FastParseSerializedExample( serialized[e], (example_names.size() > 0 ? example_names[e] : ""), e, - config, config_index, hasher, &result->dense_values, - &sparse_buffers[minibatch]); + config, config_index, hasher, &fixed_dense_values, + &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch]); if (!status_of_minibatch[minibatch].ok()) break; } }; @@ -771,8 +910,12 @@ Status FastParseExample(const Config& config, TF_RETURN_IF_ERROR(status); } + for (size_t d = 0; d < config.dense.size(); ++d) { + result->dense_values.push_back(std::move(fixed_dense_values[d])); + } + // Merge SparseBuffers from all minibatches for every config.sparse. - auto MergeMinibatches = [&](size_t d) { + auto MergeSparseMinibatches = [&](size_t d) { // Loop over minibatches size_t total_num_features = 0; size_t max_num_features = 0; @@ -849,8 +992,76 @@ Status FastParseExample(const Config& config, } }; + // Merge SparseBuffers from all minibatches for every config.dense having + // variable_length. + auto MergeDenseVarLenMinibatches = [&](size_t d) { + if (!config.dense[d].variable_length) return; + + // Loop over minibatches + size_t max_num_features = 0; + for (auto& dense_values_tmp : varlen_dense_buffers) { + std::vector& end_indices = + dense_values_tmp[d].example_end_indices; + max_num_features = std::max(max_num_features, end_indices[0]); + for (size_t i = 1; i < end_indices.size(); ++i) { + size_t example_size = end_indices[i] - end_indices[i - 1]; + max_num_features = std::max(max_num_features, example_size); + } + } + + const size_t stride_size = config.dense[d].elements_per_stride; + const size_t max_num_elements = max_num_features / stride_size; + TensorShape values_shape; + DCHECK(max_num_features % config.dense[d].elements_per_stride == 0); + const size_t batch_size = serialized.size(); + values_shape.AddDim(batch_size); + values_shape.AddDim(max_num_elements); + for (int i = 1; i < config.dense[d].shape.dims(); ++i) { + values_shape.AddDim(config.dense[d].shape.dim_size(i)); + } + Tensor values(config.dense[d].dtype, values_shape); + result->dense_values[d] = values; + const size_t num_elements = values.NumElements(); + + // Nothing to write, exit early. + if (num_elements == 0) return; + + const size_t num_elements_per_minibatch = num_elements / batch_size; + const size_t data_stride_size = + (max_num_elements == 0) + ? 0 + : (num_elements_per_minibatch / max_num_elements); + + switch (config.dense[d].dtype) { + case DT_INT64: { + FillAndCopyVarLen(d, num_elements, num_elements_per_minibatch, + data_stride_size, config, varlen_dense_buffers, + &values); + break; + } + case DT_FLOAT: { + FillAndCopyVarLen(d, num_elements, num_elements_per_minibatch, + data_stride_size, config, varlen_dense_buffers, + &values); + break; + } + case DT_STRING: { + FillAndCopyVarLen(d, num_elements, num_elements_per_minibatch, + data_stride_size, config, + varlen_dense_buffers, &values); + break; + } + default: + CHECK(false) << "Should not happen."; + } + }; + + for (size_t d = 0; d < config.dense.size(); ++d) { + MergeDenseVarLenMinibatches(d); + } + for (size_t d = 0; d < config.sparse.size(); ++d) { - MergeMinibatches(d); + MergeSparseMinibatches(d); } return Status::OK(); diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h index 4878199802..5f8b4af5fe 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.h +++ b/tensorflow/core/util/example_proto_fast_parsing.h @@ -48,6 +48,7 @@ struct FastParseExampleConfig { // Documentation is avaliable in: tensorflow/core/ops/parsing_ops.cc PartialTensorShape shape; Tensor default_value; + bool variable_length; std::size_t elements_per_stride; }; diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h index 971d97266c..44838d2e54 100644 --- a/tensorflow/core/util/example_proto_helper.h +++ b/tensorflow/core/util/example_proto_helper.h @@ -161,13 +161,32 @@ class ParseSingleExampleAttrs { // Temporary check until we start allowing a variable length outer // dimension. for (int i = 0; i < dense_shapes.size(); ++i) { - if (!dense_shapes[i].IsFullyDefined()) { + bool shape_ok = true; + if (dense_shapes[i].dims() == -1) { + shape_ok = false; + } else { + for (int d = 1; d < dense_shapes[i].dims(); ++d) { + if (dense_shapes[i].dim_size(d) == -1) { + shape_ok = false; + } + } + } + if (!shape_ok) { return errors::InvalidArgument( "dense_shapes[", i, - "] is not fully defined: ", dense_shapes[i].DebugString()); + "] has unknown rank or unknown inner dimensions: ", + dense_shapes[i].DebugString()); } TensorShape dense_shape; - dense_shapes[i].AsTensorShape(&dense_shape); + if (dense_shapes[i].dims() > 0 && dense_shapes[i].dim_size(0) == -1) { + variable_length.push_back(true); + for (int d = 1; d < dense_shapes[i].dims(); ++d) { + dense_shape.AddDim(dense_shapes[i].dim_size(d)); + } + } else { + variable_length.push_back(false); + dense_shapes[i].AsTensorShape(&dense_shape); + } elements_per_stride.push_back(dense_shape.num_elements()); } return FinishInit(); @@ -178,6 +197,7 @@ class ParseSingleExampleAttrs { std::vector sparse_types; std::vector dense_types; std::vector dense_shapes; + std::vector variable_length; std::vector elements_per_stride; private: diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py index b66d271f3c..c19d5a9536 100644 --- a/tensorflow/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/parsing_ops_test.py @@ -92,6 +92,7 @@ class ParseExampleTest(test.TestCase): expected_err[1]): out = parsing_ops.parse_example(**kwargs) sess.run(flatten_values_tensors_or_sparse(out.values())) + return else: # Returns dict w/ Tensors and SparseTensors. out = parsing_ops.parse_example(**kwargs) @@ -636,6 +637,103 @@ class ParseExampleTest(test.TestCase): } }, expected_output) + def testSerializedContainingVarLenDense(self): + aname = "a" + bname = "b" + cname = "c" + dname = "d" + example_names = ["in1", "in2", "in3", "in4"] + original = [ + example(features=features({ + cname: int64_feature([2]), + })), + example(features=features({ + aname: float_feature([1, 1]), + bname: bytes_feature([b"b0_str", b"b1_str"]), + })), + example(features=features({ + aname: float_feature([-1, -1, 2, 2]), + bname: bytes_feature([b"b1"]), + })), + example(features=features({ + aname: float_feature([]), + cname: int64_feature([3]), + })), + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + aname: + np.array( + [ + [0, 0, 0, 0], + [1, 1, 0, 0], + [-1, -1, 2, 2], + [0, 0, 0, 0], + ], + dtype=np.float32).reshape(4, 2, 2, 1), + bname: + np.array( + [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]], + dtype=bytes).reshape(4, 2, 1, 1, 1), + cname: + np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1), + dname: + np.empty(shape=(4, 0), dtype=bytes), + } + + self._test({ + "example_names": example_names, + "serialized": ops.convert_to_tensor(serialized), + "features": { + aname: + parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32), + bname: + parsing_ops.FixedLenFeature( + (None, 1, 1, 1), dtype=dtypes.string), + cname: + parsing_ops.FixedLenFeature((None,), dtype=dtypes.int64), + dname: + parsing_ops.FixedLenFeature((None,), dtype=dtypes.string), + } + }, expected_output) + + # Change number of required values so the inputs are not a + # multiple of this size. + self._test( + { + "example_names": example_names, + "serialized": ops.convert_to_tensor(serialized), + "features": { + aname: + parsing_ops.FixedLenFeature( + (None, 2, 1), dtype=dtypes.float32), + bname: + parsing_ops.FixedLenFeature( + (None, 2, 1, 1), dtype=dtypes.string), + } + }, + expected_err=( + errors_impl.OpError, "Name: in3, Key: b, Index: 2. " + "Number of bytes values is not a multiple of stride length.")) + + self._test( + { + "example_names": example_names, + "serialized": ops.convert_to_tensor(serialized), + "features": { + aname: + parsing_ops.FixedLenFeature( + (None, 2, 1), dtype=dtypes.float32, default_value=[]), + bname: + parsing_ops.FixedLenFeature( + (None, 2, 1, 1), dtype=dtypes.string), + } + }, + expected_err=(ValueError, + "Cannot reshape a tensor with 0 elements to shape")) + class ParseSingleExampleTest(test.TestCase): diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index 079837bce3..77c7cd397a 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -476,8 +476,13 @@ def _parse_example_raw(serialized, The keys of the dict must match the dense_keys of the feature. dense_shapes: A list of tuples with the same length as `dense_keys`. The shape of the data for each dense feature referenced by `dense_keys`. - Required for any input tensors identified by `dense_keys` whose shapes are - anything other than `[]` or `[1]`. + Required for any input tensors identified by `dense_keys`. Must be + either fully defined, or may contain an unknown first dimension. + An unknown first dimension means the feature is treated as having + a variable number of blocks, and the output shape along this dimension + is considered unknown at graph build time. Padding is applied for + minibatch elements smaller than the maximum number of blocks for the + given feature along this dimension. name: A name for this operation (optional). Returns: @@ -516,21 +521,42 @@ def _parse_example_raw(serialized, "Dense and sparse keys must not intersect; intersection: %s" % set(dense_keys).intersection(set(sparse_keys))) + # Convert dense_shapes to TensorShape object. + dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes] + dense_defaults_vec = [] for i, key in enumerate(dense_keys): default_value = dense_defaults.get(key) - if default_value is None: - default_value = constant_op.constant([], dtype=dense_types[i]) - elif not isinstance(default_value, ops.Tensor): - key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) - default_value = ops.convert_to_tensor( - default_value, dtype=dense_types[i], name=key_name) - default_value = array_ops.reshape(default_value, dense_shapes[i]) + dense_shape = dense_shapes[i] + if (dense_shape.ndims is not None and dense_shape.ndims > 0 and + dense_shape[0].value is None): + # Variable stride dense shape, the default value should be a + # scalar padding value + if default_value is None: + default_value = ops.convert_to_tensor( + "" if dense_types[i] == dtypes.string else 0, + dtype=dense_types[i]) + else: + # Reshape to a scalar to ensure user gets an error if they + # provide a tensor that's not intended to be a padding value + # (0 or 2+ elements). + key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) + default_value = ops.convert_to_tensor( + default_value, dtype=dense_types[i], name=key_name) + default_value = array_ops.reshape(default_value, []) + else: + if default_value is None: + default_value = constant_op.constant([], dtype=dense_types[i]) + elif not isinstance(default_value, ops.Tensor): + key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key) + default_value = ops.convert_to_tensor( + default_value, dtype=dense_types[i], name=key_name) + default_value = array_ops.reshape(default_value, dense_shape) dense_defaults_vec.append(default_value) - dense_shapes = [tensor_shape.as_shape(shape).as_proto() - for shape in dense_shapes] + # Finally, convert dense_shapes to TensorShapeProto + dense_shapes = [shape.as_proto() for shape in dense_shapes] # pylint: disable=protected-access outputs = gen_parsing_ops._parse_example( -- cgit v1.2.3 From 7734b20cefe20da220a151e900f2235dc26fa2b1 Mon Sep 17 00:00:00 2001 From: Martin Wicke Date: Tue, 7 Feb 2017 16:57:53 -0800 Subject: Fix doc generator links. Change: 146854187 --- tensorflow/tools/docs/parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 8fdf34f26d..5d70a52e5f 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -547,7 +547,7 @@ def _generate_markdown_for_module(full_name, duplicate_names, module, _CODE_URL_PREFIX = ( - 'https://www.tensorflow.org/code/') + 'https://www.tensorflow.org/code/tensorflow/') def generate_markdown(full_name, py_object, @@ -629,7 +629,7 @@ def generate_markdown(full_name, py_object, # Never include links outside this code base. if not path.startswith('..'): - markdown += '\n\nDefined in [`%s`](%s%s).\n\n' % ( + markdown += '\n\nDefined in [`tensorflow/%s`](%s%s).\n\n' % ( path, _CODE_URL_PREFIX, path) except TypeError: # getfile throws TypeError if py_object is a builtin. markdown += '\n\nThis is an alias for a Python built-in.' -- cgit v1.2.3 From b51bd5c84bc398810f678584aa8b0baee9902232 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Tue, 7 Feb 2017 16:58:36 -0800 Subject: Change description of FakeQuantizeWithMinMaxVars to not imply that the shape of the input matters. Change: 146854267 --- tensorflow/core/ops/array_ops.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index a03f910c14..631440402d 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -4911,9 +4911,8 @@ REGISTER_OP("FakeQuantWithMinMaxVars") return Status::OK(); }) .Doc(R"doc( -Fake-quantize the 'inputs' tensor of type float and shape `[b, h, w, d]` via -global float scalars `min` and `max` to 'outputs' tensor of same shape as -`inputs`. +Fake-quantize the 'inputs' tensor of type float via global float scalars `min` +and `max` to 'outputs' tensor of same shape as `inputs`. [min; max] is the clamping range for the 'inputs' data. Op divides this range into 255 steps (total of 256 values), then replaces each 'inputs' value with the -- cgit v1.2.3 From 324b32fc0d41fd5b5c3bb753140c8ef8dd7d64a5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Feb 2017 17:17:12 -0800 Subject: Update ops-related pbtxt files. Change: 146856336 --- tensorflow/core/ops/ops.pbtxt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 2e800a8887..8437527b0d 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -7109,8 +7109,8 @@ op { name: "outputs" type: DT_FLOAT } - summary: "Fake-quantize the \'inputs\' tensor of type float and shape `[b, h, w, d]` via" - description: "global float scalars `min` and `max` to \'outputs\' tensor of same shape as\n`inputs`.\n\n[min; max] is the clamping range for the \'inputs\' data. Op divides this range\ninto 255 steps (total of 256 values), then replaces each \'inputs\' value with the\nclosest of the quantized step values.\n\nThis operation has a gradient and thus allows for training `min` and `max` values." + summary: "Fake-quantize the \'inputs\' tensor of type float via global float scalars `min`" + description: "and `max` to \'outputs\' tensor of same shape as `inputs`.\n\n[min; max] is the clamping range for the \'inputs\' data. Op divides this range\ninto 255 steps (total of 256 values), then replaces each \'inputs\' value with the\nclosest of the quantized step values.\n\nThis operation has a gradient and thus allows for training `min` and `max` values." } op { name: "FakeQuantWithMinMaxVarsGradient" @@ -11793,7 +11793,7 @@ op { } input_arg { name: "dense_defaults" - description: "A list of Ndense Tensors (some may be empty).\ndense_defaults[j] provides default values\nwhen the example\'s feature_map lacks dense_key[j]. If an empty Tensor is\nprovided for dense_defaults[j], then the Feature dense_keys[j] is required.\nThe input type is inferred from dense_defaults[j], even when it\'s empty.\nIf dense_defaults[j] is not empty, its shape must match dense_shapes[j]." + description: "A list of Ndense Tensors (some may be empty).\ndense_defaults[j] provides default values\nwhen the example\'s feature_map lacks dense_key[j]. If an empty Tensor is\nprovided for dense_defaults[j], then the Feature dense_keys[j] is required.\nThe input type is inferred from dense_defaults[j], even when it\'s empty.\nIf dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,\nthen the shape of dense_defaults[j] must match that of dense_shapes[j].\nIf dense_shapes[j] has an undefined major dimension (variable strides dense\nfeature), dense_defaults[j] must contain a single element:\nthe padding element." type_list_attr: "Tdense" } output_arg { @@ -11852,7 +11852,7 @@ op { attr { name: "dense_shapes" type: "list(shape)" - description: "A list of Ndense shapes; the shapes of data in each Feature\ngiven in dense_keys.\nThe number of elements in the Feature corresponding to dense_key[j]\nmust always equal dense_shapes[j].NumEntries().\nIf dense_shapes[j] == (D0, D1, ..., DN) then the shape of output\nTensor dense_values[j] will be (|serialized|, D0, D1, ..., DN):\nThe dense outputs are just the inputs row-stacked by batch." + description: "A list of Ndense shapes; the shapes of data in each Feature\ngiven in dense_keys.\nThe number of elements in the Feature corresponding to dense_key[j]\nmust always equal dense_shapes[j].NumEntries().\nIf dense_shapes[j] == (D0, D1, ..., DN) then the shape of output\nTensor dense_values[j] will be (|serialized|, D0, D1, ..., DN):\nThe dense outputs are just the inputs row-stacked by batch.\nThis works for dense_shapes[j] = (-1, D1, ..., DN). In this case\nthe shape of the output Tensor dense_values[j] will be\n(|serialized|, M, D1, .., DN), where M is the maximum number of blocks\nof elements of length D1 * .... * DN, across all minibatch entries\nin the input. Any minibatch entry with less than M blocks of elements of\nlength D1 * ... * DN will be padded with the corresponding default_value\nscalar element along the second dimension." has_minimum: true } summary: "Transforms a vector of brain.Example protos (as strings) into typed tensors." -- cgit v1.2.3 From 890bdcdf5a6be9458d841f834c74c2e661f54f8d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Feb 2017 17:20:01 -0800 Subject: Update generated Python Op docs. Change: 146856620 --- tensorflow/g3doc/api_docs/python/array_ops.md | 5 ++--- .../functions_and_classes/shard3/tf.fake_quant_with_min_max_vars.md | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md index 6f87d3b2d1..a945ecfba5 100644 --- a/tensorflow/g3doc/api_docs/python/array_ops.md +++ b/tensorflow/g3doc/api_docs/python/array_ops.md @@ -3016,10 +3016,9 @@ Compute gradients for a FakeQuantWithMinMaxArgs operation. ### `tf.fake_quant_with_min_max_vars(inputs, min, max, name=None)` {#fake_quant_with_min_max_vars} -Fake-quantize the 'inputs' tensor of type float and shape `[b, h, w, d]` via +Fake-quantize the 'inputs' tensor of type float via global float scalars `min` -global float scalars `min` and `max` to 'outputs' tensor of same shape as -`inputs`. +and `max` to 'outputs' tensor of same shape as `inputs`. [min; max] is the clamping range for the 'inputs' data. Op divides this range into 255 steps (total of 256 values), then replaces each 'inputs' value with the diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.fake_quant_with_min_max_vars.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.fake_quant_with_min_max_vars.md index d7815f0414..74ed0e0242 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.fake_quant_with_min_max_vars.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.fake_quant_with_min_max_vars.md @@ -1,9 +1,8 @@ ### `tf.fake_quant_with_min_max_vars(inputs, min, max, name=None)` {#fake_quant_with_min_max_vars} -Fake-quantize the 'inputs' tensor of type float and shape `[b, h, w, d]` via +Fake-quantize the 'inputs' tensor of type float via global float scalars `min` -global float scalars `min` and `max` to 'outputs' tensor of same shape as -`inputs`. +and `max` to 'outputs' tensor of same shape as `inputs`. [min; max] is the clamping range for the 'inputs' data. Op divides this range into 255 steps (total of 256 values), then replaces each 'inputs' value with the -- cgit v1.2.3 From 50e03a2676c1c97c943e2a9c64b337ecee5c3b1a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Feb 2017 17:24:19 -0800 Subject: Go: Update generated wrapper functions for TensorFlow ops. Change: 146857053 --- tensorflow/go/op/wrappers.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index b3989e30d3..07716897cf 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -420,10 +420,9 @@ func DebugNumericSummary(scope *Scope, input tf.Output, optional ...DebugNumeric return op.Output(0) } -// Fake-quantize the 'inputs' tensor of type float and shape `[b, h, w, d]` via +// Fake-quantize the 'inputs' tensor of type float via global float scalars `min` // -// global float scalars `min` and `max` to 'outputs' tensor of same shape as -// `inputs`. +// and `max` to 'outputs' tensor of same shape as `inputs`. // // [min; max] is the clamping range for the 'inputs' data. Op divides this range // into 255 steps (total of 256 values), then replaces each 'inputs' value with the @@ -10064,7 +10063,11 @@ func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (o // when the example's feature_map lacks dense_key[j]. If an empty Tensor is // provided for dense_defaults[j], then the Feature dense_keys[j] is required. // The input type is inferred from dense_defaults[j], even when it's empty. -// If dense_defaults[j] is not empty, its shape must match dense_shapes[j]. +// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined, +// then the shape of dense_defaults[j] must match that of dense_shapes[j]. +// If dense_shapes[j] has an undefined major dimension (variable strides dense +// feature), dense_defaults[j] must contain a single element: +// the padding element. // sparse_types: A list of Nsparse types; the data types of data in each Feature // given in sparse_keys. // Currently the ParseExample supports DT_FLOAT (FloatList), @@ -10076,6 +10079,13 @@ func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (o // If dense_shapes[j] == (D0, D1, ..., DN) then the shape of output // Tensor dense_values[j] will be (|serialized|, D0, D1, ..., DN): // The dense outputs are just the inputs row-stacked by batch. +// This works for dense_shapes[j] = (-1, D1, ..., DN). In this case +// the shape of the output Tensor dense_values[j] will be +// (|serialized|, M, D1, .., DN), where M is the maximum number of blocks +// of elements of length D1 * .... * DN, across all minibatch entries +// in the input. Any minibatch entry with less than M blocks of elements of +// length D1 * ... * DN will be padded with the corresponding default_value +// scalar element along the second dimension. func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_keys []tf.Output, dense_keys []tf.Output, dense_defaults []tf.Output, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) { if scope.Err() != nil { return -- cgit v1.2.3 From 8418879238584c3b81aa9dd29c65234f95c6d9f2 Mon Sep 17 00:00:00 2001 From: Dandelion Mané Date: Tue, 7 Feb 2017 17:39:24 -0800 Subject: Migrate tf_color_scale to use webfiles build rules. Change: 146858476 --- .../tensorboard/components/tf_color_scale/BUILD | 63 +++++++++++++++ .../components/tf_color_scale/demo/BUILD | 26 ++++++ .../components/tf_color_scale/demo/index.html | 94 ++++++++++++++++++++++ .../components/tf_color_scale/tf-color-scale.html | 1 + 4 files changed, 184 insertions(+) create mode 100644 tensorflow/tensorboard/components/tf_color_scale/BUILD create mode 100644 tensorflow/tensorboard/components/tf_color_scale/demo/BUILD create mode 100644 tensorflow/tensorboard/components/tf_color_scale/demo/index.html diff --git a/tensorflow/tensorboard/components/tf_color_scale/BUILD b/tensorflow/tensorboard/components/tf_color_scale/BUILD new file mode 100644 index 0000000000..75bf812fe5 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_color_scale/BUILD @@ -0,0 +1,63 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule") +load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library") + +licenses(["notice"]) # Apache 2.0 + +# TODO(dandelion): Add webfiles support for the test code. + +webfiles( + name = "tf_color_scale", + srcs = [ + "tf-color-scale.html", + ":ts", + ], + path = "/tf-color-scale", + deps = [ + "//tensorflow/tensorboard/components/tf_imports:d3", + "@org_polymer", + ], +) + +tensorboard_typescript_genrule( + name = "ts", + srcs = [ + "colorScale.ts", + "palettes.ts", + ], + typings = ["@org_definitelytyped//:d3.d.ts"], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) + +################################################################################ +# MARKED FOR DELETION + +tensorboard_webcomponent_library( + name = "legacy", + srcs = [ + "tf-color-scale.html", + ":legacy_ts", + ], + destdir = "tf-color-scale", + deps = [ + "//tensorflow/tensorboard/components:tf_imports", + "//third_party/javascript/polymer/v1/polymer:lib", + ], +) + +tensorboard_ts_library( + name = "legacy_ts", + srcs = [ + "colorScale.ts", + "palettes.ts", + ], + deps = ["//tensorflow/tensorboard/components:common_deps"], +) diff --git a/tensorflow/tensorboard/components/tf_color_scale/demo/BUILD b/tensorflow/tensorboard/components/tf_color_scale/demo/BUILD new file mode 100644 index 0000000000..00b8a033b8 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_color_scale/demo/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//tensorflow:internal"]) + +load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles") + +licenses(["notice"]) # Apache 2.0 + +# bazel run //third_party/tensorflow/tensorboard/components/tf_color_scale/demo +webfiles( + name = "demo", + srcs = ["index.html"], + path = "/tf-color-scale/demo", + deps = [ + "//tensorflow/tensorboard/components/tf_color_scale", + "//tensorflow/tensorboard/components/tf_imports:d3", + "@org_polymer_iron_demo_helpers", + "@org_polymer_paper_button", + "@org_polymer_paper_styles", + "@org_polymer_webcomponentsjs", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["**"]), + tags = ["notsan"], +) diff --git a/tensorflow/tensorboard/components/tf_color_scale/demo/index.html b/tensorflow/tensorboard/components/tf_color_scale/demo/index.html new file mode 100644 index 0000000000..ad9edbda98 --- /dev/null +++ b/tensorflow/tensorboard/components/tf_color_scale/demo/index.html @@ -0,0 +1,94 @@ + + + + + +tf-color-scale demo + + + + + + + + + + + diff --git a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html b/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html index 743996f624..79bee6d957 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html +++ b/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html @@ -16,6 +16,7 @@ limitations under the License. --> +