diff --git a/source/Network/Xmpp/Sasl.hs b/source/Network/Xmpp/Sasl.hs index 4d8a952..9940a5c 100644 --- a/source/Network/Xmpp/Sasl.hs +++ b/source/Network/Xmpp/Sasl.hs @@ -38,6 +38,7 @@ import qualified Data.Text.Encoding as Text import Network.Xmpp.Stream import Network.Xmpp.Types +import System.Log.Logger (debugM) import qualified System.Random as Random import Network.Xmpp.Sasl.Types @@ -45,19 +46,19 @@ import Network.Xmpp.Sasl.Mechanisms import Control.Concurrent.STM.TMVar -import Control.Exception +import Control.Exception -import Data.XML.Pickle -import Data.XML.Types +import Data.XML.Pickle +import Data.XML.Types -import Network.Xmpp.Types -import Network.Xmpp.Marshal +import Network.Xmpp.Types +import Network.Xmpp.Marshal -import Control.Monad.State(modify) +import Control.Monad.State(modify) -import Control.Concurrent.STM.TMVar +import Control.Concurrent.STM.TMVar -import Control.Monad.Error +import Control.Monad.Error -- | Uses the first supported mechanism to authenticate, if any. Updates the -- state with non-password credentials and restarts the stream upon @@ -67,7 +68,7 @@ xmppSasl :: [SaslHandler] -- ^ Acceptable authentication mechanisms and their -- corresponding handlers -> TMVar Stream -> IO (Either XmppFailure (Maybe AuthFailure)) -xmppSasl handlers stream = (flip withStream stream) $ do +xmppSasl handlers = withStream $ do -- Chooses the first mechanism that is acceptable by both the client and the -- server. mechanisms <- gets $ streamSaslMechanisms . streamFeatures @@ -77,7 +78,13 @@ xmppSasl handlers stream = (flip withStream stream) $ do cs <- gets streamState case cs of Closed -> return . Left $ XmppNoStream - _ -> lift $ handler stream + _ -> do + r <- runErrorT handler + case r of + Left ae -> return $ Right $ Just ae + Right a -> do + _ <- runErrorT $ ErrorT restartStream + return $ Right $ Nothing -- | Authenticate to the server using the first matching method and bind a -- resource. @@ -86,8 +93,11 @@ auth :: [SaslHandler] -> TMVar Stream -> IO (Either XmppFailure (Maybe AuthFailure)) auth mechanisms resource con = runErrorT $ do + liftIO $ debugM "Pontarius.Xmpp" "pre-auth" ErrorT $ xmppSasl mechanisms con + liftIO $ debugM "Pontarius.Xmpp" "auth done" jid <- lift $ xmppBind resource con + liftIO $ debugM "Pontarius.Xmpp" $ "bound resource" ++ show jid lift $ startSession con return Nothing diff --git a/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs b/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs index f75df3e..bca3ab5 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs @@ -37,7 +37,7 @@ import Network.Xmpp.Sasl.Common import Network.Xmpp.Sasl.StringPrep import Network.Xmpp.Sasl.Types -import Control.Concurrent.STM + xmppDigestMd5 :: Text -- ^ Authentication identity (authzid or username) -> Maybe Text -- ^ Authorization identity (authcid) @@ -127,25 +127,6 @@ digestMd5 :: Text -- ^ Authentication identity (authcid or username) -> Maybe Text -- ^ Authorization identity (authzid) -> Text -- ^ Password -> SaslHandler -digestMd5 authcid authzid password = - ( "DIGEST-MD5" - , \stream -> do - stream_ <- atomically $ readTMVar stream - r <- runErrorT $ do - -- Alrighty! The problem here is that `scramSha1' runs in the - -- `IO (Either XmppFailure (Maybe AuthFailure))' monad, while we need - -- to call an `ErrorT AuthFailure (StateT Stream IO) ()' calculation. - -- The key is to use `mapErrorT', which is called with the following - -- ypes: - -- - -- mapErrorT :: (StateT Stream IO (Either AuthError ()) -> IO (Either AuthError ())) - -- -> ErrorT AuthError (StateT Stream IO) () - -- -> ErrorT AuthError IO () - mapErrorT - (\s -> runStateT s stream_ >>= \(r, _) -> return r) - (xmppDigestMd5 authcid authzid password) - case r of - Left (AuthStreamFailure e) -> return $ Left e - Left e -> return $ Right $ Just e - Right () -> return $ Right $ Nothing - ) +digestMd5 authcid authzid password = ( "DIGEST-MD5" + , xmppDigestMd5 authcid authzid password + ) diff --git a/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs b/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs index 545dd21..3e85a50 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs @@ -46,8 +46,6 @@ import qualified Data.Text as Text import Network.Xmpp.Sasl.Common import Network.Xmpp.Sasl.Types -import Control.Concurrent.STM - -- TODO: stringprep xmppPlain :: Text.Text -- ^ Password -> Maybe Text.Text -- ^ Authorization identity (authzid) @@ -79,27 +77,4 @@ plain :: Text.Text -- ^ authentication ID (username) -> Maybe Text.Text -- ^ authorization ID -> Text.Text -- ^ password -> SaslHandler -plain authcid authzid passwd = - ( "PLAIN" - , \stream -> do - stream_ <- atomically $ readTMVar stream - r <- runErrorT $ do - -- Alrighty! The problem here is that `scramSha1' runs in the - -- `IO (Either XmppFailure (Maybe AuthFailure))' monad, while we need - -- to call an `ErrorT AuthFailure (StateT Stream IO) ()' calculation. - -- The key is to use `mapErrorT', which is called with the following - -- ypes: - -- - -- mapErrorT :: (StateT Stream IO (Either AuthError ()) -> IO (Either AuthError ())) - -- -> ErrorT AuthError (StateT Stream IO) () - -- -> ErrorT AuthError IO () - mapErrorT - (\s -> runStateT s stream_ >>= \(r, _) -> return r) - (xmppPlain authcid authzid passwd) - case r of - Left (AuthStreamFailure e) -> return $ Left e - Left e -> return $ Right $ Just e - Right () -> return $ Right $ Nothing - ) - - +plain authcid authzid passwd = ("PLAIN", xmppPlain authcid authzid passwd) diff --git a/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs b/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs index 618ffb9..c9905e8 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs @@ -31,8 +31,8 @@ import Network.Xmpp.Sasl.StringPrep import Network.Xmpp.Sasl.Types import Network.Xmpp.Types + import Control.Monad.State.Strict -import Control.Concurrent.STM -- | A nicer name for undefined, for use as a dummy token to determin -- the hash function to use @@ -164,24 +164,6 @@ scramSha1 :: Text.Text -- ^ username -> Text.Text -- ^ password -> SaslHandler scramSha1 authcid authzid passwd = - ( "SCRAM-SHA-1" - , \stream -> do - stream_ <- atomically $ readTMVar stream - r <- runErrorT $ do - -- Alrighty! The problem here is that `scramSha1' runs in the - -- `IO (Either XmppFailure (Maybe AuthFailure))' monad, while we need - -- to call an `ErrorT AuthFailure (StateT Stream IO) ()' calculation. - -- The key is to use `mapErrorT', which is called with the following - -- ypes: - -- - -- mapErrorT :: (StateT Stream IO (Either AuthError ()) -> IO (Either AuthError ())) - -- -> ErrorT AuthError (StateT Stream IO) () - -- -> ErrorT AuthError IO () - mapErrorT - (\s -> runStateT s stream_ >>= \(r, _) -> return r) - (scram (hashToken :: Crypto.SHA1) authcid authzid passwd) - case r of - Left (AuthStreamFailure e) -> return $ Left e - Left e -> return $ Right $ Just e - Right () -> return $ Right $ Nothing + ("SCRAM-SHA-1" + , scram (hashToken :: Crypto.SHA1) authcid authzid passwd ) diff --git a/source/Network/Xmpp/Sasl/Types.hs b/source/Network/Xmpp/Sasl/Types.hs index fbdd408..e418cd2 100644 --- a/source/Network/Xmpp/Sasl/Types.hs +++ b/source/Network/Xmpp/Sasl/Types.hs @@ -6,7 +6,6 @@ import Control.Monad.State.Strict import Data.ByteString(ByteString) import qualified Data.Text as Text import Network.Xmpp.Types -import Control.Concurrent.STM -- | Signals a (non-fatal) SASL authentication error condition. data AuthFailure = -- | No mechanism offered by the server was matched @@ -35,4 +34,4 @@ type Pairs = [(ByteString, ByteString)] -- | Tuple defining the SASL Handler's name, and a SASL mechanism computation. -- The SASL mechanism is a stateful @Stream@ computation, which has the -- possibility of resulting in an authentication error. -type SaslHandler = (Text.Text, (TMVar Stream -> IO (Either XmppFailure (Maybe AuthFailure)))) +type SaslHandler = (Text.Text, ErrorT AuthFailure (StateT Stream IO) ())