diff --git a/mds.cabal b/mds.cabal index 49139e3..b770c91 100644 --- a/mds.cabal +++ b/mds.cabal @@ -15,10 +15,12 @@ cabal-version: >=1.10 library hs-source-dirs: src + ghc-options: -Wall -Werror exposed-modules: ATrade.MDS.Database + , ATrade.MDS.HistoryServer build-depends: base >= 4.7 && < 5 , HDBC - , HDBC-postgresql + , HDBC-sqlite3 , configurator , text , vector @@ -28,12 +30,16 @@ library , monad-loops , text-format , zeromq4-haskell + , aeson + , safe + , bytestring default-language: Haskell2010 + other-modules: ATrade.MDS.Protocol executable mds-exe hs-source-dirs: app main-is: Main.hs - ghc-options: -threaded -rtsopts -with-rtsopts=-N + ghc-options: -threaded -rtsopts -with-rtsopts=-N -Wall -Werror build-depends: base , mds default-language: Haskell2010 @@ -44,8 +50,19 @@ test-suite mds-test main-is: Spec.hs build-depends: base , mds + , libatrade + , temporary + , datetime + , vector + , text + , time + , tasty + , tasty-hunit ghc-options: -threaded -rtsopts -with-rtsopts=-N default-language: Haskell2010 + other-modules: Integration.Spec + , Integration.Database + extensions: OverloadedStrings source-repository head type: git diff --git a/src/ATrade/MDS/Database.hs b/src/ATrade/MDS/Database.hs index 97afb9e..85add5d 100644 --- a/src/ATrade/MDS/Database.hs +++ b/src/ATrade/MDS/Database.hs @@ -2,129 +2,93 @@ module ATrade.MDS.Database ( DatabaseConfig(..), - DatabaseInterface(..), - startDatabase, - stopDatabase + MdsHandle, + initDatabase, + closeDatabase, + getData, + putData, + TimeInterval(..), + Timeframe(..), + timeframeDaily, + timeframeHour, + timeframeMinute ) where -import qualified Data.Configurator as C import qualified Data.Text as T -import qualified Data.Text.Lazy as TL -import Data.Text.Format import qualified Data.Vector as V import ATrade.Types import Data.Time.Clock import Data.Time.Clock.POSIX import Data.Maybe -import Control.Concurrent.MVar -import Control.Concurrent -import System.Log.Logger import Database.HDBC -import Database.HDBC.PostgreSQL +import Database.HDBC.Sqlite3 import Control.Monad -import Control.Monad.Loops + data TimeInterval = TimeInterval UTCTime UTCTime data Timeframe = Timeframe Int -timeframeDaily = Timeframe 86400 -timeframeHour = Timeframe 3600 -timeframeMinute = Timeframe 60 +timeframeDaily :: Int -> Timeframe +timeframeDaily days = Timeframe (days * 86400) + +timeframeHour :: Int -> Timeframe +timeframeHour hours = Timeframe (hours * 3600) + +timeframeMinute :: Int -> Timeframe +timeframeMinute mins = Timeframe (mins * 60) -data DatabaseCommand = DBGet TickerId TimeInterval Timeframe | DBPut TickerId TimeInterval Timeframe (V.Vector Bar) -data DatabaseResponse = DBOk | DBData [(TimeInterval, V.Vector Bar)] | DBError T.Text data DatabaseConfig = DatabaseConfig { - dbHost :: T.Text, + dbPath :: T.Text, dbDatabase :: T.Text, dbUser :: T.Text, dbPassword :: T.Text } deriving (Show, Eq) -data DatabaseInterface = DatabaseInterface { - tid :: ThreadId, - getData :: TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)], - putData :: TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO () -} +type MdsHandle = Connection -startDatabase :: DatabaseConfig -> IO DatabaseInterface -startDatabase config = do - conn <- connectPostgreSQL (mkConnectionString config) +initDatabase :: DatabaseConfig -> IO MdsHandle +initDatabase config = do + conn <- connectSqlite3 (T.unpack $ dbPath config) makeSchema conn - cmdVar <- newEmptyMVar - respVar <- newEmptyMVar - compVar <- newEmptyMVar - tid <- forkFinally (dbThread conn cmdVar respVar) (cleanup conn cmdVar respVar compVar) - return DatabaseInterface { - tid = tid, - getData = doGetData cmdVar respVar, - putData = doPutData cmdVar respVar } + return conn + where + makeSchema conn = runRaw conn "CREATE TABLE IF NOT EXISTS bars (id SERIAL PRIMARY KEY, ticker TEXT, timestamp BIGINT, timeframe INTEGER, open NUMERIC(20, 10), high NUMERIC(20, 10), low NUMERIC(20, 10), close NUMERIC(20,10), volume BIGINT);" + +closeDatabase :: MdsHandle -> IO () +closeDatabase = disconnect + +getData :: MdsHandle -> TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)] +getData db tickerId interval@(TimeInterval start end) (Timeframe tfSec) = do + rows <- quickQuery' db "SELECT timestamp, timeframe, open, high, low, close, volume FROM bars WHERE ticker == ? AND timeframe == ? AND timestamp >= ? AND timestamp <= ? ORDER BY timestamp ASC;" [(toSql. T.unpack) tickerId, toSql tfSec, (toSql . utcTimeToPOSIXSeconds) start, (toSql . utcTimeToPOSIXSeconds) end] + return [(interval, V.fromList $ mapMaybe (barFromResult tickerId) rows)] where - makeSchema conn = runRaw conn "CREATE TABLE IF NOT EXISTS bars (id SERIAL PRIMARY KEY, ticker TEXT, timestamp BIGINT, open NUMERIC(20, 10), high NUMERIC(20, 10), low NUMERIC(20, 10), close NUMERIC(20,10), volume BIGINT);" - mkConnectionString config = TL.unpack $ format "User ID={};Password={};Host={};Port=5432;Database={}" (dbUser config, dbPassword config, dbHost config, dbDatabase config) - dbThread conn cmdVar respVar = forever $ do - cmd <- readMVar cmdVar - handleCmd conn cmd >>= putMVar respVar - whileM_ (isJust <$> tryReadMVar respVar) yield - takeMVar cmdVar - cleanup conn cmdVar respVar compVar _ = disconnect conn >> putMVar compVar () - handleCmd conn cmd = case cmd of - DBPut tickerId (TimeInterval start end) tf@(Timeframe timeframeSecs) bars -> do - delStmt <- prepare conn "DELETE FROM bars WHERE timestamp > ? AND timestamp < ? AND ticker == ? AND timeframe == ?;" - execute delStmt [(SqlPOSIXTime . utcTimeToPOSIXSeconds) start, (SqlPOSIXTime . utcTimeToPOSIXSeconds) end, (SqlString . T.unpack) tickerId, (SqlInteger . toInteger) timeframeSecs] - stmt <- prepare conn $ "INSERT INTO bars (ticker, timeframe, timestamp, open, high, low, close, volume)" ++ - " values (?, ?, ?, ?, ?, ?, ?, ?); " - executeMany stmt (map (barToSql tf) $ V.toList bars) - return DBOk - DBGet tickerId interval@(TimeInterval start end) (Timeframe timeframeSecs) -> do - rows <- quickQuery' conn "SELECT timestamp, open, high, low, close, volume FROM bars WHERE ticker == ? AND timeframe == ? AND timestamp > ? AND timestamp < ?;" [(toSql. T.unpack) tickerId, toSql timeframeSecs, (toSql . utcTimeToPOSIXSeconds) start, (toSql . utcTimeToPOSIXSeconds) end] - return $ DBData [(interval, V.fromList $ mapMaybe (barFromResult tickerId) rows)] - barFromResult ticker [ts, open, high, low, close, volume] = Just Bar { - barSecurity = ticker, - barTimestamp = fromSql ts, - barOpen = fromRational $ fromSql open, - barHigh = fromRational $ fromSql high, - barLow = fromRational $ fromSql low, - barClose = fromRational $ fromSql close, - barVolume = fromSql volume - } + barFromResult ticker [ts, _, open, high, low, close, vol] = Just Bar { + barSecurity = ticker, + barTimestamp = fromSql ts, + barOpen = fromDouble $ fromSql open, + barHigh = fromDouble $ fromSql high, + barLow = fromDouble $ fromSql low, + barClose = fromDouble $ fromSql close, + barVolume = fromSql vol + } barFromResult _ _ = Nothing +putData :: MdsHandle -> TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO () +putData db tickerId (TimeInterval start end) tf@(Timeframe tfSec) bars = do + delStmt <- prepare db "DELETE FROM bars WHERE timestamp >= ? AND timestamp <= ? AND ticker == ? AND timeframe == ?;" + void $ execute delStmt [(SqlPOSIXTime . utcTimeToPOSIXSeconds) start, (SqlPOSIXTime . utcTimeToPOSIXSeconds) end, (SqlString . T.unpack) tickerId, (SqlInteger . toInteger) tfSec] + stmt <- prepare db $ "INSERT INTO bars (ticker, timeframe, timestamp, open, high, low, close, volume)" ++ + " values (?, ?, ?, ?, ?, ?, ?, ?); " + executeMany stmt (map (barToSql tf) $ V.toList bars) + where barToSql :: Timeframe -> Bar -> [SqlValue] barToSql (Timeframe timeframeSecs) bar = [(SqlString . T.unpack . barSecurity) bar, (SqlInteger . toInteger) timeframeSecs, - (SqlRational . toRational . barOpen) bar, - (SqlRational . toRational . barHigh) bar, - (SqlRational . toRational . barLow) bar, - (SqlRational . toRational . barClose) bar, - (SqlRational . toRational . barVolume) bar ] - -stopDatabase :: MVar () -> DatabaseInterface -> IO () -stopDatabase compVar db = killThread (tid db) >> readMVar compVar - -doGetData :: MVar DatabaseCommand -> MVar DatabaseResponse -> TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)] -doGetData cmdVar respVar tickerId timeInterval timeframe = do - putMVar cmdVar (DBGet tickerId timeInterval timeframe) - resp <- takeMVar respVar - case resp of - DBData x -> return x - DBError err -> do - warningM "DB.Client" $ "Error while calling getData: " ++ show err - return [] - _ -> do - warningM "DB.Client" "Unexpected response" - return [] - -doPutData :: MVar DatabaseCommand -> MVar DatabaseResponse -> TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO () -doPutData cmdVar respVar tickerId timeInterval timeframe bars = do - putMVar cmdVar (DBPut tickerId timeInterval timeframe bars) - resp <- takeMVar respVar - case resp of - DBOk -> return () - DBError err -> do - warningM "DB.Client" $ "Error while calling putData: " ++ show err - return () - _ -> do - warningM "DB.Client" "Unexpected response" - return () + (SqlPOSIXTime . utcTimeToPOSIXSeconds . barTimestamp) bar, + (SqlDouble . toDouble . barOpen) bar, + (SqlDouble . toDouble . barHigh) bar, + (SqlDouble . toDouble . barLow) bar, + (SqlDouble . toDouble . barClose) bar, + (SqlInteger . barVolume) bar ] diff --git a/src/ATrade/MDS/HistoryServer.hs b/src/ATrade/MDS/HistoryServer.hs index 37e925c..05a91bd 100644 --- a/src/ATrade/MDS/HistoryServer.hs +++ b/src/ATrade/MDS/HistoryServer.hs @@ -1,14 +1,42 @@ module ATrade.MDS.HistoryServer ( + HistoryServer, + startHistoryServer ) where import System.ZMQ4 import ATrade.MDS.Database +import ATrade.MDS.Protocol import Control.Concurrent +import Control.Monad +import Data.Aeson +import Data.List.NonEmpty +import qualified Data.Vector as V +import Safe +import qualified Data.ByteString as B +import qualified Data.ByteString.Lazy as BL data HistoryServer = HistoryServer ThreadId -} -startHistoryServer :: DatabaseInterface -> Context -> IO HistoryServer +startHistoryServer :: MdsHandle -> Context -> IO HistoryServer startHistoryServer db ctx = do - + sock <- socket ctx Router + tid <- forkIO $ serve db sock + return $ HistoryServer tid + +serve :: (Sender a, Receiver a) => MdsHandle -> Socket a -> IO () +serve db sock = forever $ do + rq <- receiveMulti sock + let maybeCmd = (BL.fromStrict <$> rq `atMay` 2) >>= decode + case (headMay rq, maybeCmd) of + (Just peerId, Just cmd) -> handleCmd peerId cmd + _ -> return () + where + handleCmd :: B.ByteString -> MDSRequest -> IO () + handleCmd peerId cmd = case cmd of + rq -> do + qdata <- getData db (rqTicker rq) (TimeInterval (rqFrom rq) (rqTo rq)) (Timeframe (rqTimeframe rq)) + bytes <- serializeBars $ V.concat $ fmap snd qdata + sendMulti sock $ peerId :| B.empty : bytes + serializeBars = undefined + diff --git a/src/ATrade/MDS/Protocol.hs b/src/ATrade/MDS/Protocol.hs new file mode 100644 index 0000000..5060d7a --- /dev/null +++ b/src/ATrade/MDS/Protocol.hs @@ -0,0 +1,23 @@ +{-# LANGUAGE DeriveGeneric #-} + +module ATrade.MDS.Protocol ( + MDSRequest(..) +) where + +import GHC.Generics + +import ATrade.Types + +import Data.Aeson +import Data.Time.Clock + +data MDSRequest = RequestData { + rqTicker :: TickerId, + rqFrom :: UTCTime, + rqTo :: UTCTime, + rqTimeframe :: Int +} deriving (Generic, Show, Eq) + +instance ToJSON MDSRequest +instance FromJSON MDSRequest + diff --git a/src/Lib.hs b/src/Lib.hs deleted file mode 100644 index d36ff27..0000000 --- a/src/Lib.hs +++ /dev/null @@ -1,6 +0,0 @@ -module Lib - ( someFunc - ) where - -someFunc :: IO () -someFunc = putStrLn "someFunc" diff --git a/stack.yaml b/stack.yaml index db9f144..66cbf3c 100644 --- a/stack.yaml +++ b/stack.yaml @@ -15,7 +15,7 @@ # resolver: # name: custom-snapshot # location: "./custom-snapshot.yaml" -resolver: lts-7.4 +resolver: lts-11.9 # User packages to be built. # Various formats can be used as shown in the example below. @@ -38,9 +38,10 @@ resolver: lts-7.4 packages: - '.' - '../libatrade' +- '../zeromq4-haskell-zap' # Dependency packages to be pulled from upstream that are not in the resolver # (e.g., acme-missiles-0.3) -extra-deps: ["HDBC-postgresql-2.3.2.4", "datetime-0.3.1"] +extra-deps: ["HDBC-sqlite3-2.3.3.1", "datetime-0.3.1"] # Override default flag values for local packages and extra-deps flags: {} diff --git a/test/Integration/Database.hs b/test/Integration/Database.hs new file mode 100644 index 0000000..bfe644b --- /dev/null +++ b/test/Integration/Database.hs @@ -0,0 +1,77 @@ + +module Integration.Database ( + testDatabase +) where + +import Test.Tasty +import Test.Tasty.HUnit + +import ATrade.MDS.Database + +import ATrade.Types +import Control.Exception +import Data.DateTime +import Data.Time.Clock +import qualified Data.Text as T +import qualified Data.Vector as V +import System.IO.Temp + +testDatabase :: TestTree +testDatabase = testGroup "Database tests" [ testOpenClose + , testPutGet + , testGetReturnsSorted ] + +testOpenClose :: TestTree +testOpenClose = testCase "Open/Close" $ + withSystemTempDirectory "test" $ \fp -> do + let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty + db <- initDatabase dbConfig + closeDatabase db + + +bar :: UTCTime -> Price -> Price -> Price -> Price -> Integer -> Bar +bar dt o h l c v = Bar { barSecurity = "FOO", + barTimestamp = dt, + barOpen = o, + barHigh = h, + barLow = l, + barClose = c, + barVolume = v } + +testPutGet :: TestTree +testPutGet = testCase "Put/Get" $ + withSystemTempDirectory "test" $ \fp -> do + let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty + bracket (initDatabase dbConfig) closeDatabase $ \db -> do + putData db "FOO" interval (timeframeMinute 1) bars + retrievedBars <- (snd . head) <$> getData db "FOO" interval (timeframeMinute 1) + assertEqual "Retreived bars are different from saved" bars retrievedBars + + where + interval = TimeInterval (fromGregorian 2010 1 1 12 0 0) (fromGregorian 2010 1 1 12 5 0) + bars = V.fromList $ [ + bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1, + bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1, + bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1 + ] + +testGetReturnsSorted :: TestTree +testGetReturnsSorted = testCase "Get returns sorted vector" $ + withSystemTempDirectory "test" $ \fp -> do + let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty + bracket (initDatabase dbConfig) closeDatabase $ \db -> do + putData db "FOO" interval (timeframeMinute 1) bars + retrievedBars <- (snd . head) <$> getData db "FOO" interval (timeframeMinute 1) + assertEqual "Retreived bars are not sorted" sortedBars retrievedBars + where + interval = TimeInterval (fromGregorian 2010 1 1 12 0 0) (fromGregorian 2010 1 1 12 5 0) + bars = V.fromList $ [ + bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1, + bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1, + bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1 + ] + sortedBars = V.fromList $ [ + bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1, + bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1, + bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1 + ] diff --git a/test/Integration/Spec.hs b/test/Integration/Spec.hs new file mode 100644 index 0000000..a3380e2 --- /dev/null +++ b/test/Integration/Spec.hs @@ -0,0 +1,12 @@ + +module Integration.Spec ( + integrationTests +) where + +import Test.Tasty +import Test.Tasty.HUnit + +import Integration.Database + +integrationTests = testGroup "Integration tests" [ testDatabase ] + diff --git a/test/Spec.hs b/test/Spec.hs index cd4753f..ee2a8a2 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -1,2 +1,11 @@ + +import Test.Tasty +import Test.Tasty.HUnit + +import Integration.Spec + main :: IO () -main = putStrLn "Test suite not yet implemented" +main = defaultMain tests + +tests :: TestTree +tests = testGroup "Tests" [ integrationTests ]