diff --git a/source/Network/Xmpp.hs b/source/Network/Xmpp.hs index 82e02ac..cfb2c14 100644 --- a/source/Network/Xmpp.hs +++ b/source/Network/Xmpp.hs @@ -138,8 +138,6 @@ module Network.Xmpp , sendIQ' , answerIQ , listenIQChan - , iqRequestPayload - , iqResultPayload -- * Errors , StanzaError(..) , StanzaErrorType(..) @@ -157,10 +155,8 @@ module Network.Xmpp , AuthOtherFailure ) ) where -import Network import Network.Xmpp.Concurrent import Network.Xmpp.Utilities import Network.Xmpp.Sasl import Network.Xmpp.Sasl.Types -import Network.Xmpp.Tls import Network.Xmpp.Types diff --git a/source/Network/Xmpp/Concurrent.hs b/source/Network/Xmpp/Concurrent.hs index 9ff6ccc..772ca34 100644 --- a/source/Network/Xmpp/Concurrent.hs +++ b/source/Network/Xmpp/Concurrent.hs @@ -18,36 +18,28 @@ import Control.Applicative((<$>),(<*>)) import Control.Concurrent import Control.Concurrent.STM import Control.Monad +import Control.Monad.Error import qualified Data.ByteString as BS -import Data.IORef import qualified Data.Map as Map import Data.Maybe -import Data.Maybe (fromMaybe) import Data.Text as Text import Data.XML.Types import Network -import qualified Network.TLS as TLS import Network.Xmpp.Concurrent.Basic import Network.Xmpp.Concurrent.IQ import Network.Xmpp.Concurrent.Message import Network.Xmpp.Concurrent.Monad import Network.Xmpp.Concurrent.Presence import Network.Xmpp.Concurrent.Threads -import Network.Xmpp.Concurrent.Threads import Network.Xmpp.Concurrent.Types import Network.Xmpp.Marshal import Network.Xmpp.Sasl -import Network.Xmpp.Sasl.Mechanisms import Network.Xmpp.Sasl.Types import Network.Xmpp.Stream import Network.Xmpp.Tls import Network.Xmpp.Types import Network.Xmpp.Utilities -import Control.Monad.Error -import Data.Default -import System.Log.Logger -import Control.Monad.State.Strict runHandlers :: (TChan Stanza) -> [StanzaHandler] -> Stanza -> IO () runHandlers _ [] _ = return () @@ -96,7 +88,7 @@ handleIQ iqHands outC sta = atomically $ do _ <- tryPutTMVar tmvar answer -- Don't block. writeTVar handlers (byNS, byID') where - iqID (Left err) = iqErrorID err + iqID (Left err') = iqErrorID err' iqID (Right iq') = iqResultID iq' -- | Creates and initializes a new Xmpp context. @@ -104,21 +96,21 @@ newSession :: Stream -> SessionConfiguration -> IO (Either XmppFailure Session) newSession stream config = runErrorT $ do outC <- lift newTChanIO stanzaChan <- lift newTChanIO - iqHandlers <- lift $ newTVarIO (Map.empty, Map.empty) + iqHands <- lift $ newTVarIO (Map.empty, Map.empty) eh <- lift $ newTVarIO $ EventHandlers { connectionClosedHandler = sessionClosedHandler config } let stanzaHandler = runHandlers outC $ Prelude.concat [ [toChan stanzaChan] , extraStanzaHandlers config - , [handleIQ iqHandlers] + , [handleIQ iqHands] ] - (kill, wLock, streamState, readerThread) <- ErrorT $ startThreadsWith stanzaHandler eh stream + (kill, wLock, streamState, reader) <- ErrorT $ startThreadsWith stanzaHandler eh stream writer <- lift $ forkIO $ writeWorker outC wLock idGen <- liftIO $ sessionStanzaIDs config return $ Session { stanzaCh = stanzaChan , outCh = outC - , iqHandlers = iqHandlers + , iqHandlers = iqHands , writeRef = wLock - , readerThread = readerThread + , readerThread = reader , idGenerator = idGen , streamRef = streamState , eventHandlers = eh diff --git a/source/Network/Xmpp/Concurrent/IQ.hs b/source/Network/Xmpp/Concurrent/IQ.hs index bd79061..d41e8cf 100644 --- a/source/Network/Xmpp/Concurrent/IQ.hs +++ b/source/Network/Xmpp/Concurrent/IQ.hs @@ -4,8 +4,6 @@ module Network.Xmpp.Concurrent.IQ where import Control.Concurrent (forkIO, threadDelay) import Control.Concurrent.STM import Control.Monad -import Control.Monad.IO.Class -import Control.Monad.Trans.Reader import qualified Data.Map as Map import Data.Text (Text) @@ -90,17 +88,17 @@ answerIQ :: IQRequestTicket -> Session -> IO Bool answerIQ (IQRequestTicket - sentRef + sRef (IQRequest iqid from _to lang _tp bd)) answer session = do let response = case answer of Left err -> IQErrorS $ IQError iqid Nothing from lang err (Just bd) Right res -> IQResultS $ IQResult iqid Nothing from lang res atomically $ do - sent <- readTVar sentRef + sent <- readTVar sRef case sent of False -> do - writeTVar sentRef True + writeTVar sRef True writeTChan (outCh session) response return True diff --git a/source/Network/Xmpp/Concurrent/Message.hs b/source/Network/Xmpp/Concurrent/Message.hs index 543303c..234484c 100644 --- a/source/Network/Xmpp/Concurrent/Message.hs +++ b/source/Network/Xmpp/Concurrent/Message.hs @@ -3,9 +3,7 @@ module Network.Xmpp.Concurrent.Message where import Network.Xmpp.Concurrent.Types import Control.Concurrent.STM -import Data.IORef import Network.Xmpp.Types -import Network.Xmpp.Concurrent.Types import Network.Xmpp.Concurrent.Basic -- | Read an element from the inbound stanza channel, discardes any diff --git a/source/Network/Xmpp/Concurrent/Monad.hs b/source/Network/Xmpp/Concurrent/Monad.hs index 5a1d627..9a61745 100644 --- a/source/Network/Xmpp/Concurrent/Monad.hs +++ b/source/Network/Xmpp/Concurrent/Monad.hs @@ -60,15 +60,15 @@ import Network.Xmpp.Stream -- | Executes a function to update the event handlers. modifyHandlers :: (EventHandlers -> EventHandlers) -> Session -> IO () -modifyHandlers f session = atomically $ modifyTVar (eventHandlers session) f +modifyHandlers f session = atomically $ modifyTVar_ (eventHandlers session) f where -- Borrowing modifyTVar from -- http://hackage.haskell.org/packages/archive/stm/2.4/doc/html/src/Control-Concurrent-STM-TVar.html -- as it's not available in GHC 7.0. - modifyTVar :: TVar a -> (a -> a) -> STM () - modifyTVar var f = do + modifyTVar_ :: TVar a -> (a -> a) -> STM () + modifyTVar_ var g = do x <- readTVar var - writeTVar var (f x) + writeTVar var (g x) -- | Sets the handler to be executed when the server connection is closed. setConnectionClosedHandler_ :: (XmppFailure -> Session -> IO ()) -> Session -> IO () diff --git a/source/Network/Xmpp/Concurrent/Presence.hs b/source/Network/Xmpp/Concurrent/Presence.hs index d9cfc6e..cb6a502 100644 --- a/source/Network/Xmpp/Concurrent/Presence.hs +++ b/source/Network/Xmpp/Concurrent/Presence.hs @@ -2,7 +2,6 @@ module Network.Xmpp.Concurrent.Presence where import Control.Concurrent.STM -import Data.IORef import Network.Xmpp.Types import Network.Xmpp.Concurrent.Types import Network.Xmpp.Concurrent.Basic diff --git a/source/Network/Xmpp/Concurrent/Threads.hs b/source/Network/Xmpp/Concurrent/Threads.hs index f1ce15d..5c0b03b 100644 --- a/source/Network/Xmpp/Concurrent/Threads.hs +++ b/source/Network/Xmpp/Concurrent/Threads.hs @@ -4,25 +4,18 @@ module Network.Xmpp.Concurrent.Threads where -import Network.Xmpp.Types - import Control.Applicative((<$>)) import Control.Concurrent import Control.Concurrent.STM import qualified Control.Exception.Lifted as Ex import Control.Monad -import Control.Monad.IO.Class +import Control.Monad.Error import Control.Monad.State.Strict - import qualified Data.ByteString as BS +import GHC.IO (unsafeUnmask) import Network.Xmpp.Concurrent.Types import Network.Xmpp.Stream - -import Control.Concurrent.STM.TMVar - -import GHC.IO (unsafeUnmask) - -import Control.Monad.Error +import Network.Xmpp.Types import System.Log.Logger -- Worker to read stanzas from the stream and concurrently distribute them to @@ -38,8 +31,8 @@ readWorker onStanza onConnectionClosed stateRef = -- necessarily be interruptible s <- atomically $ do s@(Stream con) <- readTMVar stateRef - state <- streamConnectionState <$> readTMVar con - when (state == Closed) + scs <- streamConnectionState <$> readTMVar con + when (scs == Closed) retry return s allowInterrupt @@ -55,7 +48,7 @@ readWorker onStanza onConnectionClosed stateRef = ] case res of Nothing -> return () -- Caught an exception, nothing to do. TODO: Can this happen? - Just (Left e) -> return () + Just (Left _) -> return () Just (Right sta) -> onStanza sta where -- Defining an Control.Exception.allowInterrupt equivalent for GHC 7 @@ -85,19 +78,19 @@ startThreadsWith :: (Stanza -> IO ()) TMVar Stream, ThreadId)) startThreadsWith stanzaHandler eh con = do - read <- withStream' (gets $ streamSend . streamHandle >>= \d -> return $ Right d) con - case read of + rd <- withStream' (gets $ streamSend . streamHandle >>= \d -> return $ Right d) con + case rd of Left e -> return $ Left e Right read' -> do writeLock <- newTMVarIO read' conS <- newTMVarIO con -- lw <- forkIO $ writeWorker outC writeLock cp <- forkIO $ connPersist writeLock - rd <- forkIO $ readWorker stanzaHandler (noCon eh) conS - return $ Right ( killConnection writeLock [rd, cp] + rdw <- forkIO $ readWorker stanzaHandler (noCon eh) conS + return $ Right ( killConnection writeLock [rdw, cp] , writeLock , conS - , rd + , rdw ) where killConnection writeLock threads = liftIO $ do diff --git a/source/Network/Xmpp/Concurrent/Types.hs b/source/Network/Xmpp/Concurrent/Types.hs index 008d853..4a4b2e5 100644 --- a/source/Network/Xmpp/Concurrent/Types.hs +++ b/source/Network/Xmpp/Concurrent/Types.hs @@ -3,19 +3,13 @@ module Network.Xmpp.Concurrent.Types where -import qualified Control.Exception.Lifted as Ex import Control.Concurrent import Control.Concurrent.STM - +import qualified Control.Exception.Lifted as Ex import qualified Data.ByteString as BS -import Data.Typeable - -import Network.Xmpp.Types - -import Data.IORef import qualified Data.Map as Map import Data.Text (Text) - +import Data.Typeable import Network.Xmpp.Types -- | Handlers to be run when the Xmpp session ends and when the Xmpp connection is diff --git a/source/Network/Xmpp/IM/Presence.hs b/source/Network/Xmpp/IM/Presence.hs index 512da70..773c04d 100644 --- a/source/Network/Xmpp/IM/Presence.hs +++ b/source/Network/Xmpp/IM/Presence.hs @@ -2,7 +2,6 @@ module Network.Xmpp.IM.Presence where -import Data.Text(Text) import Network.Xmpp.Types -- | An empty presence. diff --git a/source/Network/Xmpp/Internal.hs b/source/Network/Xmpp/Internal.hs index 60f7fbc..c06d06e 100644 --- a/source/Network/Xmpp/Internal.hs +++ b/source/Network/Xmpp/Internal.hs @@ -29,7 +29,7 @@ module Network.Xmpp.Internal , pushStanza , pullStanza , pushIQ - , SaslHandler(..) + , SaslHandler , StanzaID(..) ) @@ -37,9 +37,6 @@ module Network.Xmpp.Internal import Network.Xmpp.Stream import Network.Xmpp.Sasl -import Network.Xmpp.Sasl.Common import Network.Xmpp.Sasl.Types import Network.Xmpp.Tls import Network.Xmpp.Types -import Network.Xmpp.Stream -import Network.Xmpp.Marshal diff --git a/source/Network/Xmpp/Sasl.hs b/source/Network/Xmpp/Sasl.hs index cab4c6d..d445cb9 100644 --- a/source/Network/Xmpp/Sasl.hs +++ b/source/Network/Xmpp/Sasl.hs @@ -1,6 +1,6 @@ {-# OPTIONS_HADDOCK hide #-} {-# LANGUAGE NoMonomorphismRestriction, OverloadedStrings #-} - +-- -- Submodule for functionality related to SASL negotation: -- authentication functions, SASL functionality, bind functionality, -- and the legacy `{urn:ietf:params:xml:ns:xmpp-session}session' @@ -14,51 +14,17 @@ module Network.Xmpp.Sasl , auth ) where -import Control.Applicative -import Control.Arrow (left) -import Control.Monad import Control.Monad.Error import Control.Monad.State.Strict -import Data.Maybe (fromJust, isJust) - -import qualified Crypto.Classes as CC - -import qualified Data.Binary as Binary -import qualified Data.ByteString.Base64 as B64 -import qualified Data.ByteString.Char8 as BS8 -import qualified Data.ByteString.Lazy as BL -import qualified Data.Digest.Pure.MD5 as MD5 -import qualified Data.List as L -import Data.Word (Word8) - -import qualified Data.Text as Text import Data.Text (Text) -import qualified Data.Text.Encoding as Text - -import Network.Xmpp.Stream -import Network.Xmpp.Types - -import System.Log.Logger (debugM, errorM) -import qualified System.Random as Random - -import Network.Xmpp.Sasl.Types -import Network.Xmpp.Sasl.Mechanisms - -import Control.Concurrent.STM.TMVar - -import Control.Exception - import Data.XML.Pickle import Data.XML.Types - -import Network.Xmpp.Types import Network.Xmpp.Marshal - -import Control.Monad.State(modify) - -import Control.Concurrent.STM.TMVar - -import Control.Monad.Error +import Network.Xmpp.Sasl.Mechanisms +import Network.Xmpp.Sasl.Types +import Network.Xmpp.Stream +import Network.Xmpp.Types +import System.Log.Logger (debugM, errorM, infoM) -- | Uses the first supported mechanism to authenticate, if any. Updates the -- state with non-password credentials and restarts the stream upon @@ -105,16 +71,18 @@ auth :: [SaslHandler] -> Stream -> IO (Either XmppFailure (Maybe AuthFailure)) auth mechanisms resource con = runErrorT $ do - ErrorT $ xmppSasl mechanisms con - jid <- ErrorT $ xmppBind resource con - ErrorT $ flip withStream con $ do - s <- get - case establishSession $ streamConfiguration s of - False -> return $ Right Nothing - True -> do - _ <- lift $ startSession con - return $ Right Nothing - return Nothing + mbAuthFail <- ErrorT $ xmppSasl mechanisms con + case mbAuthFail of + Nothing -> do + _jid <- ErrorT $ xmppBind resource con + ErrorT $ flip withStream con $ do + s <- get + case establishSession $ streamConfiguration s of + False -> return $ Right Nothing + True -> do + _ <-liftIO $ startSession con + return $ Right Nothing + f -> return f -- Produces a `bind' element, optionally wrapping a resource. bindBody :: Maybe Text -> Element @@ -137,16 +105,19 @@ xmppBind rsrc c = runErrorT $ do let jid = unpickleElem xpJid b case jid of Right jid' -> do - lift $ debugM "Pontarius.XMPP" $ "xmppBind: JID unpickled: " ++ show jid' - ErrorT $ withStream (do - modify $ \s -> s{streamJid = Just jid'} - return $ Right jid') c -- not pretty + lift $ infoM "Pontarius.XMPP" $ "Bound JID: " ++ show jid' + _ <- lift $ withStream ( do + modify $ \s -> + s{streamJid = Just jid'} + return $ Right ()) + c return jid' - otherwise -> do - lift $ errorM "Pontarius.XMPP" $ "xmppBind: JID could not be unpickled from: " - ++ show b + _ -> do + lift $ errorM "Pontarius.XMPP" + $ "xmppBind: JID could not be unpickled from: " + ++ show b throwError $ XmppOtherFailure - otherwise -> do + _ -> do lift $ errorM "Pontarius.XMPP" "xmppBind: IQ error received." throwError XmppOtherFailure where @@ -164,15 +135,6 @@ sessionXml = pickleElem (xpElemBlank "{urn:ietf:params:xml:ns:xmpp-session}session") () -sessionIQ :: Stanza -sessionIQ = IQRequestS $ IQRequest { iqRequestID = "sess" - , iqRequestFrom = Nothing - , iqRequestTo = Nothing - , iqRequestLangTag = Nothing - , iqRequestType = Set - , iqRequestPayload = sessionXml - } - -- Sends the session IQ set element and waits for an answer. Throws an error if -- if an IQ error stanza is returned from the server. startSession :: Stream -> IO Bool diff --git a/source/Network/Xmpp/Sasl/Common.hs b/source/Network/Xmpp/Sasl/Common.hs index 3a5382c..47f8744 100644 --- a/source/Network/Xmpp/Sasl/Common.hs +++ b/source/Network/Xmpp/Sasl/Common.hs @@ -4,28 +4,23 @@ module Network.Xmpp.Sasl.Common where -import Network.Xmpp.Types - import Control.Applicative ((<$>)) import Control.Monad.Error -import Control.Monad.State.Class - import qualified Data.Attoparsec.ByteString.Char8 as AP import Data.Bits import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 -import Data.Maybe (fromMaybe) import Data.Maybe (maybeToList) import qualified Data.Text as Text import qualified Data.Text.Encoding as Text import Data.Word (Word8) import Data.XML.Pickle import Data.XML.Types - -import Network.Xmpp.Stream +import Network.Xmpp.Marshal import Network.Xmpp.Sasl.StringPrep import Network.Xmpp.Sasl.Types -import Network.Xmpp.Marshal +import Network.Xmpp.Stream +import Network.Xmpp.Types import qualified System.Random as Random @@ -66,9 +61,9 @@ pairs = AP.parseOnly . flip AP.sepBy1 (void $ AP.char ',') $ do AP.skipSpace name <- AP.takeWhile1 (/= '=') _ <- AP.char '=' - quote <- ((AP.char '"' >> return True) `mplus` return False) + qt <- ((AP.char '"' >> return True) `mplus` return False) content <- AP.takeWhile1 (AP.notInClass [',', '"']) - when quote . void $ AP.char '"' + when qt . void $ AP.char '"' return (name, content) -- Failure element pickler. @@ -108,19 +103,20 @@ xpSaslElement = xpAlt saslSel quote :: BS.ByteString -> BS.ByteString quote x = BS.concat ["\"",x,"\""] -saslInit :: Text.Text -> Maybe BS.ByteString -> ErrorT AuthFailure (StateT StreamState IO) Bool +saslInit :: Text.Text -> Maybe BS.ByteString -> ErrorT AuthFailure (StateT StreamState IO) () saslInit mechanism payload = do r <- lift . pushElement . saslInitE mechanism $ Text.decodeUtf8 . B64.encode <$> payload case r of - Left e -> throwError $ AuthStreamFailure e - Right b -> return b + Right True -> return () + Right False -> throwError $ AuthStreamFailure XmppNoStream + Left e -> throwError $ AuthStreamFailure e -- | Pull the next element. pullSaslElement :: ErrorT AuthFailure (StateT StreamState IO) SaslElement pullSaslElement = do - r <- lift $ pullUnpickle (xpEither xpFailure xpSaslElement) - case r of + mbse <- lift $ pullUnpickle (xpEither xpFailure xpSaslElement) + case mbse of Left e -> throwError $ AuthStreamFailure e Right (Left e) -> throwError $ AuthSaslFailure e Right (Right r) -> return r @@ -173,13 +169,13 @@ toPairs ctext = case pairs ctext of Right r -> return r -- | Send a SASL response element. The content will be base64-encoded. -respond :: Maybe BS.ByteString -> ErrorT AuthFailure (StateT StreamState IO) Bool +respond :: Maybe BS.ByteString -> ErrorT AuthFailure (StateT StreamState IO) () respond m = do r <- lift . pushElement . saslResponseE . fmap (Text.decodeUtf8 . B64.encode) $ m case r of Left e -> throwError $ AuthStreamFailure e - Right b -> return b - + Right False -> throwError $ AuthStreamFailure XmppNoStream + Right True -> return () -- | Run the appropriate stringprep profiles on the credentials. -- May fail with 'AuthStringPrepFailure' diff --git a/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs b/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs index 7e7aca4..36e87eb 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs @@ -5,37 +5,21 @@ module Network.Xmpp.Sasl.Mechanisms.DigestMd5 ( digestMd5 ) where -import Control.Applicative -import Control.Arrow (left) -import Control.Monad import Control.Monad.Error import Control.Monad.State.Strict -import Data.Maybe (fromJust, isJust) - import qualified Crypto.Classes as CC - import qualified Data.Binary as Binary +import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 import qualified Data.ByteString.Char8 as BS8 import qualified Data.ByteString.Lazy as BL import qualified Data.Digest.Pure.MD5 as MD5 import qualified Data.List as L - -import qualified Data.Text as Text import Data.Text (Text) import qualified Data.Text.Encoding as Text - -import Data.XML.Pickle - -import qualified Data.ByteString as BS - -import Data.XML.Types - -import Network.Xmpp.Stream -import Network.Xmpp.Types import Network.Xmpp.Sasl.Common -import Network.Xmpp.Sasl.StringPrep import Network.Xmpp.Sasl.Types +import Network.Xmpp.Types @@ -43,19 +27,19 @@ xmppDigestMd5 :: Text -- ^ Authentication identity (authzid or username) -> Maybe Text -- ^ Authorization identity (authcid) -> Text -- ^ Password (authzid) -> ErrorT AuthFailure (StateT StreamState IO) () -xmppDigestMd5 authcid authzid password = do - (ac, az, pw) <- prepCredentials authcid authzid password +xmppDigestMd5 authcid' authzid' password' = do + (ac, az, pw) <- prepCredentials authcid' authzid' password' Just address <- gets streamAddress xmppDigestMd5' address ac az pw where xmppDigestMd5' :: Text -> Text -> Maybe Text -> Text -> ErrorT AuthFailure (StateT StreamState IO) () - xmppDigestMd5' hostname authcid authzid password = do + xmppDigestMd5' hostname authcid _authzid password = do -- TODO: use authzid? -- Push element and receive the challenge. _ <- saslInit "DIGEST-MD5" Nothing -- TODO: Check boolean? - pairs <- toPairs =<< saslFromJust =<< pullChallenge + prs <- toPairs =<< saslFromJust =<< pullChallenge cnonce <- liftIO $ makeNonce - _b <- respond . Just $ createResponse hostname pairs cnonce - challenge2 <- pullFinalMessage + _b <- respond . Just $ createResponse hostname prs cnonce + _challenge2 <- pullFinalMessage return () where -- Produce the response to the challenge. @@ -63,19 +47,19 @@ xmppDigestMd5 authcid authzid password = do -> Pairs -> BS.ByteString -- nonce -> BS.ByteString - createResponse hostname pairs cnonce = let - Just qop = L.lookup "qop" pairs -- TODO: proper handling - Just nonce = L.lookup "nonce" pairs + createResponse hname prs cnonce = let + Just qop = L.lookup "qop" prs -- TODO: proper handling + Just nonce = L.lookup "nonce" prs uname_ = Text.encodeUtf8 authcid passwd_ = Text.encodeUtf8 password -- Using Int instead of Word8 for random 1.0.0.0 (GHC 7) -- compatibility. nc = "00000001" - digestURI = "xmpp/" `BS.append` Text.encodeUtf8 hostname + digestURI = "xmpp/" `BS.append` Text.encodeUtf8 hname digest = md5Digest uname_ - (lookup "realm" pairs) + (lookup "realm" prs) passwd_ digestURI nc @@ -84,7 +68,7 @@ xmppDigestMd5 authcid authzid password = do cnonce response = BS.intercalate "," . map (BS.intercalate "=") $ [["username", quote uname_]] ++ - case L.lookup "realm" pairs of + case L.lookup "realm" prs of Just realm -> [["realm" , quote realm ]] Nothing -> [] ++ [ ["nonce" , quote nonce ] @@ -115,8 +99,8 @@ xmppDigestMd5 authcid authzid password = do -> BS8.ByteString -> BS8.ByteString -> BS8.ByteString - md5Digest uname realm password digestURI nc qop nonce cnonce = - let ha1 = hash [ hashRaw [uname, maybe "" id realm, password] + md5Digest uname realm pwd digestURI nc qop nonce cnonce = + let ha1 = hash [ hashRaw [uname, maybe "" id realm, pwd] , nonce , cnonce ] diff --git a/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs b/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs index fa35be7..0c32793 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs @@ -8,51 +8,22 @@ module Network.Xmpp.Sasl.Mechanisms.Plain ( plain ) where -import Control.Applicative -import Control.Arrow (left) -import Control.Monad import Control.Monad.Error import Control.Monad.State.Strict -import Data.Maybe (fromJust, isJust) - -import qualified Crypto.Classes as CC - -import qualified Data.Binary as Binary -import qualified Data.ByteString.Base64 as B64 -import qualified Data.ByteString.Char8 as BS8 -import qualified Data.ByteString.Lazy as BL -import qualified Data.Digest.Pure.MD5 as MD5 -import qualified Data.List as L -import Data.Word (Word8) - -import qualified Data.Text as Text -import Data.Text (Text) -import qualified Data.Text.Encoding as Text - -import Data.XML.Pickle - import qualified Data.ByteString as BS - -import Data.XML.Types - -import Network.Xmpp.Stream -import Network.Xmpp.Types - -import qualified System.Random as Random - -import Data.Maybe (fromMaybe) import qualified Data.Text as Text - +import qualified Data.Text.Encoding as Text import Network.Xmpp.Sasl.Common import Network.Xmpp.Sasl.Types +import Network.Xmpp.Types -- TODO: stringprep xmppPlain :: Text.Text -- ^ Password -> Maybe Text.Text -- ^ Authorization identity (authzid) -> Text.Text -- ^ Authentication identity (authcid) -> ErrorT AuthFailure (StateT StreamState IO) () -xmppPlain authcid authzid password = do - (ac, az, pw) <- prepCredentials authcid authzid password +xmppPlain authcid' authzid' password = do + (ac, az, pw) <- prepCredentials authcid' authzid' password _ <- saslInit "PLAIN" ( Just $ plainMessage ac az pw) _ <- pullSuccess return () @@ -63,15 +34,15 @@ xmppPlain authcid authzid password = do -> Maybe Text.Text -- Authentication identity (authcid) -> Text.Text -- Password -> BS.ByteString -- The PLAIN message - plainMessage authcid authzid passwd = BS.concat $ - [ authzid' - , "\NUL" - , Text.encodeUtf8 $ authcid - , "\NUL" - , Text.encodeUtf8 $ passwd - ] + plainMessage authcid _authzid passwd = BS.concat $ + [ authzid'' + , "\NUL" + , Text.encodeUtf8 $ authcid + , "\NUL" + , Text.encodeUtf8 $ passwd + ] where - authzid' = maybe "" Text.encodeUtf8 authzid + authzid'' = maybe "" Text.encodeUtf8 authzid' plain :: Text.Text -- ^ authentication ID (username) -> Maybe Text.Text -- ^ authorization ID diff --git a/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs b/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs index 84535dc..c7b2572 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs @@ -8,32 +8,20 @@ module Network.Xmpp.Sasl.Mechanisms.Scram import Control.Applicative ((<$>)) import Control.Monad.Error -import Control.Monad.Trans (liftIO) +import Control.Monad.State.Strict import qualified Crypto.Classes as Crypto import qualified Crypto.HMAC as Crypto import qualified Crypto.Hash.SHA1 as Crypto -import Data.Binary(Binary,encode) import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 import Data.ByteString.Char8 as BS8 (unpack) -import qualified Data.ByteString.Lazy as LBS import Data.List (foldl1', genericTake) - -import qualified Data.Binary.Builder as Build - -import Data.Maybe (maybeToList) import qualified Data.Text as Text import qualified Data.Text.Encoding as Text -import Data.Word(Word8) - import Network.Xmpp.Sasl.Common -import Network.Xmpp.Sasl.StringPrep import Network.Xmpp.Sasl.Types import Network.Xmpp.Types - -import Control.Monad.State.Strict - -- | A nicer name for undefined, for use as a dummy token to determin -- the hash function to use hashToken :: (Crypto.Hash ctx hash) => hash @@ -50,18 +38,18 @@ scram :: (Crypto.Hash ctx hash) -> Maybe Text.Text -- ^ Authorization ID -> Text.Text -- ^ Password -> ErrorT AuthFailure (StateT StreamState IO) () -scram hashToken authcid authzid password = do +scram hToken authcid authzid password = do (ac, az, pw) <- prepCredentials authcid authzid password - scramhelper hashToken ac az pw + scramhelper ac az pw where - scramhelper hashToken authcid authzid' password = do + scramhelper authcid' authzid' pwd = do cnonce <- liftIO $ makeNonce - saslInit "SCRAM-SHA-1" (Just $ cFirstMessage cnonce) + _ <- saslInit "SCRAM-SHA-1" (Just $ cFirstMessage cnonce) sFirstMessage <- saslFromJust =<< pullChallenge - pairs <- toPairs sFirstMessage - (nonce, salt, ic) <- fromPairs pairs cnonce + prs <- toPairs sFirstMessage + (nonce, salt, ic) <- fromPairs prs cnonce let (cfm, v) = cFinalMessageAndVerifier nonce salt ic sFirstMessage cnonce - respond $ Just cfm + _ <- respond $ Just cfm finalPairs <- toPairs =<< saslFromJust =<< pullFinalMessage unless (lookup "v" finalPairs == Just v) $ throwError AuthOtherFailure -- TODO: Log return () @@ -71,27 +59,27 @@ scram hashToken authcid authzid password = do encode _hashtoken = Crypto.encode hash :: BS.ByteString -> BS.ByteString - hash str = encode hashToken $ Crypto.hash' str + hash str = encode hToken $ Crypto.hash' str hmac :: BS.ByteString -> BS.ByteString -> BS.ByteString - hmac key str = encode hashToken $ Crypto.hmac' (Crypto.MacKey key) str + hmac key str = encode hToken $ Crypto.hmac' (Crypto.MacKey key) str - authzid :: Maybe BS.ByteString - authzid = (\z -> "a=" +++ Text.encodeUtf8 z) <$> authzid' + authzid'' :: Maybe BS.ByteString + authzid'' = (\z -> "a=" +++ Text.encodeUtf8 z) <$> authzid' gs2CbindFlag :: BS.ByteString gs2CbindFlag = "n" -- we don't support channel binding yet gs2Header :: BS.ByteString gs2Header = merge $ [ gs2CbindFlag - , maybe "" id authzid + , maybe "" id authzid'' , "" ] - cbindData :: BS.ByteString - cbindData = "" -- we don't support channel binding yet + -- cbindData :: BS.ByteString + -- cbindData = "" -- we don't support channel binding yet cFirstMessageBare :: BS.ByteString -> BS.ByteString - cFirstMessageBare cnonce = merge [ "n=" +++ Text.encodeUtf8 authcid + cFirstMessageBare cnonce = merge [ "n=" +++ Text.encodeUtf8 authcid' , "r=" +++ cnonce] cFirstMessage :: BS.ByteString -> BS.ByteString cFirstMessage cnonce = gs2Header +++ cFirstMessageBare cnonce @@ -99,13 +87,13 @@ scram hashToken authcid authzid password = do fromPairs :: Pairs -> BS.ByteString -> ErrorT AuthFailure (StateT StreamState IO) (BS.ByteString, BS.ByteString, Integer) - fromPairs pairs cnonce | Just nonce <- lookup "r" pairs - , cnonce `BS.isPrefixOf` nonce - , Just salt' <- lookup "s" pairs - , Right salt <- B64.decode salt' - , Just ic <- lookup "i" pairs - , [(i,"")] <- reads $ BS8.unpack ic - = return (nonce, salt, i) + fromPairs prs cnonce | Just nonce <- lookup "r" prs + , cnonce `BS.isPrefixOf` nonce + , Just salt' <- lookup "s" prs + , Right salt <- B64.decode salt' + , Just ic <- lookup "i" prs + , [(i,"")] <- reads $ BS8.unpack ic + = return (nonce, salt, i) fromPairs _ _ = throwError $ AuthOtherFailure -- TODO: Log cFinalMessageAndVerifier :: BS.ByteString @@ -126,7 +114,7 @@ scram hashToken authcid authzid password = do , "r=" +++ nonce] saltedPassword :: BS.ByteString - saltedPassword = hi (Text.encodeUtf8 password) salt ic + saltedPassword = hi (Text.encodeUtf8 pwd) salt ic clientKey :: BS.ByteString clientKey = hmac saltedPassword "Client Key" @@ -154,9 +142,9 @@ scram hashToken authcid authzid password = do -- helper hi :: BS.ByteString -> BS.ByteString -> Integer -> BS.ByteString - hi str salt ic = foldl1' xorBS (genericTake ic us) + hi str slt ic' = foldl1' xorBS (genericTake ic' us) where - u1 = hmac str (salt +++ (BS.pack [0,0,0,1])) + u1 = hmac str (slt +++ (BS.pack [0,0,0,1])) us = iterate (hmac str) u1 scramSha1 :: Text.Text -- ^ username diff --git a/source/Network/Xmpp/Sasl/StringPrep.hs b/source/Network/Xmpp/Sasl/StringPrep.hs index cff48a6..81f5117 100644 --- a/source/Network/Xmpp/Sasl/StringPrep.hs +++ b/source/Network/Xmpp/Sasl/StringPrep.hs @@ -4,27 +4,34 @@ module Network.Xmpp.Sasl.StringPrep where import Text.StringPrep import qualified Data.Set as Set -import Data.Text(singleton) +import Data.Text(Text, singleton) +nonAsciiSpaces :: Set.Set Char nonAsciiSpaces = Set.fromList [ '\x00A0', '\x1680', '\x2000', '\x2001', '\x2002' , '\x2003', '\x2004', '\x2005', '\x2006', '\x2007' , '\x2008', '\x2009', '\x200A', '\x200B', '\x202F' , '\x205F', '\x3000' ] +toSpace :: Char -> Text toSpace x = if x `Set.member` nonAsciiSpaces then " " else singleton x +saslPrepQuery :: StringPrepProfile saslPrepQuery = Profile [b1, toSpace] True [c12, c21, c22, c3, c4, c5, c6, c7, c8, c9] True +saslPrepStore :: StringPrepProfile saslPrepStore = Profile [b1, toSpace] True [a1, c12, c21, c22, c3, c4, c5, c6, c7, c8, c9] True +normalizePassword :: Text -> Maybe Text normalizePassword = runStringPrep saslPrepStore -normalizeUsername = runStringPrep saslPrepQuery \ No newline at end of file + +normalizeUsername :: Text -> Maybe Text +normalizeUsername = runStringPrep saslPrepQuery diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs index 9077d5b..1ee3266 100644 --- a/source/Network/Xmpp/Stream.hs +++ b/source/Network/Xmpp/Stream.hs @@ -7,27 +7,26 @@ module Network.Xmpp.Stream where -import Control.Applicative ((<$>), (<*>)) +import Control.Applicative ((<$>)) import Control.Concurrent (forkIO, threadDelay) import Control.Concurrent.STM import qualified Control.Exception as Ex -import Control.Exception.Base import qualified Control.Exception.Lifted as ExL import Control.Monad import Control.Monad.Error -import Control.Monad.IO.Class -import Control.Monad.Reader import Control.Monad.State.Strict -import Control.Monad.Trans.Class +import Control.Monad.Trans.Resource as R import Data.ByteString (ByteString) import qualified Data.ByteString as BS -import Data.ByteString.Base64 import qualified Data.ByteString.Char8 as BSC8 import Data.Conduit import Data.Conduit.Binary as CB import qualified Data.Conduit.Internal as DCI import qualified Data.Conduit.List as CL -import Data.Maybe (fromJust, isJust, isNothing) +import Data.IP +import Data.List +import Data.Maybe +import Data.Ord import Data.Text (Text) import qualified Data.Text as Text import Data.Void (Void) @@ -35,27 +34,18 @@ import Data.XML.Pickle import Data.XML.Types import qualified GHC.IO.Exception as GIE import Network +import Network.DNS hiding (encode, lookup) import Network.Xmpp.Marshal import Network.Xmpp.Types import System.IO import System.IO.Error (tryIOError) import System.Log.Logger +import System.Random (randomRIO) import Text.XML.Stream.Parse as XP import Text.XML.Unresolved(InvalidEventStream(..)) -import Control.Monad.Trans.Resource as R import Network.Xmpp.Utilities -import Network.DNS hiding (encode, lookup) - -import Data.Ord -import Data.Maybe -import Data.List -import Data.IP -import System.Random - -import qualified Network.Socket as NS - -- "readMaybe" definition, as readMaybe is not introduced in the `base' package -- until version 4.6. readMaybe_ :: (Read a) => String -> Maybe a @@ -73,6 +63,17 @@ lmb :: [t] -> Maybe [t] lmb [] = Nothing lmb x = Just x +pushing :: MonadIO m => + m (Either XmppFailure Bool) + -> ErrorT XmppFailure m () +pushing m = do + res <- ErrorT m + case res of + True -> return () + False -> do + liftIO $ debugM "Pontarius.Xmpp" "Failed to send data." + throwError XmppOtherFailure + -- Unpickles and returns a stream element. streamUnpickleElem :: PU [Node] a -> Element @@ -115,33 +116,34 @@ openElementFromEvents = do startStream :: StateT StreamState IO (Either XmppFailure ()) startStream = runErrorT $ do lift $ lift $ debugM "Pontarius.Xmpp" "Starting stream..." - state <- lift $ get + st <- lift $ get -- Set the `from' (which is also the expected to) attribute depending on the -- state of the stream. - let expectedTo = case ( streamConnectionState state - , toJid $ streamConfiguration state) of - (Plain, (Just (jid, True))) -> Just jid - (Secured, (Just (jid, _))) -> Just jid - (Plain, Nothing) -> Nothing - (Secured, Nothing) -> Nothing - case streamAddress state of + let expectedTo = case ( streamConnectionState st + , toJid $ streamConfiguration st) of + (Plain , (Just (jid, True))) -> Just jid + (Plain , _ ) -> Nothing + (Secured, (Just (jid, _ ))) -> Just jid + (Secured, Nothing ) -> Nothing + (Closed , _ ) -> Nothing + case streamAddress st of Nothing -> do lift $ lift $ errorM "Pontarius.XMPP" "Server sent no hostname." throwError XmppOtherFailure - Just address -> lift $ do - pushXmlDecl - pushOpenElement $ + Just address -> do + pushing pushXmlDecl + pushing . pushOpenElement $ pickleElem xpStream ( "1.0" , expectedTo , Just (Jid Nothing address Nothing) , Nothing - , preferredLang $ streamConfiguration state + , preferredLang $ streamConfiguration st ) response <- ErrorT $ runEventsSink $ runErrorT $ streamS expectedTo case response of Left e -> throwError e -- Successful unpickling of stream element. - Right (Right (ver, from, to, id, lt, features)) + Right (Right (ver, from, to, sid, lt, features)) | (Text.unpack ver) /= "1.0" -> closeStreamWithError StreamUnsupportedVersion Nothing "Unknown version" @@ -149,7 +151,7 @@ startStream = runErrorT $ do closeStreamWithError StreamInvalidXml Nothing "Stream has no language tag" -- If `from' is set, we verify that it's the correct one. TODO: Should we check against the realm instead? - | isJust from && (from /= Just (Jid Nothing (fromJust $ streamAddress state) Nothing)) -> + | isJust from && (from /= Just (Jid Nothing (fromJust $ streamAddress st) Nothing)) -> closeStreamWithError StreamInvalidFrom Nothing "Stream from is invalid" | to /= expectedTo -> @@ -158,12 +160,12 @@ startStream = runErrorT $ do | otherwise -> do modify (\s -> s{ streamFeatures = features , streamLang = lt - , streamId = id + , streamId = sid , streamFrom = from } ) return () -- Unpickling failed - we investigate the element. - Right (Left (Element name attrs children)) + Right (Left (Element name attrs _children)) | (nameLocalName name /= "stream") -> closeStreamWithError StreamInvalidXml Nothing "Root element is not stream" @@ -180,10 +182,10 @@ startStream = runErrorT $ do closeStreamWithError :: StreamErrorCondition -> Maybe Element -> String -> ErrorT XmppFailure (StateT StreamState IO) () closeStreamWithError sec el msg = do - lift . pushElement . pickleElem xpStreamError + void . lift . pushElement . pickleElem xpStreamError $ StreamErrorInfo sec Nothing el - lift $ closeStreams' - lift $ lift $ errorM "Pontarius.XMPP" $ "closeStreamWithError: " ++ msg + void . lift $ closeStreams' + liftIO $ errorM "Pontarius.XMPP" $ "closeStreamWithError: " ++ msg throwError XmppOtherFailure checkchildren children = let to' = lookup "to" children @@ -207,12 +209,12 @@ startStream = runErrorT $ do "" safeRead x = case reads $ Text.unpack x of [] -> Nothing - [(y,_),_] -> Just y + ((y,_):_) -> Just y flattenAttrs :: [(Name, [Content])] -> [(Name, Text.Text)] -flattenAttrs attrs = Prelude.map (\(name, content) -> +flattenAttrs attrs = Prelude.map (\(name, cont) -> ( name - , Text.concat $ Prelude.map uncontentify content) + , Text.concat $ Prelude.map uncontentify cont) ) attrs where @@ -230,11 +232,11 @@ restartStream = do modify (\s -> s{streamEventSource = newSource }) startStream where - loopRead read = do - bs <- liftIO (read 4096) + loopRead rd = do + bs <- liftIO (rd 4096) if BS.null bs then return () - else yield bs >> loopRead read + else yield bs >> loopRead rd -- Reads the (partial) stream:stream and the server features from the stream. -- Returns the (unvalidated) stream attributes, the unparsed element, or @@ -248,12 +250,12 @@ streamS :: Maybe Jid -> StreamSink (Either Element ( Text , Maybe Text , Maybe LangTag , StreamFeatures )) -streamS expectedTo = do - header <- xmppStreamHeader - case header of - Right (version, from, to, id, langTag) -> do +streamS _expectedTo = do -- TODO: check expectedTo + streamHeader <- xmppStreamHeader + case streamHeader of + Right (version, from, to, sid, lTag) -> do features <- xmppStreamFeatures - return $ Right (version, from, to, id, langTag, features) + return $ Right (version, from, to, sid, lTag, features) Left el -> return $ Left el where xmppStreamHeader :: StreamSink (Either Element (Text, Maybe Jid, Maybe Jid, Maybe Text.Text, Maybe LangTag)) @@ -281,7 +283,7 @@ openStream :: HostName -> StreamConfiguration -> IO (Either XmppFailure (Stream) openStream realm config = runErrorT $ do lift $ debugM "Pontarius.XMPP" "Opening stream..." stream' <- createStream realm config - result <- liftIO $ withStream startStream stream' + ErrorT . liftIO $ withStream startStream stream' return stream' -- | Send "" and wait for the server to finish processing and to @@ -290,14 +292,15 @@ openStream realm config = runErrorT $ do closeStreams :: Stream -> IO (Either XmppFailure [Element]) closeStreams = withStream closeStreams' +closeStreams' :: StateT StreamState IO (Either XmppFailure [Element]) closeStreams' = do lift $ debugM "Pontarius.XMPP" "Closing stream..." send <- gets (streamSend . streamHandle) cc <- gets (streamClose . streamHandle) - liftIO $ send "" + void . liftIO $ send "" void $ liftIO $ forkIO $ do threadDelay 3000000 -- TODO: Configurable value - (Ex.try cc) :: IO (Either Ex.SomeException ()) + void ((Ex.try cc) :: IO (Either Ex.SomeException ())) return () collectElems [] where @@ -379,8 +382,8 @@ pullElement = do -- Pulls an element and unpickles it. pullUnpickle :: PU [Node] a -> StateT StreamState IO (Either XmppFailure a) pullUnpickle p = do - elem <- pullElement - case elem of + el <- pullElement + case el of Left e -> return $ Left e Right elem' -> do let res = unpickleElem p elem' @@ -491,17 +494,17 @@ connect realm config = do UseSrv host -> connectSrv host UseRealm -> connectSrv realm where - connectSrv realm = do - case checkHostName (Text.pack realm) of - Just realm' -> do + connectSrv host = do + case checkHostName (Text.pack host) of + Just host' -> do resolvSeed <- lift $ makeResolvSeed (resolvConf config) lift $ debugM "Pontarius.Xmpp" "Performing SRV lookup..." - srvRecords <- srvLookup realm' resolvSeed + srvRecords <- srvLookup host' resolvSeed case srvRecords of Nothing -> do lift $ debugM "Pontarius.Xmpp" "No SRV records, using fallback process." - lift $ resolvAndConnectTcp resolvSeed (BSC8.pack $ realm) + lift $ resolvAndConnectTcp resolvSeed (BSC8.pack $ host) 5222 Just srvRecords' -> do lift $ debugM "Pontarius.Xmpp" @@ -517,10 +520,10 @@ connect realm config = do connectTcp :: [(HostName, PortID)] -> IO (Maybe Handle) connectTcp [] = return Nothing connectTcp ((address, port):remainder) = do - result <- try $ (do + result <- Ex.try $ (do debugM "Pontarius.Xmpp" $ "Connecting to " ++ address ++ " on port " ++ (show port) ++ "." - connectTo address port) :: IO (Either IOException Handle) + connectTo address port) :: IO (Either Ex.IOException Handle) case result of Right handle -> do debugM "Pontarius.Xmpp" "Successfully connected to HostName." @@ -534,23 +537,25 @@ connectTcp ((address, port):remainder) = do -- Surpresses all IO exceptions. resolvAndConnectTcp :: ResolvSeed -> Domain -> Int -> IO (Maybe Handle) resolvAndConnectTcp resolvSeed domain port = do - aaaaResults <- (try $ rethrowErrorCall $ withResolver resolvSeed $ - \resolver -> lookupAAAA resolver domain) :: IO (Either IOException (Maybe [IPv6])) + aaaaResults <- (Ex.try $ rethrowErrorCall $ withResolver resolvSeed $ + \resolver -> lookupAAAA resolver domain) :: IO (Either Ex.IOException (Maybe [IPv6])) handle <- case aaaaResults of Right Nothing -> return Nothing Right (Just ipv6s) -> connectTcp $ - map (\ipv6 -> ( show ipv6 + map (\ip -> ( show ip , PortNumber $ fromIntegral port)) ipv6s - Left e -> return Nothing + Left _e -> return Nothing case handle of Nothing -> do - aResults <- (try $ rethrowErrorCall $ withResolver resolvSeed $ - \resolver -> lookupA resolver domain) :: IO (Either IOException (Maybe [IPv4])) + aResults <- (Ex.try $ rethrowErrorCall $ withResolver resolvSeed $ + \resolver -> lookupA resolver domain) :: IO (Either Ex.IOException (Maybe [IPv4])) handle' <- case aResults of + Left _ -> return Nothing Right Nothing -> return Nothing + Right (Just ipv4s) -> connectTcp $ - map (\ipv4 -> (show ipv4 + map (\ip -> (show ip , PortNumber $ fromIntegral port)) ipv4s @@ -574,29 +579,30 @@ resolvSrvsAndConnectTcp resolvSeed ((domain, port):remaining) = do -- exceptions and rethrows them as IOExceptions. rethrowErrorCall :: IO a -> IO a rethrowErrorCall action = do - result <- try action + result <- Ex.try action case result of Right result' -> return result' - Left (ErrorCall e) -> ioError $ userError $ "rethrowErrorCall: " ++ e - Left e -> throwIO e + Left (Ex.ErrorCall e) -> Ex.ioError $ userError + $ "rethrowErrorCall: " ++ e -- Provides a list of A(AAA) names and port numbers upon a successful -- DNS-SRV request, or `Nothing' if the DNS-SRV request failed. srvLookup :: Text -> ResolvSeed -> ErrorT XmppFailure IO (Maybe [(Domain, Int)]) srvLookup realm resolvSeed = ErrorT $ do - result <- try $ rethrowErrorCall $ withResolver resolvSeed $ \resolver -> do + result <- Ex.try $ rethrowErrorCall $ withResolver resolvSeed + $ \resolver -> do srvResult <- lookupSRV resolver $ BSC8.pack $ "_xmpp-client._tcp." ++ (Text.unpack realm) ++ "." case srvResult of - Just srvResult -> do - debugM "Pontarius.Xmpp" $ "SRV result: " ++ (show srvResult) - -- Get [(Domain, PortNumber)] of SRV request, if any. - srvResult' <- orderSrvResult srvResult - return $ Just $ Prelude.map (\(_, _, port, domain) -> (domain, port)) srvResult' - -- The service is not available at this domain. - -- Sorts the records based on the priority value. Just [(_, _, _, ".")] -> do debugM "Pontarius.Xmpp" $ "\".\" SRV result returned." return $ Just [] + Just srvResult' -> do + debugM "Pontarius.Xmpp" $ "SRV result: " ++ (show srvResult') + -- Get [(Domain, PortNumber)] of SRV request, if any. + orderedSrvResult <- orderSrvResult srvResult' + return $ Just $ Prelude.map (\(_, _, port, domain) -> (domain, port)) orderedSrvResult + -- The service is not available at this domain. + -- Sorts the records based on the priority value. Nothing -> do debugM "Pontarius.Xmpp" "No SRV result returned." return Nothing @@ -627,7 +633,7 @@ srvLookup realm resolvSeed = ErrorT $ do orderSublist sublist = do -- Compute the running sum, as well as the total sum of -- the sublist. Add the running sum to the SRV tuples. - let (total, sublist') = Data.List.mapAccumL (\total (priority, weight, port, domain) -> (total + weight, (priority, weight, port, domain, total + weight))) 0 sublist + let (total, sublist') = Data.List.mapAccumL (\total' (priority, weight, port, domain) -> (total' + weight, (priority, weight, port, domain, total' + weight))) 0 sublist -- Choose a random number between 0 and the total sum -- (inclusive). randomNumber <- randomRIO (0, total) @@ -636,11 +642,11 @@ srvLookup realm resolvSeed = ErrorT $ do let (beginning, ((priority, weight, port, domain, _):end)) = Data.List.break (\(_, _, _, _, running) -> randomNumber <= running) sublist' -- Remove the running total number from the remaining -- elements. - let sublist'' = Data.List.map (\(priority, weight, port, domain, _) -> (priority, weight, port, domain)) (Data.List.concat [beginning, end]) + let sublist'' = Data.List.map (\(priority', weight', port', domain', _) -> (priority', weight', port', domain')) (Data.List.concat [beginning, end]) -- Repeat the ordering procedure on the remaining -- elements. - tail <- orderSublist sublist'' - return $ ((priority, weight, port, domain):tail) + rest <- orderSublist sublist'' + return $ ((priority, weight, port, domain):rest) -- Closes the connection and updates the XmppConMonad Stream state. -- killStream :: Stream -> IO (Either ExL.SomeException ()) @@ -661,23 +667,24 @@ pushIQ :: StanzaID -> Element -> Stream -> IO (Either XmppFailure (Either IQError IQResult)) -pushIQ iqID to tp lang body stream = do - pushStanza (IQRequestS $ IQRequest iqID Nothing to lang tp body) stream - res <- pullStanza stream +pushIQ iqID to tp lang body stream = runErrorT $ do + pushing $ pushStanza + (IQRequestS $ IQRequest iqID Nothing to lang tp body) stream + res <- lift $ pullStanza stream case res of - Left e -> return $ Left e - Right (IQErrorS e) -> return $ Right $ Left e + Left e -> throwError e + Right (IQErrorS e) -> return $ Left e Right (IQResultS r) -> do unless (iqID == iqResultID r) $ liftIO $ do - errorM "Pontarius.XMPP" $ "pushIQ: ID mismatch (" ++ (show iqID) ++ " /= " ++ (show $ iqResultID r) ++ ")." - ExL.throwIO XmppOtherFailure + liftIO $ errorM "Pontarius.XMPP" $ "pushIQ: ID mismatch (" ++ (show iqID) ++ " /= " ++ (show $ iqResultID r) ++ ")." + liftIO $ ExL.throwIO XmppOtherFailure -- TODO: Log: ("In sendIQ' IDs don't match: " ++ show iqID ++ -- " /= " ++ show (iqResultID r) ++ " .") - return $ Right $ Right r + return $ Right r _ -> do - errorM "Pontarius.XMPP" $ "pushIQ: Unexpected stanza type." - return . Left $ XmppOtherFailure + liftIO $ errorM "Pontarius.XMPP" $ "pushIQ: Unexpected stanza type." + throwError XmppOtherFailure debugConduit :: Pipe l ByteString ByteString u IO b debugConduit = forever $ do @@ -695,7 +702,9 @@ elements = do Just (EventBeginElement n as) -> do goE n as >>= yield elements - Just (EventEndElement streamName) -> lift $ R.monadThrow StreamEnd + -- This might be an XML error if the end element tag is not + -- "". TODO: We might want to check this at a later time + Just (EventEndElement _) -> lift $ R.monadThrow StreamEnd Nothing -> return () _ -> lift $ R.monadThrow $ InvalidXmppXml $ "not an element: " ++ show x where @@ -705,8 +714,8 @@ elements = do go front = do x <- f case x of - Left x -> return $ (x, front []) - Right y -> go (front . (:) y) + Left l -> return $ (l, front []) + Right r -> go (front . (:) r) goE n as = do (y, ns) <- many' goN if y == Just (EventEndElement n) @@ -730,11 +739,8 @@ elements = do compressNodes $ NodeContent (ContentText $ x `Text.append` y) : z compressNodes (x:xs) = x : compressNodes xs - streamName :: Name - streamName = (Name "stream" (Just "http://etherx.jabber.org/streams") (Just "stream")) - withStream :: StateT StreamState IO (Either XmppFailure c) -> Stream -> IO (Either XmppFailure c) -withStream action (Stream stream) = bracketOnError +withStream action (Stream stream) = Ex.bracketOnError (atomically $ takeTMVar stream ) (atomically . putTMVar stream) (\s -> do diff --git a/source/Network/Xmpp/Tls.hs b/source/Network/Xmpp/Tls.hs index 71e9b8d..88b56f1 100644 --- a/source/Network/Xmpp/Tls.hs +++ b/source/Network/Xmpp/Tls.hs @@ -4,7 +4,6 @@ module Network.Xmpp.Tls where -import Control.Concurrent.STM.TMVar import qualified Control.Exception.Lifted as Ex import Control.Monad import Control.Monad.Error @@ -14,16 +13,14 @@ import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as BSC8 import qualified Data.ByteString.Lazy as BL import Data.Conduit -import qualified Data.Conduit.Binary as CB import Data.IORef -import Data.Typeable import Data.XML.Types import Network.TLS -import Network.TLS.Extra import Network.Xmpp.Stream import Network.Xmpp.Types import System.Log.Logger (debugM, errorM) +mkBackend :: StreamHandle -> Backend mkBackend con = Backend { backendSend = \bs -> void (streamSend con bs) , backendRecv = streamReceive con , backendFlush = streamFlush con @@ -61,31 +58,39 @@ tls con = Ex.handle (return . Left . TlsError) where startTls = do params <- gets $ tlsParams . streamConfiguration - lift $ pushElement starttlsE + sent <- ErrorT $ pushElement starttlsE + unless sent $ do + liftIO $ errorM "Pontarius.XMPP" "startTls: Could not sent stanza." + throwError XmppOtherFailure answer <- lift $ pullElement case answer of - Left e -> return $ Left e + Left e -> throwError e Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}proceed" [] []) -> - return $ Right () + return () Right (Element "{urn:ietf:params:xml:ns:xmpp-tls}failure" _ _) -> do liftIO $ errorM "Pontarius.XMPP" "startTls: TLS initiation failed." - return . Left $ XmppOtherFailure + throwError XmppOtherFailure + Right r -> + liftIO $ errorM "Pontarius.XMPP" $ + "startTls: Unexpected element: " ++ show r hand <- gets streamHandle - (raw, _snk, psh, read, ctx) <- lift $ tlsinit params (mkBackend hand) + (_raw, _snk, psh, recv, ctx) <- lift $ tlsinit params (mkBackend hand) let newHand = StreamHandle { streamSend = catchPush . psh - , streamReceive = read - , streamFlush = contextFlush ctx - , streamClose = bye ctx >> streamClose hand - } + , streamReceive = recv + , streamFlush = contextFlush ctx + , streamClose = bye ctx >> streamClose hand + } lift $ modify ( \x -> x {streamHandle = newHand}) either (lift . Ex.throwIO) return =<< lift restartStream modify (\s -> s{streamConnectionState = Secured}) return () +client :: (MonadIO m, CPRG rng) => Params -> rng -> Backend -> m Context client params gen backend = do contextNew backend params gen -defaultParams = defaultParamsClient +xmppDefaultParams :: Params +xmppDefaultParams = defaultParamsClient tlsinit :: (MonadIO m, MonadIO m1) => TLSParams @@ -96,10 +101,10 @@ tlsinit :: (MonadIO m, MonadIO m1) => , Int -> m1 BS.ByteString , Context ) -tlsinit tlsParams backend = do +tlsinit params backend = do liftIO $ debugM "Pontarius.Xmpp.TLS" "TLS with debug mode enabled." gen <- liftIO $ getSystemRandomGen -- TODO: Find better random source? - con <- client tlsParams gen backend + con <- client params gen backend handshake con let src = forever $ do dt <- liftIO $ recvData con @@ -114,22 +119,22 @@ tlsinit tlsParams backend = do liftIO $ debugM "Pontarius.Xmpp.TLS" ("out :" ++ BSC8.unpack x) snk - read <- liftIO $ mkReadBuffer (recvData con) + readWithBuffer <- liftIO $ mkReadBuffer (recvData con) return ( src , snk , \s -> do liftIO $ debugM "Pontarius.Xmpp.TLS" ("out :" ++ BSC8.unpack s) sendData con $ BL.fromChunks [s] - , liftIO . read + , liftIO . readWithBuffer , con ) mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString) -mkReadBuffer read = do +mkReadBuffer recv = do buffer <- newIORef BS.empty let read' n = do nc <- readIORef buffer - bs <- if BS.null nc then read + bs <- if BS.null nc then recv else return nc let (result, rest) = BS.splitAt n bs writeIORef buffer rest diff --git a/source/Network/Xmpp/Utilities.hs b/source/Network/Xmpp/Utilities.hs index 419bd6b..c11d58e 100644 --- a/source/Network/Xmpp/Utilities.hs +++ b/source/Network/Xmpp/Utilities.hs @@ -3,76 +3,27 @@ {-# OPTIONS_HADDOCK hide #-} -module Network.Xmpp.Utilities (presTo, message, answerMessage, openElementToEvents, renderOpenElement, renderElement) where - -import Network.Xmpp.Types - -import Control.Monad.STM -import Control.Concurrent.STM.TVar -import Prelude - -import Data.XML.Types - -import qualified Data.Attoparsec.Text as AP -import qualified Data.Text as Text - +module Network.Xmpp.Utilities + ( presTo + , message + , answerMessage + , openElementToEvents + , renderOpenElement + , renderElement) + where + +import Network.Xmpp.Types +import Prelude +import Data.XML.Types import qualified Data.ByteString as BS +import Data.Conduit as C +import Data.Conduit.List as CL import qualified Data.Text as Text import qualified Data.Text.Encoding as Text import System.IO.Unsafe(unsafePerformIO) -import Data.Conduit.List as CL --- import Data.Typeable -import Control.Applicative ((<$>)) -import Control.Exception -import Control.Monad.Trans.Class - -import Data.Conduit as C -import Data.XML.Types - import qualified Text.XML.Stream.Render as TXSR import Text.XML.Unresolved as TXU - --- TODO: Not used, and should probably be removed. --- | Creates a new @IdGenerator@. Internally, it will maintain an infinite list --- of IDs ('[\'a\', \'b\', \'c\'...]'). The argument is a prefix to prepend the --- IDs with. Calling the function will extract an ID and update the generator's --- internal state so that the same ID will not be generated again. -idGenerator :: Text.Text -> IO IdGenerator -idGenerator prefix = atomically $ do - tvar <- newTVar $ ids prefix - return $ IdGenerator $ next tvar - where - -- Transactionally extract the next ID from the infinite list of IDs. - next :: TVar [Text.Text] -> IO Text.Text - next tvar = atomically $ do - list <- readTVar tvar - case list of - [] -> error "empty list in Utilities.hs" - (x:xs) -> do - writeTVar tvar xs - return x - - -- Generates an infinite and predictable list of IDs, all beginning with the - -- provided prefix. Adds the prefix to all combinations of IDs (ids'). - ids :: Text.Text -> [Text.Text] - ids p = Prelude.map (\ id -> Text.append p id) ids' - where - -- Generate all combinations of IDs, with increasing length. - ids' :: [Text.Text] - ids' = Prelude.map Text.pack $ Prelude.concatMap ids'' [1..] - -- Generates all combinations of IDs with the given length. - ids'' :: Integer -> [String] - ids'' 0 = [""] - ids'' l = [x:xs | x <- repertoire, xs <- ids'' (l - 1)] - -- Characters allowed in IDs. - repertoire :: String - repertoire = ['a'..'z'] - --- Constructs a "Version" based on the major and minor version numbers. -versionFromNumbers :: Integer -> Integer -> Version -versionFromNumbers major minor = Version major minor - -- | Add a recipient to a presence notification. presTo :: Presence -> Jid -> Presence presTo pres to = pres{presenceTo = Just to} @@ -124,4 +75,5 @@ renderElement e = Text.encodeUtf8 . Text.concat . unsafePerformIO $ CL.sourceList (elementToEvents e) $$ TXSR.renderText def =$ CL.consume where elementToEvents :: Element -> [Event] - elementToEvents e@(Element name _ _) = openElementToEvents e ++ [EventEndElement name] + elementToEvents el@(Element name _ _) = openElementToEvents el + ++ [EventEndElement name] diff --git a/source/Network/Xmpp/Xep/DataForms.hs b/source/Network/Xmpp/Xep/DataForms.hs index 9491acd..2c2b733 100644 --- a/source/Network/Xmpp/Xep/DataForms.hs +++ b/source/Network/Xmpp/Xep/DataForms.hs @@ -7,12 +7,9 @@ module Network.Xmpp.Xep.DataForms where import qualified Data.Text as Text +import Data.XML.Pickle import qualified Data.XML.Types as XML -import Data.XML.Pickle -import qualified Data.Text as Text - -import qualified Text.XML.Stream.Parse as Parse dataFormNs :: Text.Text dataFormNs = "jabber:x:data" @@ -95,12 +92,12 @@ instance Read FieldType where xpForm :: PU [XML.Node] Form -xpForm = xpWrap (\(tp , (title, instructions, fields, reported, items)) -> - Form tp title (map snd instructions) fields reported (map snd items)) - (\(Form tp title instructions fields reported items) -> +xpForm = xpWrap (\(tp , (ttl, ins, flds, rpd, its)) -> + Form tp ttl (map snd ins) flds rpd (map snd its)) + (\(Form tp ttl ins flds rpd its) -> (tp , - (title, map ((),) instructions - , fields, reported, map ((),) items))) + (ttl, map ((),) ins + , flds, rpd, map ((),) its))) $ xpElem (dataFormName "x") @@ -113,10 +110,10 @@ xpForm = xpWrap (\(tp , (title, instructions, fields, reported, items)) -> (xpElems (dataFormName "item") xpUnit xpFields)) xpFields :: PU [XML.Node] [Field] -xpFields = xpWrap (map $ \((var, tp, label),(desc, req, vals, opts)) - -> Field var label tp desc req vals opts) - (map $ \(Field var label tp desc req vals opts) - -> ((var, tp, label),(desc, req, vals, opts))) $ +xpFields = xpWrap (map $ \((var, tp, lbl),(desc, req, vals, opts)) + -> Field var lbl tp desc req vals opts) + (map $ \(Field var lbl tp desc req vals opts) + -> ((var, tp, lbl),(desc, req, vals, opts))) $ xpElems (dataFormName "field") (xp3Tuple (xpAttrImplied "var" xpId )