Browse Source

change reader conduit to buffered source

kill reader thread when end of stream is reached
master
Philipp Balzarek 13 years ago
parent
commit
a7ac1e59e2
  1. 17
      source/Network/Xmpp/Concurrent/Threads.hs
  2. 73
      source/Network/Xmpp/Stream.hs
  3. 2
      source/Network/Xmpp/Types.hs
  4. 15
      source/Network/Xmpp/Utilities.hs
  5. 120
      tests/Tests.hs

17
source/Network/Xmpp/Concurrent/Threads.hs

@ -23,9 +23,10 @@ import System.Log.Logger
readWorker :: (Stanza -> IO ()) readWorker :: (Stanza -> IO ())
-> (XmppFailure -> IO ()) -> (XmppFailure -> IO ())
-> TMVar Stream -> TMVar Stream
-> IO a -> IO ()
readWorker onStanza onConnectionClosed stateRef = readWorker onStanza onConnectionClosed stateRef = Ex.mask_ go
Ex.mask_ . forever $ do where
go = do
res <- Ex.catches ( do res <- Ex.catches ( do
-- we don't know whether pull will -- we don't know whether pull will
-- necessarily be interruptible -- necessarily be interruptible
@ -47,10 +48,12 @@ readWorker onStanza onConnectionClosed stateRef =
return Nothing return Nothing
] ]
case res of case res of
Nothing -> return () -- Caught an exception, nothing to do. TODO: Can this happen? Nothing -> go -- Caught an exception, nothing to do. TODO: Can this happen?
Just (Left _) -> return () Just (Left e) -> do
Just (Right sta) -> onStanza sta infoM "Pontarius.Xmpp.Reader" $
where "Connection died: " ++ show e
onConnectionClosed e
Just (Right sta) -> onStanza sta >> go
-- Defining an Control.Exception.allowInterrupt equivalent for GHC 7 -- Defining an Control.Exception.allowInterrupt equivalent for GHC 7
-- compatibility. -- compatibility.
allowInterrupt :: IO () allowInterrupt :: IO ()

73
source/Network/Xmpp/Stream.hs

