aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD7
-rw-r--r--tensorflow/core/client/tensor_c_api.cc2
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc6
-rw-r--r--tensorflow/core/framework/op_kernel.cc4
-rw-r--r--tensorflow/core/framework/op_kernel.h4
-rw-r--r--tensorflow/core/framework/shape_inference.cc6
-rw-r--r--tensorflow/core/framework/shape_inference.h23
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc61
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.cc9
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.h11
-rw-r--r--tensorflow/core/framework/tensor.cc13
-rw-r--r--tensorflow/core/framework/tensor.h5
-rw-r--r--tensorflow/core/framework/unique_tensor_references.cc2
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/barrier_ops.cc5
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc6
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc22
-rw-r--r--tensorflow/core/kernels/conv_ops.cc9
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc9
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_sigmoid.cc9
-rw-r--r--tensorflow/core/kernels/cwise_op_tanh.cc8
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h30
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h71
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h107
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.cc52
-rw-r--r--tensorflow/core/kernels/tensor_array.h4
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector_test.cc2
-rw-r--r--tensorflow/core/ops/array_ops.cc85
-rw-r--r--tensorflow/core/ops/array_ops_test.cc137
-rw-r--r--tensorflow/core/ops/compat/ops_history.v0.pbtxt56
-rw-r--r--tensorflow/core/ops/math_ops.cc21
-rw-r--r--tensorflow/core/ops/ops.pbtxt60
-rw-r--r--tensorflow/core/platform/default/tracing.cc17
-rw-r--r--tensorflow/core/platform/posix/tracing.cc40
36 files changed, 814 insertions, 94 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index b684522eb6..b2a928867f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1812,10 +1812,13 @@ tf_cc_test(
],
)
-tf_cc_test(
- name = "ops/math_ops_test",
+tf_cc_tests(
size = "small",
linkstatic = tf_kernel_tests_linkstatic(),
+ tests = [
+ "ops/array_ops_test.cc",
+ "ops/math_ops_test.cc",
+ ],
deps = [
":core",
":core_cpu",
diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc
index a7f1a4fa78..9959280029 100644
--- a/tensorflow/core/client/tensor_c_api.cc
+++ b/tensorflow/core/client/tensor_c_api.cc
@@ -475,7 +475,7 @@ void TF_Run_Helper(TF_Session* s, const char* handle,
// Store results in c_outputs[]
for (int i = 0; i < noutputs; i++) {
const Tensor& src = outputs[i];
- if (!src.IsInitialized()) {
+ if (!src.IsInitialized() || src.NumElements() == 0) {
c_outputs[i] = tensorflow::EmptyTensor(
static_cast<TF_DataType>(src.dtype()), src.shape());
continue;
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 870970b7ca..46dd7913d3 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -746,6 +746,12 @@ void MasterSession::UpdateLastAccessTime() {
}
Status MasterSession::Create(GraphDef* graph_def) {
+ if (session_opts_.config.graph_options().place_pruned_graph()) {
+ // TODO(b/29900832): Fix this or remove the option.
+ return errors::Unimplemented(
+ "MasterSession does not support the place_pruned_graph option.");
+ }
+
// Keeps a copy of graph_def->library() and flib_def_ serves the
// OpRegistryInterface used by the SimpleGraphExecutionState to construct the
// pre-partitioned graphs during DoRunWithLocalExecution().
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 6bfc55df41..12df379e8f 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -159,7 +159,7 @@ Status OpKernelConstruction::allocate_temp(DataType type,
attr.allocation_will_be_logged = true;
Tensor new_temp(allocator_, type, shape, attr);
- if (!new_temp.IsInitialized() && shape.num_elements() > 0) {
+ if (!new_temp.IsInitialized()) {
return errors::ResourceExhausted(
"OOM when allocating temporary tensor with shape", shape.DebugString());
}
@@ -447,7 +447,7 @@ Status OpKernelContext::allocate_tensor(
logged_attr.allocation_will_be_logged = true;
Tensor new_tensor(a, type, shape, logged_attr);
- if (!new_tensor.IsInitialized() && shape.num_elements() > 0) {
+ if (!new_tensor.IsInitialized()) {
return errors::ResourceExhausted("OOM when allocating tensor with shape",
shape.DebugString());
}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index 0092c6286f..a6cc323cea 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -199,7 +199,9 @@ class PersistentTensor {
// The check for initialization does not need to access the
// underlying tensor buffer.
- bool IsInitialized() { return tensor_.IsInitialized(); }
+ bool IsInitialized() const { return tensor_.IsInitialized(); }
+
+ int64 NumElements() const { return tensor_.NumElements(); }
private:
Tensor tensor_;
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 2df57d6cab..bd8e6ea309 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -25,9 +25,9 @@ constexpr int32 InferenceContext::kUnknownRank;
constexpr int64 InferenceContext::kUnknownDim;
InferenceContext::InferenceContext(
- const std::vector<string>& input_shapes, int num_outputs,
- const std::vector<const Tensor*>& input_tensors)
- : input_tensors_(input_tensors) {
+ const NodeDef* node_def, const std::vector<string>& input_shapes,
+ int num_outputs, const std::vector<const Tensor*>& input_tensors)
+ : input_tensors_(input_tensors), node_def_(*CHECK_NOTNULL(node_def)) {
for (const string& spec : input_shapes) {
if (spec == "?") {
inputs_.push_back(CreateUnknownShape());
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index bb6a66dc53..6385177bc1 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -80,7 +82,10 @@ class InferenceContext {
// the same Dimension*.
//
// <input_tensors> is NULL-padded to be the same size as <input_shapes>.
- InferenceContext(const std::vector<string>& input_shapes, int num_outputs,
+ //
+ // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext.
+ InferenceContext(const NodeDef* node_def,
+ const std::vector<string>& input_shapes, int num_outputs,
const std::vector<const Tensor*>& input_tensors = {});
~InferenceContext();
@@ -162,6 +167,12 @@ class InferenceContext {
const Dimension* CreateDim(int64 value);
const Dimension* CreateUnknownDim();
+ // Look up the attr for the NodeDef being evaluated with name attr_name and
+ // set *value to its value. If no attr with attr_name is found in def(), or
+ // the attr does not have a matching type, a non-ok status will be returned.
+ template <class T>
+ Status GetAttr(StringPiece attr_name, T* value) const;
+
private:
Status ReturnUnknownShape(const Shape** out) {
*out = CreateUnknownShape();
@@ -181,9 +192,14 @@ class InferenceContext {
std::vector<const Tensor*> input_tensors_;
std::vector<const Shape*> outputs_;
+ const NodeDef& node_def_;
+
TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext);
};
+// -----------------------------------------------------------------------------
+// Template and inline method implementations, please ignore
+
inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {}
inline Dimension::Dimension(int64 value) : value_(value) {}
@@ -191,6 +207,11 @@ inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {}
inline Shape::Shape(const std::vector<const Dimension*> dims)
: rank_(dims.size()), dims_(dims) {}
+template <class T>
+Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const {
+ return GetNodeAttr(node_def_, attr_name, value);
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index e4ca7645b2..e52d1c5a2d 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h"
@@ -21,7 +23,8 @@ namespace tensorflow {
namespace shape_inference {
TEST(ShapeInferenceTest, RankAndDimInspection) {
- InferenceContext c({"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"?", "[1,?,3]", "[]"}, 2 /* num_outputs */);
EXPECT_EQ(3, c.num_inputs());
EXPECT_EQ(2, c.num_outputs());
@@ -54,7 +57,8 @@ TEST(ShapeInferenceTest, RankAndDimInspection) {
}
TEST(ShapeInferenceTest, WithRank) {
- InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -91,7 +95,8 @@ TEST(ShapeInferenceTest, WithRank) {
}
TEST(ShapeInferenceTest, WithRankAtLeast) {
- InferenceContext c({"?", "[1,?,3]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"?", "[1,?,3]"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -125,7 +130,8 @@ TEST(ShapeInferenceTest, WithRankAtLeast) {
}
TEST(ShapeInferenceTest, WithValue) {
- InferenceContext c({"[1,?]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[1,?]"}, 2 /* num_outputs */);
auto d0 = c.Dim(c.input(0), 0);
auto d1 = c.Dim(c.input(0), 1);
@@ -163,7 +169,8 @@ TEST(ShapeInferenceTest, WithValue) {
}
TEST(ShapeInferenceTest, MergeDim) {
- InferenceContext c({"[2,?,2,1,?]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[2,?,2,1,?]"}, 2 /* num_outputs */);
auto d2 = c.Dim(c.input(0), 0);
auto d_unknown = c.Dim(c.input(0), 1);
@@ -202,7 +209,9 @@ TEST(ShapeInferenceTest, MergeDim) {
}
TEST(ShapeInferenceTest, MergeShape) {
- InferenceContext c({"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
+ NodeDef def;
+ InferenceContext c(&def,
+ {"?", "[1,2]", "[?,2]", "[1,?]", "[1,3]", "?", "[1]"},
2 /* num_outputs */);
auto s_unknown = c.input(0);
@@ -260,7 +269,8 @@ TEST(ShapeInferenceTest, MergeShape) {
}
TEST(ShapeInferenceTest, Subshape) {
- InferenceContext c({"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[1,2,3,?,5]", "?"}, 2 /* num_outputs */);
const Shape* unknown = c.input(1);
const Shape* out;
@@ -297,7 +307,8 @@ TEST(ShapeInferenceTest, Subshape) {
}
TEST(ShapeInferenceTest, Concatenate) {
- InferenceContext c({"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[1,?,3]", "[4,5]", "?"}, 2 /* num_outputs */);
auto in0 = c.input(0);
auto in1 = c.input(1);
@@ -322,7 +333,8 @@ TEST(ShapeInferenceTest, Concatenate) {
}
TEST(ShapeInferenceTest, CreateShape) {
- InferenceContext c({"[1,2,3,?,5]"}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {"[1,2,3,?,5]"}, 2 /* num_outputs */);
std::vector<const Dimension*> dims;
auto in0 = c.input(0);
@@ -341,7 +353,8 @@ TEST(ShapeInferenceTest, CreateShape) {
}
TEST(ShapeInferenceTest, CreateUnknownShape) {
- InferenceContext c({}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {}, 2 /* num_outputs */);
auto u0 = c.CreateUnknownShape();
auto u1 = c.CreateUnknownShape();
@@ -352,7 +365,8 @@ TEST(ShapeInferenceTest, CreateUnknownShape) {
TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
auto create = [](Tensor* t) {
- InferenceContext c({"?"}, 0 /* num_outputs */, {t});
+ NodeDef def;
+ InferenceContext c(&def, {"?"}, 0 /* num_outputs */, {t});
const Shape* out;
Status s = c.CreateShapeFromShapeTensor(0, &out);
if (s.ok()) {
@@ -386,7 +400,8 @@ TEST(ShapeInferenceTest, CreateShapeFromShapeTensor) {
}
TEST(ShapeInferenceTest, CreateDim) {
- InferenceContext c({}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {}, 2 /* num_outputs */);
auto* d0 = c.CreateDim(1);
auto* d1 = c.CreateDim(1);
@@ -398,7 +413,8 @@ TEST(ShapeInferenceTest, CreateDim) {
}
TEST(ShapeInferenceTest, CreateUnknownDim) {
- InferenceContext c({}, 2 /* num_outputs */);
+ NodeDef def;
+ InferenceContext c(&def, {}, 2 /* num_outputs */);
auto* d0 = c.CreateUnknownDim();
auto* d1 = c.CreateUnknownDim();
@@ -410,12 +426,29 @@ TEST(ShapeInferenceTest, CreateUnknownDim) {
TEST(ShapeInferenceTest, InputTensors) {
const Tensor t1 = tensorflow::test::AsTensor<float>({10});
const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
- InferenceContext c({"[1]", "[2]", "[3]"}, 2 /* num_outputs */, {&t1, &t2});
+ NodeDef def;
+ InferenceContext c(&def, {"[1]", "[2]", "[3]"}, 2 /* num_outputs */,
+ {&t1, &t2});
EXPECT_TRUE(c.input_tensor(0) == &t1);
EXPECT_TRUE(c.input_tensor(1) == &t2);
EXPECT_TRUE(c.input_tensor(2) == nullptr);
}
+TEST(ShapeInferenceTest, GetAttr) {
+ OpRegistrationData op_reg_data;
+ CHECK(OpDefBuilder("dummy").Attr("foo:string").Finalize(&op_reg_data).ok());
+ NodeDef def;
+ CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
+ .Attr("foo", "bar")
+ .Finalize(&def)
+ .ok());
+
+ InferenceContext c(&def, {}, 2 /* num_outputs */);
+ string value;
+ EXPECT_TRUE(c.GetAttr("foo", &value).ok());
+ EXPECT_EQ("bar", value);
+}
+
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc
index 9b56014edb..f771e47764 100644
--- a/tensorflow/core/framework/shape_inference_testutil.cc
+++ b/tensorflow/core/framework/shape_inference_testutil.cc
@@ -29,13 +29,18 @@ using shape_inference::Shape;
using errors::Unknown;
Status InferShapes(const string& op_name, const string& ins,
- const string& expected_outs) {
+ const string& expected_outs, const NodeDef* node_def) {
const OpRegistrationData* op_reg_data;
TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op_name, &op_reg_data));
const int num_outputs = op_reg_data->op_def.output_arg_size();
std::vector<string> ins_v = str_util::Split(ins, ';');
- shape_inference::InferenceContext c(ins_v, num_outputs);
+ std::unique_ptr<const NodeDef> new_node_def;
+ if (node_def == nullptr) {
+ new_node_def.reset(new NodeDef);
+ node_def = new_node_def.get();
+ }
+ shape_inference::InferenceContext c(node_def, ins_v, num_outputs);
TF_RETURN_IF_ERROR(op_reg_data->shape_inference_fn(&c));
std::unordered_map<const Dimension*, std::pair<int, int>>
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index f2581247d9..221ec875fb 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -23,6 +23,8 @@ limitations under the License.
namespace tensorflow {
+class NodeDef;
+
// Run shape inference for <op_name>, given inputs specified by <ins>
// and returns an error if the inferred shape does not match expected_outs.
//
@@ -45,11 +47,16 @@ namespace tensorflow {
// <expected_outs> can be "e"; this is used to indicate that shape inference
// should have failed.
Status InferShapes(const string& op_name, const string& ins,
- const string& expected_outs);
+ const string& expected_outs,
+ const NodeDef* node_def = nullptr);
#define INFER_OK(op, i, o) EXPECT_EQ("", InferShapes(op, i, o).error_message())
#define INFER_ERROR(s, op, i) \
- EXPECT_EQ(s, InferShapes(op, i, "x").error_message())
+ EXPECT_EQ(s, InferShapes(op, i, "e").error_message())
+#define INFER_OK_WITH_DEF(op, nd, i, o) \
+ EXPECT_EQ("", InferShapes(op, i, o, nd).error_message())
+#define INFER_ERROR_WITH_DEF(s, op, nd, i) \
+ EXPECT_EQ(s, InferShapes(op, i, "e", nd).error_message())
} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 7b85ff9c36..de15d82269 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -416,7 +416,8 @@ Tensor::Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf)
}
bool Tensor::IsInitialized() const {
- return buf_ != nullptr && buf_->data() != nullptr;
+ return (buf_ != nullptr && buf_->data() != nullptr) ||
+ shape_.num_elements() == 0;
}
void Tensor::CheckType(DataType expected_dtype) const {
@@ -507,7 +508,7 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape)
if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
CASES(type, buf_ = new Buffer<T>(a, shape.num_elements()));
}
- if (IsInitialized() && LogMemory::IsEnabled()) {
+ if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
LogMemory::RecordTensorAllocation("Unknown", LogMemory::UNKNOWN_STEP_ID,
*this);
}
@@ -521,8 +522,8 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape,
if (shape_.num_elements() > 0 || a->ShouldAllocateEmptyTensors()) {
CASES(type, buf_ = new Buffer<T>(a, shape.num_elements(), allocation_attr));
}
- if (!allocation_attr.allocation_will_be_logged && IsInitialized() &&
- LogMemory::IsEnabled()) {
+ if (!allocation_attr.allocation_will_be_logged && buf_ != nullptr &&
+ buf_->data() != nullptr && LogMemory::IsEnabled()) {
LogMemory::RecordTensorAllocation("Unknown (with attributes)",
LogMemory::UNKNOWN_STEP_ID, *this);
}
@@ -617,7 +618,7 @@ bool Tensor::FromProto(Allocator* a, const TensorProto& proto) {
buf_ = p;
// TODO(misard) add tracking of which kernels and steps are calling
// FromProto.
- if (IsInitialized() && LogMemory::IsEnabled()) {
+ if (buf_ != nullptr && buf_->data() != nullptr && LogMemory::IsEnabled()) {
LogMemory::RecordTensorAllocation("Unknown (from Proto)",
LogMemory::UNKNOWN_STEP_ID, *this);
}
@@ -765,7 +766,7 @@ string Tensor::DebugString() const {
void Tensor::FillDescription(TensorDescription* description) const {
description->set_dtype(dtype());
shape().AsProto(description->mutable_shape());
- if (IsInitialized()) {
+ if (buf_ != nullptr && buf_->data() != nullptr) {
buf_->FillAllocationDescription(
description->mutable_allocation_description());
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index dd2d9a4c86..48fbd38e0c 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -120,7 +120,10 @@ class Tensor {
// underlying refcounted storage
size_t BufferHash() const;
- /// Has this Tensor been initialized?
+ /// \brief If necessary, has this Tensor been initialized?
+ ///
+ /// Zero-element Tensors are always considered initialized, even if they
+ /// have never been assigned to and do not have any memory allocated.
bool IsInitialized() const;
/// Returns the estimated memory usage of this tensor.
diff --git a/tensorflow/core/framework/unique_tensor_references.cc b/tensorflow/core/framework/unique_tensor_references.cc
index 2ac6431c54..ab33d9ede6 100644
--- a/tensorflow/core/framework/unique_tensor_references.cc
+++ b/tensorflow/core/framework/unique_tensor_references.cc
@@ -33,7 +33,7 @@ UniqueTensorReferences::~UniqueTensorReferences() {
void UniqueTensorReferences::Add(const Tensor& tensor) {
DCHECK(!frozen_);
// Do nothing if the tensor has a null buffer.
- if (tensor.IsInitialized()) {
+ if (tensor.IsInitialized() && tensor.NumElements() > 0) {
if (referenced_tensors_set_ != nullptr) {
// There are enough tensors that we are using a hash set to
// de-duplicate.
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 5cf48bfab5..142f63c6b4 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1753,6 +1753,7 @@ filegroup(
"cwise_ops.h",
"cwise_ops_common.cc",
"cwise_ops_common.h",
+ "cwise_ops_gradients.h",
"dense_update_ops.cc",
"dense_update_ops.h",
"example_parsing_ops.cc",
diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc
index 533b03db0e..e66b9d4168 100644
--- a/tensorflow/core/kernels/barrier_ops.cc
+++ b/tensorflow/core/kernels/barrier_ops.cc
@@ -354,7 +354,8 @@ class Barrier : public ResourceBase {
element.push_back(PersistentTensor(uninitialized));
}
}
- if (element[1 + component_index].IsInitialized()) {
+ const PersistentTensor& component = element[1 + component_index];
+ if (component.IsInitialized() && component.NumElements() > 0) {
return errors::InvalidArgument("Key ", keys_vec(i),
" already has a value for component ",
component_index, " in barrier ", name());
@@ -374,7 +375,7 @@ class Barrier : public ResourceBase {
// ready queue.
bool is_complete = true;
for (int j = 0; is_complete && j < element.size(); ++j) {
- is_complete = element[j].IsInitialized();
+ is_complete = element[j].IsInitialized() && element[j].NumElements() > 0;
}
if (is_complete) {
// Add tuple to the ready queue. A queue tuple has the index
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 508ffc0402..487daa7c2d 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -1024,6 +1024,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
compatible_input_shape = input_shape;
}
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
perftools::gputools::dnn::BatchDescriptor input_desc;
input_desc.set_count(dims.batch_size)
.set_height(GetTensorDim(compatible_input_shape, data_format_, 'H'))
@@ -1382,6 +1385,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
compatible_input = input;
}
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
perftools::gputools::dnn::BatchDescriptor input_desc;
input_desc.set_count(dims.batch_size)
.set_height(GetTensorDim(compatible_input, data_format_, 'H'))
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index d0c6865951..62e60d018b 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -438,10 +438,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (padding_ == Padding::SAME) {
padding_planes =
(output_planes - 1) * strides[0] + filter_size[0] - input_size[0];
- padding_cols =
- (output_cols - 1) * strides[2] + filter_size[2] - input_size[2];
- padding_rows =
- (output_rows - 1) * strides[1] + filter_size[1] - input_size[1];
+ padding_cols = std::max<int>(
+ 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
+ padding_rows = std::max<int>(
+ 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
}
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
@@ -462,6 +462,9 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
input_size[2]};
}
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
perftools::gputools::dnn::BatchDescriptor input_desc(3);
input_desc.set_count(batch)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
@@ -659,10 +662,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
if (padding_ == Padding::SAME) {
padding_planes =
(output_planes - 1) * strides[0] + filter_size[0] - input_size[0];
- padding_cols =
- (output_cols - 1) * strides[2] + filter_size[2] - input_size[2];
- padding_rows =
- (output_rows - 1) * strides[1] + filter_size[1] - input_size[1];
+ padding_cols = std::max<int>(
+ 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
+ padding_rows = std::max<int>(
+ 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
}
bool rows_odd = (padding_rows % 2 != 0);
bool cols_odd = (padding_cols % 2 != 0);
@@ -686,6 +689,9 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
compatible_input = input;
}
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
perftools::gputools::dnn::BatchDescriptor input_desc(3);
input_desc.set_count(batch)
.set_spatial_dim(DimIndex::X, compatible_input.dim_size(3))
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ede9a77ed0..e0aff98854 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -334,8 +334,10 @@ class LaunchConvOp<GPUDevice, T> {
// We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
// and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
// we pad more on the right and bottom than on the top and left.
- padding_rows = (out_rows - 1) * row_stride + patch_rows - in_rows;
- padding_cols = (out_cols - 1) * col_stride + patch_cols - in_cols;
+ padding_rows =
+ std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
+ padding_cols =
+ std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
if (rows_odd || cols_odd) {
@@ -375,6 +377,9 @@ class LaunchConvOp<GPUDevice, T> {
input = transformed_input;
}
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
perftools::gputools::dnn::BatchDescriptor input_desc;
input_desc.set_count(in_batch)
.set_feature_map_count(in_depths)
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 697b3f6267..e236edfc0d 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -160,8 +160,10 @@ struct LaunchConvOp<GPUDevice, T> {
if (padding == Padding::SAME) {
pad_planes = (out_planes - 1) * strides[0] + filter_planes - in_planes;
- pad_rows = (out_rows - 1) * strides[1] + filter_rows - in_rows;
- pad_cols = (out_cols - 1) * strides[2] + filter_cols - in_cols;
+ pad_rows = std::max<int64>(
+ 0, (out_rows - 1) * strides[1] + filter_rows - in_rows);
+ pad_cols = std::max<int64>(
+ 0, (out_cols - 1) * strides[2] + filter_cols - in_cols);
}
// NOTE: This only works in NHWC.
@@ -239,6 +241,9 @@ struct LaunchConvOp<GPUDevice, T> {
transformed_input.tensor<T, 5>());
input = transformed_input;
+ CHECK(pad_rows >= 0 && pad_cols >= 0) << "Negative row or col paddings: ("
+ << pad_rows << ", " << pad_cols
+ << ")";
perftools::gputools::dnn::BatchDescriptor input_desc(3);
input_desc.set_count(in_batch)
.set_feature_map_count(in_depth)
diff --git a/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
index a7ac9baca0..b59d22310e 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_sigmoid.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY3(sigmoid, Eigen::half, float, double);
+DEFINE_SIMPLE_BINARY3(sigmoid_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
index 1678086c35..66ee3c193e 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_tanh.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_UNARY3(tanh, Eigen::half, float, double);
+DEFINE_SIMPLE_BINARY3(tanh_grad, Eigen::half, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc
index 9d8a849bd3..cc1f9b8f03 100644
--- a/tensorflow/core/kernels/cwise_op_sigmoid.cc
+++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
namespace tensorflow {
REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
@@ -22,4 +23,12 @@ REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double,
REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half,
double);
#endif
+
+REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float,
+ Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float,
+ Eigen::half, double);
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc
index 6604d71d14..a4c4aad053 100644
--- a/tensorflow/core/kernels/cwise_op_tanh.cc
+++ b/tensorflow/core/kernels/cwise_op_tanh.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
namespace tensorflow {
REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
@@ -21,4 +22,11 @@ REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double,
#if GOOGLE_CUDA
REGISTER3(UnaryOp, GPU, "Tanh", functor::tanh, float, Eigen::half, double);
#endif
+
+REGISTER5(SimpleBinaryOp, CPU, "TanhGrad", functor::tanh_grad, float,
+ Eigen::half, double, complex64, complex128);
+#if GOOGLE_CUDA
+REGISTER3(SimpleBinaryOp, GPU, "TanhGrad", functor::tanh_grad, float,
+ Eigen::half, double);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index 02a82c00bf..6ccbe46c7f 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -21,6 +21,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -130,6 +131,35 @@ class BinaryOp : public BinaryOpShared {
}
};
+// Basic coefficient-wise binary operations that are known to not require
+// any broadcasting. This is the case for example of the gradients of
+// unary operations.
+// Device: E.g., CPUDevice, GPUDevice.
+// Functor: defined above. E.g., functor::tanh_grad.
+template <typename Device, typename Functor>
+class SimpleBinaryOp : public OpKernel {
+ public:
+ typedef typename Functor::in_type Tin; // Input scalar data type.
+ typedef typename Functor::out_type Tout; // Output scalar data type.
+
+ explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& in0 = ctx->input(0);
+ const Tensor& in1 = ctx->input(1);
+
+ Tensor* out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out));
+ auto out_flat = out->flat<Tout>();
+ auto in0_flat = in0.flat<Tin>();
+ auto in1_flat = in1.flat<Tin>();
+ const Device& eigen_device = ctx->eigen_device<Device>();
+
+ functor::SimpleBinaryFunctor<Device, Functor>()(eigen_device, out_flat,
+ in0_flat, in1_flat);
+ }
+};
+
// Coefficient-wise unary operations:
// Device: E.g., CPUDevice, GPUDevice.
// Functor: defined in cwise_functors.h. E.g., functor::sqrt.
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
new file mode 100644
index 0000000000..4394770708
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
@@ -0,0 +1,71 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if !GOOGLE_CUDA
+#error This file must only be included when building with Cuda support
+#endif
+
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+
+#define EIGEN_USE_GPU
+
+#include <complex>
+
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/core/platform/logging.h"
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+typedef std::complex<float> complex64;
+typedef std::complex<double> complex128;
+
+// Partial specialization of SimpleBinaryFunctor<Device=GPUDevice, Functor>.
+template <typename Functor>
+struct SimpleBinaryFunctor<GPUDevice, Functor> {
+ void operator()(const GPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in1,
+ typename Functor::tin_type in2) {
+ To32Bit(out).device(d) =
+ To32Bit(in1).binaryExpr(in2, typename Functor::func());
+ }
+};
+
+// Macros to explicitly instantiate kernels on GPU for multiple types
+// (T0, T1, etc.) for SimpleBiaryFunctor (e.g., functor::tanh_grad).
+#define DEFINE_SIMPLE_BINARY1(F, T) \
+ template struct SimpleBinaryFunctor<GPUDevice, F<T> >
+#define DEFINE_SIMPLE_BINARY2(F, T0, T1) \
+ DEFINE_SIMPLE_BINARY1(F, T0); \
+ DEFINE_SIMPLE_BINARY1(F, T1)
+#define DEFINE_SIMPLE_BINARY3(F, T0, T1, T2) \
+ DEFINE_SIMPLE_BINARY2(F, T0, T1); \
+ DEFINE_SIMPLE_BINARY1(F, T2)
+#define DEFINE_SIMPLE_BINARY4(F, T0, T1, T2, T3) \
+ DEFINE_SIMPLE_BINARY2(F, T0, T1); \
+ DEFINE_SIMPLE_BINARY2(F, T2, T3)
+#define DEFINE_SIMPLE_BINARY5(F, T0, T1, T2, T3, T4) \
+ DEFINE_SIMPLE_BINARY2(F, T0, T1); \
+ DEFINE_SIMPLE_BINARY3(F, T2, T3, T4)
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
new file mode 100644
index 0000000000..a59f157281
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -0,0 +1,107 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+#define TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+
+#define EIGEN_USE_THREADS
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace Eigen {
+namespace internal {
+
+// Gradient for the tanh function
+template <typename T>
+struct scalar_tanh_gradient_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_tanh_gradient_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+ operator()(const T& output, const T& output_gradient) const {
+ return output_gradient * (T(1) - output * output);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& output, const Packet& output_gradient) const {
+ return pmul(output_gradient,
+ psub(pset1<Packet>(T(1)), pmul(output, output)));
+ }
+};
+template <typename T>
+struct functor_traits<scalar_tanh_gradient_op<T>> {
+ enum {
+ Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
+ PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
+ };
+};
+
+// Gradient for the sigmoid function
+template <typename T>
+struct scalar_sigmoid_gradient_op {
+ EIGEN_EMPTY_STRUCT_CTOR(scalar_sigmoid_gradient_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T
+ operator()(const T& output, const T& output_gradient) const {
+ return output_gradient * output * (T(1) - output);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet
+ packetOp(const Packet& output, const Packet& output_gradient) const {
+ return pmul(output_gradient,
+ pmul(output, psub(pset1<Packet>(T(1)), output)));
+ }
+};
+template <typename T>
+struct functor_traits<scalar_sigmoid_gradient_op<T>> {
+ enum {
+ Cost = NumTraits<T>::AddCost + 2 * NumTraits<T>::MulCost,
+ PacketAccess = packet_traits<T>::HasSub && packet_traits<T>::HasMul,
+ };
+};
+
+} // end namespace internal
+} // end namespace Eigen
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device, typename Functor>
+struct SimpleBinaryFunctor {
+ void operator()(const Device& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1);
+};
+
+// Partial specialization of BinaryFunctor for CPU devices
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <typename Functor>
+struct SimpleBinaryFunctor<CPUDevice, Functor> {
+ void operator()(const CPUDevice& d, typename Functor::tout_type out,
+ typename Functor::tin_type in0,
+ typename Functor::tin_type in1) {
+ out.device(d) = in0.binaryExpr(in1, typename Functor::func());
+ }
+};
+
+template <typename T>
+struct tanh_grad : base<T, Eigen::internal::scalar_tanh_gradient_op<T>> {};
+
+template <typename T>
+struct sigmoid_grad : base<T, Eigen::internal::scalar_sigmoid_gradient_op<T>> {
+};
+
+} // end namespace functor
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc
index 34411c9bbb..48124d20af 100644
--- a/tensorflow/core/kernels/sparse_xent_op.cc
+++ b/tensorflow/core/kernels/sparse_xent_op.cc
@@ -35,38 +35,42 @@ class SparseSoftmaxXentWithLogitsOp : public OpKernel {
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
- const Tensor& logits_in = context->input(0);
- const Tensor& labels_in = context->input(1);
- OP_REQUIRES(context, logits_in.shape().dim_size(0) == labels_in.NumElements(),
+ const Tensor& logits = context->input(0);
+ const Tensor& labels = context->input(1);
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits.shape()),
+ errors::InvalidArgument("logits must be 2-D, but got shape ",
+ logits.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(labels.shape()),
+ errors::InvalidArgument("labels must be 1-D, but got shape ",
+ labels.shape().DebugString()));
+ OP_REQUIRES(context, logits.dim_size(0) == labels.dim_size(0),
errors::InvalidArgument(
- "logits first dimension must match labels size. logits shape=",
- logits_in.shape().DebugString(), " labels shape=",
- labels_in.shape().DebugString()));
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
- errors::InvalidArgument("logits must be 2-dimensional"));
- // As we already tested that both inputs have the same shape no need to
- // check that "labels" is a matrix too.
-
- // loss is 1-D (one per example), and size is batch_size.
+ "logits and labels must have the same first dimension, "
+ "got logits shape ",
+ logits.shape().DebugString(), " and labels shape ",
+ labels.shape().DebugString()));
+ OP_REQUIRES(context, logits.dim_size(1) > 0,
+ errors::InvalidArgument(
+ "Must have at least one class, but got logits shape ",
+ logits.shape().DebugString()));
Tensor scratch;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({logits_in.dim_size(0)}),
- &scratch));
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+ labels.shape(), &scratch));
Tensor* loss_out = nullptr;
OP_REQUIRES_OK(context,
- context->allocate_output(
- 0, TensorShape({logits_in.dim_size(0)}), &loss_out));
+ context->allocate_output(0, labels.shape(), &loss_out));
Tensor* back_out = nullptr;
OP_REQUIRES_OK(context,
- context->allocate_output(1, logits_in.shape(), &back_out));
-
- functor::SparseXentFunctor<Device, T, Index> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- labels_in.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
- back_out->matrix<T>());
+ context->allocate_output(1, logits.shape(), &back_out));
+
+ if (logits.dim_size(0) > 0) {
+ functor::SparseXentFunctor<Device, T, Index> functor;
+ functor(context->eigen_device<Device>(), logits.matrix<T>(),
+ labels.vec<Index>(), scratch.vec<T>(), loss_out->vec<T>(),
+ back_out->matrix<T>());
+ }
}
};
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index 03b2c3b68b..1456ec2844 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -441,7 +441,7 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx,
" but the new input shape is ", value_t->shape().DebugString(), ".");
}
- if (!t.tensor.IsInitialized()) {
+ if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
// If existing_t == nullptr but written == true, then what was stored
// was just a shape, which just means zeros. So all we must do in this
// case is copy the reference over and return early.
@@ -502,7 +502,7 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
"clear_after_read = false?).");
}
- if (!t.tensor.IsInitialized()) {
+ if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
// We stored just a shape, but no value. This means create and
// return zeros of the appropriate shape.
Tensor* tensor_t;
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
index 88ec1069c5..7ce1a1d395 100644
--- a/tensorflow/core/lib/gtl/inlined_vector_test.cc
+++ b/tensorflow/core/lib/gtl/inlined_vector_test.cc
@@ -285,6 +285,7 @@ TEST(RefCountedVec, InsertConstructorDestructor) {
for (int pos = 0; pos <= len; pos++) {
SCOPED_TRACE(pos);
std::vector<int> counts(len, 0);
+ int inserted_count = 0;
RefCountedVec v;
for (int i = 0; i < len; ++i) {
SCOPED_TRACE(i);
@@ -295,7 +296,6 @@ TEST(RefCountedVec, InsertConstructorDestructor) {
EXPECT_EQ(1, elem);
}
- int inserted_count = 0;
RefCounted insert_element(9999, &inserted_count);
EXPECT_EQ(1, inserted_count);
v.insert(v.begin() + pos, insert_element);
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index dc96588f73..4ef3a48221 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -14,17 +14,67 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
#include "tensorflow/core/util/padding.h"
namespace tensorflow {
+typedef shape_inference::Dimension Dimension;
+typedef shape_inference::InferenceContext InferenceContext;
+typedef shape_inference::Shape Shape;
+
+namespace {
+
+Status GetAxisForPackAndUnpack(InferenceContext* c, int32 rank_after_pack,
+ int32* axis) {
+ TF_RETURN_IF_ERROR(c->GetAttr("axis", axis));
+ if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) {
+ return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [",
+ -1 * rank_after_pack, ",", rank_after_pack,
+ ")");
+ }
+ if (*axis < 0) *axis = (rank_after_pack + *axis);
+ return Status::OK();
+}
+
+} // namespace
+
REGISTER_OP("Pack")
.Input("values: N * T")
.Output("output: T")
.Attr("N: int >= 1")
.Attr("T: type")
.Attr("axis: int = 0")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ // Validate shapes of all inputs are compatible
+ const Shape* cur = c->input(c->num_inputs() - 1);
+ for (int i = c->num_inputs() - 2; i >= 0; --i) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
+ "From merging shape ", i,
+ " with other shapes.");
+ }
+ if (!c->RankKnown(cur)) {
+ c->set_output(0, c->CreateUnknownShape());
+ return Status::OK();
+ }
+ // Determine the axis that will be added, converting from negative
+ // axes to a positive point per negative indexing rules.
+ int32 rank = c->Rank(cur);
+ int32 axis;
+ TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank + 1, &axis));
+
+ // Copy all dimensions over, inserting a dimension of value #inputs
+ // at <axis>.
+ std::vector<const Dimension*> dims;
+ int index = 0;
+ while (index < axis) dims.push_back(c->Dim(cur, index++));
+ dims.push_back(c->CreateDim(c->num_inputs()));
+ while (index < rank) dims.push_back(c->Dim(cur, index++));
+
+ c->set_output(0, c->CreateShape(dims));
+ return Status::OK();
+ }))
.Doc(R"doc(
Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
@@ -61,6 +111,29 @@ REGISTER_OP("Unpack")
.Attr("num: int >= 0")
.Attr("T: type")
.Attr("axis: int = 0")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ const Shape* s = c->input(0);
+ const Shape* out;
+ if (c->RankKnown(s)) {
+ // Determine the axis that will be removed, converting from negative
+ // axes to a positive point per negative indexing rules.
+ int32 rank = c->Rank(s);
+ int32 axis;
+ TF_RETURN_IF_ERROR(GetAxisForPackAndUnpack(c, rank, &axis));
+
+ // Copy all dimensions, removing the <axis> dimension.
+ std::vector<const Dimension*> dims;
+ for (int i = 0; i < rank; ++i) {
+ if (i != axis) dims.push_back(c->Dim(s, i));
+ }
+ out = c->CreateShape(dims);
+ } else {
+ // All outputs are the same shape, but it's not known.
+ out = c->CreateUnknownShape();
+ }
+ for (int i = 0; i < c->num_outputs(); ++i) c->set_output(i, out);
+ return Status::OK();
+ }))
.Doc(R"doc(
Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors.
@@ -154,6 +227,18 @@ REGISTER_OP("Const")
.Output("output: dtype")
.Attr("value: tensor")
.Attr("dtype: type")
+ .SetShapeFn(OpShapeInferenceFn([](InferenceContext* c) {
+ const TensorProto* proto = nullptr;
+ TF_RETURN_IF_ERROR(c->GetAttr("value", &proto));
+ TF_RETURN_IF_ERROR(TensorShape::IsValidShape(proto->tensor_shape()));
+ TensorShape shape(proto->tensor_shape());
+ std::vector<const Dimension*> dims;
+ for (int i = 0; i < shape.dims(); ++i) {
+ dims.push_back(c->CreateDim(shape.dim_size(i)));
+ }
+ c->set_output(0, c->CreateShape(dims));
+ return Status::OK();
+ }))
.Doc(R"doc(
Returns a constant tensor.
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
new file mode 100644
index 0000000000..19dfa29358
--- /dev/null
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(ArrayOpsTest, Pack_ShapeFn) {
+ std::unique_ptr<NodeDef> def_storage(new NodeDef);
+ NodeDef* def = def_storage.get();
+ auto set_axis = [def](int axis) {
+ TF_CHECK_OK(NodeDefBuilder("test", "Pack")
+ .Input({{"a", 0, DT_FLOAT}})
+ .Attr("axis", axis)
+ .Finalize(def));
+ };
+ const char op[] = "Pack";
+
+ set_axis(0);
+ INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
+
+ for (int axis : {0, -3}) {
+ set_axis(axis);
+ INFER_OK_WITH_DEF(op, def, "?;?", "?");
+ INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[3,d0_0|d1_0,d0_1|d1_1]");
+ INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[3,d1_0,d0_1|d1_1]");
+ INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[3,d1_0,d1_1]");
+ }
+ for (int axis : {1, -2}) {
+ set_axis(axis);
+ INFER_OK_WITH_DEF(op, def, "?;?", "?");
+ INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,3,d0_1|d1_1]");
+ INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,3,d0_1|d1_1]");
+ INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,3,d1_1]");
+ }
+ for (int axis : {2, -1}) {
+ set_axis(axis);
+ INFER_OK_WITH_DEF(op, def, "?;?", "?");
+ INFER_OK_WITH_DEF(op, def, "[1,3];[1,3];?", "[d0_0|d1_0,d0_1|d1_1,3]");
+ INFER_OK_WITH_DEF(op, def, "[?,3];[1,3];?", "[d1_0,d0_1|d1_1,3]");
+ INFER_OK_WITH_DEF(op, def, "[?,?];[1,3];?", "[d1_0,d1_1,3]");
+ }
+
+ set_axis(-4);
+ INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
+ "[1,3];[1,3];?");
+ set_axis(3);
+ INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
+ "[1,3];[1,3];?");
+
+ set_axis(0);
+ INFER_ERROR_WITH_DEF(("Shapes must be equal rank, but are 3 and 2"
+ "\n\tFrom merging shape 0 with other shapes."),
+ op, def, "[1,2,3];?;[1,4]");
+}
+
+TEST(ArrayOpsTest, UnPack_ShapeFn) {
+ std::unique_ptr<NodeDef> def_storage(new NodeDef);
+ NodeDef* def = def_storage.get();
+ auto set_axis = [def](int axis) {
+ TF_CHECK_OK(NodeDefBuilder("test", "Unpack")
+ .Input("a", 0, DT_FLOAT)
+ .Attr("axis", axis)
+ .Finalize(def));
+ };
+ const char op[] = "Unpack";
+
+ set_axis(0);
+ INFER_OK_WITH_DEF(op, def, "?;?;?", "?");
+
+ for (int axis : {0, -3}) {
+ set_axis(axis);
+ INFER_OK_WITH_DEF(op, def, "?", "?");
+ INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_1,d0_2]");
+ INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_1,d0_2]");
+ }
+ for (int axis : {1, -2}) {
+ set_axis(axis);
+ INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_2]");
+ INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_2]");
+ }
+ for (int axis : {2, -1}) {
+ set_axis(axis);
+ INFER_OK_WITH_DEF(op, def, "[1,2,3]", "[d0_0,d0_1]");
+ INFER_OK_WITH_DEF(op, def, "[?,?,?]", "[d0_0,d0_1]");
+ }
+
+ set_axis(-4);
+ INFER_ERROR_WITH_DEF("Invalid axis: -4; must be in [-3,3)", op, def,
+ "[1,2,3]");
+ set_axis(3);
+ INFER_ERROR_WITH_DEF("Invalid axis: 3; must be in [-3,3)", op, def,
+ "[1,2,3]");
+}
+
+TEST(ArrayOpsTest, Const_ShapeFn) {
+ std::unique_ptr<NodeDef> def_storage(new NodeDef);
+ NodeDef* def = def_storage.get();
+ TensorProto tensor_proto;
+ auto* shape_proto = tensor_proto.mutable_tensor_shape();
+ auto rebuild_node_def = [def, &tensor_proto]() {
+ TF_CHECK_OK(NodeDefBuilder("test", "Const")
+ .Attr("value", tensor_proto)
+ .Finalize(def));
+ };
+ const char op[] = "Const";
+
+ TensorShape{}.AsProto(shape_proto);
+ rebuild_node_def();
+ INFER_OK_WITH_DEF(op, def, "", "[]");
+ TensorShape{1, 2, 3, 4}.AsProto(shape_proto);
+ rebuild_node_def();
+ INFER_OK_WITH_DEF(op, def, "", "[1,2,3,4]");
+
+ shape_proto->add_dim()->set_size(-1);
+ rebuild_node_def();
+ INFER_ERROR_WITH_DEF("Shape [1,2,3,4,-1] has negative dimensions", op, def,
+ "");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
index adaa47ab8c..2dba61efe7 100644
--- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
@@ -20209,6 +20209,34 @@ op {
}
}
op {
+ name: "SigmoidGrad"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Sign"
input_arg {
name: "x"
@@ -24558,6 +24586,34 @@ op {
}
}
op {
+ name: "TanhGrad"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "TemporaryVariable"
output_arg {
name: "ref"
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 0f9ee4942a..b220a2d2d6 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -238,6 +238,13 @@ tf.complex_abs(x) ==> [5.25594902, 6.60492229]
.Attr("T: {half, float, double, complex64, complex128}") \
.SetShapeFn(OpShapeInferenceFn(shape_inference::UnchangedShape))
+#define UNARY_GRADIENT_COMPLEX() \
+ Input("x: T") \
+ .Input("y: T") \
+ .Output("z: T") \
+ .Attr("T: {half, float, double, complex64, complex128}") \
+ .SetShapeFn(OpShapeInferenceFn(shape_inference::UnchangedShape))
+
REGISTER_OP("Neg")
.UNARY()
.Doc(R"doc(
@@ -292,6 +299,13 @@ REGISTER_OP("Tanh")
Computes hyperbolic tangent of `x` element-wise.
)doc");
+REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+Computes the gradient for the tanh of `x` wrt its input.
+
+Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`
+is the corresponding input gradient.
+)doc");
+
REGISTER_OP("Lgamma")
.UNARY_REAL()
.Doc(R"doc(
@@ -325,6 +339,13 @@ Computes sigmoid of `x` element-wise.
Specifically, `y = 1 / (1 + exp(-x))`.
)doc");
+REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX().Doc(R"doc(
+Computes the gradient of the sigmoid of `x` wrt its input.
+
+Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
+`dy` is the corresponding input gradient.
+)doc");
+
REGISTER_OP("Sin")
.UNARY_COMPLEX()
.Doc(R"doc(
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 474516bf4c..afd6507b0d 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -11797,6 +11797,36 @@ op {
description: "Specifically, `y = 1 / (1 + exp(-x))`."
}
op {
+ name: "SigmoidGrad"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ summary: "Computes the gradient of the sigmoid of `x` wrt its input."
+ description: "Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and\n`dy` is the corresponding input gradient."
+}
+op {
name: "Sign"
input_arg {
name: "x"
@@ -14644,6 +14674,36 @@ op {
summary: "Computes hyperbolic tangent of `x` element-wise."
}
op {
+ name: "TanhGrad"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ summary: "Computes the gradient for the tanh of `x` wrt its input."
+ description: "Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`\nis the corresponding input gradient."
+}
+op {
name: "TemporaryVariable"
output_arg {
name: "ref"
diff --git a/tensorflow/core/platform/default/tracing.cc b/tensorflow/core/platform/default/tracing.cc
index 7910e97db9..422564fb3e 100644
--- a/tensorflow/core/platform/default/tracing.cc
+++ b/tensorflow/core/platform/default/tracing.cc
@@ -15,8 +15,6 @@ limitations under the License.
#include "tensorflow/core/platform/tracing.h"
-#include <unistd.h>
-
namespace tensorflow {
namespace port {
@@ -26,21 +24,6 @@ void Tracing::RegisterEvent(EventCategory id, const char* name) {
void Tracing::Initialize() {}
-static bool TryGetEnv(const char* name, const char** value) {
- *value = getenv(name);
- return *value != nullptr && (*value)[0] != '\0';
-}
-
-const char* Tracing::LogDir() {
- const char* dir;
- if (TryGetEnv("TEST_TMPDIR", &dir)) return dir;
- if (TryGetEnv("TMP", &dir)) return dir;
- if (TryGetEnv("TMPDIR", &dir)) return dir;
- dir = "/tmp";
- if (access(dir, R_OK | W_OK | X_OK) == 0) return dir;
- return "."; // Default to current directory.
-}
-
static bool DoInit() {
Tracing::Initialize();
return true;
diff --git a/tensorflow/core/platform/posix/tracing.cc b/tensorflow/core/platform/posix/tracing.cc
new file mode 100644
index 0000000000..1d1aa53f2c
--- /dev/null
+++ b/tensorflow/core/platform/posix/tracing.cc
@@ -0,0 +1,40 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/tracing.h"
+
+#include <stdlib.h>
+#include <unistd.h>
+
+namespace tensorflow {
+namespace port {
+
+static bool TryGetEnv(const char* name, const char** value) {
+ *value = getenv(name);
+ return *value != nullptr && (*value)[0] != '\0';
+}
+
+const char* Tracing::LogDir() {
+ const char* dir;
+ if (TryGetEnv("TEST_TMPDIR", &dir)) return dir;
+ if (TryGetEnv("TMP", &dir)) return dir;
+ if (TryGetEnv("TMPDIR", &dir)) return dir;
+ dir = "/tmp";
+ if (access(dir, R_OK | W_OK | X_OK) == 0) return dir;
+ return "."; // Default to current directory.
+}
+
+} // namespace port
+} // namespace tensorflow