You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
186 lines
5.7 KiB
186 lines
5.7 KiB
|
8 years ago
|
'''
|
||
|
|
'''
|
||
|
|
import random
|
||
|
|
|
||
|
|
import talib
|
||
|
|
import numpy
|
||
|
|
|
||
|
|
class Signal:
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def calculate(self, series):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def get_text(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
class PriceComparisonSignalGenerator:
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def generate(self):
|
||
|
|
lhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE)
|
||
|
|
lhs_shift = random.randint(0, 10)
|
||
|
|
rhs = random.randint(PriceComparisonSignal.OPEN, PriceComparisonSignal.CLOSE)
|
||
|
|
rhs_shift = random.randint(0, 10)
|
||
|
|
return PriceComparisonSignal(lhs, lhs_shift, rhs, rhs_shift)
|
||
|
|
|
||
|
|
class PriceComparisonSignal(Signal):
|
||
|
|
|
||
|
|
OPEN = 0
|
||
|
|
HIGH = 1
|
||
|
|
LOW = 2
|
||
|
|
CLOSE = 3
|
||
|
|
|
||
|
|
def __init__(self, lhs, lhs_shift, rhs, rhs_shift):
|
||
|
|
self.lhs = lhs
|
||
|
|
self.lhs_shift = lhs_shift
|
||
|
|
self.rhs = rhs
|
||
|
|
self.rhs_shift = rhs_shift
|
||
|
|
|
||
|
|
def calculate(self, series):
|
||
|
|
result = []
|
||
|
|
for i in range(0, series.length()):
|
||
|
|
result.append(self.calculate_at_index(series, i))
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
def calculate_at_index(self, series, index):
|
||
|
|
try:
|
||
|
|
if self.lhs == PriceComparisonSignal.OPEN:
|
||
|
|
lhs = series.get_open(index - self.lhs_shift)
|
||
|
|
elif self.lhs == PriceComparisonSignal.HIGH:
|
||
|
|
lhs = series.get_high(index - self.lhs_shift)
|
||
|
|
elif self.lhs == PriceComparisonSignal.LOW:
|
||
|
|
lhs = series.get_low(index - self.lhs_shift)
|
||
|
|
elif self.lhs == PriceComparisonSignal.CLOSE:
|
||
|
|
lhs = series.get_close(index - self.lhs_shift)
|
||
|
|
else:
|
||
|
|
raise Exception('Invalid lhs type')
|
||
|
|
|
||
|
|
if self.rhs == PriceComparisonSignal.OPEN:
|
||
|
|
rhs = series.get_open(index - self.rhs_shift)
|
||
|
|
elif self.rhs == PriceComparisonSignal.HIGH:
|
||
|
|
rhs = series.get_high(index - self.rhs_shift)
|
||
|
|
elif self.rhs == PriceComparisonSignal.LOW:
|
||
|
|
rhs = series.get_low(index - self.rhs_shift)
|
||
|
|
elif self.rhs == PriceComparisonSignal.CLOSE:
|
||
|
|
rhs = series.get_close(index - self.rhs_shift)
|
||
|
|
else:
|
||
|
|
raise Exception('Invalid lhs type')
|
||
|
|
|
||
|
|
return lhs < rhs
|
||
|
|
|
||
|
|
except IndexError:
|
||
|
|
return False
|
||
|
|
|
||
|
|
def get_text(self):
|
||
|
|
return self.component_to_str(self.lhs) + '[' + str(self.lhs_shift) + '] < ' + self.component_to_str(self.rhs) + '[' + str(self.rhs_shift) + ']'
|
||
|
|
|
||
|
|
def component_to_str(self, component):
|
||
|
|
if component == PriceComparisonSignal.OPEN:
|
||
|
|
return "open"
|
||
|
|
elif component == PriceComparisonSignal.HIGH:
|
||
|
|
return "high"
|
||
|
|
elif component == PriceComparisonSignal.LOW:
|
||
|
|
return "low"
|
||
|
|
elif component == PriceComparisonSignal.CLOSE:
|
||
|
|
return "close"
|
||
|
|
else:
|
||
|
|
return "??"
|
||
|
|
|
||
|
|
|
||
|
|
class RsiSignalGenerator:
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def generate(self):
|
||
|
|
period = random.randint(2, 30)
|
||
|
|
threshold = random.randrange(1, 9) * 10
|
||
|
|
ineq_type = random.randint(RsiSignal.LT, RsiSignal.GT)
|
||
|
|
return RsiSignal(period, threshold, ineq_type)
|
||
|
|
|
||
|
|
class RsiSignal(Signal):
|
||
|
|
|
||
|
|
LT = 0
|
||
|
|
GT = 1
|
||
|
|
|
||
|
|
def __init__(self, period, threshold, inequality_type):
|
||
|
|
self.period = period
|
||
|
|
self.threshold = threshold
|
||
|
|
self.inequality_type = inequality_type
|
||
|
|
if inequality_type == RsiSignal.LT:
|
||
|
|
self.inequality_sign_str = '<'
|
||
|
|
else:
|
||
|
|
self.inequality_sign_str = '>'
|
||
|
|
|
||
|
|
def calculate(self, series):
|
||
|
|
closes = numpy.array([series.get_close(i) for i in range(0, series.length())])
|
||
|
|
|
||
|
|
rsi = talib.RSI(closes, self.period)
|
||
|
|
|
||
|
|
result = [self.calc_signal(v) for v in rsi]
|
||
|
|
return result
|
||
|
|
|
||
|
|
def calc_signal(self, value):
|
||
|
|
if self.inequality_type == RsiSignal.LT:
|
||
|
|
return value < self.threshold
|
||
|
|
else:
|
||
|
|
return value > self.threshold
|
||
|
|
|
||
|
|
def get_text(self):
|
||
|
|
return "rsi(c, " + str(self.period) + ') ' + self.inequality_sign_str + ' ' + str(self.threshold)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
class AtrSignalGenerator:
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def generate(self):
|
||
|
|
period = random.randint(2, 30)
|
||
|
|
threshold = random.randint(1, 30) * 0.001
|
||
|
|
ineq_type = random.randint(AtrSignal.LT, AtrSignal.GT)
|
||
|
|
return AtrSignal(period, threshold, ineq_type)
|
||
|
|
|
||
|
|
class AtrSignal(Signal):
|
||
|
|
|
||
|
|
LT = 0
|
||
|
|
GT = 1
|
||
|
|
|
||
|
|
def __init__(self, period, threshold_factor, inequality_type):
|
||
|
|
self.period = period
|
||
|
|
self.threshold_factor = threshold_factor
|
||
|
|
self.inequality_type = inequality_type
|
||
|
|
if inequality_type == RsiSignal.LT:
|
||
|
|
self.inequality_sign_str = '<'
|
||
|
|
else:
|
||
|
|
self.inequality_sign_str = '>'
|
||
|
|
|
||
|
|
def calculate(self, series):
|
||
|
|
closes = numpy.array([series.get_close(i) for i in range(0, series.length())])
|
||
|
|
highs = numpy.array([series.get_high(i) for i in range(0, series.length())])
|
||
|
|
lows = numpy.array([series.get_low(i) for i in range(0, series.length())])
|
||
|
|
|
||
|
|
atr = talib.ATR(highs, lows, closes, self.period)
|
||
|
|
|
||
|
|
result = [self.calc_signal(v, c) for (v, c) in zip(atr, closes)]
|
||
|
|
return result
|
||
|
|
|
||
|
|
def calc_signal(self, value, close):
|
||
|
|
if self.inequality_type == AtrSignal.LT:
|
||
|
|
return value < self.threshold_factor * close
|
||
|
|
else:
|
||
|
|
return value > self.threshold_factor * close
|
||
|
|
|
||
|
|
def get_text(self):
|
||
|
|
return "atr(" + str(self.period) + ') ' + self.inequality_sign_str + ' close[0] * ' + "{:.3f}".format(self.threshold_factor)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|