From a7ac1e59e2dd851664084d1bee06c1e5cb4c1420 Mon Sep 17 00:00:00 2001
From: Philipp Balzarek
Date: Sun, 19 May 2013 12:18:46 +0200
Subject: [PATCH] change reader conduit to buffered source kill reader thread
when end of stream is reached
---
source/Network/Xmpp/Concurrent/Threads.hs | 17 +--
source/Network/Xmpp/Stream.hs | 73 ++++++++-----
source/Network/Xmpp/Types.hs | 2 +-
source/Network/Xmpp/Utilities.hs | 15 +++
tests/Tests.hs | 120 +++++++++-------------
5 files changed, 118 insertions(+), 109 deletions(-)
diff --git a/source/Network/Xmpp/Concurrent/Threads.hs b/source/Network/Xmpp/Concurrent/Threads.hs
index 7a73a09..81b3867 100644
--- a/source/Network/Xmpp/Concurrent/Threads.hs
+++ b/source/Network/Xmpp/Concurrent/Threads.hs
@@ -23,9 +23,10 @@ import System.Log.Logger
readWorker :: (Stanza -> IO ())
-> (XmppFailure -> IO ())
-> TMVar Stream
- -> IO a
-readWorker onStanza onConnectionClosed stateRef =
- Ex.mask_ . forever $ do
+ -> IO ()
+readWorker onStanza onConnectionClosed stateRef = Ex.mask_ go
+ where
+ go = do
res <- Ex.catches ( do
-- we don't know whether pull will
-- necessarily be interruptible
@@ -47,10 +48,12 @@ readWorker onStanza onConnectionClosed stateRef =
return Nothing
]
case res of
- Nothing -> return () -- Caught an exception, nothing to do. TODO: Can this happen?
- Just (Left _) -> return ()
- Just (Right sta) -> onStanza sta
- where
+ Nothing -> go -- Caught an exception, nothing to do. TODO: Can this happen?
+ Just (Left e) -> do
+ infoM "Pontarius.Xmpp.Reader" $
+ "Connection died: " ++ show e
+ onConnectionClosed e
+ Just (Right sta) -> onStanza sta >> go
-- Defining an Control.Exception.allowInterrupt equivalent for GHC 7
-- compatibility.
allowInterrupt :: IO ()
diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs
index 88a9e81..6bc9832 100644
--- a/source/Network/Xmpp/Stream.hs
+++ b/source/Network/Xmpp/Stream.hs
@@ -142,9 +142,7 @@ startStream = runErrorT $ do
)
response <- ErrorT $ runEventsSink $ runErrorT $ streamS expectedTo
case response of
- Left e -> throwError e
- -- Successful unpickling of stream element.
- Right (Right (ver, from, to, sid, lt, features))
+ Right (ver, from, to, sid, lt, features)
| (Text.unpack ver) /= "1.0" ->
closeStreamWithError StreamUnsupportedVersion Nothing
"Unknown version"
@@ -174,7 +172,7 @@ startStream = runErrorT $ do
} )
return ()
-- Unpickling failed - we investigate the element.
- Right (Left (Element name attrs _children))
+ Left (Element name attrs _children)
| (nameLocalName name /= "stream") ->
closeStreamWithError StreamInvalidXml Nothing
"Root element is not stream"
@@ -236,11 +234,11 @@ flattenAttrs attrs = Prelude.map (\(name, cont) ->
-- and calls xmppStartStream.
restartStream :: StateT StreamState IO (Either XmppFailure ())
restartStream = do
- lift $ debugM "Pontarius.XMPP" "Restarting stream..."
+ liftIO $ debugM "Pontarius.XMPP" "Restarting stream..."
raw <- gets (streamReceive . streamHandle)
- let newSource = DCI.ResumableSource (loopRead raw $= XP.parseBytes def)
- (return ())
- modify (\s -> s{streamEventSource = newSource })
+ let newSource =loopRead raw $= XP.parseBytes def
+ buffered <- liftIO . bufferSrc $ newSource
+ modify (\s -> s{streamEventSource = buffered })
startStream
where
loopRead rd = do
@@ -253,6 +251,29 @@ restartStream = do
yield bs
loopRead rd
+-- We buffer sources because we don't want to lose data when multiple
+-- xml-entities are sent with the same packet and we don't want to eternally
+-- block the StreamState while waiting for data to arrive
+bufferSrc :: MonadIO m => Source IO o -> IO (ConduitM i o m ())
+bufferSrc src = do
+ ref <- newTMVarIO $ DCI.ResumableSource src (return ())
+ let go = do
+ dt <- liftIO $ Ex.bracketOnError (atomically $ takeTMVar ref)
+ (\_ -> atomically . putTMVar ref $
+ DCI.ResumableSource zeroSource
+ (return ())
+ )
+ (\s -> do
+ (s', dt) <- s $$++ CL.head
+ atomically $ putTMVar ref s'
+ return dt
+ )
+ case dt of
+ Nothing -> return ()
+ Just d -> yield d >> go
+ return go
+
+
-- Reads the (partial) stream:stream and the server features from the stream.
-- Returns the (unvalidated) stream attributes, the unparsed element, or
-- throwError throws a `XmppOtherFailure' (if something other than an element
@@ -388,23 +409,21 @@ pushOpenElement e = do
-- `Connect-and-resumes' the given sink to the stream source, and pulls a
-- `b' value.
-runEventsSink :: Sink Event IO b -> StateT StreamState IO (Either XmppFailure b)
+runEventsSink :: Sink Event IO b -> StateT StreamState IO b
runEventsSink snk = do -- TODO: Wrap exceptions?
src <- gets streamEventSource
- (src', r) <- lift $ src $$++ snk
- modify (\s -> s{streamEventSource = src'})
- return $ Right r
+ r <- liftIO $ src $$ snk
+ return r
pullElement :: StateT StreamState IO (Either XmppFailure Element)
pullElement = do
ExL.catches (do
e <- runEventsSink (elements =$ await)
case e of
- Left f -> return $ Left f
- Right Nothing -> do
- lift $ errorM "Pontarius.XMPP" "pullElement: No element."
+ Nothing -> do
+ lift $ errorM "Pontarius.XMPP" "pullElement: Stream ended."
return . Left $ XmppOtherFailure
- Right (Just r) -> return $ Right r
+ Just r -> return $ Right r
)
[ ExL.Handler (\StreamEnd -> return $ Left StreamEndFailure)
, ExL.Handler (\(InvalidXmppXml s) -- Invalid XML `Event' encountered, or missing element close tag
@@ -463,7 +482,7 @@ xmppNoStream = StreamState {
, streamFlush = return ()
, streamClose = return ()
}
- , streamEventSource = DCI.ResumableSource zeroSource (return ())
+ , streamEventSource = zeroSource
, streamFeatures = StreamFeatures Nothing [] []
, streamAddress = Nothing
, streamFrom = Nothing
@@ -472,11 +491,11 @@ xmppNoStream = StreamState {
, streamJid = Nothing
, streamConfiguration = def
}
- where
- zeroSource :: Source IO output
- zeroSource = liftIO $ do
- errorM "Pontarius.Xmpp" "zeroSource utilized."
- ExL.throwIO XmppOtherFailure
+
+zeroSource :: Source IO output
+zeroSource = liftIO $ do
+ errorM "Pontarius.Xmpp" "zeroSource"
+ ExL.throwIO XmppOtherFailure
createStream :: HostName -> StreamConfiguration -> ErrorT XmppFailure IO (Stream)
createStream realm config = do
@@ -486,9 +505,9 @@ createStream realm config = do
debugM "Pontarius.Xmpp" "Acquired handle."
debugM "Pontarius.Xmpp" "Setting NoBuffering mode on handle."
hSetBuffering h NoBuffering
- let eSource = DCI.ResumableSource
- ((sourceHandle h $= logConduit) $= XP.parseBytes def)
- (return ())
+ eSource <- liftIO . bufferSrc $
+ (sourceHandle h $= logConduit) $= XP.parseBytes def
+
let hand = StreamHandle { streamSend = \d -> catchPush $ BS.hPut h d
, streamReceive = \n -> BS.hGetSome h n
, streamFlush = hFlush h
@@ -795,5 +814,5 @@ withStream' action (Stream stream) = do
return r
-mkStream :: StreamState -> IO (Stream)
-mkStream con = Stream `fmap` (atomically $ newTMVar con)
+mkStream :: StreamState -> IO Stream
+mkStream con = Stream `fmap` atomically (newTMVar con)
diff --git a/source/Network/Xmpp/Types.hs b/source/Network/Xmpp/Types.hs
index 6720061..4d9e097 100644
--- a/source/Network/Xmpp/Types.hs
+++ b/source/Network/Xmpp/Types.hs
@@ -794,7 +794,7 @@ data StreamState = StreamState
-- | Functions to send, receive, flush, and close on the stream
, streamHandle :: StreamHandle
-- | Event conduit source, and its associated finalizer
- , streamEventSource :: ResumableSource IO Event
+ , streamEventSource :: Source IO Event
-- | Stream features advertised by the server
, streamFeatures :: !StreamFeatures -- TODO: Maybe?
-- | The hostname or IP specified for the connection
diff --git a/source/Network/Xmpp/Utilities.hs b/source/Network/Xmpp/Utilities.hs
index 6d4cee3..87d9a91 100644
--- a/source/Network/Xmpp/Utilities.hs
+++ b/source/Network/Xmpp/Utilities.hs
@@ -8,10 +8,14 @@ module Network.Xmpp.Utilities
, renderOpenElement
, renderElement
, checkHostName
+ , withTMVar
)
where
import Control.Applicative ((<|>))
+import Control.Concurrent.STM
+import Control.Exception
+import Control.Monad.State.Strict
import qualified Data.Attoparsec.Text as AP
import qualified Data.ByteString as BS
import Data.Conduit as C
@@ -25,6 +29,17 @@ import System.IO.Unsafe(unsafePerformIO)
import qualified Text.XML.Stream.Render as TXSR
import Text.XML.Unresolved as TXU
+-- | Apply f with the content of tv as state, restoring the original value when an
+-- exception occurs
+withTMVar :: TMVar a -> (a -> IO (c, a)) -> IO c
+withTMVar tv f = bracketOnError (atomically $ takeTMVar tv)
+ (atomically . putTMVar tv)
+ (\s -> do
+ (x, s') <- f s
+ atomically $ putTMVar tv s'
+ return x
+ )
+
openElementToEvents :: Element -> [Event]
openElementToEvents (Element name as ns) = EventBeginElement name as : goN ns []
where
diff --git a/tests/Tests.hs b/tests/Tests.hs
index 48f49e8..f9f5867 100644
--- a/tests/Tests.hs
+++ b/tests/Tests.hs
@@ -1,4 +1,4 @@
-{-# LANGUAGE PackageImports, OverloadedStrings, NoMonomorphismRestriction #-}
+{-# LANGUAGE OverloadedStrings, NoMonomorphismRestriction #-}
module Example where
import Control.Concurrent
@@ -17,24 +17,27 @@ import Data.XML.Types
import Network
import Network.Xmpp
-import Network.Xmpp.Concurrent.Channels
import Network.Xmpp.IM.Presence
-import Network.Xmpp.Pickle
+import Network.Xmpp.Internal
+import Network.Xmpp.Marshal
import Network.Xmpp.Types
-import qualified Network.Xmpp.Xep.InbandRegistration as IBR
+-- import qualified Network.Xmpp.Xep.InbandRegistration as IBR
+import Data.Default (def)
import qualified Network.Xmpp.Xep.ServiceDiscovery as Disco
-
import System.Environment
-import Text.XML.Stream.Elements
+import System.Log.Logger
testUser1 :: Jid
-testUser1 = read "testuser1@species64739.dyndns.org/bot1"
+testUser1 = "echo1@species64739.dyndns.org/bot"
testUser2 :: Jid
-testUser2 = read "testuser2@species64739.dyndns.org/bot2"
+testUser2 = "echo2@species64739.dyndns.org/bot"
supervisor :: Jid
-supervisor = read "uart14@species64739.dyndns.org"
+supervisor = "uart14@species64739.dyndns.org"
+
+config = def{sessionStreamConfiguration
+ = def{connectionDetails = UseHost "localhost" (PortNumber 5222)}}
testNS :: Text
testNS = "xmpp:library:test"
@@ -60,7 +63,7 @@ payloadP = xpWrap (\((counter,flag) , message) -> Payload counter flag message)
invertPayload (Payload count flag message) = Payload (count + 1) (not flag) (Text.reverse message)
iqResponder context = do
- chan' <- listenIQChan Get testNS context
+ chan' <- listenIQChan Set testNS context
chan <- case chan' of
Left _ -> liftIO $ putStrLn "Channel was already taken"
>> error "hanging up"
@@ -72,14 +75,12 @@ iqResponder context = do
let answerPayload = invertPayload payload
let answerBody = pickleElem payloadP answerPayload
unless (payloadCounter payload == 3) . void $
- answerIQ next (Right $ Just answerBody) context
- when (payloadCounter payload == 10) $ do
- threadDelay 1000000
- endContext (session context)
+ answerIQ next (Right $ Just answerBody)
+
autoAccept :: Xmpp ()
autoAccept context = forever $ do
- st <- waitForPresence isPresenceSubscribe context
+ st <- waitForPresence (\p -> presenceType p == Just Subscribe) context
sendPresence (presenceSubscribed (fromJust $ presenceFrom st)) context
simpleMessage :: Jid -> Text -> Message
@@ -111,23 +112,23 @@ expect debug x y context | x == y = debug "Ok."
wait3 :: MonadIO m => m ()
wait3 = liftIO $ threadDelay 1000000
-discoTest debug context = do
- q <- Disco.queryInfo "species64739.dyndns.org" Nothing context
- case q of
- Left (Disco.DiscoXMLError el e) -> do
- debug (ppElement el)
- debug (Text.unpack $ ppUnpickleError e)
- debug (show $ length $ elementNodes el)
- x -> debug $ show x
-
- q <- Disco.queryItems "species64739.dyndns.org"
- (Just "http://jabber.org/protocol/commands") context
- case q of
- Left (Disco.DiscoXMLError el e) -> do
- debug (ppElement el)
- debug (Text.unpack $ ppUnpickleError e)
- debug (show $ length $ elementNodes el)
- x -> debug $ show x
+-- discoTest debug context = do
+-- q <- Disco.queryInfo "species64739.dyndns.org" Nothing context
+-- case q of
+-- Left (Disco.DiscoXMLError el e) -> do
+-- debug (ppElement el)
+-- debug (Text.unpack $ ppUnpickleError e)
+-- debug (show $ length $ elementNodes el)
+-- x -> debug $ show x
+
+-- q <- Disco.queryItems "species64739.dyndns.org"
+-- (Just "http://jabber.org/protocol/commands") context
+-- case q of
+-- Left (Disco.DiscoXMLError el e) -> do
+-- debug (ppElement el)
+-- debug (Text.unpack $ ppUnpickleError e)
+-- debug (show $ length $ elementNodes el)
+-- x -> debug $ show x
iqTest debug we them context = do
forM [1..10] $ \count -> do
@@ -135,7 +136,7 @@ iqTest debug we them context = do
let payload = Payload count (even count) (Text.pack $ show count)
let body = pickleElem payloadP payload
debug "sending"
- answer <- sendIQ' (Just them) Get Nothing body context
+ answer <- sendIQ' (Just them) Set Nothing body context
case answer of
IQResponseResult r -> do
debug "received"
@@ -147,16 +148,12 @@ iqTest debug we them context = do
IQResponseError e -> do
debug $ "Error in packet: " ++ show count
liftIO $ threadDelay 100000
- sendUser "All tests done" context
+-- sendUser "All tests done" context
debug "ending session"
-fork action context = do
- context' <- forkSession context
- forkIO $ action context'
-
-ibrTest debug uname pw = IBR.registerWith [ (IBR.Username, "testuser2")
- , (IBR.Password, "pwd")
- ] >>= debug . show
+-- ibrTest debug uname pw = IBR.registerWith [ (IBR.Username, "testuser2")
+-- , (IBR.Password, "pwd")
+-- ] >>= debug . show
runMain :: (String -> STM ()) -> Int -> Bool -> IO ()
@@ -166,50 +163,23 @@ runMain debug number multi = do
0 -> (testUser2, testUser1,False)
let debug' = liftIO . atomically .
debug . (("Thread " ++ show number ++ ":") ++)
- context <- newSession
-
- setConnectionClosedHandler (\e s -> do
- debug' $ "connection lost because " ++ show e
- endContext s) (session context)
debug' "running"
- flip withConnection (session context) $ Ex.catch (do
- debug' "connect"
- connect "localhost" (PortNumber 5222) "species64739.dyndns.org"
--- debug' "tls start"
- startTLS exampleParams
- debug' "ibr start"
- -- ibrTest debug' (localpart we) "pwd"
- -- debug' "ibr end"
- saslResponse <- simpleAuth
- (fromJust $ localpart we) "pwd" (resourcepart we)
- case saslResponse of
- Right _ -> return ()
- Left e -> error $ show e
- debug' "session standing"
- features <- other `liftM` gets sFeatures
- liftIO . void $ forM features $ \f -> debug' $ ppElement f
- )
- (\e -> debug' $ show (e ::Ex.SomeException))
+ Right context <- session (Text.unpack $ domainpart we)
+ (Just ([scramSha1 (fromJust $ localpart we) Nothing "pwd"], resourcepart we))
+ config
sendPresence presenceOnline context
- thread1 <- fork autoAccept context
+ thread1 <- forkIO $ autoAccept =<< dupSession context
sendPresence (presenceSubscribe them) context
- thread2 <- fork iqResponder context
+ thread2 <- forkIO $ iqResponder =<< dupSession context
when active $ do
liftIO $ threadDelay 1000000 -- Wait for the other thread to go online
-- discoTest debug'
when multi $ iqTest debug' we them context
- closeConnection (session context)
killThread thread1
killThread thread2
return ()
liftIO . threadDelay $ 10^6
-- unless multi . void . withConnection $ IBR.unregister
- unless multi . void $ fork (\s -> forever $ do
- pullMessage s >>= debug' . show
- putStrLn ""
- putStrLn ""
- )
- context
liftIO . forever $ threadDelay 1000000
return ()
@@ -221,4 +191,6 @@ run i multi = do
runMain debugOut (2 + i) multi
-main = run 0 True
+main = do
+ updateGlobalLogger "Pontarius.Xmpp" $ setLevel DEBUG
+ run 0 True