Browse Source

Simplifed database access & added tests

master
Denis Tereshkin 8 years ago
parent
commit
889aaa01ce
  1. 21
      mds.cabal
  2. 154
      src/ATrade/MDS/Database.hs
  3. 32
      src/ATrade/MDS/HistoryServer.hs
  4. 23
      src/ATrade/MDS/Protocol.hs
  5. 6
      src/Lib.hs
  6. 5
      stack.yaml
  7. 77
      test/Integration/Database.hs
  8. 12
      test/Integration/Spec.hs
  9. 11
      test/Spec.hs

21
mds.cabal

@ -15,10 +15,12 @@ cabal-version: >=1.10 @@ -15,10 +15,12 @@ cabal-version: >=1.10
library
hs-source-dirs: src
ghc-options: -Wall -Werror
exposed-modules: ATrade.MDS.Database
, ATrade.MDS.HistoryServer
build-depends: base >= 4.7 && < 5
, HDBC
, HDBC-postgresql
, HDBC-sqlite3
, configurator
, text
, vector
@ -28,12 +30,16 @@ library @@ -28,12 +30,16 @@ library
, monad-loops
, text-format
, zeromq4-haskell
, aeson
, safe
, bytestring
default-language: Haskell2010
other-modules: ATrade.MDS.Protocol
executable mds-exe
hs-source-dirs: app
main-is: Main.hs
ghc-options: -threaded -rtsopts -with-rtsopts=-N
ghc-options: -threaded -rtsopts -with-rtsopts=-N -Wall -Werror
build-depends: base
, mds
default-language: Haskell2010
@ -44,8 +50,19 @@ test-suite mds-test @@ -44,8 +50,19 @@ test-suite mds-test
main-is: Spec.hs
build-depends: base
, mds
, libatrade
, temporary
, datetime
, vector
, text
, time
, tasty
, tasty-hunit
ghc-options: -threaded -rtsopts -with-rtsopts=-N
default-language: Haskell2010
other-modules: Integration.Spec
, Integration.Database
extensions: OverloadedStrings
source-repository head
type: git

154
src/ATrade/MDS/Database.hs

