diff --git a/qs-tunnel.cabal b/qs-tunnel.cabal index 2aa8b41..18482e9 100644 --- a/qs-tunnel.cabal +++ b/qs-tunnel.cabal @@ -16,11 +16,11 @@ extra-source-files: README.md executable qs-tunnel hs-source-dirs: src main-is: Main.hs + ghc-options: -threaded -rtsopts -with-rtsopts=-N -Wall -Werror -Wno-type-defaults default-language: Haskell2010 build-depends: base >= 4.7 && < 5 , libatrade , aeson - , monad-loops , zeromq4-haskell , zeromq4-haskell-zap , text @@ -28,3 +28,4 @@ executable qs-tunnel , time , hslogger , optparse-applicative + , errors diff --git a/src/Main.hs b/src/Main.hs index 6fee6ce..0ca7682 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -6,26 +6,20 @@ import Data.Aeson import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as B8 import qualified Data.ByteString.Lazy as BL -import Data.IORef import qualified Data.List as L import Data.List.NonEmpty import qualified Data.Text as T import Data.Time.Clock -import ATrade.QuoteSource.Client -import ATrade.QuoteSource.Server - import Control.Applicative -import Control.Concurrent import Control.Monad -import Control.Monad.Loops import System.IO import System.Log.Formatter import System.Log.Handler (setFormatter) import System.Log.Handler.Simple import System.Log.Logger -import System.ZMQ4 +import System.ZMQ4 hiding (events) import System.ZMQ4.ZAP import Options.Applicative @@ -46,31 +40,35 @@ instance FromJSON UpstreamConfig where UpstreamConfig <$> o .: "endpoint" <*> o .:? "certificate" + parseJSON _ = fail "Expected object" data Config = Config { - confDownstreamEp :: T.Text, - confDownstreamCertificatePath :: Maybe FilePath, - confClientCertificates :: [FilePath], - confWhitelistIps :: [T.Text], - confBlacklistIps :: [T.Text], - confUpstreams :: [UpstreamConfig], - confTimeout :: Integer + confDownstreamEp :: T.Text, + confDownstreamCertificatePath :: Maybe FilePath, + confClientCertificates :: [FilePath], + confWhitelistIps :: [T.Text], + confBlacklistIps :: [T.Text], + confUpstreams :: [UpstreamConfig], + confUpstreamClientCertificatePath :: Maybe FilePath, + confTimeout :: Integer } deriving (Show, Eq) instance FromJSON Config where parseJSON (Object o) = - Config <$> - o .: "downstream" <*> - o .:? "downstream_certificate" <*> - o .: "client_certificates" <*> - o .:? "whitelist" .!= [] <*> - o .:? "blacklist" .!= [] <*> - o .: "upstreams" <*> + Config <$> + o .: "downstream" <*> + o .:? "downstream_certificate" <*> + o .: "client_certificates" <*> + o .:? "whitelist" .!= [] <*> + o .:? "blacklist" .!= [] <*> + o .: "upstreams" <*> + o .: "upstream_client_certificate" <*> o .: "timeout" parseJSON _ = fail "Expected object" +initLogging :: IO () initLogging = do handler <- streamHandler stderr DEBUG >>= (\x -> return $ @@ -101,6 +99,7 @@ main = do ( fullDesc <> progDesc "Quotesource tunnel" ) +runWithConfig :: Config -> IO () runWithConfig conf = do withContext $ \ctx -> withZapHandler ctx $ \zap -> do @@ -108,62 +107,97 @@ runWithConfig conf = do setZapDomain (restrict "global") downstream zapSetBlacklist zap "global" $ confBlacklistIps conf zapSetWhitelist zap "global" $ confWhitelistIps conf - bind downstream $ T.unpack $ confDownstreamEp conf case (confDownstreamCertificatePath conf) of Just certPath -> do eCert <- loadCertificateFromFile certPath case eCert of - Left err -> errorM "main" $ "Unable to load certificate: " ++ certPath + Left err -> errorM "main" $ "Unable to load certificate: " ++ certPath ++ "; " ++ err Right cert -> do - zapSetServerCertificate cert downstream + setCurveServer True downstream + zapApplyCertificate cert downstream forM_ (confClientCertificates conf) (addCertificate zap) _ -> return () + bind downstream $ T.unpack $ confDownstreamEp conf - forM_ (confUpstreams conf) $ \upstreamConf -> forkIO $ do - forever $ withSocket ctx Sub $ \upstream -> do - infoM "main" $ "Connecting to: " ++ (T.unpack $ ucEndpoint upstreamConf) - case (ucCertificatePath upstreamConf) of - Just certPath -> do - eCert <- loadCertificateFromFile certPath - case eCert of - Left err -> errorM "main" $ "Unable to load certificate: " ++ certPath - Right cert -> zapApplyCertificate cert upstream - _ -> return () - connect upstream $ T.unpack $ ucEndpoint upstreamConf - subscribe upstream B.empty - now <- getCurrentTime - lastHeartbeat <- newIORef now - lastHeartbeatSent <- newIORef now - infoM "main" "Starting proxy loop" - whileM (notTimeout lastHeartbeat conf) $ do - evs <- poll 200 [Sock upstream [In] Nothing] - sendHeartbeatIfNeeded lastHeartbeatSent downstream - unless (null (L.head evs)) $ do - incoming <- receiveMulti upstream - case incoming of - x:xs -> do - now <- getCurrentTime - writeIORef lastHeartbeat now - sendMulti downstream $ x :| xs - _ -> return () - forever $ threadDelay 100000 + upstreamCert <- case confUpstreamClientCertificatePath conf of + Just fp -> do + ec <- loadCertificateFromFile fp + case ec of + Left err -> do + errorM "main" $ "Unable to load certificate: " ++ fp ++ "; " ++ err + return Nothing + Right cert -> return $ Just cert + _ -> return Nothing + now <- getCurrentTime + infoM "main" "Creating sockets" + sockets <- forM (confUpstreams conf) $ \upstreamConf -> do + infoM "main" $ "Creating: " ++ (T.unpack $ ucEndpoint upstreamConf) + s <- socket ctx Sub + maybeSc <- case (ucCertificatePath upstreamConf) of + Just certPath -> do + eCert <- loadCertificateFromFile certPath + case eCert of + Left err -> do + errorM "main" $ "Unable to load certificate: " ++ certPath ++ "; " ++ err + return Nothing + Right cert -> return $ Just cert + _ -> return Nothing + maybeCc <- case upstreamCert of + Just cert -> return $ Just cert + Nothing -> return Nothing + + infoM "main" $ "Connecting: " ++ (T.unpack $ ucEndpoint upstreamConf) + case (maybeSc, maybeCc) of + (Just serverCert, Just clientCert) -> do + zapSetServerCertificate serverCert s + zapApplyCertificate clientCert s + _ -> return () + + connect s $ T.unpack $ ucEndpoint upstreamConf + subscribe s B.empty + return (s, ucEndpoint upstreamConf, maybeSc, maybeCc, now) + + infoM "main" "Starting main loop" + go ctx downstream sockets now where - notTimeout ref conf = do + go ctx downstream sockets lastHeartbeat = do + events <- poll 200 $ fmap (\(s, _, _, _, _) -> Sock s [In] Nothing) sockets + let z = L.zip sockets events now <- getCurrentTime - lastHb <- readIORef ref - return $ now `diffUTCTime` lastHb < (fromInteger . confTimeout) conf + sockets' <- forM z $ \((s, ep, maybeSc, maybeCc, lastActivity), evts) -> do + if (not . null $ evts) + then do + incoming <- receiveMulti s + case incoming of + x:xs -> sendMulti downstream $ x :| xs + _ -> return () + return (s, ep, maybeSc, maybeCc, now) + else do + if now `diffUTCTime` lastActivity < (fromInteger . confTimeout) conf + then return (s, ep, maybeSc, maybeCc, lastActivity) + else do + close s + debugM "main" $ "Reconnecting: " ++ T.unpack ep + newS <- socket ctx Sub + case (maybeSc, maybeCc) of + (Just serverCert, Just clientCert) -> do + zapSetServerCertificate serverCert newS + zapApplyCertificate clientCert newS + _ -> return () + connect newS $ T.unpack ep + subscribe newS B.empty + return (newS, ep, maybeSc, maybeCc, now) - sendHeartbeatIfNeeded lastHbSent sock = do - now <- getCurrentTime - last <- readIORef lastHbSent - when (now `diffUTCTime` last > 1) $ do - send sock [] $ B8.pack "SYSTEM#HEARTBEAT" - writeIORef lastHbSent now + if (now `diffUTCTime` lastHeartbeat < 1) + then go ctx downstream sockets' lastHeartbeat + else do + send downstream [] $ B8.pack "SYSTEM#HEARTBEAT" + go ctx downstream sockets' now addCertificate zap clientCertPath = do eClientCert <- loadCertificateFromFile clientCertPath case eClientCert of - Left err -> errorM "main" $ "Unable to load client certificate: " ++ clientCertPath + Left err -> errorM "main" $ "Unable to load client certificate: " ++ clientCertPath ++ "; " ++ err Right clientCert -> zapAddClientCertificate zap "global" clientCert