aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/eager/BUILD2
-rw-r--r--tensorflow/c/eager/c_api.cc57
-rw-r--r--tensorflow/c/eager/c_api.h14
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc1
-rw-r--r--tensorflow/python/eager/function.py18
-rw-r--r--tensorflow/python/eager/function_test.py3
-rw-r--r--tensorflow/python/framework/test_util.py24
-rw-r--r--tensorflow/python/kernel_tests/BUILD5
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py24
-rw-r--r--tensorflow/python/pywrap_tfe.i2
11 files changed, 145 insertions, 6 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index a2d96357ac..3e14c10727 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -41,6 +41,8 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ # TODO(b/74620627): move this here
+ "//tensorflow/python:cpp_shape_inference_proto_cc",
],
}) + select({
"//tensorflow:with_xla_support": [
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 393851d13c..369342b142 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
+#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
using tensorflow::int64;
using tensorflow::string;
@@ -1015,6 +1016,62 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
ctx->context.RunMetadataProto()->Clear();
}
+void TFE_GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ TF_Buffer* output_proto,
+ TF_Status* status) {
+ tensorflow::Node* node = &output.oper->node;
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
+ handle_data.set_is_set(true);
+ {
+ tensorflow::mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ CHECK(ic != nullptr);
+ CHECK_LT(output.index, ic->num_outputs());
+ const auto* shapes_and_types =
+ ic->output_handle_shapes_and_types(output.index);
+ if (shapes_and_types == nullptr) {
+ output_proto->data = nullptr;
+ output_proto->length = 0;
+ output_proto->data_deallocator = nullptr;
+ return;
+ }
+
+ for (const auto& p : *shapes_and_types) {
+ auto* out_shape_and_type = handle_data.add_shape_and_type();
+ ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
+ out_shape_and_type->set_dtype(p.dtype);
+ }
+ }
+ status->status = MessageToBuffer(handle_data, output_proto);
+}
+
+void TFE_SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ const void* proto, size_t proto_len,
+ TF_Status* status) {
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
+ if (!handle_data.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Couldn't deserialize HandleData proto");
+ return;
+ }
+ DCHECK(handle_data.is_set());
+
+ tensorflow::mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(&output.oper->node);
+
+ std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
+ for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
+ tensorflow::shape_inference::ShapeHandle shape;
+ status->status =
+ ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
+ if (status->status.ok()) return;
+ shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
+ }
+ ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
+}
+
namespace {
TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
TF_Status* status) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 3926c22ce1..15ac0f376c 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -329,6 +329,20 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
+// Returns the serialized CppShapeInferenceResult::HandleData proto for
+// `output` if its a resource tensor, or otherwise returns an empty buffer.
+TF_CAPI_EXPORT extern void TFE_GetResourceHandleShapeAndType(
+ TF_Graph* graph, TF_Output output, TF_Buffer* output_proto,
+ TF_Status* status);
+
+// Sets `output` based on `proto`, which should be a serialized
+// CppShapeInferenceResult::HandleData proto.
+TF_CAPI_EXPORT extern void TFE_SetResourceHandleShapeAndType(TF_Graph* graph,
+ TF_Output output,
+ const void* proto,
+ size_t proto_len,
+ TF_Status* status);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
index 2311c15a68..f3e6731213 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -28,6 +28,7 @@ py_library(
py_library(
name = "rpc_op_test_base",
srcs = ["rpc_op_test_base.py"],
+ tags = ["notsan"],
deps = [
":test_example_proto_py",
"//tensorflow/contrib/proto",
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index f018499f6c..b01db91720 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -326,6 +326,7 @@ struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
ParallelDynamicStitchOpCPU<type>)
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
+TF_CALL_variant(REGISTER_DYNAMIC_STITCH);
#undef REGISTER_DYNAMIC_STITCH
#if GOOGLE_CUDA
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 5168ad3b18..0f1170bb42 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -38,6 +38,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -69,9 +70,22 @@ def capture_value(tensor_map, value, dtype, name):
captured_value = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
if captured_value.dtype == dtypes_module.resource:
- handle_data = value._handle_data # pylint: disable=protected-access
- captured_value._handle_data = handle_data # pylint: disable=protected-access
+ if ops._USE_C_SHAPES: # pylint: disable=protected-access
+ if isinstance(value, ops.EagerTensor):
+ handle_data = value._handle_data # pylint: disable=protected-access
+ else:
+ handle_data = resource_variable_ops.get_resource_handle_data(value)
+ else:
+ handle_data = value._handle_data # pylint: disable=protected-access
if handle_data is not None and handle_data.is_set:
+ # pylint: disable=protected-access
+ if ops._USE_C_SHAPES:
+ pywrap_tensorflow.TFE_SetResourceHandleShapeAndType(
+ captured_value.graph._c_graph, captured_value._as_tf_output(),
+ handle_data.SerializeToString())
+ else:
+ captured_value._handle_data = handle_data
+ # pylint: enable=protected-access
# Ensure that shapes and dtypes are propagated.
shapes, types = zip(*[(pair.shape, pair.dtype)
for pair in handle_data.shape_and_type])
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 65dde75e60..1828c987f4 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -41,6 +42,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.training import gradient_descent
+@test_util.with_c_shapes
class FunctionTest(test.TestCase):
def testBasic(self):
@@ -615,6 +617,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual([[[[4.0]]]], y.numpy())
+@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):
def testBasic(self):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 70e70abc06..f954b9d6c7 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -464,6 +464,30 @@ def with_c_api(cls):
return cls
+def with_c_shapes(cls):
+ """Adds methods that call original methods but with C API shapes enabled.
+
+ Note this enables C shapes in new methods after running the test class's
+ setup method.
+
+ Args:
+ cls: class to decorate
+
+ Returns:
+ cls with new test methods added
+ """
+ # If C shapes are already enabled, don't do anything. Some tests break if the
+ # same test is run twice, so this allows us to turn on the C shapes by default
+ # without breaking these tests.
+ if ops._USE_C_SHAPES:
+ return cls
+
+ for name, value in cls.__dict__.copy().items():
+ if callable(value) and name.startswith("test"):
+ setattr(cls, name + "WithCShapes", enable_c_shapes(value))
+ return cls
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3aedd70f8c..9440f2a4f9 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1604,7 +1604,10 @@ cuda_py_test(
"//tensorflow/python:variables",
],
shard_count = 4,
- tags = ["noasan"],
+ tags = [
+ "noasan",
+ "notap",
+ ],
)
cuda_py_test(
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 49dd7f9948..4d26b2f46e 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -24,6 +24,8 @@ from tensorflow.core.framework import variable_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
+from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -41,6 +43,19 @@ from tensorflow.python.training import checkpointable
from tensorflow.python.util import compat
+def get_resource_handle_data(graph_op):
+ assert ops._USE_C_SHAPES # pylint: disable=protected-access
+ assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
+
+ with c_api_util.tf_buffer() as buf:
+ pywrap_tensorflow.TFE_GetResourceHandleShapeAndType(
+ graph_op.graph._c_graph, graph_op._as_tf_output(), buf) # pylint: disable=protected-access
+ data = pywrap_tensorflow.TF_GetBuffer(buf)
+
+ return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
+ compat.as_bytes(data))
+
+
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
"""Creates a variable handle with information to do shape inference."""
container = ops.get_default_graph()._container # pylint: disable=protected-access
@@ -73,9 +88,12 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
# shape inference doesn't run in eager mode we copy this data here for when
# the handle is captured by an eager mode function.
# pylint: disable=protected-access
- if h._handle_data is None:
- ops.set_shape_and_handle_data_for_outputs(h.op)
- handle._handle_data = h._handle_data
+ if ops._USE_C_SHAPES:
+ handle._handle_data = get_resource_handle_data(h)
+ else:
+ if h._handle_data is None:
+ ops.set_shape_and_handle_data_for_outputs(h.op)
+ handle._handle_data = h._handle_data
# pylint: enable=protected-access
# Clean up our reference cycles to avoid making the garbage collector run.
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 5ee55301df..0982a67dee 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -59,6 +59,8 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetAsync;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
+%rename("%s") TFE_GetResourceHandleShapeAndType;
+%rename("%s") TFE_SetResourceHandleShapeAndType;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"