@ -142,9 +142,7 @@ startStream = runErrorT $ do
) )
response <- ErrorT $ runEventsSink $ runErrorT $ streamS expectedTo response <- ErrorT $ runEventsSink $ runErrorT $ streamS expectedTo
case response of case response of
Left e -> throwError e Right (ver, from, to, sid, lt, features)
-- Successful unpickling of stream element.
Right (Right (ver, from, to, sid, lt, features))
| (Text.unpack ver) /= "1.0" -> | (Text.unpack ver) /= "1.0" ->
closeStreamWithError StreamUnsupportedVersion Nothing closeStreamWithError StreamUnsupportedVersion Nothing
"Unknown version" "Unknown version"
@ -174,7 +172,7 @@ startStream = runErrorT $ do
} ) } )
return () return ()
-- Unpickling failed - we investigate the element. -- Unpickling failed - we investigate the element.
Right (Left (Element name attrs _children)) Left (Element name attrs _children)
| (nameLocalName name /= "stream") -> | (nameLocalName name /= "stream") ->
closeStreamWithError StreamInvalidXml Nothing closeStreamWithError StreamInvalidXml Nothing
"Root element is not stream" "Root element is not stream"
@ -236,11 +234,11 @@ flattenAttrs attrs = Prelude.map (\(name, cont) ->
-- and calls xmppStartStream. -- and calls xmppStartStream.
restartStream :: StateT StreamState IO (Either XmppFailure ()) restartStream :: StateT StreamState IO (Either XmppFailure ())
restartStream = do restartStream = do
lift $ debugM "Pontarius.XMPP" "Restarting stream..." liftIO $ debugM "Pontarius.XMPP" "Restarting stream..."
raw <- gets (streamReceive . streamHandle) raw <- gets (streamReceive . streamHandle)
let newSource = DCI.ResumableSource (loopRead raw $= XP.parseBytes def) let newSource =loopRead raw $= XP.parseBytes def
(return ()) buffered <- liftIO . bufferSrc $ newSource
modify (\s -> s{streamEventSource = newSource }) modify (\s -> s{streamEventSource = buffered })
startStream startStream
where where
loopRead rd = do loopRead rd = do
@ -253,6 +251,29 @@ restartStream = do
yield bs yield bs
loopRead rd loopRead rd
-- We buffer sources because we don't want to lose data when multiple
-- xml-entities are sent with the same packet and we don't want to eternally
-- block the StreamState while waiting for data to arrive
bufferSrc :: MonadIO m => Source IO o -> IO (ConduitM i o m ())
bufferSrc src = do
ref <- newTMVarIO $ DCI.ResumableSource src (return ())
let go = do
dt <- liftIO $ Ex.bracketOnError (atomically $ takeTMVar ref)
(\_ -> atomically . putTMVar ref $
DCI.ResumableSource zeroSource
(return ())
)
(\s -> do
(s', dt) <- s $$++ CL.head
atomically $ putTMVar ref s'
return dt
)
case dt of
Nothing -> return ()
Just d -> yield d >> go
return go
-- Reads the (partial) stream:stream and the server features from the stream. -- Reads the (partial) stream:stream and the server features from the stream.
-- Returns the (unvalidated) stream attributes, the unparsed element, or -- Returns the (unvalidated) stream attributes, the unparsed element, or
-- throwError throws a `XmppOtherFailure' (if something other than an element -- throwError throws a `XmppOtherFailure' (if something other than an element
@ -388,23 +409,21 @@ pushOpenElement e = do
-- `Connect-and-resumes' the given sink to the stream source, and pulls a -- `Connect-and-resumes' the given sink to the stream source, and pulls a
-- `b' value. -- `b' value.
runEventsSink :: Sink Event IO b -> StateT StreamState IO (Either XmppFailure b) runEventsSink :: Sink Event IO b -> StateT StreamState IO b
runEventsSink snk = do -- TODO: Wrap exceptions? runEventsSink snk = do -- TODO: Wrap exceptions?
src <- gets streamEventSource src <- gets streamEventSource
(src', r) <- lift $ src $$++ snk r <- liftIO $ src $$ snk
modify (\s -> s{streamEventSource = src'}) return r
return $ Right r
pullElement :: StateT StreamState IO (Either XmppFailure Element) pullElement :: StateT StreamState IO (Either XmppFailure Element)
pullElement = do pullElement = do
ExL.catches (do ExL.catches (do
e <- runEventsSink (elements =$ await) e <- runEventsSink (elements =$ await)
case e of case e of
Left f -> return $ Left f Nothing -> do
Right Nothing -> do lift $ errorM "Pontarius.XMPP" "pullElement: Stream ended."
lift $ errorM "Pontarius.XMPP" "pullElement: No element."
return . Left $ XmppOtherFailure return . Left $ XmppOtherFailure
Right (Just r) -> return $ Right r Just r -> return $ Right r
) )
[ ExL.Handler (\StreamEnd -> return $ Left StreamEndFailure) [ ExL.Handler (\StreamEnd -> return $ Left StreamEndFailure)
, ExL.Handler (\(InvalidXmppXml s) -- Invalid XML `Event' encountered, or missing element close tag , ExL.Handler (\(InvalidXmppXml s) -- Invalid XML `Event' encountered, or missing element close tag
@ -463,7 +482,7 @@ xmppNoStream = StreamState {
, streamFlush = return () , streamFlush = return ()
, streamClose = return () , streamClose = return ()
} }
, streamEventSource = DCI.ResumableSource zeroSource (return ()) , streamEventSource = zeroSource
, streamFeatures = StreamFeatures Nothing [] [] , streamFeatures = StreamFeatures Nothing [] []
, streamAddress = Nothing , streamAddress = Nothing
, streamFrom = Nothing , streamFrom = Nothing
@ -472,11 +491,11 @@ xmppNoStream = StreamState {
, streamJid = Nothing , streamJid = Nothing
, streamConfiguration = def , streamConfiguration = def
} }
where
zeroSource :: Source IO output zeroSource :: Source IO output
zeroSource = liftIO $ do zeroSource = liftIO $ do
errorM "Pontarius.Xmpp" "zeroSource utilized." errorM "Pontarius.Xmpp" "zeroSource"
ExL.throwIO XmppOtherFailure ExL.throwIO XmppOtherFailure
createStream :: HostName -> StreamConfiguration -> ErrorT XmppFailure IO (Stream) createStream :: HostName -> StreamConfiguration -> ErrorT XmppFailure IO (Stream)
createStream realm config = do createStream realm config = do
@ -486,9 +505,9 @@ createStream realm config = do
debugM "Pontarius.Xmpp" "Acquired handle." debugM "Pontarius.Xmpp" "Acquired handle."
debugM "Pontarius.Xmpp" "Setting NoBuffering mode on handle." debugM "Pontarius.Xmpp" "Setting NoBuffering mode on handle."
hSetBuffering h NoBuffering hSetBuffering h NoBuffering
let eSource = DCI.ResumableSource eSource <- liftIO . bufferSrc $
((sourceHandle h $= logConduit) $= XP.parseBytes def) (sourceHandle h $= logConduit) $= XP.parseBytes def
(return ())
let hand = StreamHandle { streamSend = \d -> catchPush $ BS.hPut h d let hand = StreamHandle { streamSend = \d -> catchPush $ BS.hPut h d
, streamReceive = \n -> BS.hGetSome h n , streamReceive = \n -> BS.hGetSome h n
, streamFlush = hFlush h , streamFlush = hFlush h
@ -795,5 +814,5 @@ withStream' action (Stream stream) = do
return r return r
mkStream :: StreamState -> IO (Stream) mkStream :: StreamState -> IO Stream
mkStream con = Stream `fmap` (atomically $ newTMVar con) mkStream con = Stream `fmap` atomically (newTMVar con)

2
source/Network/Xmpp/Types.hs

@ -794,7 +794,7 @@ data StreamState = StreamState
-- | Functions to send, receive, flush, and close on the stream -- | Functions to send, receive, flush, and close on the stream
, streamHandle :: StreamHandle , streamHandle :: StreamHandle
-- | Event conduit source, and its associated finalizer -- | Event conduit source, and its associated finalizer
, streamEventSource :: ResumableSource IO Event , streamEventSource :: Source IO Event
-- | Stream features advertised by the server -- | Stream features advertised by the server
, streamFeatures :: !StreamFeatures -- TODO: Maybe? , streamFeatures :: !StreamFeatures -- TODO: Maybe?
-- | The hostname or IP specified for the connection -- | The hostname or IP specified for the connection

15
source/Network/Xmpp/Utilities.hs

@ -8,10 +8,14 @@ module Network.Xmpp.Utilities
, renderOpenElement , renderOpenElement
, renderElement , renderElement
, checkHostName , checkHostName
, withTMVar
) )
where where
import Control.Applicative ((<|>)) import Control.Applicative ((<|>))
import Control.Concurrent.STM
import Control.Exception
import Control.Monad.State.Strict
import qualified Data.Attoparsec.Text as AP import qualified Data.Attoparsec.Text as AP
import qualified Data.ByteString as BS import qualified Data.ByteString as BS
import Data.Conduit as C import Data.Conduit as C
@ -25,6 +29,17 @@ import System.IO.Unsafe(unsafePerformIO)
import qualified Text.XML.Stream.Render as TXSR import qualified Text.XML.Stream.Render as TXSR
import Text.XML.Unresolved as TXU import Text.XML.Unresolved as TXU
-- | Apply f with the content of tv as state, restoring the original value when an
-- exception occurs
withTMVar :: TMVar a -> (a -> IO (c, a)) -> IO c
withTMVar tv f = bracketOnError (atomically $ takeTMVar tv)
(atomically . putTMVar tv)
(\s -> do
(x, s') <- f s
atomically $ putTMVar tv s'
return x
)
openElementToEvents :: Element -> [Event] openElementToEvents :: Element -> [Event]
openElementToEvents (Element name as ns) = EventBeginElement name as : goN ns [] openElementToEvents (Element name as ns) = EventBeginElement name as : goN ns []
where where

120
tests/Tests.hs

@ -1,4 +1,4 @@
{-# LANGUAGE PackageImports, OverloadedStrings, NoMonomorphismRestriction #-} {-# LANGUAGE OverloadedStrings, NoMonomorphismRestriction #-}
module Example where module Example where
import Control.Concurrent import Control.Concurrent
@ -17,24 +17,27 @@ import Data.XML.Types
import Network import Network
import Network.Xmpp import Network.Xmpp
import Network.Xmpp.Concurrent.Channels
import Network.Xmpp.IM.Presence import Network.Xmpp.IM.Presence
import Network.Xmpp.Pickle import Network.Xmpp.Internal
import Network.Xmpp.Marshal
import Network.Xmpp.Types import Network.Xmpp.Types
import qualified Network.Xmpp.Xep.InbandRegistration as IBR -- import qualified Network.Xmpp.Xep.InbandRegistration as IBR
import Data.Default (def)
import qualified Network.Xmpp.Xep.ServiceDiscovery as Disco import qualified Network.Xmpp.Xep.ServiceDiscovery as Disco
import System.Environment import System.Environment
import Text.XML.Stream.Elements import System.Log.Logger
testUser1 :: Jid testUser1 :: Jid
testUser1 = read "testuser1@species64739.dyndns.org/bot1" testUser1 = "echo1@species64739.dyndns.org/bot"
testUser2 :: Jid testUser2 :: Jid
testUser2 = read "testuser2@species64739.dyndns.org/bot2" testUser2 = "echo2@species64739.dyndns.org/bot"
supervisor :: Jid supervisor :: Jid
supervisor = read "uart14@species64739.dyndns.org" supervisor = "uart14@species64739.dyndns.org"
config = def{sessionStreamConfiguration
= def{connectionDetails = UseHost "localhost" (PortNumber 5222)}}
testNS :: Text testNS :: Text
testNS = "xmpp:library:test" testNS = "xmpp:library:test"
@ -60,7 +63,7 @@ payloadP = xpWrap (\((counter,flag) , message) -> Payload counter flag message)
invertPayload (Payload count flag message) = Payload (count + 1) (not flag) (Text.reverse message) invertPayload (Payload count flag message) = Payload (count + 1) (not flag) (Text.reverse message)
iqResponder context = do iqResponder context = do
chan' <- listenIQChan Get testNS context chan' <- listenIQChan Set testNS context
chan <- case chan' of chan <- case chan' of
Left _ -> liftIO $ putStrLn "Channel was already taken" Left _ -> liftIO $ putStrLn "Channel was already taken"
>> error "hanging up" >> error "hanging up"
@ -72,14 +75,12 @@ iqResponder context = do
let answerPayload = invertPayload payload let answerPayload = invertPayload payload
let answerBody = pickleElem payloadP answerPayload let answerBody = pickleElem payloadP answerPayload
unless (payloadCounter payload == 3) . void $ unless (payloadCounter payload == 3) . void $
answerIQ next (Right $ Just answerBody) context answerIQ next (Right $ Just answerBody)
when (payloadCounter payload == 10) $ do
threadDelay 1000000
endContext (session context)
autoAccept :: Xmpp () autoAccept :: Xmpp ()
autoAccept context = forever $ do autoAccept context = forever $ do
st <- waitForPresence isPresenceSubscribe context st <- waitForPresence (\p -> presenceType p == Just Subscribe) context
sendPresence (presenceSubscribed (fromJust $ presenceFrom st)) context sendPresence (presenceSubscribed (fromJust $ presenceFrom st)) context
simpleMessage :: Jid -> Text -> Message simpleMessage :: Jid -> Text -> Message
@ -111,23 +112,23 @@ expect debug x y context | x == y = debug "Ok."
wait3 :: MonadIO m => m () wait3 :: MonadIO m => m ()
wait3 = liftIO $ threadDelay 1000000 wait3 = liftIO $ threadDelay 1000000
discoTest debug context = do -- discoTest debug context = do
q <- Disco.queryInfo "species64739.dyndns.org" Nothing context -- q <- Disco.queryInfo "species64739.dyndns.org" Nothing context
case q of -- case q of
Left (Disco.DiscoXMLError el e) -> do -- Left (Disco.DiscoXMLError el e) -> do
debug (ppElement el) -- debug (ppElement el)
debug (Text.unpack $ ppUnpickleError e) -- debug (Text.unpack $ ppUnpickleError e)
debug (show $ length $ elementNodes el) -- debug (show $ length $ elementNodes el)
x -> debug $ show x -- x -> debug $ show x
q <- Disco.queryItems "species64739.dyndns.org" -- q <- Disco.queryItems "species64739.dyndns.org"
(Just "http://jabber.org/protocol/commands") context -- (Just "http://jabber.org/protocol/commands") context
case q of -- case q of
Left (Disco.DiscoXMLError el e) -> do -- Left (Disco.DiscoXMLError el e) -> do
debug (ppElement el) -- debug (ppElement el)
debug (Text.unpack $ ppUnpickleError e) -- debug (Text.unpack $ ppUnpickleError e)
debug (show $ length $ elementNodes el) -- debug (show $ length $ elementNodes el)
x -> debug $ show x -- x -> debug $ show x
iqTest debug we them context = do iqTest debug we them context = do
forM [1..10] $ \count -> do forM [1..10] $ \count -> do
@ -135,7 +136,7 @@ iqTest debug we them context = do
let payload = Payload count (even count) (Text.pack $ show count) let payload = Payload count (even count) (Text.pack $ show count)
let body = pickleElem payloadP payload let body = pickleElem payloadP payload
debug "sending" debug "sending"
answer <- sendIQ' (Just them) Get Nothing body context answer <- sendIQ' (Just them) Set Nothing body context
case answer of case answer of
IQResponseResult r -> do IQResponseResult r -> do
debug "received" debug "received"
@ -147,16 +148,12 @@ iqTest debug we them context = do
IQResponseError e -> do IQResponseError e -> do
debug $ "Error in packet: " ++ show count debug $ "Error in packet: " ++ show count
liftIO $ threadDelay 100000 liftIO $ threadDelay 100000
sendUser "All tests done" context -- sendUser "All tests done" context
debug "ending session" debug "ending session"
fork action context = do -- ibrTest debug uname pw = IBR.registerWith [ (IBR.Username, "testuser2")
context' <- forkSession context -- , (IBR.Password, "pwd")
forkIO $ action context' -- ] >>= debug . show
ibrTest debug uname pw = IBR.registerWith [ (IBR.Username, "testuser2")
, (IBR.Password, "pwd")
] >>= debug . show
runMain :: (String -> STM ()) -> Int -> Bool -> IO () runMain :: (String -> STM ()) -> Int -> Bool -> IO ()
@ -166,50 +163,23 @@ runMain debug number multi = do
0 -> (testUser2, testUser1,False) 0 -> (testUser2, testUser1,False)
let debug' = liftIO . atomically . let debug' = liftIO . atomically .
debug . (("Thread " ++ show number ++ ":") ++) debug . (("Thread " ++ show number ++ ":") ++)
context <- newSession
setConnectionClosedHandler (\e s -> do
debug' $ "connection lost because " ++ show e
endContext s) (session context)
debug' "running" debug' "running"
flip withConnection (session context) $ Ex.catch (do Right context <- session (Text.unpack $ domainpart we)
debug' "connect" (Just ([scramSha1 (fromJust $ localpart we) Nothing "pwd"], resourcepart we))
connect "localhost" (PortNumber 5222) "species64739.dyndns.org" config
-- debug' "tls start"
startTLS exampleParams
debug' "ibr start"
-- ibrTest debug' (localpart we) "pwd"
-- debug' "ibr end"
saslResponse <- simpleAuth
(fromJust $ localpart we) "pwd" (resourcepart we)
case saslResponse of
Right _ -> return ()
Left e -> error $ show e
debug' "session standing"
features <- other `liftM` gets sFeatures
liftIO . void $ forM features $ \f -> debug' $ ppElement f
)
(\e -> debug' $ show (e ::Ex.SomeException))
sendPresence presenceOnline context sendPresence presenceOnline context
thread1 <- fork autoAccept context thread1 <- forkIO $ autoAccept =<< dupSession context
sendPresence (presenceSubscribe them) context sendPresence (presenceSubscribe them) context
thread2 <- fork iqResponder context thread2 <- forkIO $ iqResponder =<< dupSession context
when active $ do when active $ do
liftIO $ threadDelay 1000000 -- Wait for the other thread to go online liftIO $ threadDelay 1000000 -- Wait for the other thread to go online
-- discoTest debug' -- discoTest debug'
when multi $ iqTest debug' we them context when multi $ iqTest debug' we them context
closeConnection (session context)
killThread thread1 killThread thread1
killThread thread2 killThread thread2
return () return ()
liftIO . threadDelay $ 10^6 liftIO . threadDelay $ 10^6
-- unless multi . void . withConnection $ IBR.unregister -- unless multi . void . withConnection $ IBR.unregister
unless multi . void $ fork (\s -> forever $ do
pullMessage s >>= debug' . show
putStrLn ""
putStrLn ""
)
context
liftIO . forever $ threadDelay 1000000 liftIO . forever $ threadDelay 1000000
return () return ()
@ -221,4 +191,6 @@ run i multi = do
runMain debugOut (2 + i) multi runMain debugOut (2 + i) multi
main = run 0 True main = do
updateGlobalLogger "Pontarius.Xmpp" $ setLevel DEBUG
run 0 True

Loading…
Cancel
Save