From: <sa...@us...> - 2008-03-19 18:07:59
|
Revision: 5007 http://matplotlib.svn.sourceforge.net/matplotlib/?rev=5007&view=rev Author: sameerd Date: 2008-03-19 11:07:49 -0700 (Wed, 19 Mar 2008) Log Message: ----------- Added outerjoin, lefjoin and rightjoin support to rec_join Modified Paths: -------------- trunk/matplotlib/lib/matplotlib/mlab.py Added Paths: ----------- trunk/matplotlib/examples/rec_join_demo.py Added: trunk/matplotlib/examples/rec_join_demo.py =================================================================== --- trunk/matplotlib/examples/rec_join_demo.py (rev 0) +++ trunk/matplotlib/examples/rec_join_demo.py 2008-03-19 18:07:49 UTC (rev 5007) @@ -0,0 +1,27 @@ +import numpy as np +import matplotlib.mlab as mlab + + +r = mlab.csv2rec('data/aapl.csv') +r.sort() +r1 = r[-10:] + +# Create a new array +r2 = np.empty(12, dtype=[('date', '|O4'), ('high', np.float), + ('marker', np.float)]) +r2 = r2.view(np.recarray) +r2.date = r.date[-17:-5] +r2.high = r.high[-17:-5] +r2.marker = np.arange(12) + +print "r1:" +print mlab.rec2txt(r1) +print "r2:" +print mlab.rec2txt(r2) + +defaults = {'marker':-1, 'close':np.NaN, 'low':-4444.} + +for s in ('inner', 'outer', 'leftouter'): + rec = mlab.rec_join(['date', 'high'], r1, r2, + jointype=s, defaults=defaults) + print "\n%sjoin :\n%s" % (s, mlab.rec2txt(rec)) Modified: trunk/matplotlib/lib/matplotlib/mlab.py =================================================================== --- trunk/matplotlib/lib/matplotlib/mlab.py 2008-03-19 14:36:57 UTC (rev 5006) +++ trunk/matplotlib/lib/matplotlib/mlab.py 2008-03-19 18:07:49 UTC (rev 5007) @@ -2044,12 +2044,19 @@ return npy.rec.fromarrays(arrays, names=names) -def rec_join(key, r1, r2): + +def rec_join(key, r1, r2, jointype='inner', defaults=None): """ join record arrays r1 and r2 on key; key is a tuple of field names. if r1 and r2 have equal values on all the keys in the key tuple, then their fields will be merged into a new record array containing the intersection of the fields of r1 and r2 + + The jointype keyword can be 'inner', 'outer', 'leftouter'. + To do a rightouter join just reverse r1 and r2. + + The defaults keyword is a dictionary filled with + {column_name:default_value} pairs. """ for name in key: @@ -2067,17 +2074,22 @@ r1keys = set(r1d.keys()) r2keys = set(r2d.keys()) - keys = r1keys & r2keys + common_keys = r1keys & r2keys - r1ind = npy.array([r1d[k] for k in keys]) - r2ind = npy.array([r2d[k] for k in keys]) + r1ind = npy.array([r1d[k] for k in common_keys]) + r2ind = npy.array([r2d[k] for k in common_keys]) - # Make sure that the output rows have the same relative order as r1 - sortind = r1ind.argsort() + common_len = len(common_keys) + left_len = right_len = 0 + if jointype == "outer" or jointype == "leftouter": + left_keys = r1keys.difference(r2keys) + left_ind = npy.array([r1d[k] for k in left_keys]) + left_len = len(left_ind) + if jointype == "outer": + right_keys = r2keys.difference(r1keys) + right_ind = npy.array([r2d[k] for k in right_keys]) + right_len = len(right_ind) - r1 = r1[r1ind[sortind]] - r2 = r2[r2ind[sortind]] - r2 = rec_drop_fields(r2, r1.dtype.names) @@ -2103,13 +2115,31 @@ [desc for desc in r2.dtype.descr if desc[0] not in key ] ) - newrec = npy.empty(len(r1), dtype=newdtype) + newrec = npy.empty(common_len + left_len + right_len, dtype=newdtype) + + if jointype != 'inner' and defaults is not None: # fill in the defaults enmasse + newrec_fields = newrec.dtype.fields.keys() + for k, v in defaults.items(): + if k in newrec_fields: + newrec[k] = v + for field in r1.dtype.names: - newrec[field] = r1[field] + newrec[field][:common_len] = r1[field][r1ind] + if jointype == "outer" or jointype == "leftouter": + newrec[field][common_len:(common_len+left_len)] = r1[field][left_ind] for field in r2.dtype.names: - newrec[field] = r2[field] + newrec[field][:common_len] = r2[field][r2ind] + if jointype == "outer": + newrec[field][-right_len:] = r2[field][right_ind[right_ind.argsort()]] + # sort newrec using the same order as r1 + sort_indices = r1ind.copy() + if jointype == "outer" or jointype == "leftouter": + sort_indices = npy.append(sort_indices, left_ind) + newrec[:(common_len+left_len)] = newrec[sort_indices.argsort()] + + return newrec.view(npy.recarray) This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site. |