diff --git a/source/Network/Xmpp/Sasl.hs b/source/Network/Xmpp/Sasl.hs index d338c0c..fff8bb2 100644 --- a/source/Network/Xmpp/Sasl.hs +++ b/source/Network/Xmpp/Sasl.hs @@ -67,7 +67,7 @@ xmppSasl :: [SaslHandler] -- ^ Acceptable authentication mechanisms and their -- corresponding handlers -> TMVar Stream -> IO (Either XmppFailure (Maybe AuthFailure)) -xmppSasl handlers = withStream $ do +xmppSasl handlers stream = (flip withStream stream) $ do -- Chooses the first mechanism that is acceptable by both the client and the -- server. mechanisms <- gets $ streamSaslMechanisms . streamFeatures @@ -77,13 +77,7 @@ xmppSasl handlers = withStream $ do cs <- gets streamState case cs of Closed -> return . Right $ Just AuthNoStream - _ -> do - r <- runErrorT handler - case r of - Left ae -> return $ Right $ Just ae - Right a -> do - _ <- runErrorT $ ErrorT restartStream - return $ Right $ Nothing + _ -> lift $ handler stream -- | Authenticate to the server using the first matching method and bind a -- resource. diff --git a/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs b/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs index bca3ab5..f75df3e 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/DigestMd5.hs @@ -37,7 +37,7 @@ import Network.Xmpp.Sasl.Common import Network.Xmpp.Sasl.StringPrep import Network.Xmpp.Sasl.Types - +import Control.Concurrent.STM xmppDigestMd5 :: Text -- ^ Authentication identity (authzid or username) -> Maybe Text -- ^ Authorization identity (authcid) @@ -127,6 +127,25 @@ digestMd5 :: Text -- ^ Authentication identity (authcid or username) -> Maybe Text -- ^ Authorization identity (authzid) -> Text -- ^ Password -> SaslHandler -digestMd5 authcid authzid password = ( "DIGEST-MD5" - , xmppDigestMd5 authcid authzid password - ) +digestMd5 authcid authzid password = + ( "DIGEST-MD5" + , \stream -> do + stream_ <- atomically $ readTMVar stream + r <- runErrorT $ do + -- Alrighty! The problem here is that `scramSha1' runs in the + -- `IO (Either XmppFailure (Maybe AuthFailure))' monad, while we need + -- to call an `ErrorT AuthFailure (StateT Stream IO) ()' calculation. + -- The key is to use `mapErrorT', which is called with the following + -- ypes: + -- + -- mapErrorT :: (StateT Stream IO (Either AuthError ()) -> IO (Either AuthError ())) + -- -> ErrorT AuthError (StateT Stream IO) () + -- -> ErrorT AuthError IO () + mapErrorT + (\s -> runStateT s stream_ >>= \(r, _) -> return r) + (xmppDigestMd5 authcid authzid password) + case r of + Left (AuthStreamFailure e) -> return $ Left e + Left e -> return $ Right $ Just e + Right () -> return $ Right $ Nothing + ) diff --git a/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs b/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs index 3e85a50..545dd21 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/Plain.hs @@ -46,6 +46,8 @@ import qualified Data.Text as Text import Network.Xmpp.Sasl.Common import Network.Xmpp.Sasl.Types +import Control.Concurrent.STM + -- TODO: stringprep xmppPlain :: Text.Text -- ^ Password -> Maybe Text.Text -- ^ Authorization identity (authzid) @@ -77,4 +79,27 @@ plain :: Text.Text -- ^ authentication ID (username) -> Maybe Text.Text -- ^ authorization ID -> Text.Text -- ^ password -> SaslHandler -plain authcid authzid passwd = ("PLAIN", xmppPlain authcid authzid passwd) +plain authcid authzid passwd = + ( "PLAIN" + , \stream -> do + stream_ <- atomically $ readTMVar stream + r <- runErrorT $ do + -- Alrighty! The problem here is that `scramSha1' runs in the + -- `IO (Either XmppFailure (Maybe AuthFailure))' monad, while we need + -- to call an `ErrorT AuthFailure (StateT Stream IO) ()' calculation. + -- The key is to use `mapErrorT', which is called with the following + -- ypes: + -- + -- mapErrorT :: (StateT Stream IO (Either AuthError ()) -> IO (Either AuthError ())) + -- -> ErrorT AuthError (StateT Stream IO) () + -- -> ErrorT AuthError IO () + mapErrorT + (\s -> runStateT s stream_ >>= \(r, _) -> return r) + (xmppPlain authcid authzid passwd) + case r of + Left (AuthStreamFailure e) -> return $ Left e + Left e -> return $ Right $ Just e + Right () -> return $ Right $ Nothing + ) + + diff --git a/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs b/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs index 4262c63..809a95b 100644 --- a/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs +++ b/source/Network/Xmpp/Sasl/Mechanisms/Scram.hs @@ -31,8 +31,8 @@ import Network.Xmpp.Sasl.StringPrep import Network.Xmpp.Sasl.Types import Network.Xmpp.Types - import Control.Monad.State.Strict +import Control.Concurrent.STM -- | A nicer name for undefined, for use as a dummy token to determin -- the hash function to use @@ -164,6 +164,24 @@ scramSha1 :: Text.Text -- ^ username -> Text.Text -- ^ password -> SaslHandler scramSha1 authcid authzid passwd = - ("SCRAM-SHA-1" - , scram (hashToken :: Crypto.SHA1) authcid authzid passwd + ( "SCRAM-SHA-1" + , \stream -> do + stream_ <- atomically $ readTMVar stream + r <- runErrorT $ do + -- Alrighty! The problem here is that `scramSha1' runs in the + -- `IO (Either XmppFailure (Maybe AuthFailure))' monad, while we need + -- to call an `ErrorT AuthFailure (StateT Stream IO) ()' calculation. + -- The key is to use `mapErrorT', which is called with the following + -- ypes: + -- + -- mapErrorT :: (StateT Stream IO (Either AuthError ()) -> IO (Either AuthError ())) + -- -> ErrorT AuthError (StateT Stream IO) () + -- -> ErrorT AuthError IO () + mapErrorT + (\s -> runStateT s stream_ >>= \(r, _) -> return r) + (scram (hashToken :: Crypto.SHA1) authcid authzid passwd) + case r of + Left (AuthStreamFailure e) -> return $ Left e + Left e -> return $ Right $ Just e + Right () -> return $ Right $ Nothing ) diff --git a/source/Network/Xmpp/Sasl/Types.hs b/source/Network/Xmpp/Sasl/Types.hs index c341585..8aea51e 100644 --- a/source/Network/Xmpp/Sasl/Types.hs +++ b/source/Network/Xmpp/Sasl/Types.hs @@ -6,6 +6,7 @@ import Control.Monad.State.Strict import Data.ByteString(ByteString) import qualified Data.Text as Text import Network.Xmpp.Types +import Control.Concurrent.STM data AuthFailure = AuthXmlFailure | AuthNoAcceptableMechanism [Text.Text] -- ^ Wraps mechanisms @@ -32,4 +33,4 @@ type Pairs = [(ByteString, ByteString)] -- | Tuple defining the SASL Handler's name, and a SASL mechanism computation. -- The SASL mechanism is a stateful @Stream@ computation, which has the -- possibility of resulting in an authentication error. -type SaslHandler = (Text.Text, ErrorT AuthFailure (StateT Stream IO) ()) +type SaslHandler = (Text.Text, (TMVar Stream -> IO (Either XmppFailure (Maybe AuthFailure))))