From: <jd...@us...> - 2009-04-14 14:29:36
|
Revision: 7041 http://matplotlib.svn.sourceforge.net/matplotlib/?rev=7041&view=rev Author: jdh2358 Date: 2009-04-14 14:29:31 +0000 (Tue, 14 Apr 2009) Log Message: ----------- added mpl_toolkits.mplot3d Modified Paths: -------------- trunk/matplotlib/CHANGELOG trunk/matplotlib/doc/users/credits.rst trunk/matplotlib/doc/users/toolkits.rst trunk/matplotlib/examples/pylab_examples/finance_work2.py trunk/matplotlib/examples/tests/backend_driver.py trunk/matplotlib/setup.py Added Paths: ----------- trunk/matplotlib/examples/mplot3d/ trunk/matplotlib/examples/mplot3d/demo.py trunk/matplotlib/lib/mpl_toolkits/mplot3d/ trunk/matplotlib/lib/mpl_toolkits/mplot3d/__init__.py trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py trunk/matplotlib/lib/mpl_toolkits/mplot3d/proj3d.py Modified: trunk/matplotlib/CHANGELOG =================================================================== --- trunk/matplotlib/CHANGELOG 2009-04-13 03:02:41 UTC (rev 7040) +++ trunk/matplotlib/CHANGELOG 2009-04-14 14:29:31 UTC (rev 7041) @@ -1,6 +1,11 @@ ====================================================================== -2008-04-12 Release 0.98.5.3 at r7038 +2008-04-14 Added Jonathan Taylor's Reinier Heeres' port of John + Porters' mplot3d to svn trunk. Package in + mpl_toolkits.mplot3d and demo is examples/mplot3d/demo.py. + Thanks Reiner + + 2009-04-06 The pdf backend now escapes newlines and linefeeds in strings. Fixes sf bug #2708559; thanks to Tiago Pereira for the report. Modified: trunk/matplotlib/doc/users/credits.rst =================================================================== --- trunk/matplotlib/doc/users/credits.rst 2009-04-13 03:02:41 UTC (rev 7040) +++ trunk/matplotlib/doc/users/credits.rst 2009-04-14 14:29:31 UTC (rev 7041) @@ -166,4 +166,9 @@ base. He also rewrote the transformation infrastructure to support custom projections and scales. +John Porter, Jonathon Taylor and Reinier Heeres + John Porter wrote the mplot3d module for basic 3D plotting in + matplotlib, and Jonathon Taylor and Reinier Heeres ported it to the + refactored transform trunk. + Modified: trunk/matplotlib/doc/users/toolkits.rst =================================================================== --- trunk/matplotlib/doc/users/toolkits.rst 2009-04-13 03:02:41 UTC (rev 7040) +++ trunk/matplotlib/doc/users/toolkits.rst 2009-04-14 14:29:31 UTC (rev 7041) @@ -37,10 +37,19 @@ Natgrid ======== - + mpl_toolkits.natgrid is an interface to natgrid C library for gridding irregularly spaced data. This requires a separate installation of the natgrid toolkit from the sourceforge `download <http://sourceforge.net/project/showfiles.php?group_id=80706&package_id=142792>`_ page. - + +.. _toolkit_mplot3d: + +mplot3d +=========== + +mpl_toolkits.mplot3d provides some basic 3D plotting (scatter, surf, +line, mesh) tools. Not the fastest or feature complete 3D library out +there, but ships with matplotlib and thus may be a lighter weight +solution for some use cases. Added: trunk/matplotlib/examples/mplot3d/demo.py =================================================================== --- trunk/matplotlib/examples/mplot3d/demo.py (rev 0) +++ trunk/matplotlib/examples/mplot3d/demo.py 2009-04-14 14:29:31 UTC (rev 7041) @@ -0,0 +1,138 @@ +import random +import numpy as np +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as axes3d +from matplotlib.colors import Normalize, colorConverter + +def test_scatter(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + n = 100 + for c,zl,zh in [('r',-50,-25),('b',-30,-5)]: + xs,ys,zs = zip(* + [(random.randrange(23,32), + random.randrange(100), + random.randrange(zl,zh) + ) for i in range(n)]) + ax.scatter3D(xs,ys,zs, c=c) + + ax.set_xlabel('------------ X Label --------------------') + ax.set_ylabel('------------ Y Label --------------------') + ax.set_zlabel('------------ Z Label --------------------') + +def test_wire(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + X,Y,Z = axes3d.get_test_data(0.05) + ax.plot_wireframe(X,Y,Z, rstride=10,cstride=10) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + +def test_surface(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + X,Y,Z = axes3d.get_test_data(0.05) + ax.plot_surface(X,Y,Z, rstride=10,cstride=10) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + +def test_contour(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + X,Y,Z = axes3d.get_test_data(0.05) + cset = ax.contour3D(X,Y,Z) + ax.clabel(cset, fontsize=9, inline=1) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + +def test_contourf(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + X,Y,Z = axes3d.get_test_data(0.05) + cset = ax.contourf3D(X,Y,Z) + ax.clabel(cset, fontsize=9, inline=1) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + +def test_plot(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + xs = np.arange(0,4*np.pi+0.1,0.1) + ys = np.sin(xs) + ax.plot(xs,ys, label='zl') + ax.plot(xs,ys+max(xs),label='zh') + ax.plot(xs,ys,dir='x', label='xl') + ax.plot(xs,ys,dir='x', z=max(xs),label='xh') + ax.plot(xs,ys,dir='y', label='yl') + ax.plot(xs,ys,dir='y', z=max(xs), label='yh') + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.legend() + +def test_polys(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + cc = lambda arg: colorConverter.to_rgba(arg, alpha=0.6) + + xs = np.arange(0,10,0.4) + verts = [] + zs = [0.0,1.0,2.0,3.0] + for z in zs: + ys = [random.random() for x in xs] + ys[0],ys[-1] = 0,0 + verts.append(zip(xs,ys)) + + from matplotlib.collections import PolyCollection + poly = PolyCollection(verts, facecolors = [cc('r'),cc('g'),cc('b'), + cc('y')]) + poly.set_alpha(0.7) + ax.add_collection(poly,zs=zs,dir='y') + + ax.set_xlim(0,10) + ax.set_ylim(-1,4) + ax.set_zlim(0,1) + +def test_scatter2D(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + xs = [random.random() for i in range(20)] + ys = [random.random() for x in xs] + ax.scatter(xs, ys) + ax.scatter(xs, ys, dir='y', c='r') + ax.scatter(xs, ys, dir='x', c='g') + +def test_bar2D(): + f = plt.figure() + ax = axes3d.Axes3D(f) + + for c,z in zip(['r','g','b', 'y'],[30,20,10,0]): + xs = np.arange(20) + ys = [random.random() for x in xs] + ax.bar(xs, ys, z=z, dir='y', color=c, alpha=0.8) + +if __name__ == "__main__": + + test_scatter() + test_wire() + test_surface() + test_contour() + test_contourf() + test_plot() + test_polys() + test_scatter2D() +# test_bar2D() + + plt.show() Modified: trunk/matplotlib/examples/pylab_examples/finance_work2.py =================================================================== --- trunk/matplotlib/examples/pylab_examples/finance_work2.py 2009-04-13 03:02:41 UTC (rev 7040) +++ trunk/matplotlib/examples/pylab_examples/finance_work2.py 2009-04-14 14:29:31 UTC (rev 7041) @@ -216,10 +216,21 @@ +class MyLocator(mticker.MaxNLocator): + def __init__(self, *args, **kwargs): + mticker.MaxNLocator.__init__(self, *args, **kwargs) + + def __call__(self, *args, **kwargs): + return mticker.MaxNLocator.__call__(self, *args, **kwargs) + # at most 5 ticks, pruning the upper and lower so they don't overlap # with other ticks -ax2.yaxis.set_major_locator(mticker.MaxNLocator(5, prune='both')) -ax3.yaxis.set_major_locator(mticker.MaxNLocator(5, prune='both')) +#ax2.yaxis.set_major_locator(mticker.MaxNLocator(5, prune='both')) +#ax3.yaxis.set_major_locator(mticker.MaxNLocator(5, prune='both')) + +ax2.yaxis.set_major_locator(MyLocator(5, prune='both')) +ax3.yaxis.set_major_locator(MyLocator(5, prune='both')) + plt.show() Modified: trunk/matplotlib/examples/tests/backend_driver.py =================================================================== --- trunk/matplotlib/examples/tests/backend_driver.py 2009-04-13 03:02:41 UTC (rev 7040) +++ trunk/matplotlib/examples/tests/backend_driver.py 2009-04-14 14:29:31 UTC (rev 7041) @@ -240,6 +240,11 @@ ] +mplot3d_dir = os.path.join('..', 'mplot3d') +mplot3d_files = [ + 'demo.py', + ] + # dict from dir to files we know we don't want to test (eg examples # not using pyplot, examples requiring user input, animation examples, # examples that may only work in certain environs (usetex examples?), @@ -271,7 +276,8 @@ files = ( [os.path.join(api_dir, fname) for fname in api_files] + [os.path.join(pylab_dir, fname) for fname in pylab_files] + - [os.path.join(units_dir, fname) for fname in units_files] + [os.path.join(units_dir, fname) for fname in units_files] + + [os.path.join(mplot3d_dir, fname) for fname in mplot3d_files] ) # tests known to fail on a given backend Added: trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py =================================================================== --- trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py (rev 0) +++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py 2009-04-14 14:29:31 UTC (rev 7041) @@ -0,0 +1,269 @@ +#!/usr/bin/python +# art3d.py, original mplot3d version by John Porter +# Parts rewritten by Reinier Heeres <re...@he...> + +from matplotlib import lines, text, path as mpath +from matplotlib.collections import Collection, LineCollection, \ + PolyCollection, PatchCollection +from matplotlib.patches import Patch, Rectangle +from matplotlib.colors import Normalize +from matplotlib import transforms + +import types +import numpy as np +import proj3d + +class Text3D(text.Text): + + def __init__(self, x=0, y=0, z=0, text='', dir='z'): + text.Text.__init__(self, x, y, text) + self.set_3d_properties(z, dir) + + def set_3d_properties(self, z=0, dir='z'): + x, y = self.get_position() + self._position3d = juggle_axes(x, y, z, dir) + + def draw(self, renderer): + x, y, z = self._position3d + x, y, z = proj3d.proj_transform(x, y, z, renderer.M) + self.set_position(x, y) + text.Text.draw(self, renderer) + +def text_2d_to_3d(obj, z=0, dir='z'): + """Convert a Text to a Text3D object.""" + obj.__class__ = Text3D + obj.set_3d_properties(z, dir) + +class Line3D(lines.Line2D): + + def __init__(self, xs, ys, zs, *args, **kwargs): + lines.Line2D.__init__(self, [], [], *args, **kwargs) + self._verts3d = xs, ys, zs + + def set_3d_properties(self, zs=0, dir='z'): + xs = self.get_xdata() + ys = self.get_ydata() + try: + zs = float(zs) + zs = [zs for x in xs] + except: + pass + self._verts3d = juggle_axes(xs, ys, zs, dir) + + def draw(self, renderer): + xs3d, ys3d, zs3d = self._verts3d + xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) + self.set_data(xs, ys) + lines.Line2D.draw(self, renderer) + +def line_2d_to_3d(line, z=0, dir='z'): + line.__class__ = Line3D + line.set_3d_properties(z, dir) + +def path_to_3d_segment(path, z=0, dir='z'): + '''Convert a path to a 3d segment.''' + seg = [] + for (pathseg, code) in path.iter_segments(): + seg.append(pathseg) + seg3d = [juggle_axes(x, y, z, dir) for (x, y) in seg] + return seg3d + +def paths_to_3d_segments(paths, zs=0, dir='z'): + '''Convert paths from a collection object to 3d segments.''' + + try: + zs = float(zs) + zs = [zs for i in range(len(paths))] + except: + pass + + segments = [] + for path, z in zip(paths, zs): + segments.append(path_to_3d_segment(path, z, dir)) + return segments + +class Line3DCollection(LineCollection): + + def __init__(self, segments, *args, **kwargs): + LineCollection.__init__(self, segments, *args, **kwargs) + + def set_segments(self, segments): + self._segments3d = segments + LineCollection.set_segments(self, []) + + def draw(self, renderer): + xyslist = [ + proj3d.proj_trans_points(points, renderer.M) for points in + self._segments3d] + segments_2d = [zip(xs,ys) for (xs,ys,zs) in xyslist] + LineCollection.set_segments(self, segments_2d) + LineCollection.draw(self, renderer) + +def line_collection_2d_to_3d(col, z=0, dir='z'): + """Convert a LineCollection to a Line3DCollection object.""" + segments3d = paths_to_3d_segments(col.get_paths(), z, dir) + col.__class__ = Line3DCollection + col.set_segments(segments3d) + +class Patch3D(Patch): + + def __init__(self, *args, **kwargs): + zs = kwargs.pop('zs', []) + dir = kwargs.pop('dir', 'z') + Patch.__init__(self, *args, **kwargs) + self.set_3d_properties(zs, dir) + + def set_3d_properties(self, verts, z=0, dir='z'): + self._segment3d = [juggle_axes(x, y, z, dir) for (x, y) in verts] + self._facecolor3d = Patch.get_facecolor(self) + + def get_path(self): + return self._path2d + + def get_facecolor(self): + return self._facecolor2d + + def draw(self, renderer): + s = self._segment3d + xs, ys, zs = zip(*s) + vxs,vys,vzs,vis = proj3d.proj_transform_clip(xs,ys,zs, renderer.M) + self._path2d = mpath.Path(zip(vxs, vys)) + # FIXME: coloring + self._facecolor2d = self._facecolor3d + Patch.draw(self, renderer) + +def patch_2d_to_3d(patch, z=0, dir='z'): + """Convert a Patch to a Patch3D object.""" + verts = patch.get_verts() + patch.__class__ = Patch3D + patch.set_3d_properties(verts, z, dir) + +class Patch3DCollection(PatchCollection): + + def __init__(self, *args, **kwargs): + PatchCollection.__init__(self, *args, **kwargs) + + def set_3d_properties(self, zs, dir): + xs, ys = zip(*self.get_offsets()) + self._offsets3d = juggle_axes(xs, ys, zs, dir) + self._facecolor3d = self.get_facecolor() + self._edgecolor3d = self.get_edgecolor() + + def draw(self, renderer): + xs,ys,zs = self._offsets3d + vxs,vys,vzs,vis = proj3d.proj_transform_clip(xs,ys,zs, renderer.M) + #FIXME: mpl allows us no way to unset the collection alpha value + self._alpha = None + self.set_facecolors(zalpha(self._facecolor3d, vzs)) + self.set_edgecolors(zalpha(self._edgecolor3d, vzs)) + PatchCollection.set_offsets(self, zip(vxs, vys)) + PatchCollection.draw(self, renderer) + +def patch_collection_2d_to_3d(col, zs=0, dir='z'): + """Convert a PatchCollection to a Patch3DCollection object.""" + col.__class__ = Patch3DCollection + col.set_3d_properties(zs, dir) + +class Poly3DCollection(PolyCollection): + + def __init__(self, verts, *args, **kwargs): + PolyCollection.__init__(self, verts, *args, **kwargs) + self.set_3d_properties() + + def get_vector(self, segments3d): + """optimise points for projection""" + si = 0 + ei = 0 + segis = [] + points = [] + for p in segments3d: + points.extend(p) + ei = si+len(p) + segis.append((si,ei)) + si = ei + xs,ys,zs = zip(*points) + ones = np.ones(len(xs)) + self._vec = np.array([xs,ys,zs,ones]) + self._segis = segis + + def set_verts(self, verts, closed=True): + self.get_vector(verts) + # 2D verts will be updated at draw time + PolyCollection.set_verts(self, [], closed) + + def set_3d_properties(self): + self._zsort = 1 + self._facecolors3d = PolyCollection.get_facecolors(self) + self._edgecolors3d = self.get_edgecolors() + + def get_facecolors(self): + return self._facecolors2d + get_facecolor = get_facecolors + + def draw(self, renderer): + txs, tys, tzs, tis = proj3d.proj_transform_vec_clip(self._vec, renderer.M) + xyslist = [(txs[si:ei], tys[si:ei], tzs[si:ei], tis[si:ei]) \ + for si, ei in self._segis] + colors = self._facecolors3d + # + # if required sort by depth (furthest drawn first) + if self._zsort: + z_segments_2d = [(min(zs),max(tis),zip(xs,ys),c) for + (xs,ys,zs,tis),c in zip(xyslist,colors)] + z_segments_2d.sort() + z_segments_2d.reverse() + else: + raise ValueError, "whoops" + segments_2d = [s for z,i,s,c in z_segments_2d if i] + colors = [c for z,i,s,c in z_segments_2d if i] + PolyCollection.set_verts(self, segments_2d) + self._facecolors2d = colors + return Collection.draw(self, renderer) + +def poly_collection_2d_to_3d(col, zs=None, dir='z'): + """Convert a PolyCollection to a Poly3DCollection object.""" + segments_3d = paths_to_3d_segments(col.get_paths(), zs, dir) + col.__class__ = Poly3DCollection + col.set_verts(segments_3d) + col.set_3d_properties() + +def juggle_axes(xs,ys,zs, dir): + """ + Depending on the direction of the plot re-order the axis. + This is so that 2d plots can be plotted along any direction. + """ + if dir == 'x': return zs,xs,ys + elif dir == 'y': return xs,zs,ys + else: return xs,ys,zs + +def iscolor(c): + try: + return (len(c) == 4 or len(c) == 3) and hasattr(c[0], '__float__') + except (IndexError): + return False + +def get_colors(c, num): + """Stretch the color argument to provide the required number num""" + + if type(c)==type("string"): + c = colors.colorConverter.to_rgba(colors) + + if iscolor(c): + return [c] * num + if len(c) == num: + return c + elif iscolor(c): + return [c] * num + elif iscolor(c[0]): + return [c[0]] * num + else: + raise ValueError, 'unknown color format %s' % c + +def zalpha(colors, zs): + """Modify the alphas of the color list according to depth""" + colors = get_colors(colors,len(zs)) + norm = Normalize(min(zs),max(zs)) + sats = 1 - norm(zs)*0.7 + colors = [(c[0],c[1],c[2],c[3]*s) for c,s in zip(colors,sats)] + return colors + Added: trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py =================================================================== --- trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py (rev 0) +++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py 2009-04-14 14:29:31 UTC (rev 7041) @@ -0,0 +1,868 @@ +#!/usr/bin/python +# axes3d.py, original mplot3d version by John Porter +# Created: 23 Sep 2005 +# Parts fixed by Reinier Heeres <re...@he...> + +""" +3D projection glued onto 2D Axes. + +Axes3D +""" + +from matplotlib import pyplot as plt +import random + +from matplotlib.axes import Axes +from matplotlib import cbook +from matplotlib.transforms import Bbox +import numpy as np +from matplotlib.colors import Normalize, colorConverter + +import art3d +import proj3d +import axis3d + +def sensible_format_data(self, value): + """Used to generate more comprehensible numbers in status bar""" + if abs(value) > 1e4 or abs(value)<1e-3: + s = '%1.4e'% value + return self._formatSciNotation(s) + else: + return '%4.3f' % value + +def unit_bbox(): + box = Bbox(np.array([[0,0],[1,1]])) + return box + +class Axes3DI(Axes): + """Wrap an Axes object + + The x,y data coordinates, which are manipulated by set_xlim and + set_ylim are used as the target view coordinates by the 3D + transformations. These coordinates are mostly invisible to the + outside world. + + set_w_xlim, set_w_ylim and set_w_zlim manipulate the 3D world + coordinates which are scaled to represent the data and are stored + in the xy_dataLim, zz_datalim bboxes. + + The axes representing the x,y,z world dimensions are self.w_xaxis, + self.w_yaxis and self.w_zaxis. They can probably be controlled in + more or less the normal ways. + """ + def __init__(self, fig, rect=[0.0, 0.0, 1.0, 1.0], *args, **kwargs): + self.fig = fig + self.cids = [] + + azim = kwargs.pop('azim', -60) + elev = kwargs.pop('elev', 30) + + self.xy_viewLim = unit_bbox() + self.zz_viewLim = unit_bbox() + self.xy_dataLim = unit_bbox() + self.zz_dataLim = unit_bbox() + # inihibit autoscale_view until the axises are defined + # they can't be defined until Axes.__init__ has been called + self.view_init(elev, azim) + self._ready = 0 + Axes.__init__(self, self.fig, rect, + frameon=True, + xticks=[], yticks=[], *args, **kwargs) + + self.M = None + + self._ready = 1 + self.mouse_init() + self.create_axes() + self.set_top_view() + + self.axesPatch.set_linewidth(0) + self.fig.add_axes(self) + + def set_top_view(self): + # this happens to be the right view for the viewing coordinates + # moved up and to the left slightly to fit labels and axes + xdwl = (0.95/self.dist) + xdw = (0.9/self.dist) + ydwl = (0.95/self.dist) + ydw = (0.9/self.dist) + # + self.set_xlim(-xdwl,xdw) + self.set_ylim(-ydwl,ydw) + + def really_set_xlim(self, vmin, vmax): + self.viewLim.intervalx().set_bounds(vmin, vmax) + + def really_set_ylim(self, vmin, vmax): + self.viewLim.intervaly().set_bounds(vmin, vmax) + + def vlim_argument(self, get_lim, *args): + if not args: + vmin,vmax = get_lim() + elif len(args)==2: + vmin,vmax = args + elif len(args)==1: + vmin,vmax = args[0] + return vmin,vmax + + def nset_xlim(self, *args): + raise + vmin,vmax = self.vlim_argument(self.get_xlim) + print 'xlim', vmin,vmax + + def nset_ylim(self, *args): + vmin,vmax = self.vlim_argument(self.get_ylim) + print 'ylim', vmin,vmax + + def create_axes(self): + self.w_xaxis = axis3d.XAxis('x',self.xy_viewLim.intervalx, + self.xy_dataLim.intervalx, self) + self.w_yaxis = axis3d.YAxis('y',self.xy_viewLim.intervaly, + self.xy_dataLim.intervaly, self) + self.w_zaxis = axis3d.ZAxis('z',self.zz_viewLim.intervalx, + self.zz_dataLim.intervalx, self) + + def unit_cube(self,vals=None): + minx,maxx,miny,maxy,minz,maxz = vals or self.get_w_lims() + xs,ys,zs = ([minx,maxx,maxx,minx,minx,maxx,maxx,minx], + [miny,miny,maxy,maxy,miny,miny,maxy,maxy], + [minz,minz,minz,minz,maxz,maxz,maxz,maxz]) + return zip(xs,ys,zs) + + def tunit_cube(self,vals=None,M=None): + if M is None: + M = self.M + xyzs = self.unit_cube(vals) + tcube = proj3d.proj_points(xyzs,M) + return tcube + + def tunit_edges(self, vals=None,M=None): + tc = self.tunit_cube(vals,M) + edges = [(tc[0],tc[1]), + (tc[1],tc[2]), + (tc[2],tc[3]), + (tc[3],tc[0]), + + (tc[0],tc[4]), + (tc[1],tc[5]), + (tc[2],tc[6]), + (tc[3],tc[7]), + + (tc[4],tc[5]), + (tc[5],tc[6]), + (tc[6],tc[7]), + (tc[7],tc[4])] + return edges + + def draw(self, renderer): + # draw the background patch + self.axesPatch.draw(renderer) + self._frameon = False + + # add the projection matrix to the renderer + self.M = self.get_proj() + renderer.M = self.M + renderer.vvec = self.vvec + renderer.eye = self.eye + renderer.get_axis_position = self.get_axis_position + + self.w_xaxis.draw(renderer) + self.w_yaxis.draw(renderer) + self.w_zaxis.draw(renderer) + Axes.draw(self, renderer) + + def get_axis_position(self): + vals = self.get_w_lims() + tc = self.tunit_cube(vals,self.M) + xhigh = tc[1][2]>tc[2][2] + yhigh = tc[3][2]>tc[2][2] + zhigh = tc[0][2]>tc[2][2] + return xhigh,yhigh,zhigh + + def update_datalim(self, xys): + pass + + def update_datalim_numerix(self, x, y): + pass + + def auto_scale_xyz(self, X,Y,Z=None,had_data=None): + x,y,z = map(np.asarray, (X,Y,Z)) + try: + x,y = x.flatten(),y.flatten() + if Z is not None: + z = z.flatten() + except AttributeError: + raise + + # This updates the bounding boxes as to keep a record as + # to what the minimum sized rectangular volume holds the + # data. + self.xy_dataLim.update_from_data_xy(np.array([x, y]).T, not had_data) + if z is not None: + self.zz_dataLim.update_from_data_xy(np.array([z, z]).T, not had_data) + + # Let autoscale_view figure out how to use this data. + self.autoscale_view() + + def autoscale_view(self, scalex=True, scaley=True, scalez=True): + # This method looks at the rectanglular volume (see above) + # of data and decides how to scale the view portal to fit it. + + self.set_top_view() + if not self._ready: return + + if not self.get_autoscale_on(): return + if scalex: + self.set_w_xlim(self.xy_dataLim.intervalx) + if scaley: + self.set_w_ylim(self.xy_dataLim.intervaly) + if scalez: + self.set_w_zlim(self.zz_dataLim.intervalx) + + def get_w_lims(self): + '''Get 3d world limits.''' + minpy,maxx = self.get_w_xlim() + miny,maxy = self.get_w_ylim() + minz,maxz = self.get_w_zlim() + return minpy,maxx,miny,maxy,minz,maxz + + def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs): + if xmax is None and cbook.iterable(xmin): + xmin, xmax = xmin + return (xmin, xmax) + + def set_w_zlim(self, *args, **kwargs): + '''Set 3d z limits.''' + lims = self._determine_lims(*args, **kwargs) + self.zz_viewLim.intervalx = lims + return lims + + def set_w_xlim(self, *args, **kwargs): + '''Set 3d x limits.''' + lims = self._determine_lims(*args, **kwargs) + self.xy_viewLim.intervalx = lims + return lims + + def set_w_ylim(self, *args, **kwargs): + '''Set 3d y limits.''' + lims = self._determine_lims(*args, **kwargs) + self.xy_viewLim.intervaly = lims + return lims + + def get_w_zlim(self): + return self.zz_viewLim.intervalx + + def get_w_xlim(self): + return self.xy_viewLim.intervalx + + def get_w_ylim(self): + return self.xy_viewLim.intervaly + + def pany(self, numsteps): + print 'numsteps', numsteps + + def panpy(self, numsteps): + print 'numsteps', numsteps + + def view_init(self, elev, azim): + self.dist = 10 + self.elev = elev + self.azim = azim + + def get_proj(self): + """Create the projection matrix from the current viewing + position. + + elev stores the elevation angle in the z plane + azim stores the azimuth angle in the x,y plane + + dist is the distance of the eye viewing point from the object + point. + + """ + relev,razim = np.pi * self.elev/180, np.pi * self.azim/180 + + xmin,xmax = self.get_w_xlim() + ymin,ymax = self.get_w_ylim() + zmin,zmax = self.get_w_zlim() + + # transform to uniform world coordinates 0-1.0,0-1.0,0-1.0 + worldM = proj3d.world_transformation(xmin,xmax, + ymin,ymax, + zmin,zmax) + + # look into the middle of the new coordinates + R = np.array([0.5,0.5,0.5]) + # + xp = R[0] + np.cos(razim)*np.cos(relev)*self.dist + yp = R[1] + np.sin(razim)*np.cos(relev)*self.dist + zp = R[2] + np.sin(relev)*self.dist + E = np.array((xp, yp, zp)) + # + self.eye = E + self.vvec = R - E + self.vvec = self.vvec / proj3d.mod(self.vvec) + + if abs(relev) > np.pi/2: + # upside down + V = np.array((0,0,-1)) + else: + V = np.array((0,0,1)) + zfront,zback = -self.dist,self.dist + + viewM = proj3d.view_transformation(E,R,V) + perspM = proj3d.persp_transformation(zfront,zback) + M0 = np.dot(viewM,worldM) + M = np.dot(perspM,M0) + return M + + def mouse_init(self): + self.button_pressed = None + canv = self.figure.canvas + if canv != None: + c1 = canv.mpl_connect('motion_notify_event', self.on_move) + c2 = canv.mpl_connect('button_press_event', self.button_press) + c3 = canv.mpl_connect('button_release_event', self.button_release) + self.cids = [c1, c2, c3] + + def cla(self): + # Disconnect the various events we set. + for cid in self.cids: + self.figure.canvas.mpl_disconnect(cid) + self.cids = [] + Axes.cla(self) + + def button_press(self, event): + self.button_pressed = event.button + self.sx,self.sy = event.xdata,event.ydata + + def button_release(self, event): + self.button_pressed = None + + def format_xdata(self, x): + """ + Return x string formatted. This function will use the attribute + self.fmt_xdata if it is callable, else will fall back on the xaxis + major formatter + """ + try: return self.fmt_xdata(x) + except TypeError: + fmt = self.w_xaxis.get_major_formatter() + return sensible_format_data(fmt,x) + + def format_ydata(self, y): + """ + Return y string formatted. This function will use the attribute + self.fmt_ydata if it is callable, else will fall back on the yaxis + major formatter + """ + try: return self.fmt_ydata(y) + except TypeError: + fmt = self.w_yaxis.get_major_formatter() + return sensible_format_data(fmt,y) + + def format_zdata(self, z): + """ + Return y string formatted. This function will use the attribute + self.fmt_ydata if it is callable, else will fall back on the yaxis + major formatter + """ + try: return self.fmt_zdata(z) + except (AttributeError,TypeError): + fmt = self.w_zaxis.get_major_formatter() + return sensible_format_data(fmt,z) + + def format_coord(self, xd, yd): + """Given the 2D view coordinates attempt to guess a 3D coordinate + + Looks for the nearest edge to the point and then assumes that the point is + at the same z location as the nearest point on the edge. + """ + + if self.M is None: + return '' + + if self.button_pressed == 1: + return 'azimuth=%d deg, elevation=%d deg ' % (self.azim, self.elev) + # ignore xd and yd and display angles instead + + p = (xd,yd) + edges = self.tunit_edges() + #lines = [proj3d.line2d(p0,p1) for (p0,p1) in edges] + ldists = [(proj3d.line2d_seg_dist(p0,p1,p),i) for i,(p0,p1) in enumerate(edges)] + ldists.sort() + # nearest edge + edgei = ldists[0][1] + # + p0,p1 = edges[edgei] + + # scale the z value to match + x0,y0,z0 = p0 + x1,y1,z1 = p1 + d0 = np.hypot(x0-xd,y0-yd) + d1 = np.hypot(x1-xd,y1-yd) + dt = d0+d1 + z = d1/dt * z0 + d0/dt * z1 + #print 'mid', edgei, d0, d1, z0, z1, z + + x,y,z = proj3d.inv_transform(xd,yd,z,self.M) + + xs = self.format_xdata(x) + ys = self.format_ydata(y) + zs = self.format_ydata(z) + return 'x=%s, y=%s, z=%s'%(xs,ys,zs) + + def on_move(self, event): + """Mouse moving + + button-1 rotates + button-3 zooms + """ + if not self.button_pressed: + return + + if self.M is None: + return + # this shouldn't be called before the graph has been drawn for the first time! + x, y = event.xdata, event.ydata + + # In case the mouse is out of bounds. + if x == None: + + return + dx,dy = x-self.sx,y-self.sy + x0,x1 = self.get_xlim() + y0,y1 = self.get_ylim() + w = (x1-x0) + h = (y1-y0) + self.sx,self.sy = x,y + + if self.button_pressed == 1: + # rotate viewing point + # get the x and y pixel coords + if dx == 0 and dy == 0: return + # + self.elev = axis3d.norm_angle(self.elev - (dy/h)*180) + self.azim = axis3d.norm_angle(self.azim - (dx/w)*180) + self.get_proj() + self.figure.canvas.draw() + elif self.button_pressed == 2: + # pan view + # project xv,yv,zv -> xw,yw,zw + # pan + # + pass + elif self.button_pressed == 3: + # zoom view + # hmmm..this needs some help from clipping.... + minpy,maxx,miny,maxy,minz,maxz = self.get_w_lims() + df = 1-((h - dy)/h) + dx = (maxx-minpy)*df + dy = (maxy-miny)*df + dz = (maxz-minz)*df + self.set_w_xlim(minpy-dx,maxx+dx) + self.set_w_ylim(miny-dy,maxy+dy) + self.set_w_zlim(minz-dz,maxz+dz) + self.get_proj() + self.figure.canvas.draw() + + def set_xlabel(self, xlabel, fontdict=None, **kwargs): + #par = cbook.popd(kwargs, 'par',None) + #label.set_par(par) + # + label = self.w_xaxis.get_label() + label.set_text(xlabel) + if fontdict is not None: label.update(fontdict) + label.update(kwargs) + return label + + def set_ylabel(self, ylabel, fontdict=None, **kwargs): + label = self.w_yaxis.get_label() + label.set_text(ylabel) + if fontdict is not None: label.update(fontdict) + label.update(kwargs) + return label + + def set_zlabel(self, zlabel, fontdict=None, **kwargs): + label = self.w_zaxis.get_label() + label.set_text(zlabel) + if fontdict is not None: label.update(fontdict) + label.update(kwargs) + return label + + def plot(self, *args, **kwargs): + had_data = self.has_data() + + zval = kwargs.pop( 'z', 0) + zdir = kwargs.pop('dir', 'z') + lines = Axes.plot(self, *args, **kwargs) + for line in lines: + art3d.line_2d_to_3d(line, z=zval, dir=zdir) + + xs = lines[0].get_xdata() + ys = lines[0].get_ydata() + zs = [zval for x in xs] + xs,ys,zs = art3d.juggle_axes(xs,ys,zs,zdir) + self.auto_scale_xyz(xs,ys,zs, had_data) + return lines + + def plot3D(self, xs, ys, zs, *args, **kwargs): + had_data = self.has_data() + lines = Axes.plot(self, xs,ys, *args, **kwargs) + if len(lines)==1: + line = lines[0] + art3d.line_2d_to_3d(line, zs) + self.auto_scale_xyz(xs,ys,zs, had_data) + return lines + + plot3d=plot3D + + def plot_surface(self, X, Y, Z, *args, **kwargs): + had_data = self.has_data() + + rows, cols = Z.shape + tX,tY,tZ = np.transpose(X), np.transpose(Y), np.transpose(Z) + rstride = kwargs.pop('rstride', 10) + cstride = kwargs.pop('cstride', 10) + # + polys = [] + boxes = [] + for rs in np.arange(0,rows-1,rstride): + for cs in np.arange(0,cols-1,cstride): + ps = [] + corners = [] + for a,ta in [(X,tX),(Y,tY),(Z,tZ)]: + ztop = a[rs][cs:min(cols,cs+cstride+1)] + zleft = ta[min(cols-1,cs+cstride)][rs:min(rows,rs+rstride+1)] + zbase = a[min(rows-1,rs+rstride)][cs:min(cols,cs+cstride+1):] + zbase = zbase[::-1] + zright = ta[cs][rs:min(rows,rs+rstride+1):] + zright = zright[::-1] + corners.append([ztop[0],ztop[-1],zbase[0],zbase[-1]]) + z = np.concatenate((ztop,zleft,zbase,zright)) + ps.append(z) + boxes.append(map(np.array,zip(*corners))) + polys.append(zip(*ps)) + # + lines = [] + shade = [] + for box in boxes: + n = proj3d.cross(box[0]-box[1], + box[0]-box[2]) + n = n/proj3d.mod(n)*5 + shade.append(np.dot(n,[-1,-1,0.5])) + lines.append((box[0],n+box[0])) + # + color = np.array([0,0,1,1]) + norm = Normalize(min(shade),max(shade)) + colors = [color * (0.5+norm(v)*0.5) for v in shade] + for c in colors: c[3] = 1 + polyc = art3d.Poly3DCollection(polys, facecolors=colors, *args, **kwargs) + polyc._zsort = 1 + self.add_collection(polyc) + # + self.auto_scale_xyz(X,Y,Z, had_data) + return polyc + + def plot_wireframe(self, X, Y, Z, *args, **kwargs): + rstride = kwargs.pop("rstride", 1) + cstride = kwargs.pop("cstride", 1) + + had_data = self.has_data() + rows,cols = Z.shape + + tX,tY,tZ = np.transpose(X), np.transpose(Y), np.transpose(Z) + + rii = [i for i in range(0,rows,rstride)]+[rows-1] + cii = [i for i in range(0,cols,cstride)]+[cols-1] + xlines = [X[i] for i in rii] + ylines = [Y[i] for i in rii] + zlines = [Z[i] for i in rii] + # + txlines = [tX[i] for i in cii] + tylines = [tY[i] for i in cii] + tzlines = [tZ[i] for i in cii] + # + lines = [zip(xl,yl,zl) for xl,yl,zl in zip(xlines,ylines,zlines)] + lines += [zip(xl,yl,zl) for xl,yl,zl in zip(txlines,tylines,tzlines)] + linec = self.add_lines(lines, *args, **kwargs) + + self.auto_scale_xyz(X,Y,Z, had_data) + return linec + + def contour3D(self, X, Y, Z, *args, **kwargs): + had_data = self.has_data() + cset = self.contour(X, Y, Z, *args, **kwargs) + for z, linec in zip(cset.levels, cset.collections): + zl = [] + art3d.line_collection_2d_to_3d(linec, z) + self.auto_scale_xyz(X,Y,Z, had_data) + return cset + + def clabel(self, *args, **kwargs): +# r = Axes.clabel(self, *args, **kwargs) + return None + + def contourf3D(self, X, Y, Z, *args, **kwargs): + had_data = self.has_data() + + cset = self.contourf(X, Y, Z, *args, **kwargs) + levels = cset.levels + colls = cset.collections + + for z1,z2,linec in zip(levels,levels[1:],colls): + zs = [z1] * (len(linec.get_paths()[0])/2) + zs += [z2] * (len(linec.get_paths()[0])/2) + art3d.poly_collection_2d_to_3d(linec, zs) + self.auto_scale_xyz(X,Y,Z, had_data) + return cset + + def scatter3D(self, xs, ys, zs, *args, **kwargs): + had_data = self.has_data() + patches = Axes.scatter(self,xs,ys,*args,**kwargs) + patches = art3d.patch_collection_2d_to_3d(patches, zs) + self.auto_scale_xyz(xs,ys,zs, had_data) + return patches + scatter3d = scatter3D + + def add_lines(self, lines, *args, **kwargs): + linec = art3d.Line3DCollection(lines, *args, **kwargs) + self.add_collection(linec) + return linec + """ + def text3D(self, x,y,z,s, *args, **kwargs): + text = Axes.text(self,x,y,s,*args,**kwargs) + art3d.wrap_text(text,z) + return text + """ + def ahvline(self, x,y): + pass + + def ahvxplane(self, x): + pass + + def ahvyplane(self, y): + pass + +class Scaler: + def __init__(self, points): + self.inpoints = points + self.drawpoints = None + + def update(self, lims): + for x,y,z in self.points: + pass + +class Axes3D: + """ + Wrapper for Axes3DI + + Provides set_xlim, set_ylim etc. + + 2D functions can be caught here and mapped + to their 3D approximations. + + This should probably be the case for plot etc... + """ + def __init__(self, fig, *args, **kwargs): + self.__dict__['wrapped'] = Axes3DI(fig, *args, **kwargs) + + def set_xlim(self, *args, **kwargs): + self.wrapped.set_w_xlim(*args, **kwargs) + + def set_ylim(self, *args, **kwargs): + self.wrapped.set_w_ylim(*args, **kwargs) + + def set_zlim(self, *args, **kwargs): + self.wrapped.set_w_zlim(*args, **kwargs) + + def __getattr__(self, k): + return getattr(self.wrapped,k) + + def __setattr__(self, k,v): + return setattr(self.wrapped,k,v) + + def add_collection(self, polys, zs=None, dir='z'): + art3d.poly_collection_2d_to_3d(polys, zs=zs, dir=dir) + self.add_3DCollection(polys) + + def add_3DCollection(self, patches): + self.wrapped.add_collection(patches) + + def text(self, x,y, text, *args,**kwargs): + self.wrapped.text3D(x,y,0,text,*args,**kwargs) + + def scatter(self, xs,ys,zs=None,dir='z',*args,**kwargs): + patches = self.wrapped.scatter(xs,ys,*args,**kwargs) + if zs is None: + zs = [0]*len(xs) + art3d.patch_collection_2d_to_3d(patches, zs=zs, dir=dir) + return patches + + def bar(self, left, height, z=0, dir='z', *args, **kwargs): + had_data = self.has_data() + patches = self.wrapped.bar(left, height, *args, **kwargs) + verts = [] + for p in patches: + vs = p.get_verts() + zs = [z]*len(vs) + verts += vs.tolist() + art3d.patch_2d_to_3d(p, zs[0], dir) + if 'alpha' in kwargs: + p.set_alpha(kwargs['alpha']) + xs,ys = zip(*verts) + zs = [z]*len(xs) + xs,ys,zs=art3d.juggle_axes(xs,ys,zs,dir) + self.wrapped.auto_scale_xyz(xs,ys,zs, had_data) + return patches + +def test_scatter(): + f = plt.figure() + ax = Axes3D(f) + + n = 100 + for c,zl,zh in [('r',-50,-25),('b',-30,-5)]: + xs,ys,zs = zip(* + [(random.randrange(23,32), + random.randrange(100), + random.randrange(zl,zh) + ) for i in range(n)]) + ax.scatter3D(xs,ys,zs, c=c) + + ax.set_xlabel('------------ X Label --------------------') + ax.set_ylabel('------------ Y Label --------------------') + ax.set_zlabel('------------ Z Label --------------------') + +def get_test_data(delta=0.05): + from matplotlib.mlab import bivariate_normal + x = y = np.arange(-3.0, 3.0, delta) + X, Y = np.meshgrid(x,y) + + Z1 = bivariate_normal(X, Y, 1.0, 1.0, 0.0, 0.0) + Z2 = bivariate_normal(X, Y, 1.5, 0.5, 1, 1) + Z = Z2-Z1 + + X = X * 10 + Y = Y * 10 + Z = Z * 500 + return X,Y,Z + +def test_wire(): + f = plt.figure() + ax = Axes3D(f) + + X,Y,Z = get_test_data(0.05) + ax.plot_wireframe(X,Y,Z, rstride=10,cstride=10) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + +def test_surface(): + f = plt.figure() + ax = Axes3D(f) + + X,Y,Z = get_test_data(0.05) + ax.plot_surface(X,Y,Z, rstride=10,cstride=10) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + +def test_contour(): + f = plt.figure() + ax = Axes3D(f) + + X,Y,Z = get_test_data(0.05) + cset = ax.contour3D(X,Y,Z) + ax.clabel(cset, fontsize=9, inline=1) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + +def test_contourf(): + f = plt.figure() + ax = Axes3D(f) + + X,Y,Z = get_test_data(0.05) + cset = ax.contourf3D(X,Y,Z) + ax.clabel(cset, fontsize=9, inline=1) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + +def test_plot(): + f = plt.figure() + ax = Axes3D(f) + + xs = np.arange(0,4*np.pi+0.1,0.1) + ys = np.sin(xs) + ax.plot(xs,ys, label='zl') + ax.plot(xs,ys+max(xs),label='zh') + ax.plot(xs,ys,dir='x', label='xl') + ax.plot(xs,ys,dir='x', z=max(xs),label='xh') + ax.plot(xs,ys,dir='y', label='yl') + ax.plot(xs,ys,dir='y', z=max(xs), label='yh') + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.legend() + +def test_polys(): + f = plt.figure() + ax = Axes3D(f) + + cc = lambda arg: colorConverter.to_rgba(arg, alpha=0.6) + + xs = np.arange(0,10,0.4) + verts = [] + zs = [0.0,1.0,2.0,3.0] + for z in zs: + ys = [random.random() for x in xs] + ys[0],ys[-1] = 0,0 + verts.append(zip(xs,ys)) + + from matplotlib.collections import PolyCollection + poly = PolyCollection(verts, facecolors = [cc('r'),cc('g'),cc('b'), + cc('y')]) + poly.set_alpha(0.7) + ax.add_collection(poly,zs=zs,dir='y') + + ax.set_xlim(0,10) + ax.set_ylim(-1,4) + ax.set_zlim(0,1) + +def test_scatter2D(): + f = plt.figure() + ax = Axes3D(f) + + xs = [random.random() for i in range(20)] + ys = [random.random() for x in xs] + ax.scatter(xs, ys) + ax.scatter(xs, ys, dir='y', c='r') + ax.scatter(xs, ys, dir='x', c='g') + +def test_bar2D(): + f = plt.figure() + ax = Axes3D(f) + + for c,z in zip(['r','g','b', 'y'],[30,20,10,0]): + xs = np.arange(20) + ys = [random.random() for x in xs] + ax.bar(xs, ys, z=z, dir='y', color=c, alpha=0.8) + +if __name__ == "__main__": + import pylab + import axis3d; reload(axis3d); + import art3d; reload(art3d); + import proj3d; reload(proj3d); + + test_scatter() + test_wire() + test_surface() + test_contour() + test_contourf() + test_plot() + test_polys() + test_scatter2D() +# test_bar2D() + + pylab.show() Added: trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py =================================================================== --- trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py (rev 0) +++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py 2009-04-14 14:29:31 UTC (rev 7041) @@ -0,0 +1,310 @@ +#!/usr/bin/python +# axis3d.py, original mplot3d version by John Porter +# Created: 23 Sep 2005 +# Parts rewritten by Reinier Heeres <re...@he...> + +import math +import copy + +from matplotlib import lines +from matplotlib import axis +from matplotlib import patches +from matplotlib import text + +import art3d +import proj3d + +import numpy as np + +def norm_angle(a): + """Return angle between -180 and +180""" + a = (a+360)%360 + if a > 180: a = a-360 + return a + +def norm_text_angle(a): + """Return angle between -90 and +90""" + a = (a + 180) % 180 + if a > 90: + a = a - 180 + return a + +def get_flip_min_max(coord, index, mins, maxs): + if coord[index] == mins[index]: + return maxs[index] + else: + return mins[index] + +def move_from_center(coord, centers, deltas, axmask=(True, True, True)): + '''Return a coordinate that is moved by "deltas" away from the center.''' + ret = copy.copy(coord) + for i in range(3): + if not axmask[i]: + continue + if coord[i] < centers[i]: + coord[i] -= deltas[i] + else: + coord[i] += deltas[i] + return coord + +def tick_update_position(tick, tickxs, tickys, labelpos): + '''Update tick line and label position and style.''' + + for (label, on) in ((tick.label1, tick.label1On), \ + (tick.label2, tick.label2On)): + if on: + label.set_position(labelpos) + + tick.tick1On, tick.tick2On = True, False + tick.tick1line.set_linestyle('-') + tick.tick1line.set_marker('') + tick.tick1line.set_data(tickxs, tickys) + +class Axis(axis.XAxis): + + # These points from the unit cube make up the x, y and z-planes + _PLANES = ( + (0, 3, 7, 4), (1, 2, 6, 5), # yz planes + (0, 1, 5, 4), (3, 2, 6, 7), # xz planes + (0, 1, 2, 3), (4, 5, 6, 7), # xy planes + ) + + # Some properties for the axes + _AXINFO = { + 'x': {'i': 0, 'tickdir': 1, + 'color': (0.95, 0.95, 0.95, 0.5)}, + 'y': {'i': 1, 'tickdir': 0, + 'color': (0.90, 0.90, 0.90, 0.5)}, + 'z': {'i': 2, 'tickdir': 0, + 'color': (0.925, 0.925, 0.925, 0.5)}, + } + + def __init__(self, adir, v_intervalx, d_intervalx, axes, *args, **kwargs): + # adir identifies which axes this is + self.adir = adir + # data and viewing intervals for this direction + self.d_interval = d_intervalx + self.v_interval = v_intervalx + # + axis.XAxis.__init__(self, axes, *args, **kwargs) + self.line = lines.Line2D(xdata=(0,0),ydata=(0,0), + linewidth=0.75, + color=(0,0,0,0), + antialiased=True, + ) + + # Store dummy data in Polygon object + self.has_pane = True + self.pane = patches.Polygon(np.array([[0,0],[0,1],[1,0],[0,0]]), + alpha=0.8, + facecolor=(1,1,1,0), + edgecolor=(1,1,1,0)) + + self.axes._set_artist_props(self.line) + self.axes._set_artist_props(self.pane) + self.gridlines = art3d.Line3DCollection([], ) + self.axes._set_artist_props(self.gridlines) + self.axes._set_artist_props(self.label) + self.label._transform = self.axes.transData + self.set_rotate_label(kwargs.get('rotate_label', None)) + + def get_tick_positions(self): + majorTicks = self.get_major_ticks() + majorLocs = self.major.locator() + self.major.formatter.set_locs(majorLocs) + majorLabels = [self.major.formatter(val, i) for i, val in enumerate(majorLocs)] + return majorLabels,majorLocs + + def get_major_ticks(self): + ticks = axis.XAxis.get_major_ticks(self) + for t in ticks: + def update_coords(renderer,self=t.label1): + return text_update_coords(self, renderer) + # Text overrides setattr so need this to force new method + t.tick1line.set_transform(self.axes.transData) + t.tick2line.set_transform(self.axes.transData) + t.gridline.set_transform(self.axes.transData) + t.label1.set_transform(self.axes.transData) + t.label2.set_transform(self.axes.transData) + return ticks + + def set_pane(self, xys, color): + if self.has_pane: + xys = np.asarray(xys) + xys = xys[:,:2] + self.pane.xy = xys + self.pane.set_edgecolor(color) + self.pane.set_facecolor(color) + self.pane.set_alpha(color[-1]) + + def set_rotate_label(self, val): + ''' + Whether to rotate the axis label: True, False or None. + If set to None the label will be rotated if longer than 4 chars. + ''' + self._rotate_label = val + + def get_rotate_label(self, text): + if self._rotate_label is not None: + return self._rotate_label + else: + return len(text) > 4 + + def draw(self, renderer): + self.label._transform = self.axes.transData + renderer.open_group('axis3d') + + # code from XAxis + majorTicks = self.get_major_ticks() + majorLocs = self.major.locator() + + # filter locations here so that no extra grid lines are drawn + interval = self.get_view_interval() + majorLocs = [loc for loc in majorLocs if interval[0] < loc < interval[1]] + self.major.formatter.set_locs(majorLocs) + majorLabels = [self.major.formatter(val, i) + for i, val in enumerate(majorLocs)] + + # Determine bounds + minx,maxx,miny,maxy,minz,maxz = self.axes.get_w_lims() + mins = (minx, miny, minz) + maxs = (maxx, maxy, maxz) + centers = [(maxv + minv) / 2 for minv, maxv in zip(mins, maxs)] + deltas = [(maxv - minv) / 12 for minv, maxv in zip(mins, maxs)] + mins = [minv - delta / 4 for minv, delta in zip(mins, deltas)] + maxs = [maxv + delta / 4 for maxv, delta in zip(maxs, deltas)] + + # Determine which planes should be visible by the avg z value + vals = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2] + tc = self.axes.tunit_cube(vals,renderer.M) + avgz = [tc[p1][2] + tc[p2][2] + tc[p3][2] + tc[p4][2] for \ + p1, p2, p3, p4 in self._PLANES] + highs = [avgz[2*i] < avgz[2*i+1] for i in range(3)] + + # Draw plane + info = self._AXINFO[self.adir] + index = info['i'] + if not highs[index]: + plane = self._PLANES[2 * index] + else: + plane = self._PLANES[2 * index + 1] + xys = [tc[p] for p in plane] + self.set_pane(xys, info['color']) + self.pane.draw(renderer) + + # Determine grid lines + minmax = [] + for i, val in enumerate(highs): + if val: + minmax.append(maxs[i]) + else: + minmax.append(mins[i]) + + # Draw main axis line + juggled = art3d.juggle_axes(0, 2, 1, self.adir) + edgep1 = copy.copy(minmax) + edgep1[juggled[0]] = get_flip_min_max(edgep1, juggled[0], mins, maxs) + edgep2 = copy.copy(edgep1) + edgep2[juggled[1]] = get_flip_min_max(edgep2, juggled[1], mins, maxs) + pep = proj3d.proj_trans_points([edgep1, edgep2], renderer.M) + self.line.set_data((pep[0][0], pep[0][1]), (pep[1][0], pep[1][1])) + self.line.draw(renderer) + + # Grid points where the planes meet + xyz0 = [] + for val in majorLocs: + coord = copy.copy(minmax) + coord[index] = val + xyz0.append(coord) + + # Draw labels + dy = pep[1][1] - pep[1][0] + dx = pep[0][1] - pep[0][0] + lxyz = [(v1 + v2) / 2 for v1, v2 in zip(edgep1, edgep2)] + labeldeltas = [1.3 * x for x in deltas] + lxyz = move_from_center(lxyz, centers, labeldeltas) + tlx,tly,tlz = proj3d.proj_transform(lxyz[0], lxyz[1], lxyz[2], renderer.M) + self.label.set_position((tlx, tly)) + if self.get_rotate_label(self.label.get_text()): + angle = norm_text_angle(math.degrees(math.atan2(dy, dx))) + self.label.set_rotation(angle) + self.label.set_va('center') + self.label.draw(renderer) + + # Grid points at end of one plane + xyz1 = copy.deepcopy(xyz0) + newindex = (index + 1) % 3 + newval = get_flip_min_max(xyz1[0], newindex, mins, maxs) + for i in range(len(majorLocs)): + xyz1[i][newindex] = newval + + # Grid points at end of the other plane + xyz2 = copy.deepcopy(xyz0) + newindex = (index + 2) % 3 + newval = get_flip_min_max(xyz2[0], newindex, mins, maxs) + for i in range(len(majorLocs)): + xyz2[i][newindex] = newval + + lines = zip(xyz1, xyz0, xyz2) + self.gridlines.set_segments(lines) + self.gridlines.set_color([(0.9,0.9,0.9,1)] * len(lines)) + self.gridlines.draw(renderer) + + # Draw ticks + tickdir = info['tickdir'] + tickdelta = deltas[tickdir] + if highs[tickdir]: + ticksign = 1 + else: + ticksign = -1 + + for tick, loc, label in zip(majorTicks, majorLocs, majorLabels): + if tick is None: + continue + + # Get tick line positions + pos = copy.copy(edgep1) + pos[index] = loc + pos[tickdir] = edgep1[tickdir] + 0.1 * ticksign * tickdelta + x1, y1, z1 = proj3d.proj_transform(pos[0], pos[1], pos[2], renderer.M) + pos[tickdir] = edgep1[tickdir] - 0.2 * ticksign * tickdelta + x2, y2, z2 = proj3d.proj_transform(pos[0], pos[1], pos[2], renderer.M) + + # Get position of label + labeldeltas = [0.6 * x for x in deltas] + axmask = [True, True, True] + axmask[index] = False + pos[tickdir] = edgep1[tickdir] + pos = move_from_center(pos, centers, labeldeltas, axmask) + lx, ly, lz = proj3d.proj_transform(pos[0], pos[1], pos[2], renderer.M) + + tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly)) + tick.set_label1(label) + tick.set_label2(label) + tick.draw(renderer) + + renderer.close_group('axis3d') + + def get_view_interval(self): + """return the Interval instance for this axis view limits""" + return self.v_interval + +# Each type of axis should be looking in a different place for its +# current data limits so we do this with classes. I think there is +# a lot more that I can and should move down into these classes also. + +class XAxis(Axis): + def get_data_interval(self): + 'return the Interval instance for this axis data limits' + return self.axes.xy_dataLim.intervalx + + +class YAxis(Axis): + def get_data_interval(self): + 'return the Interval instance for this axis data limits' + return self.axes.xy_dataLim.intervaly + +class ZAxis(Axis): + def get_data_interval(self): + 'return the Interval instance for this axis data limits' + return self.axes.zz_dataLim.intervalx Added: trunk/matplotlib/lib/mpl_toolkits/mplot3d/proj3d.py =================================================================== --- trunk/matplotlib/lib/mpl_toolkits/mplot3d/proj3d.py (rev 0) +++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/proj3d.py 2009-04-14 14:29:31 UTC (rev 7041) @@ -0,0 +1,286 @@ +#!/usr/bin/python +# 3dproj.py +# +""" +Various transforms used for by the 3D code +""" + +from matplotlib.collections import LineCollection +from matplotlib.patches import Circle +import numpy as np +import numpy.linalg as linalg + +def _hide_cross(a,b): + """ + Cross product of two vectors + A x B = <Ay*Bz - Az*By, Az*Bx - Ax*Bz, Ax*By - Ay*Bx> + a x b = [a2b3 - a3b2, a3b1 - a1b3, a1b2 - a2b1] + """ + return np.array([a[1]*b[2]-a[2]*b[1],a[2]*b[0]-a[0]*b[2],a[0]*b[1] - a[1]*b[0]]) +cross = _hide_cross + +def line2d(p0,p1): + """ + Return 2D equation of line in the form ax+by+c = 0 + """ + # x + x1 = 0 + x0,y0 = p0[:2] + x1,y1 = p1[:2] + # + if x0==x1: + a = -1 + b = 0 + c = x1 + elif y0==y1: + a = 0 + b = 1 + c = -y1 + else: + a = (y0-y1) + b = (x0-x1) + c = (x0*y1 - x1*y0) + return a,b,c + +def line2d_dist(l, p): + """ + Distance from line to point + line is a tuple of coefficients a,b,c + """ + a,b,c = l + x0,y0 = p + return abs((a*x0 + b*y0 + c)/np.sqrt(a**2+b**2)) + + +def line2d_seg_dist(p1,p2, p0): + """distance(s) from line defined by p1 - p2 to point(s) p0 + + p0[0] = x(s) + p0[1] = y(s) + + intersection point p = p1 + u*(p2-p1) + and intersection point lies within segement if u is between 0 and 1 + """ + + x21 = p2[0] - p1[0] + y21 = p2[1] - p1[1] + x01 = np.asarray(p0[0]) - p1[0] + y01 = np.asarray(p0[1]) - p1[1] + + u = (x01*x21 + y01*y21)/float(abs(x21**2 + y21**2)) + u = np.clip(u, 0, 1) + d = np.sqrt((x01 - u*x21)**2 + (y01 - u*y21)**2) + + return d + + +def test_lines_dists(): + ax = pylab.gca() + + xs,ys = (0,30),(20,150) + pylab.plot(xs,ys) + points = zip(xs,ys) + p0,p1 = points + + xs,ys = (0,0,20,30),(100,150,30,200) + pylab.scatter(xs,ys) + # + dist = line2d_seg_dist(p0,p1,(xs[0],ys[0])) + dist = line2d_seg_dist(p0,p1,np.array((xs,ys))) + for x,y,d in zip(xs,ys,dist): + c = Circle((x,y),d,fill=0) + ax.add_patch(c) + # + pylab.xlim(-200,... [truncated message content] |