From bd90cbe9038cf1b96313dce68aa6f9f0c6c4f98c Mon Sep 17 00:00:00 2001
From: Philipp Balzarek
Date: Wed, 21 Aug 2013 15:04:15 +0200
Subject: [PATCH] add tlsConnect function
---
source/Network/Xmpp.hs | 6 +++--
source/Network/Xmpp/Stream.hs | 50 +++++++++++++++++------------------
source/Network/Xmpp/Tls.hs | 21 +++++++++++++++
source/Network/Xmpp/Types.hs | 2 +-
4 files changed, 51 insertions(+), 28 deletions(-)
diff --git a/source/Network/Xmpp.hs b/source/Network/Xmpp.hs
index 5de4093..6ceb0e4 100644
--- a/source/Network/Xmpp.hs
+++ b/source/Network/Xmpp.hs
@@ -32,6 +32,8 @@ module Network.Xmpp
, StreamConfiguration(..)
, SessionConfiguration(..)
, ConnectionDetails(..)
+ , closeConnection
+ , endSession
-- TODO: Close session, etc.
-- ** Authentication handlers
-- | The use of 'scramSha1' is /recommended/, but 'digestMd5' might be
@@ -39,8 +41,6 @@ module Network.Xmpp
, scramSha1
, plain
, digestMd5
- , closeConnection
- , endSession
-- * Addressing
-- | A JID (historically: Jabber ID) is XMPPs native format
-- for addressing entities in the network. It is somewhat similar to an e-mail
@@ -184,6 +184,7 @@ module Network.Xmpp
, AuthOtherFailure )
, SaslHandler
, ConnectionState(..)
+ , connectTls
) where
import Network.Xmpp.Concurrent
@@ -191,3 +192,4 @@ import Network.Xmpp.Sasl
import Network.Xmpp.Sasl.Types
import Network.Xmpp.Stanza
import Network.Xmpp.Types
+import Network.Xmpp.Tls
diff --git a/source/Network/Xmpp/Stream.hs b/source/Network/Xmpp/Stream.hs
index ffee9e2..c011582 100644
--- a/source/Network/Xmpp/Stream.hs
+++ b/source/Network/Xmpp/Stream.hs
@@ -570,42 +570,42 @@ connect realm config = do
liftIO $ hSetBuffering h' NoBuffering
return . Just $ handleToStreamHandle h'
UseSrv host -> do
- h <- connectSrv host
+ h <- connectSrv (resolvConf config) host
case h of
Nothing -> return Nothing
Just h' -> do
liftIO $ hSetBuffering h' NoBuffering
return . Just $ handleToStreamHandle h'
UseRealm -> do
- h <- connectSrv realm
+ h <- connectSrv (resolvConf config) realm
case h of
Nothing -> return Nothing
Just h' -> do
liftIO $ hSetBuffering h' NoBuffering
return . Just $ handleToStreamHandle h'
- UseConnection mkC -> Just <$> liftIO mkC
-
- where
- connectSrv host = do
- case checkHostName (Text.pack host) of
- Just host' -> do
- resolvSeed <- lift $ makeResolvSeed (resolvConf config)
- lift $ debugM "Pontarius.Xmpp" "Performing SRV lookup..."
- srvRecords <- srvLookup host' resolvSeed
- case srvRecords of
- Nothing -> do
- lift $ debugM "Pontarius.Xmpp"
- "No SRV records, using fallback process."
- lift $ resolvAndConnectTcp resolvSeed (BSC8.pack $ host)
- 5222
- Just srvRecords' -> do
- lift $ debugM "Pontarius.Xmpp"
- "SRV records found, performing A/AAAA lookups."
- lift $ resolvSrvsAndConnectTcp resolvSeed srvRecords'
- Nothing -> do
- lift $ errorM "Pontarius.Xmpp"
- "The hostname could not be validated."
- throwError XmppIllegalTcpDetails
+ UseConnection mkC -> Just <$> mkC
+
+connectSrv :: ResolvConf -> String -> ErrorT XmppFailure IO (Maybe Handle)
+connectSrv config host = do
+ case checkHostName (Text.pack host) of
+ Just host' -> do
+ resolvSeed <- lift $ makeResolvSeed config
+ lift $ debugM "Pontarius.Xmpp" "Performing SRV lookup..."
+ srvRecords <- srvLookup host' resolvSeed
+ case srvRecords of
+ Nothing -> do
+ lift $ debugM "Pontarius.Xmpp"
+ "No SRV records, using fallback process."
+ lift $ resolvAndConnectTcp resolvSeed (BSC8.pack $ host)
+ 5222
+ Just srvRecords' -> do
+ lift $ debugM "Pontarius.Xmpp"
+ "SRV records found, performing A/AAAA lookups."
+ lift $ resolvSrvsAndConnectTcp resolvSeed srvRecords'
+ Nothing -> do
+ lift $ errorM "Pontarius.Xmpp"
+ "The hostname could not be validated."
+ throwError XmppIllegalTcpDetails
-- Connects to a list of addresses and ports. Surpresses any exceptions from
-- connectTcp.
diff --git a/source/Network/Xmpp/Tls.hs b/source/Network/Xmpp/Tls.hs
index 548d07b..e2a08e4 100644
--- a/source/Network/Xmpp/Tls.hs
+++ b/source/Network/Xmpp/Tls.hs
@@ -15,6 +15,7 @@ import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import Data.IORef
import Data.XML.Types
+import Network.DNS.Resolver (ResolvConf)
import Network.TLS
import Network.Xmpp.Stream
import Network.Xmpp.Types
@@ -154,3 +155,23 @@ mkReadBuffer recv = do
writeIORef buffer rest
return result
return read'
+
+-- | Connect to an XMPP server and secure the connection with TLS before
+-- starting the XMPP streams
+connectTls :: ResolvConf -- ^ Resolv conf to use (try defaultResolvConf as a
+ -- default)
+ -> TLSParams -- ^ TLS parameters to use when securing the connection
+ -> String -- ^ Host to use when connecting (will be resolved
+ -- using SRV records)
+ -> ErrorT XmppFailure IO StreamHandle
+connectTls config params host = do
+ h <- connectSrv config host >>= \h' -> case h' of
+ Nothing -> throwError TcpConnectionFailure
+ Just h'' -> return h''
+ let hand = handleToStreamHandle h
+ (_raw, _snk, psh, recv, ctx) <- tlsinit params $ mkBackend hand
+ return $ StreamHandle { streamSend = catchPush . psh
+ , streamReceive = recv
+ , streamFlush = contextFlush ctx
+ , streamClose = bye ctx >> streamClose hand
+ }
diff --git a/source/Network/Xmpp/Types.hs b/source/Network/Xmpp/Types.hs
index 7f5b545..3396cf0 100644
--- a/source/Network/Xmpp/Types.hs
+++ b/source/Network/Xmpp/Types.hs
@@ -1007,7 +1007,7 @@ instance Exception InvalidXmppXml
data ConnectionDetails = UseRealm -- ^ Use realm to resolv host
| UseSrv HostName -- ^ Use this hostname for a SRV lookup
| UseHost HostName PortID -- ^ Use specified host
- | UseConnection (IO StreamHandle)
+ | UseConnection (ErrorT XmppFailure IO StreamHandle)
-- | Configuration settings related to the stream.
data StreamConfiguration =