From a9bc030a3a9eb876c4ca00b40d68013ed91a8d66 Mon Sep 17 00:00:00 2001
From: Alexander Polcyn <apolcyn@google.com>
Date: Thu, 11 Aug 2016 12:05:55 -0700
Subject: add mutex wrapper around sending and modifying of initial metadata

---
 src/ruby/lib/grpc/generic/active_call.rb  |  37 ++++---
 src/ruby/spec/generic/active_call_spec.rb | 159 ++++++++++--------------------
 2 files changed, 80 insertions(+), 116 deletions(-)

(limited to 'src')

diff --git a/src/ruby/lib/grpc/generic/active_call.rb b/src/ruby/lib/grpc/generic/active_call.rb
index d43a9e7a4b..f9c41f0c0e 100644
--- a/src/ruby/lib/grpc/generic/active_call.rb
+++ b/src/ruby/lib/grpc/generic/active_call.rb
@@ -113,11 +113,17 @@ module GRPC
 
       fail(ArgumentError, 'Already sent md') if started && metadata_to_send
       @metadata_to_send = metadata_to_send || {} unless started
+      @send_initial_md_mutex = Mutex.new
     end
 
+    # Sends the initial metadata that has yet to be sent.
+    # Fails if metadata has already been sent for this call.
     def send_initial_metadata
-      fail 'Already sent metadata' if @metadata_sent
-      start_call(@metadata_to_send)
+      @send_initial_md_mutex.synchronize do
+        fail('Already send initial metadata') if @metadata_sent
+        @metadata_tag = ActiveCall.client_invoke(@call, @metadata_to_send)
+        @metadata_sent = true
+      end
     end
 
     # output_metadata are provides access to hash that can be used to
@@ -195,7 +201,7 @@ module GRPC
     # @param marshalled [false, true] indicates if the object is already
     # marshalled.
     def remote_send(req, marshalled = false)
-      start_call(@metadata_to_send) unless @metadata_sent
+      send_initial_metadata unless @metadata_sent
       GRPC.logger.debug("sending #{req}, marshalled? #{marshalled}")
       payload = marshalled ? req : @marshal.call(req)
       @call.run_batch(SEND_MESSAGE => payload)
@@ -211,7 +217,7 @@ module GRPC
     # list, mulitple metadata for its key are sent
     def send_status(code = OK, details = '', assert_finished = false,
                     metadata: {})
-      start_call unless @metadata_sent
+      send_initial_metadata unless @metadata_sent
       ops = {
         SEND_STATUS_FROM_SERVER => Struct::Status.new(code, details, metadata)
       }
@@ -312,7 +318,8 @@ module GRPC
     # a list, multiple metadata for its key are sent
     # @return [Object] the response received from the server
     def request_response(req, metadata: {})
-      start_call(metadata)
+      merge_metadata_to_send(metadata) &&
+        send_initial_metadata unless @metadata_sent
       remote_send(req)
       writes_done(false)
       response = remote_read
@@ -336,7 +343,8 @@ module GRPC
     # a list, multiple metadata for its key are sent
     # @return [Object] the response received from the server
     def client_streamer(requests, metadata: {})
-      start_call(metadata)
+      merge_metadata_to_send(metadata) &&
+        send_initial_metadata unless @metadata_sent
       requests.each { |r| remote_send(r) }
       writes_done(false)
       response = remote_read
@@ -362,7 +370,8 @@ module GRPC
     # a list, multiple metadata for its key are sent
     # @return [Enumerator|nil] a response Enumerator
     def server_streamer(req, metadata: {})
-      start_call(metadata)
+      merge_metadata_to_send(metadata) &&
+        send_initial_metadata unless @metadata_sent
       remote_send(req)
       writes_done(false)
       replies = enum_for(:each_remote_read_then_finish)
@@ -401,7 +410,8 @@ module GRPC
     # a list, multiple metadata for its key are sent
     # @return [Enumerator, nil] a response Enumerator
     def bidi_streamer(requests, metadata: {}, &blk)
