From 30e57ebc21aaf4b326d02f7ba9d9022c3fbf24e7 Mon Sep 17 00:00:00 2001 From: Philipp Balzarek Date: Mon, 4 Jan 2016 13:59:07 +0100 Subject: [PATCH] fix stream input logger (#100) Logger tried to decode UTF8 at packet boundaries --- pontarius-xmpp.cabal | 1 + source/Network/Xmpp/Stream.hs | 40 ++++++++++++++++++++++++++--------- tests/Main.hs | 2 ++ tests/Tests/Stream.hs | 15 +++++++++---- 4 files changed, 44 insertions(+), 14 deletions(-) diff --git a/pontarius-xmpp.cabal b/pontarius-xmpp.cabal index dcd9d19..beecc81 100644 --- a/pontarius-xmpp.cabal +++ b/pontarius-xmpp.cabal @@ -156,6 +156,7 @@ Test-Suite tests , Tests.Arbitrary.Xmpp , Tests.Parsers , Tests.Picklers + , Tests.Stream ghc-options: -Wall -O2 -fno-warn-orphans Test-Suite doctest diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs index 780c32f..758798f 100644 --- a/source/Network/Xmpp/Stream.hs +++ b/source/Network/Xmpp/Stream.hs @@ -31,6 +31,7 @@ import Data.Ord import Data.Text (Text) import qualified Data.Text as Text import qualified Data.Text.Encoding as Text +import qualified Data.Text.Encoding.Error as Text import Data.Void (Void) import Data.XML.Pickle import Data.XML.Types @@ -240,22 +241,41 @@ restartStream = do -- Creates a conduit from a StreamHandle -sourceStreamHandle :: (MonadIO m, MonadError XmppFailure m) +sourceStreamHandleRaw :: (MonadIO m, MonadError XmppFailure m) => StreamHandle -> ConduitM i ByteString m () -sourceStreamHandle s = loopRead $ streamReceive s +sourceStreamHandleRaw s = forever . read $ streamReceive s where - loopRead rd = do + read rd = do bs' <- liftIO (rd 4096) bs <- case bs' of Left e -> throwError e Right r -> return r - if BS.null bs - then return () - else do - liftIO $ debugM "Pontarius.Xmpp" $ "in: " ++ - (Text.unpack . Text.decodeUtf8 $ bs) - yield bs - loopRead rd + yield bs + +sourceStreamHandle :: (MonadIO m, MonadError XmppFailure m) + => StreamHandle -> ConduitM i ByteString m () +sourceStreamHandle sh = sourceStreamHandleRaw sh $= logInput + +logInput :: MonadIO m => ConduitM ByteString ByteString m () +logInput = go Nothing + where + go mbDec = do + mbBs <- await + case mbBs of + Nothing -> return () + Just bs -> do + let decode = case mbDec of + Nothing -> Text.streamDecodeUtf8With Text.lenientDecode + Just d -> d + (Text.Some out leftover cont) = decode bs + cont' = if BS.null leftover + then Nothing + else Just cont + unless (Text.null out) $ + liftIO $ debugM "Pontarius.Xmpp" + $ "in: " ++ Text.unpack out + yield bs + go cont' -- We buffer sources because we don't want to lose data when multiple -- xml-entities are sent with the same packet and we don't want to eternally diff --git a/tests/Main.hs b/tests/Main.hs index 775ba6d..154abf5 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -4,8 +4,10 @@ import Test.Tasty import Tests.Parsers import Tests.Picklers +import Tests.Stream main :: IO () main = defaultMain $ testGroup "root" [ parserTests , picklerTests + , streamTests ] diff --git a/tests/Tests/Stream.hs b/tests/Tests/Stream.hs index fae939b..cc03f09 100644 --- a/tests/Tests/Stream.hs +++ b/tests/Tests/Stream.hs @@ -3,13 +3,15 @@ module Tests.Stream where +import Control.Monad.Trans import Data.Conduit import qualified Data.Conduit.List as CL import Data.XML.Types import Test.Hspec -import Test.Tasty.TH import Test.Tasty +import Test.Tasty.HUnit import Test.Tasty.Hspec +import Test.Tasty.TH import Network.Xmpp.Stream @@ -27,7 +29,7 @@ junk = [ EventBeginDocument beginElem = EventBeginElement "foo" [] -case_ThrowOutJunk = do +case_ThrowOutJunk = hspec . describe "throwOutJunk" $ do it "drops everything but EvenBeginElement" $ do res <- CL.sourceList junk $$ throwOutJunk >> await res `shouldBe` Nothing @@ -36,5 +38,10 @@ case_ThrowOutJunk = do $$ throwOutJunk >> CL.consume res `shouldBe` (beginElem : junk) -testStreams :: TestTree -testStreams = $testGroupGenerator +case_LogInput = hspec . describe "logInput" $ do + it "Can handle split UTF8 codepoints" $ do + res <- CL.sourceList ["\209","\136"] $= logInput $$ CL.consume + res `shouldBe` ["\209","\136"] + +streamTests :: TestTree +streamTests = $testGroupGenerator