|
From: <he...@us...> - 2009-12-10 00:03:13
|
Revision: 8015
http://matplotlib.svn.sourceforge.net/matplotlib/?rev=8015&view=rev
Author: heeres
Date: 2009-12-10 00:03:03 +0000 (Thu, 10 Dec 2009)
Log Message:
-----------
mplot3d updates:
* Fix scatter markers
* Add facecolor support for plot_surface
* Fix XYZ-pane order drawing
* Add examples (animations, colored surface)
Modified Paths:
--------------
trunk/matplotlib/examples/mplot3d/bars3d_demo.py
trunk/matplotlib/examples/mplot3d/hist3d_demo.py
trunk/matplotlib/examples/mplot3d/scatter3d_demo.py
trunk/matplotlib/examples/mplot3d/surface3d_demo.py
trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py
trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py
trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py
Added Paths:
-----------
trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py
trunk/matplotlib/examples/mplot3d/surface3d_demo3.py
trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py
Modified: trunk/matplotlib/examples/mplot3d/bars3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/bars3d_demo.py 2009-12-09 20:29:10 UTC (rev 8014)
+++ trunk/matplotlib/examples/mplot3d/bars3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -7,8 +7,13 @@
for c, z in zip(['r', 'g', 'b', 'y'], [30, 20, 10, 0]):
xs = np.arange(20)
ys = np.random.rand(20)
- ax.bar(xs, ys, zs=z, zdir='y', color=c, alpha=0.8)
+ # You can provide either a single color or an array. To demonstrate this,
+ # the first bar of each set will be colored cyan.
+ cs = [c] * len(xs)
+ cs[0] = 'c'
+ ax.bar(xs, ys, zs=z, zdir='y', color=cs, alpha=0.8)
+
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
Modified: trunk/matplotlib/examples/mplot3d/hist3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/hist3d_demo.py 2009-12-09 20:29:10 UTC (rev 8014)
+++ trunk/matplotlib/examples/mplot3d/hist3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -16,6 +16,7 @@
dx = 0.5 * np.ones_like(zpos)
dy = dx.copy()
dz = hist.flatten()
+
ax.bar3d(xpos, ypos, zpos, dx, dy, dz, color='b')
plt.show()
Added: trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py (rev 0)
+++ trunk/matplotlib/examples/mplot3d/rotate_axes3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -0,0 +1,15 @@
+from mpl_toolkits.mplot3d import axes3d
+import matplotlib.pyplot as plt
+import numpy as np
+
+plt.ion()
+
+fig = plt.figure()
+ax = axes3d.Axes3D(fig)
+X, Y, Z = axes3d.get_test_data(0.1)
+ax.plot_wireframe(X, Y, Z, rstride=5, cstride=5)
+
+for angle in range(0, 360):
+ ax.view_init(30, angle)
+ plt.draw()
+
Modified: trunk/matplotlib/examples/mplot3d/scatter3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/scatter3d_demo.py 2009-12-09 20:29:10 UTC (rev 8014)
+++ trunk/matplotlib/examples/mplot3d/scatter3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -2,18 +2,17 @@
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
-
def randrange(n, vmin, vmax):
return (vmax-vmin)*np.random.rand(n) + vmin
fig = plt.figure()
ax = Axes3D(fig)
n = 100
-for c, zl, zh in [('r', -50, -25), ('b', -30, -5)]:
+for c, m, zl, zh in [('r', 'o', -50, -25), ('b', '^', -30, -5)]:
xs = randrange(n, 23, 32)
ys = randrange(n, 0, 100)
zs = randrange(n, zl, zh)
- ax.scatter(xs, ys, zs, c=c)
+ ax.scatter(xs, ys, zs, c=c, marker=m)
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
Modified: trunk/matplotlib/examples/mplot3d/surface3d_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/surface3d_demo.py 2009-12-09 20:29:10 UTC (rev 8014)
+++ trunk/matplotlib/examples/mplot3d/surface3d_demo.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -1,5 +1,6 @@
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
+from matplotlib.ticker import LinearLocator, FixedLocator, FormatStrFormatter
import matplotlib.pyplot as plt
import numpy as np
@@ -10,7 +11,14 @@
X, Y = np.meshgrid(X, Y)
R = np.sqrt(X**2 + Y**2)
Z = np.sin(R)
-ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet)
+surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.jet,
+ linewidth=0, antialiased=False)
+ax.set_zlim3d(-1.01, 1.01)
+ax.w_zaxis.set_major_locator(LinearLocator(10))
+ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.03f'))
+
+fig.colorbar(surf, shrink=0.5, aspect=5)
+
plt.show()
Added: trunk/matplotlib/examples/mplot3d/surface3d_demo3.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/surface3d_demo3.py (rev 0)
+++ trunk/matplotlib/examples/mplot3d/surface3d_demo3.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -0,0 +1,31 @@
+from mpl_toolkits.mplot3d import Axes3D
+from matplotlib import cm
+from matplotlib.ticker import LinearLocator, FixedLocator, FormatStrFormatter
+import matplotlib.pyplot as plt
+import numpy as np
+
+fig = plt.figure()
+ax = Axes3D(fig)
+X = np.arange(-5, 5, 0.25)
+xlen = len(X)
+Y = np.arange(-5, 5, 0.25)
+ylen = len(Y)
+X, Y = np.meshgrid(X, Y)
+R = np.sqrt(X**2 + Y**2)
+Z = np.sin(R)
+
+colortuple = ('y', 'b')
+colors = np.empty(X.shape, dtype=str)
+for y in range(ylen):
+ for x in range(xlen):
+ colors[x, y] = colortuple[(x + y) % len(colortuple)]
+
+surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors,
+ linewidth=0, antialiased=False)
+
+ax.set_zlim3d(-1.01, 1.01)
+ax.w_zaxis.set_major_locator(LinearLocator(10))
+ax.w_zaxis.set_major_formatter(FormatStrFormatter('%.03f'))
+
+plt.show()
+
Added: trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py
===================================================================
--- trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py (rev 0)
+++ trunk/matplotlib/examples/mplot3d/wire3d_animation_demo.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -0,0 +1,34 @@
+from mpl_toolkits.mplot3d import axes3d
+import matplotlib.pyplot as plt
+import numpy as np
+import time
+
+def generate(X, Y, phi):
+ R = 1 - np.sqrt(X**2 + Y**2)
+ return np.cos(2 * np.pi * X + phi) * R
+
+plt.ion()
+fig = plt.figure()
+ax = axes3d.Axes3D(fig)
+
+xs = np.linspace(-1, 1, 50)
+ys = np.linspace(-1, 1, 50)
+X, Y = np.meshgrid(xs, ys)
+Z = generate(X, Y, 0.0)
+
+wframe = None
+tstart = time.time()
+for phi in np.linspace(0, 360 / 2 / np.pi, 100):
+
+ oldcol = wframe
+
+ Z = generate(X, Y, phi)
+ wframe = ax.plot_wireframe(X, Y, Z, rstride=2, cstride=2)
+
+ # Remove old line collection before drawing
+ if oldcol is not None:
+ ax.collections.remove(oldcol)
+
+ plt.draw()
+
+print 'FPS: %f' % (100 / (time.time() - tstart))
Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py
===================================================================
--- trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py 2009-12-09 20:29:10 UTC (rev 8014)
+++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/art3d.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -274,6 +274,7 @@
def __init__(self, *args, **kwargs):
PatchCollection.__init__(self, *args, **kwargs)
+ self._old_draw = lambda x: PatchCollection.draw(self, x)
def set_3d_properties(self, zs, zdir):
xs, ys = zip(*self.get_offsets())
@@ -293,10 +294,15 @@
return min(vzs)
def draw(self, renderer):
- PatchCollection.draw(self, renderer)
+ self._old_draw(renderer)
def patch_collection_2d_to_3d(col, zs=0, zdir='z'):
"""Convert a PatchCollection to a Patch3DCollection object."""
+
+ # The tricky part here is that there are several classes that are
+ # derived from PatchCollection. We need to use the right draw method.
+ col._old_draw = col.draw
+
col.__class__ = Patch3DCollection
col.set_3d_properties(zs, zdir)
Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py
===================================================================
--- trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py 2009-12-09 20:29:10 UTC (rev 8014)
+++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/axes3d.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -13,7 +13,7 @@
from matplotlib.transforms import Bbox
from matplotlib import collections
import numpy as np
-from matplotlib.colors import Normalize, colorConverter
+from matplotlib.colors import Normalize, colorConverter, LightSource
import art3d
import proj3d
@@ -37,6 +37,21 @@
"""
def __init__(self, fig, rect=None, *args, **kwargs):
+ '''
+ Build an :class:`Axes3D` instance in
+ :class:`~matplotlib.figure.Figure` *fig* with
+ *rect=[left, bottom, width, height]* in
+ :class:`~matplotlib.figure.Figure` coordinates
+
+ Optional keyword arguments:
+
+ ================ =========================================
+ Keyword Description
+ ================ =========================================
+ *azim* Azimuthal viewing angle (default -60)
+ *elev* Elevation viewing angle (default 30)
+ '''
+
if rect is None:
rect = [0.0, 0.0, 1.0, 1.0]
self.fig = fig
@@ -146,9 +161,12 @@
for i, (z, patch) in enumerate(zlist):
patch.zorder = i
- self.w_xaxis.draw(renderer)
- self.w_yaxis.draw(renderer)
- self.w_zaxis.draw(renderer)
+ axes = (self.w_xaxis, self.w_yaxis, self.w_zaxis)
+ for ax in axes:
+ ax.draw_pane(renderer)
+ for ax in axes:
+ ax.draw(renderer)
+
Axes.draw(self, renderer)
def get_axis_position(self):
@@ -322,8 +340,9 @@
self.grid(rcParams['axes3d.grid'])
def _button_press(self, event):
- self.button_pressed = event.button
- self.sx, self.sy = event.xdata, event.ydata
+ if event.inaxes == self:
+ self.button_pressed = event.button
+ self.sx, self.sy = event.xdata, event.ydata
def _button_release(self, event):
self.button_pressed = None
@@ -565,6 +584,12 @@
*cstride* Array column stride (step size)
*color* Color of the surface patches
*cmap* A colormap for the surface patches.
+ *facecolors* Face colors for the individual patches
+ *norm* An instance of Normalize to map values to colors
+ *vmin* Minimum value to map
+ *vmax* Maximum value to map
+ *shade* Whether to shade the facecolors, default:
+ false when cmap specified, true otherwise
========== ================================================
'''
@@ -575,13 +600,28 @@
rstride = kwargs.pop('rstride', 10)
cstride = kwargs.pop('cstride', 10)
- color = kwargs.pop('color', 'b')
- color = np.array(colorConverter.to_rgba(color))
+ if 'facecolors' in kwargs:
+ fcolors = kwargs.pop('facecolors')
+ else:
+ color = np.array(colorConverter.to_rgba(kwargs.pop('color', 'b')))
+ fcolors = None
+
cmap = kwargs.get('cmap', None)
+ norm = kwargs.pop('norm', None)
+ vmin = kwargs.pop('vmin', None)
+ vmax = kwargs.pop('vmax', None)
+ linewidth = kwargs.get('linewidth', None)
+ shade = kwargs.pop('shade', cmap is None)
+ lightsource = kwargs.pop('lightsource', None)
+ # Shade the data
+ if shade and cmap is not None and fcolors is not None:
+ fcolors = self._shade_colors_lightsource(Z, cmap, lightsource)
+
polys = []
normals = []
- avgz = []
+ #colset contains the data for coloring: either average z or the facecolor
+ colset = []
for rs in np.arange(0, rows-1, rstride):
for cs in np.arange(0, cols-1, cstride):
ps = []
@@ -609,19 +649,38 @@
lastp = p
avgzsum += p[2]
polys.append(ps2)
- avgz.append(avgzsum / len(ps2))
- v1 = np.array(ps2[0]) - np.array(ps2[1])
- v2 = np.array(ps2[2]) - np.array(ps2[0])
- normals.append(np.cross(v1, v2))
+ if fcolors is not None:
+ colset.append(fcolors[rs][cs])
+ else:
+ colset.append(avgzsum / len(ps2))
+ # Only need vectors to shade if no cmap
+ if cmap is None and shade:
+ v1 = np.array(ps2[0]) - np.array(ps2[1])
+ v2 = np.array(ps2[2]) - np.array(ps2[0])
+ normals.append(np.cross(v1, v2))
+
polyc = art3d.Poly3DCollection(polys, *args, **kwargs)
- if cmap is not None:
- polyc.set_array(np.array(avgz))
- polyc.set_linewidth(0)
+
+ if fcolors is not None:
+ if shade:
+ colset = self._shade_colors(colset, normals)
+ polyc.set_facecolors(colset)
+ polyc.set_edgecolors(colset)
+ elif cmap:
+ colset = np.array(colset)
+ polyc.set_array(colset)
+ if vmin is not None or vmax is not None:
+ polyc.set_clim(vmin, vmax)
+ if norm is not None:
+ polyc.set_norm(norm)
else:
- colors = self._shade_colors(color, normals)
- polyc.set_facecolors(colors)
+ if shade:
+ colset = self._shade_colors(color, normals)
+ else:
+ colset = color
+ polyc.set_facecolors(colset)
self.add_collection(polyc)
self.auto_scale_xyz(X, Y, Z, had_data)
@@ -643,24 +702,39 @@
return normals
def _shade_colors(self, color, normals):
+ '''
+ Shade *color* using normal vectors given by *normals*.
+ *color* can also be an array of the same length as *normals*.
+ '''
+
shade = []
for n in normals:
- n = n / proj3d.mod(n) * 5
+ n = n / proj3d.mod(n)
shade.append(np.dot(n, [-1, -1, 0.5]))
shade = np.array(shade)
mask = ~np.isnan(shade)
if len(shade[mask]) > 0:
- norm = Normalize(min(shade[mask]), max(shade[mask]))
- color = color.copy()
- color[3] = 1
- colors = [color * (0.5 + norm(v) * 0.5) for v in shade]
+ norm = Normalize(min(shade[mask]), max(shade[mask]))
+ if art3d.iscolor(color):
+ color = color.copy()
+ color[3] = 1
+ colors = [color * (0.5 + norm(v) * 0.5) for v in shade]
+ else:
+ colors = [np.array(colorConverter.to_rgba(c)) * \
+ (0.5 + norm(v) * 0.5) \
+ for c, v in zip(color, shade)]
else:
- colors = color.copy()
+ colors = color.copy()
return colors
+ def _shade_colors_lightsource(self, data, cmap, lightsource):
+ if lightsource is None:
+ lightsource = LightSource(azdeg=135, altdeg=55)
+ return lightsource.shade(data, cmap)
+
def plot_wireframe(self, X, Y, Z, *args, **kwargs):
'''
Plot a 3D wireframe.
Modified: trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py
===================================================================
--- trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py 2009-12-09 20:29:10 UTC (rev 8014)
+++ trunk/matplotlib/lib/mpl_toolkits/mplot3d/axis3d.py 2009-12-10 00:03:03 UTC (rev 8015)
@@ -75,7 +75,7 @@
maxis.XAxis.__init__(self, axes, *args, **kwargs)
self.line = mlines.Line2D(xdata=(0, 0), ydata=(0, 0),
linewidth=0.75,
- color=(0,0, 0,0),
+ color=(0, 0, 0, 1),
antialiased=True,
)
@@ -100,8 +100,8 @@
majorLabels = [self.major.formatter(val, i) for i, val in enumerate(majorLocs)]
return majorLabels, majorLocs
- def get_major_ticks(self):
- ticks = maxis.XAxis.get_major_ticks(self)
+ def get_major_ticks(self, numticks=None):
+ ticks = maxis.XAxis.get_major_ticks(self, numticks)
for t in ticks:
t.tick1line.set_transform(self.axes.transData)
t.tick2line.set_transform(self.axes.transData)
@@ -132,23 +132,7 @@
else:
return len(text) > 4
- def draw(self, renderer):
- self.label._transform = self.axes.transData
- renderer.open_group('axis3d')
-
- # code from XAxis
- majorTicks = self.get_major_ticks()
- majorLocs = self.major.locator()
-
- # filter locations here so that no extra grid lines are drawn
- interval = self.get_view_interval()
- majorLocs = [loc for loc in majorLocs if \
- interval[0] < loc < interval[1]]
- self.major.formatter.set_locs(majorLocs)
- majorLabels = [self.major.formatter(val, i)
- for i, val in enumerate(majorLocs)]
-
- # Determine bounds
+ def _get_coord_info(self, renderer):
minx, maxx, miny, maxy, minz, maxz = self.axes.get_w_lims()
mins = np.array((minx, miny, minz))
maxs = np.array((maxx, maxy, maxz))
@@ -157,15 +141,19 @@
mins = mins - deltas / 4.
maxs = maxs + deltas / 4.
- # Determine which planes should be visible by the avg z value
vals = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]
tc = self.axes.tunit_cube(vals, renderer.M)
- #raise RuntimeError('WTF: p1=%s'%p1)
avgz = [tc[p1][2] + tc[p2][2] + tc[p3][2] + tc[p4][2] for \
p1, p2, p3, p4 in self._PLANES]
highs = np.array([avgz[2*i] < avgz[2*i+1] for i in range(3)])
- # Draw plane
+ return mins, maxs, centers, deltas, tc, highs
+
+ def draw_pane(self, renderer):
+ renderer.open_group('pane3d')
+
+ mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
+
info = self._AXINFO[self.adir]
index = info['i']
if not highs[index]:
@@ -176,6 +164,29 @@
self.set_pane(xys, info['color'])
self.pane.draw(renderer)
+ renderer.close_group('pane3d')
+
+ def draw(self, renderer):
+ self.label._transform = self.axes.transData
+ renderer.open_group('axis3d')
+
+ # code from XAxis
+ majorTicks = self.get_major_ticks()
+ majorLocs = self.major.locator()
+
+ info = self._AXINFO[self.adir]
+ index = info['i']
+
+ # filter locations here so that no extra grid lines are drawn
+ interval = self.get_view_interval()
+ majorLocs = [loc for loc in majorLocs if \
+ interval[0] < loc < interval[1]]
+ self.major.formatter.set_locs(majorLocs)
+ majorLabels = [self.major.formatter(val, i)
+ for i, val in enumerate(majorLocs)]
+
+ mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
+
# Determine grid lines
minmax = np.where(highs, maxs, mins)
This was sent by the SourceForge.net collaborative development platform, the world's largest Open Source development site.
|