diff --git a/data/signal.py b/data/signal.py
index 1cb7e89..bb13f64 100644
--- a/data/signal.py
+++ b/data/signal.py
@@ -1,9 +1,11 @@
'''
'''
+import copy
import random
-import talib
import numpy
+import talib
+
class Signal:
@@ -16,6 +18,9 @@ class Signal:
def get_text(self):
pass
+ def mutate(self, factor):
+ return copy.deepcopy(self)
+
class PriceComparisonSignalGenerator:
def __init__(self):
@@ -59,7 +64,7 @@ class PriceComparisonSignal(Signal):
elif self.lhs == PriceComparisonSignal.CLOSE:
lhs = series.get_close(index - self.lhs_shift)
else:
- raise Exception('Invalid lhs type')
+ raise Exception('Invalid lhs type: ' + str(self.lhs))
if self.rhs == PriceComparisonSignal.OPEN:
rhs = series.get_open(index - self.rhs_shift)
@@ -92,6 +97,18 @@ class PriceComparisonSignal(Signal):
else:
return "??"
+ def mutate(self, factor):
+ mutation_type = random.randint(0, 3)
+ if mutation_type == 0:
+ self.lhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE)
+ elif mutation_type == 1:
+ self.rhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE)
+ elif mutation_type == 2:
+ self.lhs_shift = max(0, self.lhs_shift + random.randint(-int(10 * factor + 1), int(10 * factor + 1)))
+ elif mutation_type == 3:
+ self.rhs_shift = max(0, self.rhs_shift + random.randint(-int(10 * factor + 1), int(10 * factor + 1)))
+
+
class RsiSignalGenerator:
@@ -99,20 +116,22 @@ class RsiSignalGenerator:
pass
def generate(self):
+ shift = random.randint(0, 5)
period = random.randint(2, 30)
threshold = random.randrange(1, 9) * 10
ineq_type = random.randint(RsiSignal.LT, RsiSignal.GT)
- return RsiSignal(period, threshold, ineq_type)
+ return RsiSignal(period, threshold, ineq_type, shift)
class RsiSignal(Signal):
LT = 0
GT = 1
- def __init__(self, period, threshold, inequality_type):
+ def __init__(self, period, threshold, inequality_type, shift):
self.period = period
self.threshold = threshold
self.inequality_type = inequality_type
+ self.shift = shift
if inequality_type == RsiSignal.LT:
self.inequality_sign_str = '<'
else:
@@ -124,7 +143,10 @@ class RsiSignal(Signal):
rsi = talib.RSI(closes, self.period)
result = [self.calc_signal(v) for v in rsi]
- return result
+ if self.shift == 0:
+ return result
+ else:
+ return [False] * self.shift + result[:-self.shift]
def calc_signal(self, value):
if self.inequality_type == RsiSignal.LT:
@@ -133,10 +155,17 @@ class RsiSignal(Signal):
return value > self.threshold
def get_text(self):
- return "rsi(c, " + str(self.period) + ') ' + self.inequality_sign_str + ' ' + str(self.threshold)
+ return "rsi(c, {:d})[{:d}] {:s} {:d}".format(self.period, self.shift, self.inequality_sign_str, int(self.threshold))
+
+ def mutate(self, factor):
+ mutation_type = random.randint(0, 2)
+ if mutation_type == 0:
+ self.period = max(2, self.period + random.randint(-int(10 * factor + 1), int(10 * factor + 1)))
+ elif mutation_type == 1:
+ self.threshold = random.randrange(1, 9) * 10
+ elif mutation_type == 2:
+ random.randint(RsiSignal.LT, RsiSignal.GT)
-
-
class AtrSignalGenerator:
def __init__(self):
@@ -157,7 +186,7 @@ class AtrSignal(Signal):
self.period = period
self.threshold_factor = threshold_factor
self.inequality_type = inequality_type
- if inequality_type == RsiSignal.LT:
+ if inequality_type == AtrSignal.LT:
self.inequality_sign_str = '<'
else:
self.inequality_sign_str = '>'
@@ -179,8 +208,454 @@ class AtrSignal(Signal):
return value > self.threshold_factor * close
def get_text(self):
- return "atr(" + str(self.period) + ') ' + self.inequality_sign_str + ' close[0] * ' + "{:.3f}".format(self.threshold_factor)
+ return "atr(" + str(self.period) + ') ' + self.inequality_sign_str + ' close[0] * ' + "{:.3f}".format(self.threshold_factor)
+
+ def mutate(self, factor):
+ mutation_type = random.randint(0, 2)
+ if mutation_type == 0:
+ self.period = max(2, self.period + random.randint(-int(10 * factor + 1), int(10 * factor + 1)))
+ elif mutation_type == 1:
+ self.threshold_factor = random.randint(1, 30) * 0.001
+ elif mutation_type == 2:
+ random.randint(AtrSignal.LT, AtrSignal.GT)
+
+class AtrDeltaSignalGenerator:
+
+ def __init__(self):
+ pass
+
+ def generate(self):
+ period = random.randint(2, 30)
+ threshold = random.randint(1, 15) * 0.2
+ ineq_type = random.randint(AtrDeltaSignal.LT, AtrDeltaSignal.GT)
+ sign = random.randint(AtrDeltaSignal.PLUS, AtrDeltaSignal.MINUS)
+ return AtrDeltaSignal(period, threshold, ineq_type, sign)
+
+class AtrDeltaSignal(Signal):
+
+ LT = 0
+ GT = 1
+
+ PLUS = 0
+ MINUS = 1
+
+ def __init__(self, period, threshold_factor, inequality_type, sign):
+ self.period = period
+ self.threshold_factor = threshold_factor
+ self.inequality_type = inequality_type
+ if inequality_type == AtrDeltaSignal.LT:
+ self.inequality_sign_str = '<'
+ else:
+ self.inequality_sign_str = '>'
+
+ self.sign = sign
+ if self.sign == AtrDeltaSignal.PLUS:
+ self.sign_str = '+'
+ else:
+ self.sign_str = '-'
+
+ def calculate(self, series):
+ closes = numpy.array([series.get_close(i) for i in range(0, series.length())])
+ highs = numpy.array([series.get_high(i) for i in range(0, series.length())])
+ lows = numpy.array([series.get_low(i) for i in range(0, series.length())])
+
+ atr = talib.ATR(highs, lows, closes, self.period)
+
+ result = [False]
+
+ for i in range(1, len(closes)):
+ result.append(self.calc_signal(atr[i], closes[i], closes[i - 1]))
+
+ return result
+
+ def calc_signal(self, value, c1, c2):
+ if self.sign == AtrDeltaSignal.PLUS:
+ if self.inequality_type == AtrDeltaSignal.LT:
+ return c1 < c2 + value * self.threshold_factor
+ else:
+ return c1 > c2 + value * self.threshold_factor
+ else:
+ if self.inequality_type == AtrDeltaSignal.LT:
+ return c1 < c2 - value * self.threshold_factor
+ else:
+ return c1 > c2 - value * self.threshold_factor
+
+ def get_text(self):
+ return 'close[0] {:s} close[1] {:s} atr({:d}) * {:f}'.format(self.inequality_sign_str, self.sign_str, self.period, self.threshold_factor)
+
+ def mutate(self, factor):
+ mutation_type = random.randint(0, 2)
+ if mutation_type == 0:
+ self.period = max(2, self.period + random.randint(-int(10 * factor + 1), int(10 * factor + 1)))
+ elif mutation_type == 1:
+ self.threshold_factor = random.randint(1, 30) * 0.001
+ elif mutation_type == 2:
+ random.randint(AtrDeltaSignal.LT, AtrDeltaSignal.GT)
+
+
+class DayOfWeekSignalGenerator:
+
+ def __init__(self):
+ pass
+
+ def generate(self):
+ dow = random.randint(0, 6)
+ return DayOfWeekSignal(dow)
+
+
+class DayOfWeekSignal(Signal):
+
+ def __init__(self, day_of_week):
+ self.day_of_week = day_of_week
+
+ def calculate(self, series):
+ result = []
+ for i in range(0, series.length()):
+ result.append(series.get_dt(i).date().weekday() == self.day_of_week)
+
+ return result
+
+ def get_text(self):
+ return "day_of_week == " + self.dow_str(self.day_of_week)
+
+ def dow_str(self, dow):
+ return ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday'][dow]
+
+ def mutate(self, factor):
+ self.day_of_week = random.randint(0, 6)
+
+
+class DayOfMonthSignalGenerator:
+
+ def __init__(self):
+ pass
+
+ def generate(self):
+ month_day = random.randint(1, 31)
+ ineq_type = random.randint(DayOfMonthSignal.LT, DayOfMonthSignal.GT)
+ return DayOfMonthSignal(month_day, ineq_type)
+
+class DayOfMonthSignal(Signal):
+
+ LT = 0
+ GT = 1
+ def __init__(self, day_of_month, inequality_type):
+ self.day_of_month = day_of_month
+ self.inequality_type = inequality_type
+ if inequality_type == CrtdrSignal.LT:
+ self.inequality_sign_str = '<'
+ else:
+ self.inequality_sign_str = '>'
-
\ No newline at end of file
+ def calculate(self, series):
+ result = []
+ for i in range(0, series.length()):
+ if self.inequality_type == DayOfMonthSignal.LT:
+ result.append(series.get_dt(i).date().day < self.day_of_month)
+ else:
+ result.append(series.get_dt(i).date().day > self.day_of_month)
+
+ return result
+
+ def get_text(self):
+ return "day_of_month {:s} {:d}".format(self.inequality_sign_str, self.day_of_month)
+
+ def mutate(self, factor):
+ self.day_of_month = random.randint(1, 31)
+
+
+class CrtdrSignalGenerator:
+
+ def __init__(self):
+ pass
+
+ def generate(self):
+ shift = random.randint(0, 10)
+ ineq_type = random.randint(CrtdrSignal.LT, CrtdrSignal.GT)
+ threshold = 0.05 * random.randint(1, 19)
+ return CrtdrSignal(shift, threshold, ineq_type)
+
+class CrtdrSignal(Signal):
+
+ LT = 0
+ GT = 1
+
+ def __init__(self, shift, threshold, inequality_type):
+ self.shift = shift
+ self.threshold = threshold
+ self.inequality_type = inequality_type
+ if inequality_type == CrtdrSignal.LT:
+ self.inequality_sign_str = '<'
+ else:
+ self.inequality_sign_str = '>'
+
+ def calculate(self, series):
+ result = []
+ for i in range(0, series.length()):
+ try:
+ h = series.get_high(i - self.shift)
+ l = series.get_low(i - self.shift)
+ c = series.get_close(i - self.shift)
+
+ if h > l:
+ if self.inequality_type == CrtdrSignal.LT:
+ result.append((c - l) / (h - l) < self.threshold)
+ else:
+ result.append((c - l) / (h - l) > self.threshold)
+ else:
+ result.append(False)
+ except IndexError:
+ result.append(False)
+
+ return result
+
+ def get_text(self):
+ return 'crtdr[{:d}] {:s} {:.2f}'.format(self.shift, self.inequality_sign_str, self.threshold)
+
+ def mutate(self, factor):
+ mutation_type = random.randint(0, 2)
+ if mutation_type == 0:
+ self.shift = max(0, self.shift + random.randint(-int(10 * factor + 1), int(10 * factor + 1)))
+ elif mutation_type == 1:
+ self.threshold = 0.05 * random.randint(1, 19)
+ elif mutation_type == 2:
+ random.randint(CrtdrSignal.LT, CrtdrSignal.GT)
+
+
+class SmaSignalGenerator:
+
+ def __init__(self):
+ pass
+
+ def generate(self):
+ lhs = random.randint(SmaSignal.OPEN, SmaSignal.CLOSE)
+ period = random.randint(2, 30)
+ rhs = random.randint(SmaSignal.OPEN, SmaSignal.CLOSE)
+ ineq_sign = random.randint(SmaSignal.LT, SmaSignal.GT)
+ return SmaSignal(period, lhs, rhs, ineq_sign)
+
+class SmaSignal(Signal):
+
+ OPEN = 0
+ HIGH = 1
+ LOW = 2
+ CLOSE = 3
+
+ LT = 0
+ GT = 1
+
+ def __init__(self, period, lhs, rhs, inequality_type):
+ self.period = period
+ self.lhs = lhs
+ self.rhs = rhs
+ self.inequality_type = inequality_type
+ if inequality_type == SmaSignal.LT:
+ self.inequality_sign_str = '<'
+ else:
+ self.inequality_sign_str = '>'
+
+ def calculate(self, series):
+ if self.lhs == SmaSignal.OPEN:
+ lhs = numpy.array([series.get_open(i) for i in range(0, series.length())])
+ elif self.lhs == SmaSignal.HIGH:
+ lhs = numpy.array([series.get_high(i) for i in range(0, series.length())])
+ elif self.lhs == SmaSignal.LOW:
+ lhs = numpy.array([series.get_low(i) for i in range(0, series.length())])
+ elif self.lhs == SmaSignal.CLOSE:
+ lhs = numpy.array([series.get_close(i) for i in range(0, series.length())])
+
+ if self.rhs == SmaSignal.OPEN:
+ rhs = numpy.array([series.get_open(i) for i in range(0, series.length())])
+ elif self.rhs == SmaSignal.HIGH:
+ rhs = numpy.array([series.get_high(i) for i in range(0, series.length())])
+ elif self.rhs == SmaSignal.LOW:
+ rhs = numpy.array([series.get_low(i) for i in range(0, series.length())])
+ elif self.rhs == SmaSignal.CLOSE:
+ rhs = numpy.array([series.get_close(i) for i in range(0, series.length())])
+
+ rhs_sma = talib.SMA(rhs, self.period)
+
+ result = [self.calc_signal(l, r) for (l, r) in zip(lhs, rhs_sma)]
+ return result
+
+ def calc_signal(self, l, r):
+ if self.inequality_type == SmaSignal.LT:
+ return l < r
+ else:
+ return l > r
+
+ def get_text(self):
+ return '{:s}[0] {:s} sma({:s}, {:d})'.format(self.component_to_str(self.lhs), self.inequality_sign_str, self.component_to_str(self.rhs), self.period)
+
+ def component_to_str(self, component):
+ if component == SmaSignal.OPEN:
+ return "open"
+ elif component == SmaSignal.HIGH:
+ return "high"
+ elif component == SmaSignal.LOW:
+ return "low"
+ elif component == SmaSignal.CLOSE:
+ return "close"
+ else:
+ return "??"
+
+
+class CciSignalGenerator:
+
+ def __init__(self):
+ pass
+
+ def generate(self):
+ period = random.randint(2, 30)
+ shift = random.randint(0, 5)
+ threshold = random.randint(1, 30) * 10
+ ineq_type = random.randint(CciSignal.LT, CciSignal.GT)
+ return CciSignal(shift, period, threshold, ineq_type)
+
+class CciSignal(Signal):
+
+ LT = 0
+ GT = 1
+
+ def __init__(self, shift, period, threshold, inequality_type):
+ self.shift = shift
+ self.period = period
+ self.threshold = threshold
+ self.inequality_type = inequality_type
+ if inequality_type == CciSignal.LT:
+ self.inequality_sign_str = '<'
+ else:
+ self.inequality_sign_str = '>'
+
+ def calculate(self, series):
+ closes = numpy.array([series.get_close(i) for i in range(0, series.length())])
+ highs = numpy.array([series.get_high(i) for i in range(0, series.length())])
+ lows = numpy.array([series.get_low(i) for i in range(0, series.length())])
+
+ cci = talib.CCI(highs, lows, closes, self.period)
+
+ result = [self.calc_signal(v) for v in cci]
+
+ if self.shift == 0:
+ return result
+ else:
+ return [False] * self.shift + result[:-self.shift]
+
+ def calc_signal(self, value):
+ if self.inequality_type == CciSignal.LT:
+ return value < self.threshold
+ else:
+ return value > self.threshold
+
+ def get_text(self):
+ return "cci({:d})[{:d}] {:s} {:f}".format(self.period, self.shift, self.inequality_sign_str, self.threshold)
+
+ def mutate(self, factor):
+ mutation_type = random.randint(0, 2)
+ if mutation_type == 0:
+ self.period = max(2, self.period + random.randint(-int(10 * factor + 1), int(10 * factor + 1)))
+ elif mutation_type == 1:
+ self.threshold_factor = random.randint(1, 30) * 0.001
+ elif mutation_type == 2:
+ random.randint(AtrSignal.LT, AtrSignal.GT)
+
+class BbandsSignalGenerator:
+
+ def __init__(self):
+ pass
+
+ def generate(self):
+ lhs = random.randint(BbandsSignal.OPEN, BbandsSignal.CLOSE)
+ period = random.randint(2, 30)
+ rhs = random.randint(BbandsSignal.OPEN, BbandsSignal.CLOSE)
+ ineq_sign = random.randint(BbandsSignal.LT, BbandsSignal.GT)
+ dev = random.randint(1, 6) * 0.5
+ band_type = random.randint(BbandsSignal.UP, BbandsSignal.DOWN)
+ return BbandsSignal(period, lhs, rhs, dev, ineq_sign, band_type)
+
+
+class BbandsSignal(Signal):
+
+ OPEN = 0
+ HIGH = 1
+ LOW = 2
+ CLOSE = 3
+
+ LT = 0
+ GT = 1
+
+ UP = 0
+ DOWN = 1
+
+ def __init__(self, period, lhs, rhs, dev, inequality_type, band_type):
+ self.period = period
+ self.lhs = lhs
+ self.rhs = rhs
+ self.dev = dev
+ self.inequality_type = inequality_type
+ self.band_type = band_type
+
+ if inequality_type == SmaSignal.LT:
+ self.inequality_sign_str = '<'
+ else:
+ self.inequality_sign_str = '>'
+
+ if inequality_type == BbandsSignal.UP:
+ self.band_type_str = 'up'
+ else:
+ self.band_type_str = 'down'
+
+ def calculate(self, series):
+ if self.lhs == BbandsSignal.OPEN:
+ lhs = numpy.array([series.get_open(i) for i in range(0, series.length())])
+ elif self.lhs == BbandsSignal.HIGH:
+ lhs = numpy.array([series.get_high(i) for i in range(0, series.length())])
+ elif self.lhs == BbandsSignal.LOW:
+ lhs = numpy.array([series.get_low(i) for i in range(0, series.length())])
+ elif self.lhs == BbandsSignal.CLOSE:
+ lhs = numpy.array([series.get_close(i) for i in range(0, series.length())])
+
+ if self.rhs == BbandsSignal.OPEN:
+ rhs = numpy.array([series.get_open(i) for i in range(0, series.length())])
+ elif self.rhs == BbandsSignal.HIGH:
+ rhs = numpy.array([series.get_high(i) for i in range(0, series.length())])
+ elif self.rhs == BbandsSignal.LOW:
+ rhs = numpy.array([series.get_low(i) for i in range(0, series.length())])
+ elif self.rhs == BbandsSignal.CLOSE:
+ rhs = numpy.array([series.get_close(i) for i in range(0, series.length())])
+
+ (upband, _, downband) = talib.BBANDS(rhs, self.period, self.dev, self.dev)
+
+
+
+ if self.band_type == BbandsSignal.UP:
+ band = upband
+ else:
+ band = downband
+
+ result = [self.calc_signal(l, r) for (l, r) in zip(lhs, band)]
+ return result
+
+ def calc_signal(self, l, r):
+ if self.inequality_type == BbandsSignal.LT:
+ return l < r
+ else:
+ return l > r
+
+ def get_text(self):
+ return '{:s}[0] {:s} bband({:s}, {:s}, {:.2f}, {:d})'.format(self.component_to_str(self.lhs), self.inequality_sign_str, self.component_to_str(self.rhs), self.band_type_str, self.dev, self.period)
+
+ def component_to_str(self, component):
+ if component == BbandsSignal.OPEN:
+ return "open"
+ elif component == BbandsSignal.HIGH:
+ return "high"
+ elif component == BbandsSignal.LOW:
+ return "low"
+ elif component == BbandsSignal.CLOSE:
+ return "close"
+ else:
+ return "??"
+
\ No newline at end of file
diff --git a/execution/executor.py b/execution/executor.py
index 9ba9b7a..d2f67ca 100644
--- a/execution/executor.py
+++ b/execution/executor.py
@@ -15,7 +15,7 @@ class Executor(object):
self.series = series
self.max_hold_bars = 1
- def execute(self, signals):
+ def execute(self, signals, long=True):
self.trades = []
sig_vectors = []
vec_length = 0
@@ -30,6 +30,7 @@ class Executor(object):
in_trade = False
current_entry_price = None
bar_counter = 0
+ entry_bar = 0
for i in range(0, vec_length):
if not in_trade:
@@ -42,11 +43,16 @@ class Executor(object):
if has_signal and i + 1 < vec_length:
in_trade = True
current_entry_price = self.series.get_open(i + 1)
+ entry_bar = i + 1
bar_counter = 0
else:
bar_counter += 1
if bar_counter >= self.max_hold_bars:
in_trade = False
- self.trades.append(Trade(current_entry_price, self.series.get_close(i)))
+ if long:
+ trade_dir = Trade.LONG
+ else:
+ trade_dir = Trade.SHORT
+ self.trades.append(Trade(current_entry_price, self.series.get_close(i), entry_bar, i, trade_dir))
return self.trades
\ No newline at end of file
diff --git a/execution/trade.py b/execution/trade.py
index c24b7da..c108874 100644
--- a/execution/trade.py
+++ b/execution/trade.py
@@ -8,7 +8,7 @@ class Trade:
LONG = 1
SHORT = 2
- def __init__(self, entry_price, exit_price, direction = LONG):
+ def __init__(self, entry_price, exit_price, entry_bar, exit_bar, direction = LONG):
'''
Constructor
'''
@@ -16,6 +16,8 @@ class Trade:
self.direction = direction
self.entry_price = entry_price
self.exit_price = exit_price
+ self.entry_bar = entry_bar
+ self.exit_bar = exit_bar
def pnl(self):
diff --git a/solver/solver.py b/solver/solver.py
index ecea999..38c7bb9 100644
--- a/solver/solver.py
+++ b/solver/solver.py
@@ -6,17 +6,23 @@ from execution.executor import Executor
import random
from math import inf
import numpy
+from PyQt5.Qt import pyqtSignal, QObject
+from PyQt5 import QtCore
+import talib
-class Solver():
+class Solver(QObject):
'''
'''
-
+
+ progress = pyqtSignal(int, int, name='progress')
+ done = pyqtSignal(list, name='done')
def __init__(self, series):
'''
Constructor
'''
+ super().__init__(None)
self.series = series
self.executor = Executor(series)
@@ -24,30 +30,49 @@ class Solver():
def add_generator(self, generator):
self.generators.append(generator)
+
+ @QtCore.pyqtSlot(dict)
+ def solve(self, params):
+ max_signals = 5
+ max_strategies = params.get('num_strategies', 1000)
+ results = []
- def solve(self):
- max_signals = 3
- max_strategies = 1000
- self.results = []
+ min_trades = params.get('min_trades', 0)
+ min_win_rate = params.get('min_win_rate', 0)
+ min_sharpe = params.get('min_sharpe', 0)
- for x in range(0, max_strategies):
+ is_long = params.get('direction', 'long') == 'long'
+ print(is_long)
+ while len(results) < max_strategies:
sig_num = random.randint(1, max_signals)
strategy = []
for i in range(0, sig_num):
strategy.append(random.choice(self.generators).generate())
- trades = self.executor.execute(strategy)
- result = self.evaluate_trades(trades)
- result['display_name'] = ' && '.join([signal.get_text() for signal in strategy])
- self.results.append(result)
+ trades = self.executor.execute(strategy, is_long)
+ if len(trades) >= min_trades:
+ result = self.evaluate_trades(trades)
+ if result['win_percentage'] > min_win_rate and result['sharpe'] > min_sharpe:
+ result['strategy'] = strategy
+ result['display_name'] = ' && '.join([signal.get_text() for signal in strategy])
+ result['trades'] = trades
+ results.append(result)
+ self.progress.emit(len(results), max_strategies)
- return self.results
+ self.done.emit(results)
def evaluate_trades(self, trades):
result = {}
profits = [x.pnl() for x in trades]
+ total_won = len(list(filter(lambda x: x.pnl() > 0, trades)))
+
+ if len(trades) > 0:
+ result['win_percentage'] = total_won / len(trades) * 100
+ else:
+ result['win_percentage'] = 0
+
result['trades_number'] = len(trades)
result['total_pnl'] = sum(profits)
diff --git a/ui/mainwindow.py b/ui/mainwindow.py
index b7fbec9..1524755 100644
--- a/ui/mainwindow.py
+++ b/ui/mainwindow.py
@@ -7,8 +7,13 @@ from .ui_mainwindow import Ui_MainWindow
from solver.solver import Solver
from data.series import Series
from data.signal import PriceComparisonSignalGenerator, RsiSignalGenerator,\
- AtrSignalGenerator
-from PyQt5.Qt import Qt
+ AtrSignalGenerator, DayOfWeekSignalGenerator, CrtdrSignalGenerator,\
+ AtrDeltaSignalGenerator, SmaSignalGenerator, DayOfMonthSignalGenerator,\
+ CciSignalGenerator, BbandsSignalGenerator
+from PyQt5.Qt import Qt, QFileDialog, QThread, Q_ARG, QMetaObject
+
+import pyqtgraph
+import numpy
class MainWindow(QMainWindow, Ui_MainWindow):
'''
@@ -20,24 +25,79 @@ class MainWindow(QMainWindow, Ui_MainWindow):
'''
super().__init__(parent)
self.setupUi(self)
-
+ self.work_thread = QThread()
+ self.work_thread.start()
+
+ def browse(self):
+ fname = QFileDialog.getOpenFileName(self, 'Open file')
+ if fname[0] != '':
+ self.e_filename.setText(fname[0])
+
+ def go(self):
+ self.tw_strategies.clear()
self.series = Series()
- self.series.load_from_finam_csv('/home/asakul/tmp/daily/RTSI_20000101_20171231_daily.csv')
+ self.series.load_from_finam_csv(self.e_filename.text())
self.solver = Solver(self.series)
self.solver.add_generator(PriceComparisonSignalGenerator())
self.solver.add_generator(RsiSignalGenerator())
self.solver.add_generator(AtrSignalGenerator())
- results = self.solver.solve()
+ self.solver.add_generator(AtrDeltaSignalGenerator())
+ self.solver.add_generator(DayOfWeekSignalGenerator())
+ self.solver.add_generator(DayOfMonthSignalGenerator())
+ self.solver.add_generator(SmaSignalGenerator())
+ self.solver.add_generator(CrtdrSignalGenerator())
+ self.solver.add_generator(CciSignalGenerator())
+ self.solver.add_generator(BbandsSignalGenerator())
+ params = { 'num_strategies' : self.sb_strategiesNum.value() }
+ if self.cb_minTradesFilter.isChecked():
+ params['min_trades'] = self.sb_minTrades.value()
+
+ if self.cb_minWinRate.isChecked():
+ params['min_win_rate'] = self.sb_minWinRate.value()
+ if self.cb_minSharpe.isChecked():
+ params['min_sharpe'] = self.sb_minSharpe.value()
+
+ if self.rb_long.isChecked():
+ params['direction'] = 'long'
+ else:
+ params['direction'] = 'short'
+
+ self.solver.done.connect(self.done)
+ self.solver.progress.connect(self.progress)
+
+ self.solver.moveToThread(self.work_thread)
+ QMetaObject.invokeMethod(self.solver, 'solve', Q_ARG(dict, params))
+ #results = self.solver.solve(params)
+
+ def done(self, results):
for result in results:
item = QTreeWidgetItem(self.tw_strategies)
item.setText(0, result['display_name'])
item.setText(1, str(result['trades_number']))
- item.setText(2, str(result['total_pnl']))
+ item.setText(2, "{:.4f}".format(result['total_pnl']))
item.setText(3, "{:.2f}".format(result['profit_factor']))
item.setText(4, "{:.2f}".format(result['sharpe']))
item.setText(5, "{:.2f}%".format(result['avg_percentage']))
+ item.setText(6, "{:.2f}%".format(result['win_percentage']))
item.setData(0, Qt.UserRole + 1, result)
+ for i in range(0, 7):
+ self.tw_strategies.resizeColumnToContents(i)
+
+ def progress(self, current, total):
+ if current < total:
+ self.pb_progress.setValue(float(current) / total * 100)
+ else:
+ self.pb_progress.setValue(100)
+
+
def strategyClicked(self, item, column):
- result = item.getData(0, Qt.UserRole + 1)
\ No newline at end of file
+ result = item.data(0, Qt.UserRole + 1)
+ pnl = numpy.cumsum([trade.pnl() for trade in result['trades']])
+ xs = [trade.entry_bar for trade in result['trades']]
+ pyqtgraph.plot(xs, pnl)
+
+ for trade in result['trades']:
+ print(trade.entry_bar, trade.entry_price, trade.exit_price)
+
diff --git a/ui/mainwindow.ui b/ui/mainwindow.ui
index 5b7eff9..b036c26 100644
--- a/ui/mainwindow.ui
+++ b/ui/mainwindow.ui
@@ -6,7 +6,7 @@
0
0
- 800
+ 1041
600
@@ -27,7 +27,87 @@
2
- -
+
-
+
+
+ Min. sharpe
+
+
+
+ -
+
+
+ 0
+
+
+
+ -
+
+
+ 100
+
+
+ 50
+
+
+
+ -
+
+
+ -
+
+
+ #Strategies to generate:
+
+
+
+ -
+
+
+ 1
+
+
+ 10000
+
+
+ 100
+
+
+
+ -
+
+
+ Go
+
+
+
+ -
+
+
+ 100
+
+
+ 10000
+
+
+ 1000
+
+
+
+ -
+
+
+ Qt::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
true
@@ -62,6 +142,59 @@
Avg. %
+
+
+ Win %
+
+
+
+
+ -
+
+
+ Long
+
+
+ true
+
+
+
+ -
+
+
+ Data source:
+
+
+
+ -
+
+
+ Browse
+
+
+
+ -
+
+
+ Min. trades:
+
+
+
+ -
+
+
+ Min. win rate
+
+
+
+ -
+
+
+ -
+
+
+ Short
+
@@ -71,7 +204,7 @@
0
0
- 800
+ 1041
27
@@ -96,8 +229,42 @@
+
+ b_browse
+ clicked()
+ MainWindow
+ browse()
+
+
+ 776
+ 42
+
+
+ 706
+ 56
+
+
+
+
+ e_go
+ clicked()
+ MainWindow
+ go()
+
+
+ 768
+ 69
+
+
+ 680
+ 84
+
+
+
strategyClicked(QTreeWidgetItem*,int)
+ browse()
+ go()
diff --git a/ui/ui_mainwindow.py b/ui/ui_mainwindow.py
index 91bed5a..36fb68c 100644
--- a/ui/ui_mainwindow.py
+++ b/ui/ui_mainwindow.py
@@ -11,18 +11,75 @@ from PyQt5 import QtCore, QtGui, QtWidgets
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
- MainWindow.resize(800, 600)
+ MainWindow.resize(1041, 600)
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.gridLayout = QtWidgets.QGridLayout(self.centralwidget)
self.gridLayout.setContentsMargins(2, 2, 2, 2)
self.gridLayout.setObjectName("gridLayout")
+ self.cb_minSharpe = QtWidgets.QCheckBox(self.centralwidget)
+ self.cb_minSharpe.setObjectName("cb_minSharpe")
+ self.gridLayout.addWidget(self.cb_minSharpe, 1, 6, 1, 1)
+ self.pb_progress = QtWidgets.QProgressBar(self.centralwidget)
+ self.pb_progress.setProperty("value", 0)
+ self.pb_progress.setObjectName("pb_progress")
+ self.gridLayout.addWidget(self.pb_progress, 5, 0, 1, 10)
+ self.sb_minWinRate = QtWidgets.QSpinBox(self.centralwidget)
+ self.sb_minWinRate.setMaximum(100)
+ self.sb_minWinRate.setProperty("value", 50)
+ self.sb_minWinRate.setObjectName("sb_minWinRate")
+ self.gridLayout.addWidget(self.sb_minWinRate, 1, 5, 1, 1)
+ self.e_filename = QtWidgets.QLineEdit(self.centralwidget)
+ self.e_filename.setObjectName("e_filename")
+ self.gridLayout.addWidget(self.e_filename, 0, 1, 1, 8)
+ self.label_2 = QtWidgets.QLabel(self.centralwidget)
+ self.label_2.setObjectName("label_2")
+ self.gridLayout.addWidget(self.label_2, 1, 0, 1, 1)
+ self.sb_minTrades = QtWidgets.QSpinBox(self.centralwidget)
+ self.sb_minTrades.setMinimum(1)
+ self.sb_minTrades.setMaximum(10000)
+ self.sb_minTrades.setProperty("value", 100)
+ self.sb_minTrades.setObjectName("sb_minTrades")
+ self.gridLayout.addWidget(self.sb_minTrades, 1, 3, 1, 1)
+ self.e_go = QtWidgets.QPushButton(self.centralwidget)
+ self.e_go.setObjectName("e_go")
+ self.gridLayout.addWidget(self.e_go, 1, 9, 1, 1)
+ self.sb_strategiesNum = QtWidgets.QSpinBox(self.centralwidget)
+ self.sb_strategiesNum.setMinimum(100)
+ self.sb_strategiesNum.setMaximum(10000)
+ self.sb_strategiesNum.setProperty("value", 1000)
+ self.sb_strategiesNum.setObjectName("sb_strategiesNum")
+ self.gridLayout.addWidget(self.sb_strategiesNum, 1, 1, 1, 1)
+ spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
+ self.gridLayout.addItem(spacerItem, 1, 8, 1, 1)
self.tw_strategies = QtWidgets.QTreeWidget(self.centralwidget)
self.tw_strategies.setObjectName("tw_strategies")
- self.gridLayout.addWidget(self.tw_strategies, 0, 0, 1, 1)
+ self.gridLayout.addWidget(self.tw_strategies, 4, 0, 1, 10)
+ self.rb_long = QtWidgets.QRadioButton(self.centralwidget)
+ self.rb_long.setChecked(True)
+ self.rb_long.setObjectName("rb_long")
+ self.gridLayout.addWidget(self.rb_long, 2, 0, 1, 1)
+ self.label = QtWidgets.QLabel(self.centralwidget)
+ self.label.setObjectName("label")
+ self.gridLayout.addWidget(self.label, 0, 0, 1, 1)
+ self.b_browse = QtWidgets.QPushButton(self.centralwidget)
+ self.b_browse.setObjectName("b_browse")
+ self.gridLayout.addWidget(self.b_browse, 0, 9, 1, 1)
+ self.cb_minTradesFilter = QtWidgets.QCheckBox(self.centralwidget)
+ self.cb_minTradesFilter.setObjectName("cb_minTradesFilter")
+ self.gridLayout.addWidget(self.cb_minTradesFilter, 1, 2, 1, 1)
+ self.cb_minWinRate = QtWidgets.QCheckBox(self.centralwidget)
+ self.cb_minWinRate.setObjectName("cb_minWinRate")
+ self.gridLayout.addWidget(self.cb_minWinRate, 1, 4, 1, 1)
+ self.sb_minSharpe = QtWidgets.QDoubleSpinBox(self.centralwidget)
+ self.sb_minSharpe.setObjectName("sb_minSharpe")
+ self.gridLayout.addWidget(self.sb_minSharpe, 1, 7, 1, 1)
+ self.rb_short = QtWidgets.QRadioButton(self.centralwidget)
+ self.rb_short.setObjectName("rb_short")
+ self.gridLayout.addWidget(self.rb_short, 3, 0, 1, 1)
MainWindow.setCentralWidget(self.centralwidget)
self.menubar = QtWidgets.QMenuBar(MainWindow)
- self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 27))
+ self.menubar.setGeometry(QtCore.QRect(0, 0, 1041, 27))
self.menubar.setObjectName("menubar")
MainWindow.setMenuBar(self.menubar)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
@@ -31,11 +88,16 @@ class Ui_MainWindow(object):
self.retranslateUi(MainWindow)
self.tw_strategies.itemClicked['QTreeWidgetItem*','int'].connect(MainWindow.strategyClicked)
+ self.b_browse.clicked.connect(MainWindow.browse)
+ self.e_go.clicked.connect(MainWindow.go)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
+ self.cb_minSharpe.setText(_translate("MainWindow", "Min. sharpe"))
+ self.label_2.setText(_translate("MainWindow", "#Strategies to generate:"))
+ self.e_go.setText(_translate("MainWindow", "Go"))
self.tw_strategies.setSortingEnabled(True)
self.tw_strategies.headerItem().setText(0, _translate("MainWindow", "Strategy"))
self.tw_strategies.headerItem().setText(1, _translate("MainWindow", "Trades #"))
@@ -43,4 +105,11 @@ class Ui_MainWindow(object):
self.tw_strategies.headerItem().setText(3, _translate("MainWindow", "PF"))
self.tw_strategies.headerItem().setText(4, _translate("MainWindow", "Sharpe"))
self.tw_strategies.headerItem().setText(5, _translate("MainWindow", "Avg. %"))
+ self.tw_strategies.headerItem().setText(6, _translate("MainWindow", "Win %"))
+ self.rb_long.setText(_translate("MainWindow", "Long"))
+ self.label.setText(_translate("MainWindow", "Data source:"))
+ self.b_browse.setText(_translate("MainWindow", "Browse"))
+ self.cb_minTradesFilter.setText(_translate("MainWindow", "Min. trades:"))
+ self.cb_minWinRate.setText(_translate("MainWindow", "Min. win rate"))
+ self.rb_short.setText(_translate("MainWindow", "Short"))