@ -2,129 +2,93 @@ @@ -2,129 +2,93 @@
module ATrade.MDS.Database (
DatabaseConfig(..),
DatabaseInterface(..),
startDatabase,
stopDatabase
MdsHandle,
initDatabase,
closeDatabase,
getData,
putData,
TimeInterval(..),
Timeframe(..),
timeframeDaily,
timeframeHour,
timeframeMinute
) where
import qualified Data.Configurator as C
import qualified Data.Text as T
import qualified Data.Text.Lazy as TL
import Data.Text.Format
import qualified Data.Vector as V
import ATrade.Types
import Data.Time.Clock
import Data.Time.Clock.POSIX
import Data.Maybe
import Control.Concurrent.MVar
import Control.Concurrent
import System.Log.Logger
import Database.HDBC
import Database.HDBC.PostgreSQL
import Database.HDBC.Sqlite3
import Control.Monad
import Control.Monad.Loops
data TimeInterval = TimeInterval UTCTime UTCTime
data Timeframe = Timeframe Int
timeframeDaily = Timeframe 86400
timeframeHour = Timeframe 3600
timeframeMinute = Timeframe 60
timeframeDaily :: Int -> Timeframe
timeframeDaily days = Timeframe (days * 86400)
timeframeHour :: Int -> Timeframe
timeframeHour hours = Timeframe (hours * 3600)
timeframeMinute :: Int -> Timeframe
timeframeMinute mins = Timeframe (mins * 60)
data DatabaseCommand = DBGet TickerId TimeInterval Timeframe | DBPut TickerId TimeInterval Timeframe (V.Vector Bar)
data DatabaseResponse = DBOk | DBData [(TimeInterval, V.Vector Bar)] | DBError T.Text
data DatabaseConfig = DatabaseConfig {
dbHost :: T.Text,
dbPath :: T.Text,
dbDatabase :: T.Text,
dbUser :: T.Text,
dbPassword :: T.Text
} deriving (Show, Eq)
data DatabaseInterface = DatabaseInterface {
tid :: ThreadId,
getData :: TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)],
putData :: TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO ()
}
type MdsHandle = Connection
startDatabase :: DatabaseConfig -> IO DatabaseInterface
startDatabase config = do
conn <- connectPostgreSQL (mkConnectionString config)
initDatabase :: DatabaseConfig -> IO MdsHandle
initDatabase config = do
conn <- connectSqlite3 (T.unpack $ dbPath config)
makeSchema conn
cmdVar <- newEmptyMVar
respVar <- newEmptyMVar
compVar <- newEmptyMVar
tid <- forkFinally (dbThread conn cmdVar respVar) (cleanup conn cmdVar respVar compVar)
return DatabaseInterface {
tid = tid,
getData = doGetData cmdVar respVar,
putData = doPutData cmdVar respVar }
return conn
where
makeSchema conn = runRaw conn "CREATE TABLE IF NOT EXISTS bars (id SERIAL PRIMARY KEY, ticker TEXT, timestamp BIGINT, timeframe INTEGER, open NUMERIC(20, 10), high NUMERIC(20, 10), low NUMERIC(20, 10), close NUMERIC(20,10), volume BIGINT);"
closeDatabase :: MdsHandle -> IO ()
closeDatabase = disconnect
getData :: MdsHandle -> TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)]
getData db tickerId interval@(TimeInterval start end) (Timeframe tfSec) = do
rows <- quickQuery' db "SELECT timestamp, timeframe, open, high, low, close, volume FROM bars WHERE ticker == ? AND timeframe == ? AND timestamp >= ? AND timestamp <= ? ORDER BY timestamp ASC;" [(toSql. T.unpack) tickerId, toSql tfSec, (toSql . utcTimeToPOSIXSeconds) start, (toSql . utcTimeToPOSIXSeconds) end]
return [(interval, V.fromList $ mapMaybe (barFromResult tickerId) rows)]
where
makeSchema conn = runRaw conn "CREATE TABLE IF NOT EXISTS bars (id SERIAL PRIMARY KEY, ticker TEXT, timestamp BIGINT, open NUMERIC(20, 10), high NUMERIC(20, 10), low NUMERIC(20, 10), close NUMERIC(20,10), volume BIGINT);"
mkConnectionString config = TL.unpack $ format "User ID={};Password={};Host={};Port=5432;Database={}" (dbUser config, dbPassword config, dbHost config, dbDatabase config)
dbThread conn cmdVar respVar = forever $ do
cmd <- readMVar cmdVar
handleCmd conn cmd >>= putMVar respVar
whileM_ (isJust <$> tryReadMVar respVar) yield
takeMVar cmdVar
cleanup conn cmdVar respVar compVar _ = disconnect conn >> putMVar compVar ()
handleCmd conn cmd = case cmd of
DBPut tickerId (TimeInterval start end) tf@(Timeframe timeframeSecs) bars -> do
delStmt <- prepare conn "DELETE FROM bars WHERE timestamp > ? AND timestamp < ? AND ticker == ? AND timeframe == ?;"
execute delStmt [(SqlPOSIXTime . utcTimeToPOSIXSeconds) start, (SqlPOSIXTime . utcTimeToPOSIXSeconds) end, (SqlString . T.unpack) tickerId, (SqlInteger . toInteger) timeframeSecs]
stmt <- prepare conn $ "INSERT INTO bars (ticker, timeframe, timestamp, open, high, low, close, volume)" ++
" values (?, ?, ?, ?, ?, ?, ?, ?); "
executeMany stmt (map (barToSql tf) $ V.toList bars)
return DBOk
DBGet tickerId interval@(TimeInterval start end) (Timeframe timeframeSecs) -> do
rows <- quickQuery' conn "SELECT timestamp, open, high, low, close, volume FROM bars WHERE ticker == ? AND timeframe == ? AND timestamp > ? AND timestamp < ?;" [(toSql. T.unpack) tickerId, toSql timeframeSecs, (toSql . utcTimeToPOSIXSeconds) start, (toSql . utcTimeToPOSIXSeconds) end]
return $ DBData [(interval, V.fromList $ mapMaybe (barFromResult tickerId) rows)]
barFromResult ticker [ts, open, high, low, close, volume] = Just Bar {
barSecurity = ticker,
barTimestamp = fromSql ts,
barOpen = fromRational $ fromSql open,
barHigh = fromRational $ fromSql high,
barLow = fromRational $ fromSql low,
barClose = fromRational $ fromSql close,
barVolume = fromSql volume
}
barFromResult ticker [ts, _, open, high, low, close, vol] = Just Bar {
barSecurity = ticker,
barTimestamp = fromSql ts,
barOpen = fromDouble $ fromSql open,
barHigh = fromDouble $ fromSql high,
barLow = fromDouble $ fromSql low,
barClose = fromDouble $ fromSql close,
barVolume = fromSql vol
}
barFromResult _ _ = Nothing
putData :: MdsHandle -> TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO ()
putData db tickerId (TimeInterval start end) tf@(Timeframe tfSec) bars = do
delStmt <- prepare db "DELETE FROM bars WHERE timestamp >= ? AND timestamp <= ? AND ticker == ? AND timeframe == ?;"
void $ execute delStmt [(SqlPOSIXTime . utcTimeToPOSIXSeconds) start, (SqlPOSIXTime . utcTimeToPOSIXSeconds) end, (SqlString . T.unpack) tickerId, (SqlInteger . toInteger) tfSec]
stmt <- prepare db $ "INSERT INTO bars (ticker, timeframe, timestamp, open, high, low, close, volume)" ++
" values (?, ?, ?, ?, ?, ?, ?, ?); "
executeMany stmt (map (barToSql tf) $ V.toList bars)
where
barToSql :: Timeframe -> Bar -> [SqlValue]
barToSql (Timeframe timeframeSecs) bar = [(SqlString . T.unpack . barSecurity) bar,
(SqlInteger . toInteger) timeframeSecs,
(SqlRational . toRational . barOpen) bar,
(SqlRational . toRational . barHigh) bar,
(SqlRational . toRational . barLow) bar,
(SqlRational . toRational . barClose) bar,
(SqlRational . toRational . barVolume) bar ]
stopDatabase :: MVar () -> DatabaseInterface -> IO ()
stopDatabase compVar db = killThread (tid db) >> readMVar compVar
doGetData :: MVar DatabaseCommand -> MVar DatabaseResponse -> TickerId -> TimeInterval -> Timeframe -> IO [(TimeInterval, V.Vector Bar)]
doGetData cmdVar respVar tickerId timeInterval timeframe = do
putMVar cmdVar (DBGet tickerId timeInterval timeframe)
resp <- takeMVar respVar
case resp of
DBData x -> return x
DBError err -> do
warningM "DB.Client" $ "Error while calling getData: " ++ show err
return []
_ -> do
warningM "DB.Client" "Unexpected response"
return []
doPutData :: MVar DatabaseCommand -> MVar DatabaseResponse -> TickerId -> TimeInterval -> Timeframe -> V.Vector Bar -> IO ()
doPutData cmdVar respVar tickerId timeInterval timeframe bars = do
putMVar cmdVar (DBPut tickerId timeInterval timeframe bars)
resp <- takeMVar respVar
case resp of
DBOk -> return ()
DBError err -> do
warningM "DB.Client" $ "Error while calling putData: " ++ show err
return ()
_ -> do
warningM "DB.Client" "Unexpected response"
return ()
(SqlPOSIXTime . utcTimeToPOSIXSeconds . barTimestamp) bar,
(SqlDouble . toDouble . barOpen) bar,
(SqlDouble . toDouble . barHigh) bar,
(SqlDouble . toDouble . barLow) bar,
(SqlDouble . toDouble . barClose) bar,
(SqlInteger . barVolume) bar ]

