You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

186 lines
5.7 KiB

'''
'''
import random
import talib
import numpy
class Signal:
def __init__(self):
pass
def calculate(self, series):
pass
def get_text(self):
pass
class PriceComparisonSignalGenerator:
def __init__(self):
pass
def generate(self):
lhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE)
lhs_shift = random.randint(0, 10)
rhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE)
rhs_shift = random.randint(0, 10)
return PriceComparisonSignal(lhs, lhs_shift, rhs, rhs_shift)
class PriceComparisonSignal(Signal):
OPEN = 0
HIGH = 1
LOW = 2
CLOSE = 3
def __init__(self, lhs, lhs_shift, rhs, rhs_shift):
self.lhs = lhs
self.lhs_shift = lhs_shift
self.rhs = rhs
self.rhs_shift = rhs_shift
def calculate(self, series):
result = []
for i in range(0, series.length()):
result.append(self.calculate_at_index(series, i))
return result
def calculate_at_index(self, series, index):
try:
if self.lhs == PriceComparisonSignal.OPEN:
lhs = series.get_open(index - self.lhs_shift)
elif self.lhs == PriceComparisonSignal.HIGH:
lhs = series.get_high(index - self.lhs_shift)
elif self.lhs == PriceComparisonSignal.LOW:
lhs = series.get_low(index - self.lhs_shift)
elif self.lhs == PriceComparisonSignal.CLOSE:
lhs = series.get_close(index - self.lhs_shift)
else:
raise Exception('Invalid lhs type')
if self.rhs == PriceComparisonSignal.OPEN:
rhs = series.get_open(index - self.rhs_shift)
elif self.rhs == PriceComparisonSignal.HIGH:
rhs = series.get_high(index - self.rhs_shift)
elif self.rhs == PriceComparisonSignal.LOW:
rhs = series.get_low(index - self.rhs_shift)
elif self.rhs == PriceComparisonSignal.CLOSE:
rhs = series.get_close(index - self.rhs_shift)
else:
raise Exception('Invalid lhs type')
return lhs < rhs
except IndexError:
return False
def get_text(self):
return self.component_to_str(self.lhs) + '[' + str(self.lhs_shift) + '] < ' + self.component_to_str(self.rhs) + '[' + str(self.rhs_shift) + ']'
def component_to_str(self, component):
if component == PriceComparisonSignal.OPEN:
return "open"
elif component == PriceComparisonSignal.HIGH:
return "high"
elif component == PriceComparisonSignal.LOW:
return "low"
elif component == PriceComparisonSignal.CLOSE:
return "close"
else:
return "??"
class RsiSignalGenerator:
def __init__(self):
pass
def generate(self):
period = random.randint(2, 30)
threshold = random.randrange(1, 9) * 10
ineq_type = random.randint(RsiSignal.LT, RsiSignal.GT)
return RsiSignal(period, threshold, ineq_type)
class RsiSignal(Signal):
LT = 0
GT = 1
def __init__(self, period, threshold, inequality_type):
self.period = period
self.threshold = threshold
self.inequality_type = inequality_type
if inequality_type == RsiSignal.LT:
self.inequality_sign_str = '<'
else:
self.inequality_sign_str = '>'
def calculate(self, series):
closes = numpy.array([series.get_close(i) for i in range(0, series.length())])
rsi = talib.RSI(closes, self.period)
result = [self.calc_signal(v) for v in rsi]
return result
def calc_signal(self, value):
if self.inequality_type == RsiSignal.LT:
return value < self.threshold
else:
return value > self.threshold
def get_text(self):
return "rsi(c, " + str(self.period) + ') ' + self.inequality_sign_str + ' ' + str(self.threshold)
class AtrSignalGenerator:
def __init__(self):
pass
def generate(self):
period = random.randint(2, 30)
threshold = random.randint(1, 30) * 0.001
ineq_type = random.randint(AtrSignal.LT, AtrSignal.GT)
return AtrSignal(period, threshold, ineq_type)
class AtrSignal(Signal):
LT = 0
GT = 1
def __init__(self, period, threshold_factor, inequality_type):
self.period = period
self.threshold_factor = threshold_factor
self.inequality_type = inequality_type
if inequality_type == RsiSignal.LT:
self.inequality_sign_str = '<'
else:
self.inequality_sign_str = '>'
def calculate(self, series):
closes = numpy.array([series.get_close(i) for i in range(0, series.length())])
highs = numpy.array([series.get_high(i) for i in range(0, series.length())])
lows = numpy.array([series.get_low(i) for i in range(0, series.length())])
atr = talib.ATR(highs, lows, closes, self.period)
result = [self.calc_signal(v, c) for (v, c) in zip(atr, closes)]
return result
def calc_signal(self, value, close):
if self.inequality_type == AtrSignal.LT:
return value < self.threshold_factor * close
else:
return value > self.threshold_factor * close
def get_text(self):
return "atr(" + str(self.period) + ') ' + self.inequality_sign_str + ' close[0] * ' + "{:.3f}".format(self.threshold_factor)