|
|
|
|
@ -6,26 +6,20 @@ import Data.Aeson
@@ -6,26 +6,20 @@ import Data.Aeson
|
|
|
|
|
import qualified Data.ByteString as B |
|
|
|
|
import qualified Data.ByteString.Char8 as B8 |
|
|
|
|
import qualified Data.ByteString.Lazy as BL |
|
|
|
|
import Data.IORef |
|
|
|
|
import qualified Data.List as L |
|
|
|
|
import Data.List.NonEmpty |
|
|
|
|
import qualified Data.Text as T |
|
|
|
|
import Data.Time.Clock |
|
|
|
|
|
|
|
|
|
import ATrade.QuoteSource.Client |
|
|
|
|
import ATrade.QuoteSource.Server |
|
|
|
|
|
|
|
|
|
import Control.Applicative |
|
|
|
|
import Control.Concurrent |
|
|
|
|
import Control.Monad |
|
|
|
|
import Control.Monad.Loops |
|
|
|
|
|
|
|
|
|
import System.IO |
|
|
|
|
import System.Log.Formatter |
|
|
|
|
import System.Log.Handler (setFormatter) |
|
|
|
|
import System.Log.Handler.Simple |
|
|
|
|
import System.Log.Logger |
|
|
|
|
import System.ZMQ4 |
|
|
|
|
import System.ZMQ4 hiding (events) |
|
|
|
|
import System.ZMQ4.ZAP |
|
|
|
|
|
|
|
|
|
import Options.Applicative |
|
|
|
|
@ -46,6 +40,7 @@ instance FromJSON UpstreamConfig where
@@ -46,6 +40,7 @@ instance FromJSON UpstreamConfig where
|
|
|
|
|
UpstreamConfig <$> |
|
|
|
|
o .: "endpoint" <*> |
|
|
|
|
o .:? "certificate" |
|
|
|
|
parseJSON _ = fail "Expected object" |
|
|
|
|
|
|
|
|
|
data Config = Config |
|
|
|
|
{ |
|
|
|
|
@ -55,6 +50,7 @@ data Config = Config
@@ -55,6 +50,7 @@ data Config = Config
|
|
|
|
|
confWhitelistIps :: [T.Text], |
|
|
|
|
confBlacklistIps :: [T.Text], |
|
|
|
|
confUpstreams :: [UpstreamConfig], |
|
|
|
|
confUpstreamClientCertificatePath :: Maybe FilePath, |
|
|
|
|
confTimeout :: Integer |
|
|
|
|
} deriving (Show, Eq) |
|
|
|
|
|
|
|
|
|
@ -67,10 +63,12 @@ instance FromJSON Config where
@@ -67,10 +63,12 @@ instance FromJSON Config where
|
|
|
|
|
o .:? "whitelist" .!= [] <*> |
|
|
|
|
o .:? "blacklist" .!= [] <*> |
|
|
|
|
o .: "upstreams" <*> |
|
|
|
|
o .: "upstream_client_certificate" <*> |
|
|
|
|
o .: "timeout" |
|
|
|
|
|
|
|
|
|
parseJSON _ = fail "Expected object" |
|
|
|
|
|
|
|
|
|
initLogging :: IO () |
|
|
|
|
initLogging = do |
|
|
|
|
handler <- streamHandler stderr DEBUG >>= |
|
|
|
|
(\x -> return $ |
|
|
|
|
@ -101,6 +99,7 @@ main = do
@@ -101,6 +99,7 @@ main = do
|
|
|
|
|
( fullDesc |
|
|
|
|
<> progDesc "Quotesource tunnel" ) |
|
|
|
|
|
|
|
|
|
runWithConfig :: Config -> IO () |
|
|
|
|
runWithConfig conf = do |
|
|
|
|
withContext $ \ctx -> |
|
|
|
|
withZapHandler ctx $ \zap -> do |
|
|
|
|
@ -108,62 +107,97 @@ runWithConfig conf = do
@@ -108,62 +107,97 @@ runWithConfig conf = do
|
|
|
|
|
setZapDomain (restrict "global") downstream |
|
|
|
|
zapSetBlacklist zap "global" $ confBlacklistIps conf |
|
|
|
|
zapSetWhitelist zap "global" $ confWhitelistIps conf |
|
|
|
|
bind downstream $ T.unpack $ confDownstreamEp conf |
|
|
|
|
case (confDownstreamCertificatePath conf) of |
|
|
|
|
Just certPath -> do |
|
|
|
|
eCert <- loadCertificateFromFile certPath |
|
|
|
|
case eCert of |
|
|
|
|
Left err -> errorM "main" $ "Unable to load certificate: " ++ certPath |
|
|
|
|
Left err -> errorM "main" $ "Unable to load certificate: " ++ certPath ++ "; " ++ err |
|
|
|
|
Right cert -> do |
|
|
|
|
zapSetServerCertificate cert downstream |
|
|
|
|
setCurveServer True downstream |
|
|
|
|
zapApplyCertificate cert downstream |
|
|
|
|
forM_ (confClientCertificates conf) (addCertificate zap) |
|
|
|
|
_ -> return () |
|
|
|
|
bind downstream $ T.unpack $ confDownstreamEp conf |
|
|
|
|
|
|
|
|
|
forM_ (confUpstreams conf) $ \upstreamConf -> forkIO $ do |
|
|
|
|
forever $ withSocket ctx Sub $ \upstream -> do |
|
|
|
|
infoM "main" $ "Connecting to: " ++ (T.unpack $ ucEndpoint upstreamConf) |
|
|
|
|
case (ucCertificatePath upstreamConf) of |
|
|
|
|
upstreamCert <- case confUpstreamClientCertificatePath conf of |
|
|
|
|
Just fp -> do |
|
|
|
|
ec <- loadCertificateFromFile fp |
|
|
|
|
case ec of |
|
|
|
|
Left err -> do |
|
|
|
|
errorM "main" $ "Unable to load certificate: " ++ fp ++ "; " ++ err |
|
|
|
|
return Nothing |
|
|
|
|
Right cert -> return $ Just cert |
|
|
|
|
_ -> return Nothing |
|
|
|
|
now <- getCurrentTime |
|
|
|
|
infoM "main" "Creating sockets" |
|
|
|
|
sockets <- forM (confUpstreams conf) $ \upstreamConf -> do |
|
|
|
|
infoM "main" $ "Creating: " ++ (T.unpack $ ucEndpoint upstreamConf) |
|
|
|
|
s <- socket ctx Sub |
|
|
|
|
maybeSc <- case (ucCertificatePath upstreamConf) of |
|
|
|
|
Just certPath -> do |
|
|
|
|
eCert <- loadCertificateFromFile certPath |
|
|
|
|
case eCert of |
|
|
|
|
Left err -> errorM "main" $ "Unable to load certificate: " ++ certPath |
|
|
|
|
Right cert -> zapApplyCertificate cert upstream |
|
|
|
|
Left err -> do |
|
|
|
|
errorM "main" $ "Unable to load certificate: " ++ certPath ++ "; " ++ err |
|
|
|
|
return Nothing |
|
|
|
|
Right cert -> return $ Just cert |
|
|
|
|
_ -> return Nothing |
|
|
|
|
maybeCc <- case upstreamCert of |
|
|
|
|
Just cert -> return $ Just cert |
|
|
|
|
Nothing -> return Nothing |
|
|
|
|
|
|
|
|
|
infoM "main" $ "Connecting: " ++ (T.unpack $ ucEndpoint upstreamConf) |
|
|
|
|
case (maybeSc, maybeCc) of |
|
|
|
|
(Just serverCert, Just clientCert) -> do |
|
|
|
|
zapSetServerCertificate serverCert s |
|
|
|
|
zapApplyCertificate clientCert s |
|
|
|
|
_ -> return () |
|
|
|
|
connect upstream $ T.unpack $ ucEndpoint upstreamConf |
|
|
|
|
subscribe upstream B.empty |
|
|
|
|
|
|
|
|
|
connect s $ T.unpack $ ucEndpoint upstreamConf |
|
|
|
|
subscribe s B.empty |
|
|
|
|
return (s, ucEndpoint upstreamConf, maybeSc, maybeCc, now) |
|
|
|
|
|
|
|
|
|
infoM "main" "Starting main loop" |
|
|
|
|
go ctx downstream sockets now |
|
|
|
|
where |
|
|
|
|
go ctx downstream sockets lastHeartbeat = do |
|
|
|
|
events <- poll 200 $ fmap (\(s, _, _, _, _) -> Sock s [In] Nothing) sockets |
|
|
|
|
let z = L.zip sockets events |
|
|
|
|
now <- getCurrentTime |
|
|
|
|
lastHeartbeat <- newIORef now |
|
|
|
|
lastHeartbeatSent <- newIORef now |
|
|
|
|
infoM "main" "Starting proxy loop" |
|
|
|
|
whileM (notTimeout lastHeartbeat conf) $ do |
|
|
|
|
evs <- poll 200 [Sock upstream [In] Nothing] |
|
|
|
|
sendHeartbeatIfNeeded lastHeartbeatSent downstream |
|
|
|
|
unless (null (L.head evs)) $ do |
|
|
|
|
incoming <- receiveMulti upstream |
|
|
|
|
sockets' <- forM z $ \((s, ep, maybeSc, maybeCc, lastActivity), evts) -> do |
|
|
|
|
if (not . null $ evts) |
|
|
|
|
then do |
|
|
|
|
incoming <- receiveMulti s |
|
|
|
|
case incoming of |
|
|
|
|
x:xs -> do |
|
|
|
|
now <- getCurrentTime |
|
|
|
|
writeIORef lastHeartbeat now |
|
|
|
|
sendMulti downstream $ x :| xs |
|
|
|
|
x:xs -> sendMulti downstream $ x :| xs |
|
|
|
|
_ -> return () |
|
|
|
|
forever $ threadDelay 100000 |
|
|
|
|
where |
|
|
|
|
notTimeout ref conf = do |
|
|
|
|
now <- getCurrentTime |
|
|
|
|
lastHb <- readIORef ref |
|
|
|
|
return $ now `diffUTCTime` lastHb < (fromInteger . confTimeout) conf |
|
|
|
|
return (s, ep, maybeSc, maybeCc, now) |
|
|
|
|
else do |
|
|
|
|
if now `diffUTCTime` lastActivity < (fromInteger . confTimeout) conf |
|
|
|
|
then return (s, ep, maybeSc, maybeCc, lastActivity) |
|
|
|
|
else do |
|
|
|
|
close s |
|
|
|
|
debugM "main" $ "Reconnecting: " ++ T.unpack ep |
|
|
|
|
newS <- socket ctx Sub |
|
|
|
|
case (maybeSc, maybeCc) of |
|
|
|
|
(Just serverCert, Just clientCert) -> do |
|
|
|
|
zapSetServerCertificate serverCert newS |
|
|
|
|
zapApplyCertificate clientCert newS |
|
|
|
|
_ -> return () |
|
|
|
|
connect newS $ T.unpack ep |
|
|
|
|
subscribe newS B.empty |
|
|
|
|
return (newS, ep, maybeSc, maybeCc, now) |
|
|
|
|
|
|
|
|
|
sendHeartbeatIfNeeded lastHbSent sock = do |
|
|
|
|
now <- getCurrentTime |
|
|
|
|
last <- readIORef lastHbSent |
|
|
|
|
when (now `diffUTCTime` last > 1) $ do |
|
|
|
|
send sock [] $ B8.pack "SYSTEM#HEARTBEAT" |
|
|
|
|
writeIORef lastHbSent now |
|
|
|
|
if (now `diffUTCTime` lastHeartbeat < 1) |
|
|
|
|
then go ctx downstream sockets' lastHeartbeat |
|
|
|
|
else do |
|
|
|
|
send downstream [] $ B8.pack "SYSTEM#HEARTBEAT" |
|
|
|
|
go ctx downstream sockets' now |
|
|
|
|
|
|
|
|
|
addCertificate zap clientCertPath = do |
|
|
|
|
eClientCert <- loadCertificateFromFile clientCertPath |
|
|
|
|
case eClientCert of |
|
|
|
|
Left err -> errorM "main" $ "Unable to load client certificate: " ++ clientCertPath |
|
|
|
|
Left err -> errorM "main" $ "Unable to load client certificate: " ++ clientCertPath ++ "; " ++ err |
|
|
|
|
Right clientCert -> zapAddClientCertificate zap "global" clientCert |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|