32
src/ATrade/MDS/HistoryServer.hs

@ -1,14 +1,42 @@ @@ -1,14 +1,42 @@
module ATrade.MDS.HistoryServer (
HistoryServer,
startHistoryServer
) where
import System.ZMQ4
import ATrade.MDS.Database
import ATrade.MDS.Protocol
import Control.Concurrent
import Control.Monad
import Data.Aeson
import Data.List.NonEmpty
import qualified Data.Vector as V
import Safe
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
data HistoryServer = HistoryServer ThreadId
}
startHistoryServer :: DatabaseInterface -> Context -> IO HistoryServer
startHistoryServer :: MdsHandle -> Context -> IO HistoryServer
startHistoryServer db ctx = do
sock <- socket ctx Router
tid <- forkIO $ serve db sock
return $ HistoryServer tid
serve :: (Sender a, Receiver a) => MdsHandle -> Socket a -> IO ()
serve db sock = forever $ do
rq <- receiveMulti sock
let maybeCmd = (BL.fromStrict <$> rq `atMay` 2) >>= decode
case (headMay rq, maybeCmd) of
(Just peerId, Just cmd) -> handleCmd peerId cmd
_ -> return ()
where
handleCmd :: B.ByteString -> MDSRequest -> IO ()
handleCmd peerId cmd = case cmd of
rq -> do
qdata <- getData db (rqTicker rq) (TimeInterval (rqFrom rq) (rqTo rq)) (Timeframe (rqTimeframe rq))
bytes <- serializeBars $ V.concat $ fmap snd qdata
sendMulti sock $ peerId :| B.empty : bytes
serializeBars = undefined

23
src/ATrade/MDS/Protocol.hs