-      start_call(metadata) unless @metadata_sent
+      merge_metadata_to_send(metadata) &&
+        send_initial_metadata unless @metadata_sent
       bd = BidiCall.new(@call,
                         @marshal,
                         @unmarshal,
@@ -444,9 +454,14 @@ module GRPC
       @op_notifier.notify(self)
     end
 
+    # Add to the metadata that will be sent from the server.
+    # Fails if metadata has already been sent.
+    # Unused by client calls.
     def merge_metadata_to_send(new_metadata = {})
-      fail('cant change metadata after already sent') if @metadata_sent
-      @metadata_to_send.merge!(new_metadata)
+      @send_initial_md_mutex.synchronize do
+        fail('cant change metadata after already sent') if @metadata_sent
+        @metadata_to_send.merge!(new_metadata)
+      end
     end
 
     private
@@ -456,7 +471,7 @@ module GRPC
     # a list, multiple metadata for its key are sent
     def start_call(metadata = {})
       return if @metadata_sent
-      @metadata_tag = ActiveCall.client_invoke(@call, metadata)
+      merge_metadata_to_send(metadata) && send_initial_metadata
       @metadata_sent = true
     end
 
diff --git a/src/ruby/spec/generic/active_call_spec.rb b/src/ruby/spec/generic/active_call_spec.rb
index 0c72be9a98..79f739e8fa 100644
--- a/src/ruby/spec/generic/active_call_spec.rb
+++ b/src/ruby/spec/generic/active_call_spec.rb
@@ -242,7 +242,12 @@ describe GRPC::ActiveCall do
   describe '#merge_metadata_to_send', merge_metadata_to_send: true do
     it 'adds to existing metadata when there is existing metadata to send' do
       call = make_test_call
-      starting_metadata = { k1: 'key1_val', k2: 'key2_val' }
+      starting_metadata = {
+        k1: 'key1_val',
+        k2: 'key2_val',
+        k3: 'key3_val'
+      }
+
       @client_call = ActiveCall.new(
         call,
         @pass_through, @pass_through,
@@ -253,13 +258,13 @@ describe GRPC::ActiveCall do
       expect(@client_call.metadata_to_send).to eq(starting_metadata)
 
       @client_call.merge_metadata_to_send(
-        k3: 'key3_val',
+        k3: 'key3_new_val',
         k4: 'key4_val')
 
       expected_md_to_send = {
         k1: 'key1_val',
         k2: 'key2_val',
-        k3: 'key3_val',
+        k3: 'key3_new_val',
         k4: 'key4_val' }
 
       expect(@client_call.metadata_to_send).to eq(expected_md_to_send)
@@ -269,23 +274,6 @@ describe GRPC::ActiveCall do
       expect(@client_call.metadata_to_send).to eq(expected_md_to_send)
     end
 
-    it 'overrides existing metadata if adding metadata with an existing key' do
-      call = make_test_call
-      starting_metadata = { k1: 'key1_val', k2: 'key2_val' }
-      @client_call = ActiveCall.new(
-        call,
-        @pass_through,
-        @pass_through,
-        deadline,
-        started: false,
-        metadata_to_send: starting_metadata)
-
-      expect(@client_call.metadata_to_send).to eq(starting_metadata)
-      @client_call.merge_metadata_to_send(k1: 'key1_new_val')
-      expect(@client_call.metadata_to_send).to eq(k1: 'key1_new_val',
-                                                  k2: 'key2_val')
-    end
-
     it 'fails when initial metadata has already been sent' do
       call = make_test_call
       @client_call = ActiveCall.new(
@@ -530,121 +518,82 @@ describe GRPC::ActiveCall do
   end
 
   # Test sending of the initial metadata in #run_server_bidi
-  # from the server handler both implicitly and explicitly,
-  # when the server handler function has one argument and two arguments
-  describe '#run_server_bidi sanity tests', run_server_bidi: true do
-    it 'sends the initial metadata implicitly if not already sent' do
-      requests = ['first message', 'second message']
-      server_to_client_metadata = { 'test_key' => 'test_val' }
-      server_status = OK
+  # from the server handler both implicitly and explicitly.
+  describe '#run_server_bidi metadata sending tests', run_server_bidi: true do
+    before(:each) do
+      @requests = ['first message', 'second message']
+      @server_to_client_metadata = { 'test_key' => 'test_val' }
+      @server_status = OK
 
-      client_call = make_test_call
-      client_call.run_batch(CallOps::SEND_INITIAL_METADATA => {})
+      @client_call = make_test_call
+      @client_call.run_batch(CallOps::SEND_INITIAL_METADATA => {})
 
       recvd_rpc = @server.request_call
       recvd_call = recvd_rpc.call
-      server_call = ActiveCall.new(recvd_call,
-                                   @pass_through,
-                                   @pass_through,
-                                   deadline,
-                                   metadata_received: true,
-                                   started: false,
-                                   metadata_to_send: server_to_client_metadata)
-
-      # Server handler that doesn't have access to a "call"
-      # It echoes the requests
-      fake_gen_each_reply_with_no_call_param = proc do |msgs|
-        msgs
-      end
-
-      server_thread = Thread.new do
-        server_call.run_server_bidi(
-          fake_gen_each_reply_with_no_call_param)
-        server_call.send_status(server_status)
-      end
+      @server_call = ActiveCall.new(
+        recvd_call,
+        @pass_through,
+        @pass_through,
+        deadline,
+        metadata_received: true,
+        started: false,
+        metadata_to_send: @server_to_client_metadata)
+    end
 
+    after(:each) do
       # Send the requests and send a close so the server can send a status
-      requests.each do |message|
-        client_call.run_batch(CallOps::SEND_MESSAGE => message)
+      @requests.each do |message|
+        @client_call.run_batch(CallOps::SEND_MESSAGE => message)
       end
-      client_call.run_batch(CallOps::SEND_CLOSE_FROM_CLIENT => nil)
+      @client_call.run_batch(CallOps::SEND_CLOSE_FROM_CLIENT => nil)
 
-      server_thread.join
+      @server_thread.join
 
       # Expect that initial metadata was sent,
       # the requests were echoed, and a status was sent
-      batch_result = client_call.run_batch(
+      batch_result = @client_call.run_batch(
         CallOps::RECV_INITIAL_METADATA => nil)
-      expect(batch_result.metadata).to eq(server_to_client_metadata)
+      expect(batch_result.metadata).to eq(@server_to_client_metadata)
 
-      requests.each do |message|
-        batch_result = client_call.run_batch(
+      @requests.each do |message|
+        batch_result = @client_call.run_batch(
           CallOps::RECV_MESSAGE => nil)
         expect(batch_result.message).to eq(message)
       end
 
-      batch_result = client_call.run_batch(
+      batch_result = @client_call.run_batch(
         CallOps::RECV_STATUS_ON_CLIENT => nil)
-      expect(batch_result.status.code).to eq(server_status)
+      expect(batch_result.status.code).to eq(@server_status)
     end
 
-    it 'sends the metadata when sent explicitly and not already sent' do
-      requests = ['first message', 'second message']
-      server_to_client_metadata = { 'test_key' => 'test_val' }
-      server_status = OK
-
-      client_call = make_test_call
-      client_call.run_batch(CallOps::SEND_INITIAL_METADATA => {})
+    it 'sends the initial metadata implicitly if not already sent' do
+      # Server handler that doesn't have access to a "call"
+      # It echoes the requests
+      fake_gen_each_reply_with_no_call_param = proc do |msgs|
+        msgs
+      end
 
-      recvd_rpc = @server.request_call
-      recvd_call = recvd_rpc.call
-      server_call = ActiveCall.new(recvd_call,
-                                   @pass_through,
-                                   @pass_through,
-                                   deadline,
-                                   metadata_received: true,
-                                   started: false)
+      @server_thread = Thread.new do
+        @server_call.run_server_bidi(
+          fake_gen_each_reply_with_no_call_param)
+        @server_call.send_status(@server_status)
+      end
+    end
 
+    it 'sends the metadata when sent explicitly and not already sent' do
       # Fake server handler that has access to a "call" object and
-      # uses it to explicitly update and sent the initial metadata
+      # uses it to explicitly update and send the initial metadata
       fake_gen_each_reply_with_call_param = proc do |msgs, call_param|
-        call_param.merge_metadata_to_send(server_to_client_metadata)
+        call_param.merge_metadata_to_send(@server_to_client_metadata)
         call_param.send_initial_metadata
         msgs
       end
 
-      server_thread = Thread.new do
-        server_call.run_server_bidi(
+      @server_thread = Thread.new do
+        @server_call.run_server_bidi(
           fake_gen_each_reply_with_call_param)
-        server_call.send_status(server_status)
-      end
-
-      # Send requests and a close from the client so the server
-      # can send a status
-      requests.each do |message|
-        client_call.run_batch(
-          CallOps::SEND_MESSAGE => message)
-      end
-      client_call.run_batch(
-        CallOps::SEND_CLOSE_FROM_CLIENT => nil)
-
-      server_thread.join
-
-      # Verify that the correct metadata was sent, the requests
-      # were echoed, and the correct status was sent
-      batch_result = client_call.run_batch(
-        CallOps::RECV_INITIAL_METADATA => nil)
-      expect(batch_result.metadata).to eq(server_to_client_metadata)
-
-      requests.each do |message|
-        batch_result = client_call.run_batch(
-          CallOps::RECV_MESSAGE => nil)
-        expect(batch_result.message).to eq(message)
+        @server_call.send_status(@server_status)
       end
-
-      batch_result = client_call.run_batch(
-        CallOps::RECV_STATUS_ON_CLIENT => nil)
-      expect(batch_result.status.code).to eq(server_status)
     end
   end
 
-- 
cgit v1.2.3