diff --git a/qs-tunnel.cabal b/qs-tunnel.cabal index 334eeb9..243a3ba 100644 --- a/qs-tunnel.cabal +++ b/qs-tunnel.cabal @@ -1,5 +1,5 @@ name: qs-tunnel -version: 0.1.0.0 +version: 0.2.0.0 synopsis: Quotesource tunnel proxy -- description: homepage: https://github.com/asakul/qs-tunnel#readme @@ -7,7 +7,7 @@ license: BSD3 license-file: LICENSE author: Denis Tereshkin maintainer: denis@kasan.ws -copyright: 2017 Denis Tereshkin +copyright: 2017-2019 Denis Tereshkin category: Web build-type: Simple cabal-version: >=1.10 diff --git a/src/Main.hs b/src/Main.hs index 3b82686..5f82d4d 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -2,46 +2,63 @@ module Main where -import qualified Data.Text as T -import qualified Data.ByteString as B -import qualified Data.ByteString.Char8 as B8 -import qualified Data.ByteString.Lazy as BL -import qualified Data.List as L -import Data.List.NonEmpty -import Data.IORef -import Data.Time.Clock -import Data.Aeson - -import ATrade.QuoteSource.Server -import ATrade.QuoteSource.Client - -import Control.Monad -import Control.Monad.Loops - -import System.IO -import System.Log.Logger -import System.Log.Handler.Simple -import System.Log.Handler (setFormatter) -import System.Log.Formatter -import System.ZMQ4 -import System.ZMQ4.ZAP +import Data.Aeson +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as B8 +import qualified Data.ByteString.Lazy as BL +import Data.IORef +import qualified Data.List as L +import Data.List.NonEmpty +import qualified Data.Text as T +import Data.Time.Clock + +import ATrade.QuoteSource.Client +import ATrade.QuoteSource.Server + +import Control.Concurrent +import Control.Monad +import Control.Monad.Loops + +import System.IO +import System.Log.Formatter +import System.Log.Handler (setFormatter) +import System.Log.Handler.Simple +import System.Log.Logger +import System.ZMQ4 +import System.ZMQ4.ZAP + +data UpstreamConfig = UpstreamConfig + { + ucEndpoint :: T.Text, + ucCertificatePath :: Maybe FilePath + } deriving (Show, Eq) + +instance FromJSON UpstreamConfig where + parseJSON (Object o) = + UpstreamConfig <$> + o .: "endpoint" <*> + o .:? "certificate" data Config = Config { - confDownstreamEp :: T.Text, - confWhitelistIps :: [T.Text], - confBlacklistIps :: [T.Text], - confUpstreamEp :: T.Text, - confTimeout :: Integer + confDownstreamEp :: T.Text, + confDownstreamCertificatePath :: Maybe FilePath, + confClientCertificates :: [FilePath], + confWhitelistIps :: [T.Text], + confBlacklistIps :: [T.Text], + confUpstreams :: [UpstreamConfig], + confTimeout :: Integer } deriving (Show, Eq) instance FromJSON Config where parseJSON (Object o) = - Config <$> - o .: "downstream" <*> - o .:? "whitelist" .!= [] <*> - o .:? "blacklist" .!= [] <*> - o .: "upstream" <*> + Config <$> + o .: "downstream" <*> + o .:? "downstream_certificate" <*> + o .: "client_certificates" <*> + o .:? "whitelist" .!= [] <*> + o .:? "blacklist" .!= [] <*> + o .: "upstreams" <*> o .: "timeout" parseJSON _ = fail "Expected object" @@ -59,46 +76,73 @@ main :: IO () main = do initLogging infoM "main" "Starting" - eConf <- eitherDecode . BL.fromStrict <$> B.readFile "qs-tunnel.conf" + eConf <- eitherDecode . BL.fromStrict <$> B.readFile "qs-tunnel.conf" case eConf of Left errMsg -> error errMsg - Right conf -> runWithConfig conf + Right conf -> runWithConfig conf runWithConfig conf = do withContext $ \ctx -> - withSocket ctx Pub $ \downstream -> do - bind downstream $ T.unpack $ confDownstreamEp conf - - forever $ withSocket ctx Sub $ \upstream -> do - infoM "main" $ "Connecting to: " ++ (T.unpack $ confUpstreamEp conf) - connect upstream $ T.unpack $ confUpstreamEp conf - subscribe upstream B.empty - now <- getCurrentTime - lastHeartbeat <- newIORef now - lastHeartbeatSent <- newIORef now - infoM "main" "Starting proxy loop" - whileM (notTimeout lastHeartbeat conf) $ do - evs <- poll 200 [Sock upstream [In] Nothing] - sendHeartbeatIfNeeded lastHeartbeatSent downstream - unless (null (L.head evs)) $ do - incoming <- receiveMulti upstream - case incoming of - x:xs -> do - now <- getCurrentTime - writeIORef lastHeartbeat now - sendMulti downstream $ x :| xs + withZapHandler ctx $ \zap -> do + withSocket ctx Pub $ \downstream -> do + setZapDomain (restrict "global") downstream + zapSetBlacklist zap "global" $ confBlacklistIps conf + zapSetWhitelist zap "global" $ confWhitelistIps conf + bind downstream $ T.unpack $ confDownstreamEp conf + case (confDownstreamCertificatePath conf) of + Just certPath -> do + eCert <- loadCertificateFromFile certPath + case eCert of + Left err -> errorM "main" $ "Unable to load certificate: " ++ certPath + Right cert -> do + zapSetServerCertificate cert downstream + forM_ (confClientCertificates conf) (addCertificate zap) + _ -> return () + + forM_ (confUpstreams conf) $ \upstreamConf -> forkIO $ do + forever $ withSocket ctx Sub $ \upstream -> do + infoM "main" $ "Connecting to: " ++ (T.unpack $ ucEndpoint upstreamConf) + case (ucCertificatePath upstreamConf) of + Just certPath -> do + eCert <- loadCertificateFromFile certPath + case eCert of + Left err -> errorM "main" $ "Unable to load certificate: " ++ certPath + Right cert -> zapApplyCertificate cert upstream _ -> return () - where - notTimeout ref conf = do - now <- getCurrentTime - lastHb <- readIORef ref - return $ now `diffUTCTime` lastHb < (fromInteger . confTimeout) conf - - sendHeartbeatIfNeeded lastHbSent sock = do - now <- getCurrentTime - last <- readIORef lastHbSent - when (now `diffUTCTime` last > 1) $ do - send sock [] $ B8.pack "SYSTEM#HEARTBEAT" - writeIORef lastHbSent now - + connect upstream $ T.unpack $ ucEndpoint upstreamConf + subscribe upstream B.empty + now <- getCurrentTime + lastHeartbeat <- newIORef now + lastHeartbeatSent <- newIORef now + infoM "main" "Starting proxy loop" + whileM (notTimeout lastHeartbeat conf) $ do + evs <- poll 200 [Sock upstream [In] Nothing] + sendHeartbeatIfNeeded lastHeartbeatSent downstream + unless (null (L.head evs)) $ do + incoming <- receiveMulti upstream + case incoming of + x:xs -> do + now <- getCurrentTime + writeIORef lastHeartbeat now + sendMulti downstream $ x :| xs + _ -> return () + where + notTimeout ref conf = do + now <- getCurrentTime + lastHb <- readIORef ref + return $ now `diffUTCTime` lastHb < (fromInteger . confTimeout) conf + + sendHeartbeatIfNeeded lastHbSent sock = do + now <- getCurrentTime + last <- readIORef lastHbSent + when (now `diffUTCTime` last > 1) $ do + send sock [] $ B8.pack "SYSTEM#HEARTBEAT" + writeIORef lastHbSent now + + addCertificate zap clientCertPath = do + eClientCert <- loadCertificateFromFile clientCertPath + case eClientCert of + Left err -> errorM "main" $ "Unable to load client certificate: " ++ clientCertPath + Right clientCert -> zapAddClientCertificate zap "global" clientCert +