diff --git a/src/ATrade/Broker/Server.hs b/src/ATrade/Broker/Server.hs index c2c90f4..38e1253 100644 --- a/src/ATrade/Broker/Server.hs +++ b/src/ATrade/Broker/Server.hs @@ -8,6 +8,7 @@ module ATrade.Broker.Server ( import ATrade.Types import ATrade.Broker.Protocol import System.ZMQ4 +import Data.List.NonEmpty import qualified Data.Map as M import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL @@ -37,7 +38,8 @@ data BrokerServerState = BrokerServerState { lastPacket :: M.Map B.ByteString (RequestSqnum, B.ByteString), pendingNotifications :: [(Notification, UTCTime)], -- List of tuples (Order with new state, Time when notification enqueued) brokers :: [BrokerInterface], - completionMvar :: MVar () + completionMvar :: MVar (), + orderIdCounter :: OrderId } data BrokerServerHandle = BrokerServerHandle ThreadId (MVar ()) @@ -54,7 +56,8 @@ startBrokerServer brokers c ep = do lastPacket = M.empty, pendingNotifications = [], brokers = brokers, - completionMvar = compMv + completionMvar = compMv, + orderIdCounter = 1 } BrokerServerHandle <$> forkIO (brokerServerThread state) <*> pure compMv @@ -63,7 +66,7 @@ brokerServerThread state = finally brokerServerThread' cleanup where brokerServerThread' = forever $ do sock <- bsSocket <$> readIORef state - receiveMulti sock >>= handleMessage + receiveMulti sock >>= handleMessage >>= sendMessage sock cleanup = do sock <- bsSocket <$> readIORef state @@ -71,18 +74,27 @@ brokerServerThread state = finally brokerServerThread' cleanup mv <- completionMvar <$> readIORef state putMVar mv () - handleMessage :: [B.ByteString] -> IO () + handleMessage :: [B.ByteString] -> IO (B.ByteString, BrokerServerResponse) handleMessage [peerId, _, payload] = do bros <- brokers <$> readIORef state case decode . BL.fromStrict $ payload of Just (RequestSubmitOrder sqnum order) -> case findBroker (orderAccountId order) bros of - Just bro -> submitOrder bro order - Nothing -> return () - Nothing -> return () + Just bro -> do + oid <- nextOrderId + submitOrder bro order { orderId = oid } + return (peerId, ResponseOrderSubmitted oid) + + Nothing -> error "foobar" + Nothing -> error "foobar" + handleMessage x = do + warningM "Broker.Server" ("Invalid packet received: " ++ show x) + error "foobar" + + sendMessage sock (peerId, resp) = sendMulti sock (peerId :| [B.empty, BL.toStrict . encode $ resp]) - handleMessage x = warningM "Broker.Server" ("Invalid packet received: " ++ show x) findBroker account = L.find (L.elem account . accounts) + nextOrderId = atomicModifyIORef' state (\s -> ( s {orderIdCounter = 1 + orderIdCounter s}, orderIdCounter s)) stopBrokerServer :: BrokerServerHandle -> IO () diff --git a/test/TestBrokerServer.hs b/test/TestBrokerServer.hs index 5f13d13..6fc8b21 100644 --- a/test/TestBrokerServer.hs +++ b/test/TestBrokerServer.hs @@ -84,12 +84,17 @@ testBrokerServerSubmitOrder = testCase "Broker Server submits order" $ withConte orderQuantity = 10, orderOperation = Buy } - bracket (startBrokerServer [mockBroker] ctx ep) stopBrokerServer (\broS -> + bracket (startBrokerServer [mockBroker] ctx ep) stopBrokerServer (\broS -> withSocket ctx Req (\sock -> do connect sock (T.unpack ep) send sock [] (BL.toStrict . encode $ RequestSubmitOrder 1 order) threadDelay 10000 s <- readIORef broState (length . orders) s @?= 1 + resp <- decode . BL.fromStrict <$> receive sock + case resp of + Just (ResponseOrderSubmitted _) -> return () + Nothing -> assertFailure "Invalid response" + ))) - +