From d126d8fe4159b68feadd4145f0e2a1f90d7ca502 Mon Sep 17 00:00:00 2001 From: huangyi Date: Sun, 23 Oct 2011 13:22:18 +0800 Subject: make encoder and decoder invertable, and tests to ensure that. --- Network/DNS/Internal.hs | 46 ++++++++++++- Network/DNS/Query.hs | 96 ++++++++++++++++++++++---- Network/DNS/Resolver.hs | 4 +- Network/DNS/Response.hs | 20 ++++-- Network/DNS/StateBinary.hs | 8 ++- Network/DNS/Types.hs | 11 +-- TestProtocol.hs | 163 +++++++++++++++++++++++++++++++++++++++++++++ dns.cabal | 2 +- 8 files changed, 319 insertions(+), 31 deletions(-) create mode 100644 TestProtocol.hs diff --git a/Network/DNS/Internal.hs b/Network/DNS/Internal.hs index 8274afe..0368bd4 100644 --- a/Network/DNS/Internal.hs +++ b/Network/DNS/Internal.hs @@ -151,7 +151,15 @@ defaultQuery :: DNSFormat defaultQuery = DNSFormat { header = DNSHeader { identifier = 0 - , flags = undefined + , flags = DNSFlags { + qOrR = QR_Query + , opcode = OP_STD + , authAnswer = False + , trunCation = False + , recDesired = True + , recAvailable = False + , rcode = NoErr + } , qdCount = 0 , anCount = 0 , nsCount = 0 @@ -162,3 +170,39 @@ defaultQuery = DNSFormat { , authority = [] , additional = [] } + +defaultResponse :: DNSFormat +defaultResponse = + let hd = header defaultQuery + flg = flags hd + in defaultQuery { + header = hd { + flags = flg { + qOrR = QR_Response + , authAnswer = True + , recAvailable = True + } + } + } + +responseA :: Int -> Question -> IPv4 -> DNSFormat +responseA ident q ip = + let hd = header defaultResponse + dom = qname q + an = ResourceRecord dom A 300 4 (RD_A ip) + in defaultResponse { + header = hd { identifier=ident, qdCount = 1, anCount = 1 } + , question = [q] + , answer = [an] + } + +responseAAAA :: Int -> Question -> IPv6 -> DNSFormat +responseAAAA ident q ip = + let hd = header defaultResponse + dom = qname q + an = ResourceRecord dom AAAA 300 16 (RD_AAAA ip) + in defaultResponse { + header = hd { identifier=ident, qdCount = 1, anCount = 1 } + , question = [q] + , answer = [an] + } diff --git a/Network/DNS/Query.hs b/Network/DNS/Query.hs index 30c7e24..0d164b4 100644 --- a/Network/DNS/Query.hs +++ b/Network/DNS/Query.hs @@ -1,19 +1,28 @@ -module Network.DNS.Query (composeQuery) where +{-# LANGUAGE RecordWildCards #-} +module Network.DNS.Query (composeQuery, composeDNSFormat) where import qualified Data.ByteString.Lazy.Char8 as BL (ByteString) import qualified Data.ByteString as BS (unpack) import qualified Data.ByteString.Char8 as BS (length, split, null) +import Blaze.ByteString.Builder.ByteString (writeByteString) import Network.DNS.StateBinary import Network.DNS.Internal import Data.Monoid +import Control.Monad.State +import Data.Bits +import Data.Word +import Data.IP (+++) :: Monoid a => a -> a -> a (+++) = mappend ---------------------------------------------------------------- +composeDNSFormat :: DNSFormat -> BL.ByteString +composeDNSFormat fmt = runSPut (encodeDNSFormat fmt) + composeQuery :: Int -> [Question] -> BL.ByteString -composeQuery idt qs = runSPut (encodeQuery qry) +composeQuery idt qs = composeDNSFormat qry where hdr = header defaultQuery qry = defaultQuery { @@ -26,12 +35,18 @@ composeQuery idt qs = runSPut (encodeQuery qry) ---------------------------------------------------------------- -encodeQuery :: DNSFormat -> SPut -encodeQuery fmt = encodeHeader hdr - +++ encodeQuestion qs +encodeDNSFormat :: DNSFormat -> SPut +encodeDNSFormat fmt = encodeHeader hdr + +++ mconcat (map encodeQuestion qs) + +++ mconcat (map encodeRR an) + +++ mconcat (map encodeRR au) + +++ mconcat (map encodeRR ad) where hdr = header fmt qs = question fmt + an = answer fmt + au = authority fmt + ad = additional fmt encodeHeader :: DNSHeader -> SPut encodeHeader hdr = encodeIdentifier (identifier hdr) @@ -48,16 +63,69 @@ encodeHeader hdr = encodeIdentifier (identifier hdr) decodeArCount = putInt16 encodeFlags :: DNSFlags -> SPut -encodeFlags _ = put16 0x0100 -- xxx - -encodeQuestion :: [Question] -> SPut -encodeQuestion qs = encodeDomain dom - +++ putInt16 (typeToInt typ) - +++ put16 1 +encodeFlags DNSFlags{..} = put16 word where - q = head qs - dom = qname q - typ = qtype q + word16 :: Enum a => a -> Word16 + word16 = toEnum . fromEnum + + set :: Word16 -> State Word16 () + set byte = modify (.|. byte) + + st :: State Word16 () + st = sequence_ + [ set (word16 rcode) + , when recAvailable $ set (bit 7) + , when recDesired $ set (bit 8) + , when trunCation $ set (bit 9) + , when authAnswer $ set (bit 10) + , set (word16 opcode `shiftL` 11) + , when (qOrR==QR_Response) $ set (bit 15) + ] + + word = execState st 0 + +encodeQuestion :: Question -> SPut +encodeQuestion Question{..} = + encodeDomain qname + +++ putInt16 (typeToInt qtype) + +++ put16 1 + +encodeRR :: ResourceRecord -> SPut +encodeRR ResourceRecord{..} = + mconcat + [ encodeDomain rrname + , putInt16 (typeToInt rrtype) + , put16 1 + , putInt32 rrttl + , putInt16 rdlen + , encodeRDATA rdata + ] + +encodeRDATA :: RDATA -> SPut +encodeRDATA rd = case rd of + (RD_A ip) -> mconcat $ map putInt8 (fromIPv4 ip) + (RD_AAAA ip) -> mconcat $ map putInt16 (fromIPv6 ip) + (RD_NS dom) -> encodeDomain dom + (RD_CNAME dom) -> encodeDomain dom + (RD_PTR dom) -> encodeDomain dom + (RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom] + (RD_TXT txt) -> writeByteString txt + (RD_OTH bytes) -> mconcat $ map putInt8 bytes + (RD_SOA d1 d2 serial refresh retry expire min') -> mconcat $ + [ encodeDomain d1 + , encodeDomain d2 + , putInt32 serial + , putInt32 refresh + , putInt32 retry + , putInt32 expire + , putInt32 min' + ] + (RD_SRV prio weight port dom) -> mconcat $ + [ putInt16 prio + , putInt16 weight + , putInt16 port + , encodeDomain dom + ] ---------------------------------------------------------------- diff --git a/Network/DNS/Resolver.hs b/Network/DNS/Resolver.hs index 5919088..b3182f6 100644 --- a/Network/DNS/Resolver.hs +++ b/Network/DNS/Resolver.hs @@ -21,7 +21,7 @@ module Network.DNS.Resolver ( -- ** Intermediate data type for resolver , ResolvSeed, makeResolvSeed -- ** Type and function for resolver - , Resolver, withResolver + , Resolver(..), withResolver -- ** Looking up functions , lookup, lookupRaw ) where @@ -128,7 +128,7 @@ makeAddrInfo addr = do argument. 'withResolver' should be passed to 'forkIO'. -} -withResolver :: ResolvSeed -> (Resolver -> IO ()) -> IO () +withResolver :: ResolvSeed -> (Resolver -> IO a) -> IO a withResolver seed func = do let ai = addrInfo seed sock <- socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai) diff --git a/Network/DNS/Response.hs b/Network/DNS/Response.hs index 36f6203..abcb8ba 100644 --- a/Network/DNS/Response.hs +++ b/Network/DNS/Response.hs @@ -1,6 +1,6 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.DNS.Response (responseIter, parseResponse) where +module Network.DNS.Response (responseIter, parseResponse, runDNSFormat, runDNSFormat_) where import Control.Applicative import Control.Monad @@ -12,13 +12,21 @@ import Network.DNS.Internal import Network.DNS.StateBinary import Data.Enumerator (Enumerator, Iteratee, run_, ($$)) import Data.ByteString (ByteString) +import qualified Data.ByteString.Lazy as BL -responseIter :: Iteratee ByteString IO (DNSFormat, PState) -responseIter = runSGet decodeResponse +runDNSFormat :: BL.ByteString -> Either String (DNSFormat, PState) +runDNSFormat bs = runSGet decodeResponse bs -parseResponse :: Enumerator ByteString IO (a,b) - -> Iteratee ByteString IO (a,b) - -> IO a +runDNSFormat_ :: BL.ByteString -> Either String DNSFormat +runDNSFormat_ bs = fst <$> runDNSFormat bs + +responseIter :: Monad m => Iteratee ByteString m (DNSFormat, PState) +responseIter = iterSGet decodeResponse + +parseResponse :: (Functor m, Monad m) + => Enumerator ByteString m (a,b) + -> Iteratee ByteString m (a,b) + -> m a parseResponse enum iter = fst <$> run_ (enum $$ iter) ---------------------------------------------------------------- diff --git a/Network/DNS/StateBinary.hs b/Network/DNS/StateBinary.hs index 05f8468..e44ee98 100644 --- a/Network/DNS/StateBinary.hs +++ b/Network/DNS/StateBinary.hs @@ -5,6 +5,7 @@ import Control.Applicative import Control.Monad.State import Data.Attoparsec import Data.Attoparsec.Enumerator +import qualified Data.Attoparsec.Lazy as AL import Data.ByteString (ByteString) import qualified Data.ByteString as BS (unpack) import qualified Data.ByteString.Lazy as BL (ByteString) @@ -114,8 +115,11 @@ getNByteString n = lift (take n) <* addPosition n initialState :: PState initialState = PState IM.empty 0 -runSGet :: SGet a -> Iteratee ByteString IO (a, PState) -runSGet parser = iterParser (runStateT parser initialState) +iterSGet :: Monad m => SGet a -> Iteratee ByteString m (a, PState) +iterSGet parser = iterParser (runStateT parser initialState) + +runSGet :: SGet a -> BL.ByteString -> Either String (a, PState) +runSGet parser bs = AL.eitherResult $ AL.parse (runStateT parser initialState) bs runSPut :: SPut -> BL.ByteString runSPut = toLazyByteString . fromWrite diff --git a/Network/DNS/Types.hs b/Network/DNS/Types.hs index 58832e0..ce27046 100644 --- a/Network/DNS/Types.hs +++ b/Network/DNS/Types.hs @@ -8,18 +8,19 @@ module Network.DNS.Types ( -- * TYPE , TYPE (..), intToType, typeToInt, toType -- * DNS Format - , DNSFormat, header, question, answer, authority, additional + , DNSFormat (DNSFormat), header, question, answer, authority, additional -- * DNS Header - , DNSHeader, identifier, flags, qdCount, anCount, nsCount, arCount + , DNSHeader (DNSHeader), identifier, flags, qdCount, anCount, nsCount, arCount -- * DNS Flags - , DNSFlags, qOrR, opcode, authAnswer, trunCation, recDesired, recAvailable, rcode + , DNSFlags (DNSFlags), qOrR, opcode, authAnswer, trunCation, recDesired, recAvailable, rcode -- * DNS Body , QorR (..) , OPCODE (..) , RCODE (..) - , ResourceRecord, rrname, rrtype, rrttl, rdlen, rdata - , Question, qname, qtype, makeQuestion + , ResourceRecord (ResourceRecord), rrname, rrtype, rrttl, rdlen, rdata + , Question (Question), qname, qtype, makeQuestion , RDATA (..) + , responseA, responseAAAA ) where import Network.DNS.Internal diff --git a/TestProtocol.hs b/TestProtocol.hs new file mode 100644 index 0000000..bde01dd --- /dev/null +++ b/TestProtocol.hs @@ -0,0 +1,163 @@ +{-# LANGUAGE OverloadedStrings #-} + +module TestProtocol where + +import Network.DNS +import Network.DNS.Internal +import Network.DNS.Query +import Network.DNS.Response +import Data.IP +import Control.Monad.State +import Test.Framework (defaultMain, testGroup, Test) +import Test.Framework.Providers.HUnit +import Test.HUnit hiding (Test) + +tests :: [Test] +tests = + [ testGroup "Test case" + [ testCase "QueryA" (test_Format queryA) + , testCase "QueryAAAA" (test_Format queryAAAA) + , testCase "ResponseA" (test_Format responseA) + ] + ] + +defaultHeader :: DNSHeader +defaultHeader = header defaultQuery + +queryA :: DNSFormat +queryA = defaultQuery + { header = defaultHeader + { identifier = 1000 + , qdCount = 1 + } + , question = [makeQuestion "www.mew.org." A] + } + +queryAAAA :: DNSFormat +queryAAAA = defaultQuery + { header = defaultHeader + { identifier = 1000 + , qdCount = 1 + } + , question = [makeQuestion "www.mew.org." AAAA] + } + +responseA :: DNSFormat +responseA = DNSFormat { header = DNSHeader { identifier = 61046 + , flags = DNSFlags { qOrR = QR_Response + , opcode = OP_STD + , authAnswer = False + , trunCation = False + , recDesired = True + , recAvailable = True + , rcode = NoErr + } + , qdCount = 1 + , anCount = 8 + , nsCount = 2 + , arCount = 4 + } + , question = [Question {qname = "492056364.qzone.qq.com.", qtype = A}] + , answer = [ ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [119, 147, 15, 122] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [119, 147, 79, 106] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [183, 60, 55, 43] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [183, 60, 55, 107] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [113, 108, 7, 172] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [113, 108, 7, 174] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [113, 108, 7, 175] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [119, 147, 15, 100] + } + ] + , authority = [ ResourceRecord { rrname = "qzone.qq.com." + , rrtype = NS + , rrttl = 45919 + , rdlen = 10 + , rdata = RD_NS "ns-tel2.qq.com." + } + , ResourceRecord { rrname = "qzone.qq.com." + , rrtype = NS + , rrttl = 45919 + , rdlen = 10 + , rdata = RD_NS "ns-tel1.qq.com." + } + ] + , additional = [ ResourceRecord { rrname = "ns-tel1.qq.com." + , rrtype = A + , rrttl = 46520 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [121, 14, 73, 115] + } + , ResourceRecord { rrname = "ns-tel2.qq.com." + , rrtype = A + , rrttl = 2890 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [222, 73, 76, 226] + } + , ResourceRecord { rrname = "ns-tel2.qq.com." + , rrtype = A + , rrttl = 2890 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [183, 60, 3, 202] + } + , ResourceRecord { rrname = "ns-tel2.qq.com." + , rrtype = A + , rrttl = 2890 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [218, 30, 72, 180] + } + ] + } + +assertEither :: (a -> String) -> Either a b -> IO () +assertEither f e = either (assertFailure . f) (const $ return ()) e + +test_Format :: DNSFormat -> IO () +test_Format fmt = do + assertEither id result + let (Right fmt') = result + assertEqual "fail" fmt fmt' + where + bs = composeDNSFormat fmt + result = runResponse_ bs + +main :: IO () +main = defaultMain tests diff --git a/dns.cabal b/dns.cabal index 1069339..1d1856a 100644 --- a/dns.cabal +++ b/dns.cabal @@ -12,7 +12,7 @@ Description: DNS libary. Currently only resolver side Category: Network Cabal-Version: >= 1.6 Build-Type: Simple -Extra-Source-Files: Test.hs +Extra-Source-Files: Test.hs, TestProtocol.hs library if impl(ghc >= 6.12) GHC-Options: -Wall -fno-warn-unused-do-bind -- cgit v1.2.3