Skip to content

Commit

Permalink
ENH: first cut at margins for pivot_table. testing still needed, #114
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed Dec 12, 2011
1 parent 6a079a0 commit f57770c
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 4 deletions.
20 changes: 20 additions & 0 deletions pandas/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 1202,26 @@ def swaplevel(self, i, j):
return MultiIndex(levels=new_levels, labels=new_labels,
names=new_names)

def reorder_levels(self, order):
"""
Rearrange levels using input order. May not drop or duplicate levels
Parameters
----------
"""
try:
assert(set(order) == set(range(self.nlevels)))
except AssertionError:
raise Exception('New order must be permutation of range(%d)' %
self.nlevels)

new_levels = [self.levels[i] for i in order]
new_labels = [self.labels[i] for i in order]
new_names = [self.names[i] for i in order]

return MultiIndex(levels=new_levels, labels=new_labels,
names=new_names)

def __getslice__(self, i, j):
return self.__getitem__(slice(i, j))

Expand Down
60 changes: 56 additions & 4 deletions pandas/tools/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 2,7 @@
import numpy as np

def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
fill_value=None):
fill_value=None, margins=False):
"""
Create a spreadsheet-style pivot table as a DataFrame. The levels in the
pivot table will be stored in MultiIndex objects (hierarchical indexes) on
Expand All @@ -19,6 19,8 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
aggfunc : function, default numpy.mean
fill_value : scalar, default None
Value to replace missing values with
margins : boolean, default False
Add all row / columns (e.g. for subtotal / grand totals)
Examples
--------
Expand Down Expand Up @@ -59,15 61,14 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
else:
values_multi = False
values = [values]
else:
values = list(data.columns.drop(keys))

if values_passed:
data = data[keys values]

grouped = data.groupby(keys)

if values_passed and not values_multi:
grouped = grouped[values[0]]

agged = grouped.agg(aggfunc)

table = agged
Expand All @@ -77,10 78,61 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc=np.mean,
if fill_value is not None:
table = table.fillna(value=fill_value)

if margins:
table = _add_margins(table, data, values, rows=rows,
cols=cols, aggfunc=aggfunc)

# discard the top level
if values_passed and not values_multi:
table = table[values[0]]

return table

DataFrame.pivot_table = pivot_table

def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
if rows is not None:
col_margin = data[rows values].groupby(rows).agg(aggfunc)

# need to "interleave" the margins

table_pieces = []
margin_keys = []
for key, piece in table.groupby(level=0, axis=1):
all_key = (key, 'All') ('',) * (len(cols) - 1)
piece[all_key] = col_margin[key]
table_pieces.append(piece)
margin_keys.append(all_key)

result = table_pieces[0]
for piece in table_pieces[1:]:
result = result.join(piece)
else:
result = table
margin_keys = []

grand_margin = data[values].apply(aggfunc)

if cols is not None:
row_margin = data[cols values].groupby(cols).agg(aggfunc)
row_margin = row_margin.stack()

# slight hack
new_order = [len(cols)] range(len(cols))
row_margin.index = row_margin.index.reorder_levels(new_order)

key = ('All',) ('',) * (len(rows) - 1)

row_margin = row_margin.reindex(result.columns)
# populate grand margin
for k in margin_keys:
row_margin[k] = grand_margin[k[0]]

margin_dummy = DataFrame(row_margin, columns=[key]).T
result = result.append(margin_dummy)

return result

def _convert_by(by):
if by is None:
by = []
Expand Down
6 changes: 6 additions & 0 deletions pandas/tools/tests/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 27,9 @@ def test_pivot_table(self):
cols= 'C'
table = pivot_table(self.data, values='D', rows=rows, cols=cols)

table2 = self.data.pivot_table(values='D', rows=rows, cols=cols)
assert_frame_equal(table, table2)

# this works
pivot_table(self.data, values='D', rows=rows)

Expand Down Expand Up @@ -57,6 60,9 @@ def test_pivot_multi_values(self):
rows='A', cols=['B', 'C'], fill_value=0)
assert_frame_equal(result, expected)

def test_margins(self):
pass

if __name__ == '__main__':
import nose
nose.runmodule(argv=[__file__,'-vvs','-x','--pdb', '--pdb-failure'],
Expand Down

0 comments on commit f57770c

Please sign in to comment.