@ -0,0 +1,23 @@ @@ -0,0 +1,23 @@
{-# LANGUAGE DeriveGeneric #-}
module ATrade.MDS.Protocol (
MDSRequest(..)
) where
import GHC.Generics
import ATrade.Types
import Data.Aeson
import Data.Time.Clock
data MDSRequest = RequestData {
rqTicker :: TickerId,
rqFrom :: UTCTime,
rqTo :: UTCTime,
rqTimeframe :: Int
} deriving (Generic, Show, Eq)
instance ToJSON MDSRequest
instance FromJSON MDSRequest

6
src/Lib.hs

@ -1,6 +0,0 @@ @@ -1,6 +0,0 @@
module Lib
( someFunc
) where
someFunc :: IO ()
someFunc = putStrLn "someFunc"

5
stack.yaml

@ -15,7 +15,7 @@ @@ -15,7 +15,7 @@
# resolver:
# name: custom-snapshot
# location: "./custom-snapshot.yaml"
resolver: lts-7.4
resolver: lts-11.9
# User packages to be built.
# Various formats can be used as shown in the example below.
@ -38,9 +38,10 @@ resolver: lts-7.4 @@ -38,9 +38,10 @@ resolver: lts-7.4
packages:
- '.'
- '../libatrade'
- '../zeromq4-haskell-zap'
# Dependency packages to be pulled from upstream that are not in the resolver
# (e.g., acme-missiles-0.3)
extra-deps: ["HDBC-postgresql-2.3.2.4", "datetime-0.3.1"]
extra-deps: ["HDBC-sqlite3-2.3.3.1", "datetime-0.3.1"]
# Override default flag values for local packages and extra-deps
flags: {}

77
test/Integration/Database.hs

@ -0,0 +1,77 @@ @@ -0,0 +1,77 @@
module Integration.Database (
testDatabase
) where
import Test.Tasty
import Test.Tasty.HUnit
import ATrade.MDS.Database
import ATrade.Types
import Control.Exception
import Data.DateTime
import Data.Time.Clock
import qualified Data.Text as T
import qualified Data.Vector as V
import System.IO.Temp
testDatabase :: TestTree
testDatabase = testGroup "Database tests" [ testOpenClose
, testPutGet
, testGetReturnsSorted ]
testOpenClose :: TestTree
testOpenClose = testCase "Open/Close" $
withSystemTempDirectory "test" $ \fp -> do
let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty
db <- initDatabase dbConfig
closeDatabase db
bar :: UTCTime -> Price -> Price -> Price -> Price -> Integer -> Bar
bar dt o h l c v = Bar { barSecurity = "FOO",
barTimestamp = dt,
barOpen = o,
barHigh = h,
barLow = l,
barClose = c,
barVolume = v }
testPutGet :: TestTree
testPutGet = testCase "Put/Get" $
withSystemTempDirectory "test" $ \fp -> do
let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty
bracket (initDatabase dbConfig) closeDatabase $ \db -> do
putData db "FOO" interval (timeframeMinute 1) bars
retrievedBars <- (snd . head) <$> getData db "FOO" interval (timeframeMinute 1)
assertEqual "Retreived bars are different from saved" bars retrievedBars
where
interval = TimeInterval (fromGregorian 2010 1 1 12 0 0) (fromGregorian 2010 1 1 12 5 0)
bars = V.fromList $ [
bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1,
bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1,
bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1
]
testGetReturnsSorted :: TestTree
testGetReturnsSorted = testCase "Get returns sorted vector" $
withSystemTempDirectory "test" $ \fp -> do
let dbConfig = DatabaseConfig (T.pack $ fp ++ "/test.db") T.empty T.empty T.empty
bracket (initDatabase dbConfig) closeDatabase $ \db -> do
putData db "FOO" interval (timeframeMinute 1) bars
retrievedBars <- (snd . head) <$> getData db "FOO" interval (timeframeMinute 1)
assertEqual "Retreived bars are not sorted" sortedBars retrievedBars
where
interval = TimeInterval (fromGregorian 2010 1 1 12 0 0) (fromGregorian 2010 1 1 12 5 0)
bars = V.fromList $ [
bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1,
bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1,
bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1
]
sortedBars = V.fromList $ [
bar (fromGregorian 2010 1 1 12 0 0) 10 11 9 10 1,
bar (fromGregorian 2010 1 1 12 1 0) 12 15 9 10 1,
bar (fromGregorian 2010 1 1 12 2 0) 13 15 9 12 1
]

12
test/Integration/Spec.hs

@ -0,0 +1,12 @@ @@ -0,0 +1,12 @@
module Integration.Spec (
integrationTests
) where
import Test.Tasty
import Test.Tasty.HUnit
import Integration.Database
integrationTests = testGroup "Integration tests" [ testDatabase ]

11
test/Spec.hs

@ -1,2 +1,11 @@ @@ -1,2 +1,11 @@
import Test.Tasty
import Test.Tasty.HUnit
import Integration.Spec
main :: IO ()
main = putStrLn "Test suite not yet implemented"
main = defaultMain tests
tests :: TestTree
tests = testGroup "Tests" [ integrationTests ]

Loading…
Cancel
Save