Browse Source

change reader conduit to buffered source

kill reader thread when end of stream is reached
master
Philipp Balzarek 13 years ago
parent
commit
a7ac1e59e2
  1. 17
      source/Network/Xmpp/Concurrent/Threads.hs
  2. 67
      source/Network/Xmpp/Stream.hs
  3. 2
      source/Network/Xmpp/Types.hs
  4. 15
      source/Network/Xmpp/Utilities.hs
  5. 120
      tests/Tests.hs

17
source/Network/Xmpp/Concurrent/Threads.hs

@ -23,9 +23,10 @@ import System.Log.Logger @@ -23,9 +23,10 @@ import System.Log.Logger
readWorker :: (Stanza -> IO ())
-> (XmppFailure -> IO ())
-> TMVar Stream
-> IO a
readWorker onStanza onConnectionClosed stateRef =
Ex.mask_ . forever $ do
-> IO ()
readWorker onStanza onConnectionClosed stateRef = Ex.mask_ go
where
go = do
res <- Ex.catches ( do
-- we don't know whether pull will
-- necessarily be interruptible
@ -47,10 +48,12 @@ readWorker onStanza onConnectionClosed stateRef = @@ -47,10 +48,12 @@ readWorker onStanza onConnectionClosed stateRef =
return Nothing
]
case res of
Nothing -> return () -- Caught an exception, nothing to do. TODO: Can this happen?
Just (Left _) -> return ()
Just (Right sta) -> onStanza sta
where
Nothing -> go -- Caught an exception, nothing to do. TODO: Can this happen?
Just (Left e) -> do
infoM "Pontarius.Xmpp.Reader" $
"Connection died: " ++ show e
onConnectionClosed e
Just (Right sta) -> onStanza sta >> go
-- Defining an Control.Exception.allowInterrupt equivalent for GHC 7
-- compatibility.
allowInterrupt :: IO ()

67
source/Network/Xmpp/Stream.hs

