aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--btls.cabal3
-rw-r--r--src/BTLS/BoringSSL/HKDF.chs5
-rw-r--r--src/BTLS/Buffer.hs25
-rw-r--r--src/Codec/Crypto/HKDF.hs33
-rw-r--r--tests/BTLS/Assertions.hs27
-rw-r--r--tests/Codec/Crypto/HKDFTests.hs12
-rw-r--r--tests/Data/HMACTests.hs11
7 files changed, 80 insertions, 36 deletions
diff --git a/btls.cabal b/btls.cabal
index d0c4442..1ff34f0 100644
--- a/btls.cabal
+++ b/btls.cabal
@@ -125,7 +125,8 @@ test-suite tests
-threaded
-optl-Wl,-z,relro -optl-Wl,-z,now -optl-Wl,-s
main-is: Tests.hs
- other-modules: Codec.Crypto.HKDFTests
+ other-modules: BTLS.Assertions
+ , Codec.Crypto.HKDFTests
, Data.DigestTests
, Data.Digest.HashTests
, Data.Digest.MD5Tests
diff --git a/src/BTLS/BoringSSL/HKDF.chs b/src/BTLS/BoringSSL/HKDF.chs
index 1a28ccc..8ad3df2 100644
--- a/src/BTLS/BoringSSL/HKDF.chs
+++ b/src/BTLS/BoringSSL/HKDF.chs
@@ -22,15 +22,14 @@ import Foreign.C.Types
{#import BTLS.BoringSSL.Base#}
import BTLS.Buffer (unsafeUseAsCBuffer)
-import BTLS.Result
#include <openssl/hkdf.h>
{#fun HKDF_extract as hkdfExtract
{ id `Ptr CUChar', id `Ptr CULong', `Ptr EVPMD'
, unsafeUseAsCBuffer* `ByteString'&, unsafeUseAsCBuffer* `ByteString'& }
- -> `()' requireSuccess*-#}
+ -> `Int'#}
{#fun HKDF_expand as hkdfExpand
{ id `Ptr CUChar', `Int', `Ptr EVPMD', unsafeUseAsCBuffer* `ByteString'&
- , unsafeUseAsCBuffer* `ByteString'& } -> `()' requireSuccess*-#}
+ , unsafeUseAsCBuffer* `ByteString'& } -> `Int'#}
diff --git a/src/BTLS/Buffer.hs b/src/BTLS/Buffer.hs
index 7168a10..354c787 100644
--- a/src/BTLS/Buffer.hs
+++ b/src/BTLS/Buffer.hs
@@ -15,9 +15,11 @@
module BTLS.Buffer
( unsafeUseAsCBuffer
, packCUStringLen
- , onBufferOfMaxSize
+ , onBufferOfMaxSize, onBufferOfMaxSize'
) where
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.Except (ExceptT, runExceptT)
import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Unsafe as ByteString
@@ -44,9 +46,22 @@ onBufferOfMaxSize ::
=> Int
-> (Ptr CUChar -> Ptr size -> IO ())
-> IO ByteString
-onBufferOfMaxSize maxSize f =
+onBufferOfMaxSize maxSize f = do
+ Right r <- onBufferOfMaxSize' maxSize (compose2 lift f)
+ return r
+
+-- | Like 'onBufferOfMaxSize' but may fail.
+onBufferOfMaxSize' ::
+ (Integral size, Storable size)
+ => Int
+ -> (Ptr CUChar -> Ptr size -> ExceptT e IO ())
+ -> IO (Either e ByteString)
+onBufferOfMaxSize' maxSize f =
allocaArray maxSize $ \pOut ->
- alloca $ \pOutLen -> do
+ alloca $ \pOutLen -> runExceptT $ do
f pOut pOutLen
- outLen <- peek pOutLen
- packCUStringLen (pOut, outLen)
+ outLen <- lift $ peek pOutLen
+ lift $ packCUStringLen (pOut, outLen)
+
+compose2 :: (r -> r') -> (a -> b -> r) -> a -> b -> r'
+compose2 f g = \a b -> f (g a b)
diff --git a/src/Codec/Crypto/HKDF.hs b/src/Codec/Crypto/HKDF.hs
index bc1ea61..31d0be3 100644
--- a/src/Codec/Crypto/HKDF.hs
+++ b/src/Codec/Crypto/HKDF.hs
@@ -48,16 +48,23 @@ module Codec.Crypto.HKDF
-- to do so, specify the empty string as the associated data.
, AssociatedData(AssociatedData)
+ -- * Error handling
+ , Error
+
-- * Legacy functions
, md5
) where
+import Control.Monad ((>=>))
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Trans.Except (runExceptT)
import Foreign (allocaArray)
import Foreign.Marshal.Unsafe (unsafeLocalState)
import BTLS.BoringSSL.Digest (evpMaxMDSize)
import BTLS.BoringSSL.HKDF
-import BTLS.Buffer (onBufferOfMaxSize, packCUStringLen)
+import BTLS.Buffer (onBufferOfMaxSize', packCUStringLen)
+import BTLS.Result (Error, check)
import BTLS.Types
( Algorithm(Algorithm), AssociatedData(AssociatedData), Salt(Salt)
, SecretKey(SecretKey), noSalt
@@ -66,7 +73,7 @@ import Data.Digest (md5, sha1, sha224, sha256, sha384, sha512)
-- | Computes an HKDF. It is defined by
--
--- prop> hkdf md salt info len = expand md info len . extract md salt
+-- prop> hkdf md salt info len = extract md salt >=> expand md info len
--
-- but may be faster than calling the two functions individually.
hkdf ::
@@ -75,16 +82,16 @@ hkdf ::
-> AssociatedData
-> Int -- ^ The length of the derived key, in bytes.
-> SecretKey
- -> SecretKey
-hkdf md salt info outLen = expand md info outLen . extract md salt
+ -> Either [Error] SecretKey
+hkdf md salt info outLen = extract md salt >=> expand md info outLen
-- | Computes an HKDF pseudorandom key (PRK).
-extract :: Algorithm -> Salt -> SecretKey -> SecretKey
+extract :: Algorithm -> Salt -> SecretKey -> Either [Error] SecretKey
extract (Algorithm md) (Salt salt) (SecretKey secret) =
- SecretKey $
+ fmap SecretKey $
unsafeLocalState $
- onBufferOfMaxSize evpMaxMDSize $ \pOutKey pOutLen -> do
- hkdfExtract pOutKey pOutLen md secret salt
+ onBufferOfMaxSize' evpMaxMDSize $ \pOutKey pOutLen ->
+ check $ hkdfExtract pOutKey pOutLen md secret salt
-- | Computes HKDF output key material (OKM).
expand ::
@@ -92,10 +99,10 @@ expand ::
-> AssociatedData
-> Int -- ^ The length of the OKM, in bytes.
-> SecretKey
- -> SecretKey
+ -> Either [Error] SecretKey
expand (Algorithm md) (AssociatedData info) outLen (SecretKey secret) =
- SecretKey $
+ fmap SecretKey $
unsafeLocalState $
- allocaArray outLen $ \pOutKey -> do
- hkdfExpand pOutKey outLen md secret info
- packCUStringLen (pOutKey, outLen)
+ allocaArray outLen $ \pOutKey -> runExceptT $ do
+ check $ hkdfExpand pOutKey outLen md secret info
+ lift $ packCUStringLen (pOutKey, outLen)
diff --git a/tests/BTLS/Assertions.hs b/tests/BTLS/Assertions.hs
new file mode 100644
index 0000000..23ceb8d
--- /dev/null
+++ b/tests/BTLS/Assertions.hs
@@ -0,0 +1,27 @@
+-- Copyright 2018 Google LLC
+--
+-- Licensed under the Apache License, Version 2.0 (the "License"); you may not
+-- use this file except in compliance with the License. You may obtain a copy of
+-- the License at
+--
+-- https://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+-- WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+-- License for the specific language governing permissions and limitations under
+-- the License.
+
+module BTLS.Assertions
+ ( isRightAndHolds
+ ) where
+
+import Control.Monad (unless)
+import Test.Tasty.HUnit (Assertion, assertFailure)
+
+isRightAndHolds :: (Eq a, Show a, Show e) => Either e a -> a -> Assertion
+actual@(Left _) `isRightAndHolds` _ =
+ assertFailure ("expected: Right _\n but got: " ++ show actual)
+Right actual `isRightAndHolds` expected =
+ unless (expected == actual) $
+ assertFailure ("expected: Right " ++ show expected ++ "\n but got: Right " ++ show actual)
diff --git a/tests/Codec/Crypto/HKDFTests.hs b/tests/Codec/Crypto/HKDFTests.hs
index 44a41cd..d4cfe30 100644
--- a/tests/Codec/Crypto/HKDFTests.hs
+++ b/tests/Codec/Crypto/HKDFTests.hs
@@ -21,8 +21,9 @@ import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Base16 as ByteString.Base16
import qualified Data.ByteString.Char8 as ByteString.Char8
import Test.Tasty (TestTree, testGroup)
-import Test.Tasty.HUnit ((@?=), testCase)
+import Test.Tasty.HUnit (testCase)
+import BTLS.Assertions (isRightAndHolds)
import Codec.Crypto.HKDF (AssociatedData(AssociatedData), Salt(Salt), SecretKey(SecretKey), noSalt)
import qualified Codec.Crypto.HKDF as HKDF
import Data.Digest (sha1, sha256)
@@ -91,10 +92,11 @@ testRFC5869 = testGroup "RFC 5869 examples"
]
where
t name hash ikm salt info len prk okm =
- testGroup name [ testCase "hkdf" $ HKDF.hkdf hash salt info len ikm @?= okm
- , testCase "extract" $ HKDF.extract hash salt ikm @?= prk
- , testCase "expand" $ HKDF.expand hash info len prk @?= okm
- ]
+ testGroup name
+ [ testCase "hkdf" $ HKDF.hkdf hash salt info len ikm `isRightAndHolds` okm
+ , testCase "extract" $ HKDF.extract hash salt ikm `isRightAndHolds` prk
+ , testCase "expand" $ HKDF.expand hash info len prk `isRightAndHolds` okm
+ ]
hex :: ByteString -> ByteString
hex s =
diff --git a/tests/Data/HMACTests.hs b/tests/Data/HMACTests.hs
index 10500e4..55e795e 100644
--- a/tests/Data/HMACTests.hs
+++ b/tests/Data/HMACTests.hs
@@ -16,13 +16,13 @@
module Data.HMACTests (tests) where
-import Control.Monad (unless)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as ByteString.Lazy
import qualified Data.ByteString.Lazy.Char8 as ByteString.Lazy.Char8
import Test.Tasty (TestTree, testGroup)
-import Test.Tasty.HUnit (Assertion, assertFailure, testCase)
+import Test.Tasty.HUnit (testCase)
+import BTLS.Assertions (isRightAndHolds)
import Data.Digest (md5, sha1, sha224, sha256, sha384, sha512)
import Data.HMAC (Error, SecretKey(SecretKey), hmac)
@@ -191,10 +191,3 @@ hmacSha224 key bytes = show <$> hmac sha224 key bytes
hmacSha256 key bytes = show <$> hmac sha256 key bytes
hmacSha384 key bytes = show <$> hmac sha384 key bytes
hmacSha512 key bytes = show <$> hmac sha512 key bytes
-
-isRightAndHolds :: (Eq a, Show a, Show e) => Either e a -> a -> Assertion
-actual@(Left _) `isRightAndHolds` _ =
- assertFailure ("expected: Right _\n but got: " ++ show actual)
-Right actual `isRightAndHolds` expected =
- unless (expected == actual) $
- assertFailure ("expected: Right " ++ show expected ++ "\n but got: Right " ++ show actual)