From d126d8fe4159b68feadd4145f0e2a1f90d7ca502 Mon Sep 17 00:00:00 2001 From: huangyi Date: Sun, 23 Oct 2011 13:22:18 +0800 Subject: make encoder and decoder invertable, and tests to ensure that. --- Network/DNS/Internal.hs | 46 ++++++++++++- Network/DNS/Query.hs | 96 ++++++++++++++++++++++---- Network/DNS/Resolver.hs | 4 +- Network/DNS/Response.hs | 20 ++++-- Network/DNS/StateBinary.hs | 8 ++- Network/DNS/Types.hs | 11 +-- TestProtocol.hs | 163 +++++++++++++++++++++++++++++++++++++++++++++ dns.cabal | 2 +- 8 files changed, 319 insertions(+), 31 deletions(-) create mode 100644 TestProtocol.hs diff --git a/Network/DNS/Internal.hs b/Network/DNS/Internal.hs index 8274afe..0368bd4 100644 --- a/Network/DNS/Internal.hs +++ b/Network/DNS/Internal.hs @@ -151,7 +151,15 @@ defaultQuery :: DNSFormat defaultQuery = DNSFormat { header = DNSHeader { identifier = 0 - , flags = undefined + , flags = DNSFlags { + qOrR = QR_Query + , opcode = OP_STD + , authAnswer = False + , trunCation = False + , recDesired = True + , recAvailable = False + , rcode = NoErr + } , qdCount = 0 , anCount = 0 , nsCount = 0 @@ -162,3 +170,39 @@ defaultQuery = DNSFormat { , authority = [] , additional = [] } + +defaultResponse :: DNSFormat +defaultResponse = + let hd = header defaultQuery + flg = flags hd + in defaultQuery { + header = hd { + flags = flg { + qOrR = QR_Response + , authAnswer = True + , recAvailable = True + } + } + } + +responseA :: Int -> Question -> IPv4 -> DNSFormat +responseA ident q ip = + let hd = header defaultResponse + dom = qname q + an = ResourceRecord dom A 300 4 (RD_A ip) + in defaultResponse { + header = hd { identifier=ident, qdCount = 1, anCount = 1 } + , question = [q] + , answer = [an] + } + +responseAAAA :: Int -> Question -> IPv6 -> DNSFormat +responseAAAA ident q ip = + let hd = header defaultResponse + dom = qname q + an = ResourceRecord dom AAAA 300 16 (RD_AAAA ip) + in defaultResponse { + header = hd { identifier=ident, qdCount = 1, anCount = 1 } + , question = [q] + , answer = [an] + } diff --git a/Network/DNS/Query.hs b/Network/DNS/Query.hs index 30c7e24..0d164b4 100644 --- a/Network/DNS/Query.hs +++ b/Network/DNS/Query.hs @@ -1,19 +1,28 @@ -module Network.DNS.Query (composeQuery) where +{-# LANGUAGE RecordWildCards #-} +module Network.DNS.Query (composeQuery, composeDNSFormat) where import qualified Data.ByteString.Lazy.Char8 as BL (ByteString) import qualified Data.ByteString as BS (unpack) import qualified Data.ByteString.Char8 as BS (length, split, null) +import Blaze.ByteString.Builder.ByteString (writeByteString) import Network.DNS.StateBinary import Network.DNS.Internal import Data.Monoid +import Control.Monad.State +import Data.Bits +import Data.Word +import Data.IP (+++) :: Monoid a => a -> a -> a (+++) = mappend ---------------------------------------------------------------- +composeDNSFormat :: DNSFormat -> BL.ByteString +composeDNSFormat fmt = runSPut (encodeDNSFormat fmt) + composeQuery :: Int -> [Question] -> BL.ByteString -composeQuery idt qs = runSPut (encodeQuery qry) +composeQuery idt qs = composeDNSFormat qry where hdr = header defaultQuery qry = defaultQuery { @@ -26,12 +35,18 @@ composeQuery idt qs = runSPut (encodeQuery qry) ---------------------------------------------------------------- -encodeQuery :: DNSFormat -> SPut -encodeQuery fmt = encodeHeader hdr - +++ encodeQuestion qs +encodeDNSFormat :: DNSFormat -> SPut +encodeDNSFormat fmt = encodeHeader hdr + +++ mconcat (map encodeQuestion qs) + +++ mconcat (map encodeRR an) + +++ mconcat (map encodeRR au) + +++ mconcat (map encodeRR ad) where hdr = header fmt qs = question fmt + an = answer fmt + au = authority fmt + ad = additional fmt encodeHeader :: DNSHeader -> SPut encodeHeader hdr = encodeIdentifier (identifier hdr) @@ -48,16 +63,69 @@ encodeHeader hdr = encodeIdentifier (identifier hdr) decodeArCount = putInt16 encodeFlags :: DNSFlags -> SPut -encodeFlags _ = put16 0x0100 -- xxx - -encodeQuestion :: [Question] -> SPut -encodeQuestion qs = encodeDomain dom - +++ putInt16 (typeToInt typ) - +++ put16 1 +encodeFlags DNSFlags{..} = put16 word where - q = head qs - dom = qname q - typ = qtype q + word16 :: Enum a => a -> Word16 + word16 = toEnum . fromEnum + + set :: Word16 -> State Word16 () + set byte = modify (.|. byte) + + st :: State Word16 () + st = sequence_ + [ set (word16 rcode) + , when recAvailable $ set (bit 7) + , when recDesired $ set (bit 8) + , when trunCation $ set (bit 9) + , when authAnswer $ set (bit 10) + , set (word16 opcode `shiftL` 11) + , when (qOrR==QR_Response) $ set (bit 15) + ] + + word = execState st 0 + +encodeQuestion :: Question -> SPut +encodeQuestion Question{..} = + encodeDomain qname + +++ putInt16 (typeToInt qtype) + +++ put16 1 + +encodeRR :: ResourceRecord -> SPut +encodeRR ResourceRecord{..} = + mconcat + [ encodeDomain rrname + , putInt16 (typeToInt rrtype) + , put16 1 + , putInt32 rrttl + , putInt16 rdlen + , encodeRDATA rdata + ] + +encodeRDATA :: RDATA -> SPut +encodeRDATA rd = case rd of + (RD_A ip) -> mconcat $ map putInt8 (fromIPv4 ip) + (RD_AAAA ip) -> mconcat $ map putInt16 (fromIPv6 ip) + (RD_NS dom) -> encodeDomain dom + (RD_CNAME dom) -> encodeDomain dom + (RD_PTR dom) -> encodeDomain dom + (RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom] + (RD_TXT txt) -> writeByteString txt + (RD_OTH bytes) -> mconcat $ map putInt8 bytes + (RD_SOA d1 d2 serial refresh retry expire min') -> mconcat $ + [ encodeDomain d1 + , encodeDomain d2 + , putInt32 serial + , putInt32 refresh + , putInt32 retry + , putInt32 expire + , putInt32 min' + ] + (RD_SRV prio weight port dom) -> mconcat $ + [ putInt16 prio + , putInt16 weight + , putInt16 port + , encodeDomain dom + ] ---------------------------------------------------------------- diff --git a/Network/DNS/Resolver.hs b/Network/DNS/Resolver.hs index 5919088..b3182f6 100644 --- a/Network/DNS/Resolver.hs +++ b/Network/DNS/Resolver.hs @@ -21,7 +21,7 @@ module Network.DNS.Resolver ( -- ** Intermediate data type for resolver , ResolvSeed, makeResolvSeed -- ** Type and function for resolver - , Resolver, withResolver + , Resolver(..), withResolver -- ** Looking up functions , lookup, lookupRaw ) where @@ -128,7 +128,7 @@ makeAddrInfo addr = do argument. 'withResolver' should be passed to 'forkIO'. -} -withResolver :: ResolvSeed -> (Resolver -> IO ()) -> IO () +withResolver :: ResolvSeed -> (Resolver -> IO a) -> IO a withResolver seed func = do let ai = addrInfo seed sock <- socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai) diff --git a/Network/DNS/Response.hs b/Network/DNS/Response.hs index 36f6203..abcb8ba 100644 --- a/Network/DNS/Response.hs +++ b/Network/DNS/Response.hs @@ -1,6 +1,6 @@ {-# LANGUAGE OverloadedStrings #-} -module Network.DNS.Response (responseIter, parseResponse) where +module Network.DNS.Response (responseIter, parseResponse, runDNSFormat, runDNSFormat_) where import Control.Applicative import Control.Monad @@ -12,13 +12,21 @@ import Network.DNS.Internal import Network.DNS.StateBinary import Data.Enumerator (Enumerator, Iteratee, run_, ($$)) import Data.ByteString (ByteString) +import qualified Data.ByteString.Lazy as BL -responseIter :: Iteratee ByteString IO (DNSFormat, PState) -responseIter = runSGet decodeResponse +runDNSFormat :: BL.ByteString -> Either String (DNSFormat, PState) +runDNSFormat bs = runSGet decodeResponse bs -parseResponse :: Enumerator ByteString IO (a,b) - -> Iteratee ByteString IO (a,b) - -> IO a +runDNSFormat_ :: BL.ByteString -> Either String DNSFormat +runDNSFormat_ bs = fst <$> runDNSFormat bs + +responseIter :: Monad m => Iteratee ByteString m (DNSFormat, PState) +responseIter = iterSGet decodeResponse + +parseResponse :: (Functor m, Monad m) + => Enumerator ByteString m (a,b) + -> Iteratee ByteString m (a,b) + -> m a parseResponse enum iter = fst <$> run_ (enum $$ iter) ---------------------------------------------------------------- diff --git a/Network/DNS/StateBinary.hs b/Network/DNS/StateBinary.hs index 05f8468..e44ee98 100644 --- a/Network/DNS/StateBinary.hs +++ b/Network/DNS/StateBinary.hs @@ -5,6 +5,7 @@ import Control.Applicative import Control.Monad.State import Data.Attoparsec import Data.Attoparsec.Enumerator +import qualified Data.Attoparsec.Lazy as AL import Data.ByteString (ByteString) import qualified Data.ByteString as BS (unpack) import qualified Data.ByteString.Lazy as BL (ByteString) @@ -114,8 +115,11 @@ getNByteString n = lift (take n) <* addPosition n initialState :: PState initialState = PState IM.empty 0 -runSGet :: SGet a -> Iteratee ByteString IO (a, PState) -runSGet parser = iterParser (runStateT parser initialState) +iterSGet :: Monad m => SGet a -> Iteratee ByteString m (a, PState) +iterSGet parser = iterParser (runStateT parser initialState) + +runSGet :: SGet a -> BL.ByteString -> Either String (a, PState) +runSGet parser bs = AL.eitherResult $ AL.parse (runStateT parser initialState) bs runSPut :: SPut -> BL.ByteString runSPut = toLazyByteString . fromWrite diff --git a/Network/DNS/Types.hs b/Network/DNS/Types.hs index 58832e0..ce27046 100644 --- a/Network/DNS/Types.hs +++ b/Network/DNS/Types.hs @@ -8,18 +8,19 @@ module Network.DNS.Types ( -- * TYPE , TYPE (..), intToType, typeToInt, toType -- * DNS Format - , DNSFormat, header, question, answer, authority, additional + , DNSFormat (DNSFormat), header, question, answer, authority, additional -- * DNS Header - , DNSHeader, identifier, flags, qdCount, anCount, nsCount, arCount + , DNSHeader (DNSHeader), identifier, flags, qdCount, anCount, nsCount, arCount -- * DNS Flags - , DNSFlags, qOrR, opcode, authAnswer, trunCation, recDesired, recAvailable, rcode + , DNSFlags (DNSFlags), qOrR, opcode, authAnswer, trunCation, recDesired, recAvailable, rcode -- * DNS Body , QorR (..) , OPCODE (..) , RCODE (..) - , ResourceRecord, rrname, rrtype, rrttl, rdlen, rdata - , Question, qname, qtype, makeQuestion + , ResourceRecord (ResourceRecord), rrname, rrtype, rrttl, rdlen, rdata + , Question (Question), qname, qtype, makeQuestion , RDATA (..) + , responseA, responseAAAA ) where import Network.DNS.Internal diff --git a/TestProtocol.hs b/TestProtocol.hs new file mode 100644 index 0000000..bde01dd --- /dev/null +++ b/TestProtocol.hs @@ -0,0 +1,163 @@ +{-# LANGUAGE OverloadedStrings #-} + +module TestProtocol where + +import Network.DNS +import Network.DNS.Internal +import Network.DNS.Query +import Network.DNS.Response +import Data.IP +import Control.Monad.State +import Test.Framework (defaultMain, testGroup, Test) +import Test.Framework.Providers.HUnit +import Test.HUnit hiding (Test) + +tests :: [Test] +tests = + [ testGroup "Test case" + [ testCase "QueryA" (test_Format queryA) + , testCase "QueryAAAA" (test_Format queryAAAA) + , testCase "ResponseA" (test_Format responseA) + ] + ] + +defaultHeader :: DNSHeader +defaultHeader = header defaultQuery + +queryA :: DNSFormat +queryA = defaultQuery + { header = defaultHeader + { identifier = 1000 + , qdCount = 1 + } + , question = [makeQuestion "www.mew.org." A] + } + +queryAAAA :: DNSFormat +queryAAAA = defaultQuery + { header = defaultHeader + { identifier = 1000 + , qdCount = 1 + } + , question = [makeQuestion "www.mew.org." AAAA] + } + +responseA :: DNSFormat +responseA = DNSFormat { header = DNSHeader { identifier = 61046 + , flags = DNSFlags { qOrR = QR_Response + , opcode = OP_STD + , authAnswer = False + , trunCation = False + , recDesired = True + , recAvailable = True + , rcode = NoErr + } + , qdCount = 1 + , anCount = 8 + , nsCount = 2 + , arCount = 4 + } + , question = [Question {qname = "492056364.qzone.qq.com.", qtype = A}] + , answer = [ ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [119, 147, 15, 122] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [119, 147, 79, 106] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [183, 60, 55, 43] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [183, 60, 55, 107] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [113, 108, 7, 172] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [113, 108, 7, 174] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [113, 108, 7, 175] + } + , ResourceRecord { rrname = "492056364.qzone.qq.com." + , rrtype = A + , rrttl = 568 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [119, 147, 15, 100] + } + ] + , authority = [ ResourceRecord { rrname = "qzone.qq.com." + , rrtype = NS + , rrttl = 45919 + , rdlen = 10 + , rdata = RD_NS "ns-tel2.qq.com." + } + , ResourceRecord { rrname = "qzone.qq.com." + , rrtype = NS + , rrttl = 45919 + , rdlen = 10 + , rdata = RD_NS "ns-tel1.qq.com." + } + ] + , additional = [ ResourceRecord { rrname = "ns-tel1.qq.com." + , rrtype = A + , rrttl = 46520 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [121, 14, 73, 115] + } + , ResourceRecord { rrname = "ns-tel2.qq.com." + , rrtype = A + , rrttl = 2890 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [222, 73, 76, 226] + } + , ResourceRecord { rrname = "ns-tel2.qq.com." + , rrtype = A + , rrttl = 2890 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [183, 60, 3, 202] + } + , ResourceRecord { rrname = "ns-tel2.qq.com." + , rrtype = A + , rrttl = 2890 + , rdlen = 4 + , rdata = RD_A $ toIPv4 [218, 30, 72, 180] + } + ] + } + +assertEither :: (a -> String) -> Either a b -> IO () +assertEither f e = either (assertFailure . f) (const $ return ()) e + +test_Format :: DNSFormat -> IO () +test_Format fmt = do + assertEither id result + let (Right fmt') = result + assertEqual "fail" fmt fmt' + where + bs = composeDNSFormat fmt + result = runResponse_ bs + +main :: IO () +main = defaultMain tests diff --git a/dns.cabal b/dns.cabal index 1069339..1d1856a 100644 --- a/dns.cabal +++ b/dns.cabal @@ -12,7 +12,7 @@ Description: DNS libary. Currently only resolver side Category: Network Cabal-Version: >= 1.6 Build-Type: Simple -Extra-Source-Files: Test.hs +Extra-Source-Files: Test.hs, TestProtocol.hs library if impl(ghc >= 6.12) GHC-Options: -Wall -fno-warn-unused-do-bind -- cgit v1.2.3 From f16d70af84c736b986153727e2bbcb11ec5da7bd Mon Sep 17 00:00:00 2001 From: huangyi Date: Sun, 23 Oct 2011 13:23:10 +0800 Subject: make a simple dns proxy server to demonstrate protocol encoder and decoder --- SimpleServer.hs | 106 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 SimpleServer.hs diff --git a/SimpleServer.hs b/SimpleServer.hs new file mode 100644 index 0000000..b01d864 --- /dev/null +++ b/SimpleServer.hs @@ -0,0 +1,106 @@ +{-# LANGUAGE RecordWildCards, OverloadedStrings #-} + +import System.Environment +import Debug.Trace +import Control.Monad +import Control.Concurrent +import Control.Applicative +import Data.Monoid +import Data.Maybe +import qualified Data.ByteString as S +import Data.ByteString.Lazy hiding (putStrLn, filter, length) +import System.Timeout +import Network.BSD +import Network.DNS hiding (lookup) +import Network.DNS.Response +import Network.DNS.Query +import Network.Socket hiding (recvFrom) +import Network.Socket.ByteString +import Network.Socket.Enumerator +import Data.Default + +data Conf = Conf { + bufSize :: Int + , timeOut :: Int + , realDNS :: HostName +} + +instance Default Conf where + def = Conf { + bufSize = 512 + , timeOut = 3 * 1000 * 1000 + , realDNS = "192.168.1.1" + } + +timeout' :: String -> Int -> IO a -> IO (Maybe a) +timeout' msg tm io = do + result <- timeout tm io + maybe (putStrLn msg) (const $ return ()) result + return result + +proxyRequest :: Conf -> ResolvConf -> DNSFormat -> IO (Maybe DNSFormat) +proxyRequest Conf{..} rc req = do + let + worker Resolver{..} = do + let packet = mconcat . toChunks $ composeDNSFormat req + sendAll dnsSock packet + let responseEnum = enumSocket dnsBufsize dnsSock + parseResponse responseEnum responseIter + + rs <- makeResolvSeed rc + withResolver rs $ \r -> do + (>>= check) <$> timeout' "proxy timeout" timeOut (worker r) + where + ident = identifier . header $ req + check :: DNSFormat -> Maybe DNSFormat + check rsp = let hdr = header rsp + in if identifier hdr == ident + then Just rsp + else trace "identifier not match" Nothing + +{-- + - 先尝试本地查询,查询不到就代理到真正的dns服务器 + --} +handleRequest :: Conf -> ResolvConf -> DNSFormat -> IO (Maybe DNSFormat) +handleRequest conf rc req = maybe (proxyRequest conf rc req) (trace "return A record" $ return . Just) mresponse + where + filterA = filter ((==A) . qtype) + mresponse = do + let ident = identifier . header $ req + q <- listToMaybe . filterA . question $ req + let dom = qname q + ip <- lookup dom hosts + return $ responseA ident q ip + hosts = [ ("proxy.com.", "127.0.0.1") + --, ("*.proxy.com", "127.0.0.1") + ] + +handlePacket :: Conf -> Socket -> SockAddr -> S.ByteString -> IO () +handlePacket conf@Conf{..} sock addr bs = case runDNSFormat_ (fromChunks [bs]) of + Right req -> do + print req + let rc = defaultResolvConf { resolvInfo = RCHostName realDNS } + mrsp <- handleRequest conf rc req + print mrsp + case mrsp of + Just rsp -> + let packet = mconcat . toChunks $ composeDNSFormat rsp + in timeout' "send timeout" timeOut (sendAllTo sock packet addr) >> + print (S.length packet) >> + return () + Nothing -> return () + Left msg -> putStrLn msg + +main :: IO () +main = withSocketsDo $ do + dns <- fromMaybe (realDNS def) . listToMaybe <$> getArgs + let conf = def { realDNS=dns } + addrinfos <- getAddrInfo + (Just (defaultHints {addrFlags = [AI_PASSIVE]})) + Nothing (Just "domain") + addrinfo <- maybe (fail "no addr info") return (listToMaybe addrinfos) + sock <- socket (addrFamily addrinfo) Datagram defaultProtocol + bindSocket sock (addrAddress addrinfo) + forever $ do + (bs, addr) <- recvFrom sock (bufSize conf) + forkIO $ handlePacket conf sock addr bs -- cgit v1.2.3 From 89d6ab583274e7e10a69bc915b0e48cfdbc6207a Mon Sep 17 00:00:00 2001 From: huangyi Date: Sun, 23 Oct 2011 22:03:27 +0800 Subject: add domain compress --- Network/DNS/Query.hs | 32 +++++++++++++++---------- Network/DNS/StateBinary.hs | 60 +++++++++++++++++++++++++++++++++++++++------- TestProtocol.hs | 20 ++++++++-------- 3 files changed, 81 insertions(+), 31 deletions(-) diff --git a/Network/DNS/Query.hs b/Network/DNS/Query.hs index 0d164b4..3ebd5e0 100644 --- a/Network/DNS/Query.hs +++ b/Network/DNS/Query.hs @@ -2,9 +2,7 @@ module Network.DNS.Query (composeQuery, composeDNSFormat) where import qualified Data.ByteString.Lazy.Char8 as BL (ByteString) -import qualified Data.ByteString as BS (unpack) -import qualified Data.ByteString.Char8 as BS (length, split, null) -import Blaze.ByteString.Builder.ByteString (writeByteString) +import qualified Data.ByteString.Char8 as BS (length, null, break, drop) import Network.DNS.StateBinary import Network.DNS.Internal import Data.Monoid @@ -109,7 +107,7 @@ encodeRDATA rd = case rd of (RD_CNAME dom) -> encodeDomain dom (RD_PTR dom) -> encodeDomain dom (RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom] - (RD_TXT txt) -> writeByteString txt + (RD_TXT txt) -> putByteString txt (RD_OTH bytes) -> mconcat $ map putInt8 bytes (RD_SOA d1 d2 serial refresh retry expire min') -> mconcat $ [ encodeDomain d1 @@ -130,13 +128,23 @@ encodeRDATA rd = case rd of ---------------------------------------------------------------- encodeDomain :: Domain -> SPut -encodeDomain dom = foldr ((+++) . encodeSubDomain) (put8 0) $ zip ls ss +encodeDomain dom | BS.null dom = put8 0 +encodeDomain dom = do + mpos <- wsPop dom + cur <- gets wsPosition + case mpos of + Just pos -> encodePointer pos + Nothing -> wsPush dom cur >> + mconcat [ encodePartialDomain hd + , encodeDomain tl + ] where - ss = filter (not . BS.null) $ BS.split '.' dom - ls = map BS.length ss + (hd, tl') = BS.break (=='.') dom + tl = if BS.null tl' then tl' else BS.drop 1 tl' -encodeSubDomain :: (Int, Domain) -> SPut -encodeSubDomain (len,sub) = putInt8 len - +++ foldr ((+++) . put8) mempty ss - where - ss = BS.unpack sub +encodePointer :: Int -> SPut +encodePointer pos = let w = (pos .|. 0xc000) in putInt16 w + +encodePartialDomain :: Domain -> SPut +encodePartialDomain sub = putInt8 (BS.length sub) + +++ putByteString sub diff --git a/Network/DNS/StateBinary.hs b/Network/DNS/StateBinary.hs index e44ee98..6898d3b 100644 --- a/Network/DNS/StateBinary.hs +++ b/Network/DNS/StateBinary.hs @@ -1,43 +1,85 @@ +{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} module Network.DNS.StateBinary where import Blaze.ByteString.Builder import Control.Applicative import Control.Monad.State +import Data.Monoid import Data.Attoparsec import Data.Attoparsec.Enumerator import qualified Data.Attoparsec.Lazy as AL import Data.ByteString (ByteString) -import qualified Data.ByteString as BS (unpack) +import qualified Data.ByteString as BS (unpack, length) import qualified Data.ByteString.Lazy as BL (ByteString) import Data.Enumerator (Iteratee) import Data.Int import Data.IntMap (IntMap) import qualified Data.IntMap as IM (insert, lookup, empty) +import Data.Map (Map) +import qualified Data.Map as M (insert, lookup, empty) import Data.Word import Network.DNS.Types import Prelude hiding (lookup, take) ---------------------------------------------------------------- -type SPut = Write +type SPut = State WState Write + +data WState = WState { + wsDomain :: Map Domain Int + , wsPosition :: Int +} + +initialWState :: WState +initialWState = WState M.empty 0 + +instance Monoid SPut where + mempty = return mempty + mappend a b = mconcat <$> sequence [a, b] put8 :: Word8 -> SPut -put8 = writeWord8 +put8 = fixedSized 1 writeWord8 put16 :: Word16 -> SPut -put16 = writeWord16be +put16 = fixedSized 2 writeWord16be put32 :: Word32 -> SPut -put32 = writeWord32be +put32 = fixedSized 4 writeWord32be putInt8 :: Int -> SPut -putInt8 = writeInt8 . fromIntegral +putInt8 = fixedSized 1 (writeInt8 . fromIntegral) putInt16 :: Int -> SPut -putInt16 = writeInt16be . fromIntegral +putInt16 = fixedSized 2 (writeInt16be . fromIntegral) putInt32 :: Int -> SPut -putInt32 = writeInt32be . fromIntegral +putInt32 = fixedSized 4 (writeInt32be . fromIntegral) + +putByteString :: ByteString -> SPut +putByteString = writeSized BS.length writeByteString + +addPositionW :: Int -> State WState () +addPositionW n = do + (WState m cur) <- get + put $ WState m (cur+n) + +fixedSized :: Int -> (a -> Write) -> a -> SPut +fixedSized n f a = do addPositionW n + return (f a) + +writeSized :: Show a => (a -> Int) -> (a -> Write) -> a -> SPut +writeSized n f a = do addPositionW (n a) + return (f a) + +wsPop :: Domain -> State WState (Maybe Int) +wsPop dom = do + doms <- gets wsDomain + return $ M.lookup dom doms + +wsPush :: Domain -> Int -> State WState () +wsPush dom pos = do + (WState m cur) <- get + put $ WState (M.insert dom pos m) cur ---------------------------------------------------------------- @@ -122,4 +164,4 @@ runSGet :: SGet a -> BL.ByteString -> Either String (a, PState) runSGet parser bs = AL.eitherResult $ AL.parse (runStateT parser initialState) bs runSPut :: SPut -> BL.ByteString -runSPut = toLazyByteString . fromWrite +runSPut = toLazyByteString . fromWrite . flip evalState initialWState diff --git a/TestProtocol.hs b/TestProtocol.hs index bde01dd..cef00ee 100644 --- a/TestProtocol.hs +++ b/TestProtocol.hs @@ -15,17 +15,17 @@ import Test.HUnit hiding (Test) tests :: [Test] tests = [ testGroup "Test case" - [ testCase "QueryA" (test_Format queryA) - , testCase "QueryAAAA" (test_Format queryAAAA) - , testCase "ResponseA" (test_Format responseA) + [ testCase "QueryA" (test_Format testQueryA) + , testCase "QueryAAAA" (test_Format testQueryAAAA) + , testCase "ResponseA" (test_Format $ testResponseA) ] ] defaultHeader :: DNSHeader defaultHeader = header defaultQuery -queryA :: DNSFormat -queryA = defaultQuery +testQueryA :: DNSFormat +testQueryA = defaultQuery { header = defaultHeader { identifier = 1000 , qdCount = 1 @@ -33,8 +33,8 @@ queryA = defaultQuery , question = [makeQuestion "www.mew.org." A] } -queryAAAA :: DNSFormat -queryAAAA = defaultQuery +testQueryAAAA :: DNSFormat +testQueryAAAA = defaultQuery { header = defaultHeader { identifier = 1000 , qdCount = 1 @@ -42,8 +42,8 @@ queryAAAA = defaultQuery , question = [makeQuestion "www.mew.org." AAAA] } -responseA :: DNSFormat -responseA = DNSFormat { header = DNSHeader { identifier = 61046 +testResponseA :: DNSFormat +testResponseA = DNSFormat { header = DNSHeader { identifier = 61046 , flags = DNSFlags { qOrR = QR_Response , opcode = OP_STD , authAnswer = False @@ -157,7 +157,7 @@ test_Format fmt = do assertEqual "fail" fmt fmt' where bs = composeDNSFormat fmt - result = runResponse_ bs + result = runDNSFormat_ bs main :: IO () main = defaultMain tests -- cgit v1.2.3