Skip to content

Commit

Permalink
Merge pull request matplotlib#391 from deltreey/mav_shift
Browse files Browse the repository at this point in the history
add ability to shift moving average on plots
  • Loading branch information
DanielGoldfarb authored May 13, 2021
2 parents 3c02aea d22a8cc commit d777f43
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 14 deletions.
43 changes: 33 additions & 10 deletions src/mplfinance/_arg_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 103,45 @@ def _get_valid_plot_types(plottype=None):


def _mav_validator(mav_value):
'''
'''
Value for mav (moving average) keyword may be:
scalar int greater than 1, or tuple of ints, or list of ints (greater than 1).
tuple or list limited to length of 7 moving averages (to keep the plot clean).
scalar int greater than 1, or tuple of ints, or list of ints (each greater than 1)
or a dict of `period` and `shift` each of which may be:
scalar int, or tuple of ints, or list of ints: each `period` int must be greater than 1
'''
if isinstance(mav_value,int) and mav_value > 1:
def _valid_mav(value, is_period=True):
if not isinstance(value,(tuple,list,int)):
return False
if isinstance(value,int):
return (value >= 2 or not is_period)
# Must be a tuple or list here:
for num in value:
if not isinstance(num,int) or (is_period and num < 2):
return False
return True
elif not isinstance(mav_value,tuple) and not isinstance(mav_value,list):

if not isinstance(mav_value,(tuple,list,int,dict)):
return False

if not len(mav_value) < 8:
if not isinstance(mav_value,dict):
return _valid_mav(mav_value)

else: #isinstance(mav_value,dict)
if 'period' not in mav_value: return False

period = mav_value['period']
if not _valid_mav(period): return False

if 'shift' not in mav_value: return True

shift = mav_value['shift']
if not _valid_mav(shift, False): return False
if isinstance(period,int) and isinstance(shift,int): return True
if isinstance(period,(tuple,list)) and isinstance(shift,(tuple,list)):
if len(period) != len(shift): return False
return True
return False
for num in mav_value:
if not isinstance(num,int) and num > 1:
return False
return True


def _hlines_validator(value):
if isinstance(value,dict):
Expand Down
2 changes: 1 addition & 1 deletion src/mplfinance/_version.py
Original file line number Diff line number Diff line change
@@ -1,5 1,5 @@

version_info = (0, 12, 7, 'alpha', 17)
version_info = (0, 12, 7, 'alpha', 18)

_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}

Expand Down
13 changes: 10 additions & 3 deletions src/mplfinance/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,8 978,12 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
mavgs = config['mav']
mavp_list = []
if mavgs is not None:
shift = None
if isinstance(mavgs,dict):
shift = mavgs['shift']
mavgs = mavgs['period']
if isinstance(mavgs,int):
mavgs = mavgs, # convert to tuple
mavgs = mavgs, # convert to tuple
if len(mavgs) > 7:
mavgs = mavgs[0:7] # take at most 7

Expand All @@ -988,8 992,11 @@ def _plot_mav(ax,config,xdates,prices,apmav=None,apwidth=None):
else:
mavc = None

for mav in mavgs:
mavprices = pd.Series(prices).rolling(mav).mean().values
for idx,mav in enumerate(mavgs):
mean = pd.Series(prices).rolling(mav).mean()
if shift is not None:
mean = mean.shift(periods=shift[idx])
mavprices = mean.values
lw = config['_width_config']['line_width']
if mavc:
ax.plot(xdates, mavprices, linewidth=lw, color=next(mavc))
Expand Down
Binary file added tests/reference_images/addplot12.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 20 additions & 0 deletions tests/test_addplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 354,23 @@ def test_addplot11(bolldata):
print('result=',result)
assert result is None

def test_addplot12(bolldata):

df = bolldata

fname = base '12.png'
tname = os.path.join(tdir,fname)
rname = os.path.join(refd,fname)

mpf.plot(df,type='candle',volume=True,savefig=tname,mav={'period':(20,40,60), 'shift': [5,10,20]})

tsize = os.path.getsize(tname)
print(glob.glob(tname),'[',tsize,'bytes',']')

rsize = os.path.getsize(rname)
print(glob.glob(rname),'[',rsize,'bytes',']')

result = compare_images(rname,tname,tol=IMGCOMP_TOLERANCE)
if result is not None:
print('result=',result)
assert result is None

0 comments on commit d777f43

Please sign in to comment.