From 89d6ab583274e7e10a69bc915b0e48cfdbc6207a Mon Sep 17 00:00:00 2001 From: huangyi Date: Sun, 23 Oct 2011 22:03:27 +0800 Subject: add domain compress --- Network/DNS/Query.hs | 32 +++++++++++++++---------- Network/DNS/StateBinary.hs | 60 +++++++++++++++++++++++++++++++++++++++------- TestProtocol.hs | 20 ++++++++-------- 3 files changed, 81 insertions(+), 31 deletions(-) diff --git a/Network/DNS/Query.hs b/Network/DNS/Query.hs index 0d164b4..3ebd5e0 100644 --- a/Network/DNS/Query.hs +++ b/Network/DNS/Query.hs @@ -2,9 +2,7 @@ module Network.DNS.Query (composeQuery, composeDNSFormat) where import qualified Data.ByteString.Lazy.Char8 as BL (ByteString) -import qualified Data.ByteString as BS (unpack) -import qualified Data.ByteString.Char8 as BS (length, split, null) -import Blaze.ByteString.Builder.ByteString (writeByteString) +import qualified Data.ByteString.Char8 as BS (length, null, break, drop) import Network.DNS.StateBinary import Network.DNS.Internal import Data.Monoid @@ -109,7 +107,7 @@ encodeRDATA rd = case rd of (RD_CNAME dom) -> encodeDomain dom (RD_PTR dom) -> encodeDomain dom (RD_MX prf dom) -> mconcat [putInt16 prf, encodeDomain dom] - (RD_TXT txt) -> writeByteString txt + (RD_TXT txt) -> putByteString txt (RD_OTH bytes) -> mconcat $ map putInt8 bytes (RD_SOA d1 d2 serial refresh retry expire min') -> mconcat $ [ encodeDomain d1 @@ -130,13 +128,23 @@ encodeRDATA rd = case rd of ---------------------------------------------------------------- encodeDomain :: Domain -> SPut -encodeDomain dom = foldr ((+++) . encodeSubDomain) (put8 0) $ zip ls ss +encodeDomain dom | BS.null dom = put8 0 +encodeDomain dom = do + mpos <- wsPop dom + cur <- gets wsPosition + case mpos of + Just pos -> encodePointer pos + Nothing -> wsPush dom cur >> + mconcat [ encodePartialDomain hd + , encodeDomain tl + ] where - ss = filter (not . BS.null) $ BS.split '.' dom - ls = map BS.length ss + (hd, tl') = BS.break (=='.') dom + tl = if BS.null tl' then tl' else BS.drop 1 tl' -encodeSubDomain :: (Int, Domain) -> SPut -encodeSubDomain (len,sub) = putInt8 len - +++ foldr ((+++) . put8) mempty ss - where - ss = BS.unpack sub +encodePointer :: Int -> SPut +encodePointer pos = let w = (pos .|. 0xc000) in putInt16 w + +encodePartialDomain :: Domain -> SPut +encodePartialDomain sub = putInt8 (BS.length sub) + +++ putByteString sub diff --git a/Network/DNS/StateBinary.hs b/Network/DNS/StateBinary.hs index e44ee98..6898d3b 100644 --- a/Network/DNS/StateBinary.hs +++ b/Network/DNS/StateBinary.hs @@ -1,43 +1,85 @@ +{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} module Network.DNS.StateBinary where import Blaze.ByteString.Builder import Control.Applicative import Control.Monad.State +import Data.Monoid import Data.Attoparsec import Data.Attoparsec.Enumerator import qualified Data.Attoparsec.Lazy as AL import Data.ByteString (ByteString) -import qualified Data.ByteString as BS (unpack) +import qualified Data.ByteString as BS (unpack, length) import qualified Data.ByteString.Lazy as BL (ByteString) import Data.Enumerator (Iteratee) import Data.Int import Data.IntMap (IntMap) import qualified Data.IntMap as IM (insert, lookup, empty) +import Data.Map (Map) +import qualified Data.Map as M (insert, lookup, empty) import Data.Word import Network.DNS.Types import Prelude hiding (lookup, take) ---------------------------------------------------------------- -type SPut = Write +type SPut = State WState Write + +data WState = WState { + wsDomain :: Map Domain Int + , wsPosition :: Int +} + +initialWState :: WState +initialWState = WState M.empty 0 + +instance Monoid SPut where + mempty = return mempty + mappend a b = mconcat <$> sequence [a, b] put8 :: Word8 -> SPut -put8 = writeWord8 +put8 = fixedSized 1 writeWord8 put16 :: Word16 -> SPut -put16 = writeWord16be +put16 = fixedSized 2 writeWord16be put32 :: Word32 -> SPut -put32 = writeWord32be +put32 = fixedSized 4 writeWord32be putInt8 :: Int -> SPut -putInt8 = writeInt8 . fromIntegral +putInt8 = fixedSized 1 (writeInt8 . fromIntegral) putInt16 :: Int -> SPut -putInt16 = writeInt16be . fromIntegral +putInt16 = fixedSized 2 (writeInt16be . fromIntegral) putInt32 :: Int -> SPut -putInt32 = writeInt32be . fromIntegral +putInt32 = fixedSized 4 (writeInt32be . fromIntegral) + +putByteString :: ByteString -> SPut +putByteString = writeSized BS.length writeByteString + +addPositionW :: Int -> State WState () +addPositionW n = do + (WState m cur) <- get + put $ WState m (cur+n) + +fixedSized :: Int -> (a -> Write) -> a -> SPut +fixedSized n f a = do addPositionW n + return (f a) + +writeSized :: Show a => (a -> Int) -> (a -> Write) -> a -> SPut +writeSized n f a = do addPositionW (n a) + return (f a) + +wsPop :: Domain -> State WState (Maybe Int) +wsPop dom = do + doms <- gets wsDomain + return $ M.lookup dom doms + +wsPush :: Domain -> Int -> State WState () +wsPush dom pos = do + (WState m cur) <- get + put $ WState (M.insert dom pos m) cur ---------------------------------------------------------------- @@ -122,4 +164,4 @@ runSGet :: SGet a -> BL.ByteString -> Either String (a, PState) runSGet parser bs = AL.eitherResult $ AL.parse (runStateT parser initialState) bs runSPut :: SPut -> BL.ByteString -runSPut = toLazyByteString . fromWrite +runSPut = toLazyByteString . fromWrite . flip evalState initialWState diff --git a/TestProtocol.hs b/TestProtocol.hs index bde01dd..cef00ee 100644 --- a/TestProtocol.hs +++ b/TestProtocol.hs @@ -15,17 +15,17 @@ import Test.HUnit hiding (Test) tests :: [Test] tests = [ testGroup "Test case" - [ testCase "QueryA" (test_Format queryA) - , testCase "QueryAAAA" (test_Format queryAAAA) - , testCase "ResponseA" (test_Format responseA) + [ testCase "QueryA" (test_Format testQueryA) + , testCase "QueryAAAA" (test_Format testQueryAAAA) + , testCase "ResponseA" (test_Format $ testResponseA) ] ] defaultHeader :: DNSHeader defaultHeader = header defaultQuery -queryA :: DNSFormat -queryA = defaultQuery +testQueryA :: DNSFormat +testQueryA = defaultQuery { header = defaultHeader { identifier = 1000 , qdCount = 1 @@ -33,8 +33,8 @@ queryA = defaultQuery , question = [makeQuestion "www.mew.org." A] } -queryAAAA :: DNSFormat -queryAAAA = defaultQuery +testQueryAAAA :: DNSFormat +testQueryAAAA = defaultQuery { header = defaultHeader { identifier = 1000 , qdCount = 1 @@ -42,8 +42,8 @@ queryAAAA = defaultQuery , question = [makeQuestion "www.mew.org." AAAA] } -responseA :: DNSFormat -responseA = DNSFormat { header = DNSHeader { identifier = 61046 +testResponseA :: DNSFormat +testResponseA = DNSFormat { header = DNSHeader { identifier = 61046 , flags = DNSFlags { qOrR = QR_Response , opcode = OP_STD , authAnswer = False @@ -157,7 +157,7 @@ test_Format fmt = do assertEqual "fail" fmt fmt' where bs = composeDNSFormat fmt - result = runResponse_ bs + result = runDNSFormat_ bs main :: IO () main = defaultMain tests -- cgit v1.2.3