@ -142,9 +142,7 @@ startStream = runErrorT $ do @@ -142,9 +142,7 @@ startStream = runErrorT $ do
)
response <- ErrorT $ runEventsSink $ runErrorT $ streamS expectedTo
case response of
Left e -> throwError e
-- Successful unpickling of stream element.
Right (Right (ver, from, to, sid, lt, features))
Right (ver, from, to, sid, lt, features)
| (Text.unpack ver) /= "1.0" ->
closeStreamWithError StreamUnsupportedVersion Nothing
"Unknown version"
@ -174,7 +172,7 @@ startStream = runErrorT $ do @@ -174,7 +172,7 @@ startStream = runErrorT $ do
} )
return ()
-- Unpickling failed - we investigate the element.
Right (Left (Element name attrs _children))
Left (Element name attrs _children)
| (nameLocalName name /= "stream") ->
closeStreamWithError StreamInvalidXml Nothing
"Root element is not stream"
@ -236,11 +234,11 @@ flattenAttrs attrs = Prelude.map (\(name, cont) -> @@ -236,11 +234,11 @@ flattenAttrs attrs = Prelude.map (\(name, cont) ->
-- and calls xmppStartStream.
restartStream :: StateT StreamState IO (Either XmppFailure ())
restartStream = do
lift $ debugM "Pontarius.XMPP" "Restarting stream..."
liftIO $ debugM "Pontarius.XMPP" "Restarting stream..."
raw <- gets (streamReceive . streamHandle)
let newSource = DCI.ResumableSource (loopRead raw $= XP.parseBytes def)
(return ())
modify (\s -> s{streamEventSource = newSource })
let newSource =loopRead raw $= XP.parseBytes def
buffered <- liftIO . bufferSrc $ newSource
modify (\s -> s{streamEventSource = buffered })
startStream
where
loopRead rd = do
@ -253,6 +251,29 @@ restartStream = do @@ -253,6 +251,29 @@ restartStream = do
yield bs
loopRead rd
-- We buffer sources because we don't want to lose data when multiple
-- xml-entities are sent with the same packet and we don't want to eternally
-- block the StreamState while waiting for data to arrive
bufferSrc :: MonadIO m => Source IO o -> IO (ConduitM i o m ())
bufferSrc src = do
ref <- newTMVarIO $ DCI.ResumableSource src (return ())
let go = do
dt <- liftIO $ Ex.bracketOnError (atomically $ takeTMVar ref)
(\_ -> atomically . putTMVar ref $
DCI.ResumableSource zeroSource
(return ())
)
(\s -> do
(s', dt) <- s $$++ CL.head
atomically $ putTMVar ref s'
return dt
)
case dt of
Nothing -> return ()
Just d -> yield d >> go
return go
-- Reads the (partial) stream:stream and the server features from the stream.
-- Returns the (unvalidated) stream attributes, the unparsed element, or
-- throwError throws a `XmppOtherFailure' (if something other than an element
@ -388,23 +409,21 @@ pushOpenElement e = do @@ -388,23 +409,21 @@ pushOpenElement e = do
-- `Connect-and-resumes' the given sink to the stream source, and pulls a
-- `b' value.
runEventsSink :: Sink Event IO b -> StateT StreamState IO (Either XmppFailure b)
runEventsSink :: Sink Event IO b -> StateT StreamState IO b
runEventsSink snk = do -- TODO: Wrap exceptions?
src <- gets streamEventSource
(src', r) <- lift $ src $$++ snk
modify (\s -> s{streamEventSource = src'})
return $ Right r
r <- liftIO $ src $$ snk
return r
pullElement :: StateT StreamState IO (Either XmppFailure Element)
pullElement = do
ExL.catches (do
e <- runEventsSink (elements =$ await)
case e of
Left f -> return $ Left f
Right Nothing -> do
lift $ errorM "Pontarius.XMPP" "pullElement: No element."
Nothing -> do
lift $ errorM "Pontarius.XMPP" "pullElement: Stream ended."
return . Left $ XmppOtherFailure
Right (Just r) -> return $ Right r
Just r -> return $ Right r
)
[ ExL.Handler (\StreamEnd -> return $ Left StreamEndFailure)
, ExL.Handler (\(InvalidXmppXml s) -- Invalid XML `Event' encountered, or missing element close tag
@ -463,7 +482,7 @@ xmppNoStream = StreamState { @@ -463,7 +482,7 @@ xmppNoStream = StreamState {
, streamFlush = return ()
, streamClose = return ()
}
, streamEventSource = DCI.ResumableSource zeroSource (return ())
, streamEventSource = zeroSource
, streamFeatures = StreamFeatures Nothing [] []
, streamAddress = Nothing
, streamFrom = Nothing
@ -472,10 +491,10 @@ xmppNoStream = StreamState { @@ -472,10 +491,10 @@ xmppNoStream = StreamState {
, streamJid = Nothing
, streamConfiguration = def
}
where
zeroSource :: Source IO output
zeroSource = liftIO $ do
errorM "Pontarius.Xmpp" "zeroSource utilized."
errorM "Pontarius.Xmpp" "zeroSource"
ExL.throwIO XmppOtherFailure
createStream :: HostName -> StreamConfiguration -> ErrorT XmppFailure IO (Stream)
@ -486,9 +505,9 @@ createStream realm config = do @@ -486,9 +505,9 @@ createStream realm config = do
debugM "Pontarius.Xmpp" "Acquired handle."
debugM "Pontarius.Xmpp" "Setting NoBuffering mode on handle."
hSetBuffering h NoBuffering
let eSource = DCI.ResumableSource
((sourceHandle h $= logConduit) $= XP.parseBytes def)
(return ())
eSource <- liftIO . bufferSrc $
(sourceHandle h $= logConduit) $= XP.parseBytes def
let hand = StreamHandle { streamSend = \d -> catchPush $ BS.hPut h d
, streamReceive = \n -> BS.hGetSome h n
, streamFlush = hFlush h
@ -795,5 +814,5 @@ withStream' action (Stream stream) = do @@ -795,5 +814,5 @@ withStream' action (Stream stream) = do
return r
mkStream :: StreamState -> IO (Stream)
mkStream con = Stream `fmap` (atomically $ newTMVar con)
mkStream :: StreamState -> IO Stream
mkStream con = Stream `fmap` atomically (newTMVar con)

2
source/Network/Xmpp/Types.hs

@ -794,7 +794,7 @@ data StreamState = StreamState @@ -794,7 +794,7 @@ data StreamState = StreamState
-- | Functions to send, receive, flush, and close on the stream
, streamHandle :: StreamHandle
-- | Event conduit source, and its associated finalizer
, streamEventSource :: ResumableSource IO Event
, streamEventSource :: Source IO Event
-- | Stream features advertised by the server
, streamFeatures :: !StreamFeatures -- TODO: Maybe?
-- | The hostname or IP specified for the connection

15
source/Network/Xmpp/Utilities.hs

@ -8,10 +8,14 @@ module Network.Xmpp.Utilities @@ -8,10 +8,14 @@ module Network.Xmpp.Utilities
, renderOpenElement
, renderElement
, checkHostName
, withTMVar
)
where
import Control.Applicative ((<|>))
import Control.Concurrent.STM
import Control.Exception
import Control.Monad.State.Strict
import qualified Data.Attoparsec.Text as AP
import qualified Data.ByteString as BS
import Data.Conduit as C
@ -25,6 +29,17 @@ import System.IO.Unsafe(unsafePerformIO) @@ -25,6 +29,17 @@ import System.IO.Unsafe(unsafePerformIO)
import qualified Text.XML.Stream.Render as TXSR
import Text.XML.Unresolved as TXU
-- | Apply f with the content of tv as state, restoring the original value when an
-- exception occurs
withTMVar :: TMVar a -> (a -> IO (c, a)) -> IO c
withTMVar tv f = bracketOnError (atomically $ takeTMVar tv)
(atomically . putTMVar tv)
(\s -> do
(x, s') <- f s
atomically $ putTMVar tv s'
return x
)
openElementToEvents :: Element -> [Event]
openElementToEvents (Element name as ns) = EventBeginElement name as : goN ns []
where

120
tests/Tests.hs

@ -1,4 +1,4 @@ @@ -1,4 +1,4 @@
{-# LANGUAGE PackageImports, OverloadedStrings, NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedStrings, NoMonomorphismRestriction #-}
module Example where
import Control.Concurrent
@ -17,24 +17,27 @@ import Data.XML.Types @@ -17,24 +17,27 @@ import Data.XML.Types
import Network
import Network.Xmpp
import Network.Xmpp.Concurrent.Channels
import Network.Xmpp.IM.Presence
import Network.Xmpp.Pickle
import Network.Xmpp.Internal
import Network.Xmpp.Marshal
import Network.Xmpp.Types
import qualified Network.Xmpp.Xep.InbandRegistration as IBR
-- import qualified Network.Xmpp.Xep.InbandRegistration as IBR
import Data.Default (def)
import qualified Network.Xmpp.Xep.ServiceDiscovery as Disco
import System.Environment
import Text.XML.Stream.Elements
import System.Log.Logger
testUser1 :: Jid
testUser1 = read "testuser1@species64739.dyndns.org/bot1"
testUser1 = "echo1@species64739.dyndns.org/bot"
testUser2 :: Jid
testUser2 = read "testuser2@species64739.dyndns.org/bot2"
testUser2 = "echo2@species64739.dyndns.org/bot"
supervisor :: Jid
supervisor = read "uart14@species64739.dyndns.org"
supervisor = "uart14@species64739.dyndns.org"
config = def{sessionStreamConfiguration
= def{connectionDetails = UseHost "localhost" (PortNumber 5222)}}
testNS :: Text
testNS = "xmpp:library:test"
@ -60,7 +63,7 @@ payloadP = xpWrap (\((counter,flag) , message) -> Payload counter flag message) @@ -60,7 +63,7 @@ payloadP = xpWrap (\((counter,flag) , message) -> Payload counter flag message)
invertPayload (Payload count flag message) = Payload (count + 1) (not flag) (Text.reverse message)
iqResponder context = do
chan' <- listenIQChan Get testNS context
chan' <- listenIQChan Set testNS context
chan <- case chan' of
Left _ -> liftIO $ putStrLn "Channel was already taken"
>> error "hanging up"
@ -72,14 +75,12 @@ iqResponder context = do @@ -72,14 +75,12 @@ iqResponder context = do
let answerPayload = invertPayload payload
let answerBody = pickleElem payloadP answerPayload
unless (payloadCounter payload == 3) . void $
answerIQ next (Right $ Just answerBody) context
when (payloadCounter payload == 10) $ do
threadDelay 1000000
endContext (session context)
answerIQ next (Right $ Just answerBody)
autoAccept :: Xmpp ()
autoAccept context = forever $ do
st <- waitForPresence isPresenceSubscribe context
st <- waitForPresence (\p -> presenceType p == Just Subscribe) context
sendPresence (presenceSubscribed (fromJust $ presenceFrom st)) context
simpleMessage :: Jid -> Text -> Message
@ -111,23 +112,23 @@ expect debug x y context | x == y = debug "Ok." @@ -111,23 +112,23 @@ expect debug x y context | x == y = debug "Ok."
wait3 :: MonadIO m => m ()
wait3 = liftIO $ threadDelay 1000000
discoTest debug context = do
q <- Disco.queryInfo "species64739.dyndns.org" Nothing context
case q of
Left (Disco.DiscoXMLError el e) -> do
debug (ppElement el)
debug (Text.unpack $ ppUnpickleError e)
debug (show $ length $ elementNodes el)
x -> debug $ show x
q <- Disco.queryItems "species64739.dyndns.org"
(Just "http://jabber.org/protocol/commands") context
case q of
Left (Disco.DiscoXMLError el e) -> do
debug (ppElement el)
debug (Text.unpack $ ppUnpickleError e)
debug (show $ length $ elementNodes el)
x -> debug $ show x
-- discoTest debug context = do
-- q <- Disco.queryInfo "species64739.dyndns.org" Nothing context
-- case q of
-- Left (Disco.DiscoXMLError el e) -> do
-- debug (ppElement el)
-- debug (Text.unpack $ ppUnpickleError e)
-- debug (show $ length $ elementNodes el)
-- x -> debug $ show x
-- q <- Disco.queryItems "species64739.dyndns.org"
-- (Just "http://jabber.org/protocol/commands") context
-- case q of
-- Left (Disco.DiscoXMLError el e) -> do
-- debug (ppElement el)
-- debug (Text.unpack $ ppUnpickleError e)
-- debug (show $ length $ elementNodes el)
-- x -> debug $ show x
iqTest debug we them context = do
forM [1..10] $ \count -> do
@ -135,7 +136,7 @@ iqTest debug we them context = do @@ -135,7 +136,7 @@ iqTest debug we them context = do
let payload = Payload count (even count) (Text.pack $ show count)
let body = pickleElem payloadP payload
debug "sending"
answer <- sendIQ' (Just them) Get Nothing body context
answer <- sendIQ' (Just them) Set Nothing body context
case answer of
IQResponseResult r -> do
debug "received"
@ -147,16 +148,12 @@ iqTest debug we them context = do @@ -147,16 +148,12 @@ iqTest debug we them context = do
IQResponseError e -> do
debug $ "Error in packet: " ++ show count
liftIO $ threadDelay 100000
sendUser "All tests done" context
-- sendUser "All tests done" context
debug "ending session"
fork action context = do
context' <- forkSession context
forkIO $ action context'
ibrTest debug uname pw = IBR.registerWith [ (IBR.Username, "testuser2")
, (IBR.Password, "pwd")
] >>= debug . show
-- ibrTest debug uname pw = IBR.registerWith [ (IBR.Username, "testuser2")
-- , (IBR.Password, "pwd")
-- ] >>= debug . show
runMain :: (String -> STM ()) -> Int -> Bool -> IO ()
@ -166,50 +163,23 @@ runMain debug number multi = do @@ -166,50 +163,23 @@ runMain debug number multi = do
0 -> (testUser2, testUser1,False)
let debug' = liftIO . atomically .
debug . (("Thread " ++ show number ++ ":") ++)
context <- newSession
setConnectionClosedHandler (\e s -> do
debug' $ "connection lost because " ++ show e
endContext s) (session context)
debug' "running"
flip withConnection (session context) $ Ex.catch (do
debug' "connect"
connect "localhost" (PortNumber 5222) "species64739.dyndns.org"
-- debug' "tls start"
startTLS exampleParams
debug' "ibr start"
-- ibrTest debug' (localpart we) "pwd"
-- debug' "ibr end"
saslResponse <- simpleAuth
(fromJust $ localpart we) "pwd" (resourcepart we)
case saslResponse of
Right _ -> return ()
Left e -> error $ show e
debug' "session standing"
features <- other `liftM` gets sFeatures
liftIO . void $ forM features $ \f -> debug' $ ppElement f
)
(\e -> debug' $ show (e ::Ex.SomeException))
Right context <- session (Text.unpack $ domainpart we)
(Just ([scramSha1 (fromJust $ localpart we) Nothing "pwd"], resourcepart we))
config
sendPresence presenceOnline context
thread1 <- fork autoAccept context
thread1 <- forkIO $ autoAccept =<< dupSession context
sendPresence (presenceSubscribe them) context
thread2 <- fork iqResponder context
thread2 <- forkIO $ iqResponder =<< dupSession context
when active $ do
liftIO $ threadDelay 1000000 -- Wait for the other thread to go online
-- discoTest debug'
when multi $ iqTest debug' we them context
closeConnection (session context)
killThread thread1
killThread thread2
return ()
liftIO . threadDelay $ 10^6
-- unless multi . void . withConnection $ IBR.unregister
unless multi . void $ fork (\s -> forever $ do
pullMessage s >>= debug' . show
putStrLn ""
putStrLn ""
)
context
liftIO . forever $ threadDelay 1000000
return ()
@ -221,4 +191,6 @@ run i multi = do @@ -221,4 +191,6 @@ run i multi = do
runMain debugOut (2 + i) multi
main = run 0 True
main = do
updateGlobalLogger "Pontarius.Xmpp" $ setLevel DEBUG
run 0 True

Loading…
Cancel
Save